Skip to content

Commit

Permalink
feat: support data type dumped to a different name (#727)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced an optional `deepmd_name` parameter in the `DataType` class
for enhanced naming flexibility.
- Updated data type declarations in the `System` and `LabeledSystem`
classes for better integration with the DeepMD framework.

- **Bug Fixes**
- Removed handling of energy, force, and virial data to simplify data
processing and storage.

- **Documentation**
- Updated documentation for the `DataType` class to clarify the new
`deepmd_name` parameter.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Sep 20, 2024
1 parent 480242e commit a2fbdd8
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 157 deletions.
4 changes: 4 additions & 0 deletions dpdata/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class DataType:
represents numbers
required : bool, default=True
whether this data is required
deepmd_name : str, optional
DeePMD-kit data type name. When not given, it is the same as `name`.
"""

def __init__(
Expand All @@ -54,11 +56,13 @@ def __init__(
dtype: type,
shape: tuple[int | Axis, ...] | None = None,
required: bool = True,
deepmd_name: str | None = None,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required
self.deepmd_name = name if deepmd_name is None else deepmd_name

def real_shape(self, system: System) -> tuple[int]:
"""Returns expected real shape of a system."""
Expand Down
62 changes: 5 additions & 57 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ def _load_set(folder, nopbc: bool):
cells = np.zeros((coords.shape[0], 3, 3))
else:
cells = np.load(os.path.join(folder, "box.npy"))
eners = _cond_load_data(os.path.join(folder, "energy.npy"))
forces = _cond_load_data(os.path.join(folder, "force.npy"))
virs = _cond_load_data(os.path.join(folder, "virial.npy"))
return cells, coords, eners, forces, virs
return cells, coords


def to_system_data(folder, type_map=None, labels=True):
Expand All @@ -41,31 +38,13 @@ def to_system_data(folder, type_map=None, labels=True):
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
all_cells = []
all_coords = []
all_eners = []
all_forces = []
all_virs = []
for ii in sets:
cells, coords, eners, forces, virs = _load_set(ii, data.get("nopbc", False))
cells, coords = _load_set(ii, data.get("nopbc", False))
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
if eners is not None:
eners = np.reshape(eners, [nframes])
if labels:
if eners is not None and eners.size > 0:
all_eners.append(np.reshape(eners, [nframes]))
if forces is not None and forces.size > 0:
all_forces.append(np.reshape(forces, [nframes, -1, 3]))
if virs is not None and virs.size > 0:
all_virs.append(np.reshape(virs, [nframes, 3, 3]))
data["cells"] = np.concatenate(all_cells, axis=0)
data["coords"] = np.concatenate(all_coords, axis=0)
if len(all_eners) > 0:
data["energies"] = np.concatenate(all_eners, axis=0)
if len(all_forces) > 0:
data["forces"] = np.concatenate(all_forces, axis=0)
if len(all_virs) > 0:
data["virials"] = np.concatenate(all_virs, axis=0)
# allow custom dtypes
if labels:
dtypes = dpdata.system.LabeledSystem.DTYPES
Expand All @@ -82,9 +61,6 @@ def to_system_data(folder, type_map=None, labels=True):
"coords",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -93,13 +69,13 @@ def to_system_data(folder, type_map=None, labels=True):
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/npy format."
)
continue
natoms = data["coords"].shape[1]
natoms = data["atom_types"].shape[0]
shape = [
natoms if xx == dpdata.system.Axis.NATOMS else xx for xx in dtype.shape[1:]
]
all_data = []
for ii in sets:
tmp = _cond_load_data(os.path.join(ii, dtype.name + ".npy"))
tmp = _cond_load_data(os.path.join(ii, dtype.deepmd_name + ".npy"))
if tmp is not None:
all_data.append(np.reshape(tmp, [tmp.shape[0], *shape]))
if len(all_data) > 0:
Expand Down Expand Up @@ -136,19 +112,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
# reshape frame properties and convert prec
nframes = data["cells"].shape[0]
cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec)
coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec)
eners = None
forces = None
virials = None
if "energies" in data:
eners = np.reshape(data["energies"], [nframes]).astype(comp_prec)
if "forces" in data:
forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec)
if "virials" in data:
virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec)
if "atom_pref" in data:
atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec)
# dump frame properties: cell, coord, energy, force and virial
nsets = nframes // set_size
if set_size * nsets < nframes:
Expand All @@ -158,16 +121,6 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
set_end = (ii + 1) * set_size
set_folder = os.path.join(folder, "set.%03d" % ii)
os.makedirs(set_folder)
np.save(os.path.join(set_folder, "box"), cells[set_stt:set_end])
np.save(os.path.join(set_folder, "coord"), coords[set_stt:set_end])
if eners is not None:
np.save(os.path.join(set_folder, "energy"), eners[set_stt:set_end])
if forces is not None:
np.save(os.path.join(set_folder, "force"), forces[set_stt:set_end])
if virials is not None:
np.save(os.path.join(set_folder, "virial"), virials[set_stt:set_end])
if "atom_pref" in data:
np.save(os.path.join(set_folder, "atom_pref"), atom_pref[set_stt:set_end])
try:
os.remove(os.path.join(folder, "nopbc"))
except OSError:
Expand All @@ -187,13 +140,8 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
"atom_names",
"atom_types",
"orig",
"cells",
"coords",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -211,4 +159,4 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
set_stt = ii * set_size
set_end = (ii + 1) * set_size
set_folder = os.path.join(folder, "set.%03d" % ii)
np.save(os.path.join(set_folder, dtype.name), ddata[set_stt:set_end])
np.save(os.path.join(set_folder, dtype.deepmd_name), ddata[set_stt:set_end])
61 changes: 7 additions & 54 deletions dpdata/deepmd/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,38 +69,7 @@ def to_system_data(
data["nopbc"] = True
sets = globfilter(g.keys(), "set.*")

data_types = {
"cells": {
"fn": "box",
"labeled": False,
"shape": (3, 3),
"required": "nopbc" not in data,
},
"coords": {
"fn": "coord",
"labeled": False,
"shape": (natoms, 3),
"required": True,
},
"energies": {
"fn": "energy",
"labeled": True,
"shape": tuple(),
"required": False,
},
"forces": {
"fn": "force",
"labeled": True,
"shape": (natoms, 3),
"required": False,
},
"virials": {
"fn": "virial",
"labeled": True,
"shape": (3, 3),
"required": False,
},
}
data_types = {}
# allow custom dtypes
if labels:
dtypes = dpdata.system.LabeledSystem.DTYPES
Expand All @@ -112,14 +81,9 @@ def to_system_data(
"atom_names",
"atom_types",
"orig",
"cells",
"coords",
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -133,10 +97,10 @@ def to_system_data(
]

data_types[dtype.name] = {
"fn": dtype.name,
"labeled": True,
"fn": dtype.deepmd_name,
"shape": shape,
"required": False,
"required": dtype.required
and not (dtype.name == "cells" and data.get("nopbc", False)),
}

for dt, prop in data_types.items():
Expand Down Expand Up @@ -206,13 +170,7 @@ def dump(
nopbc = data.get("nopbc", False)
reshaped_data = {}

data_types = {
"cells": {"fn": "box", "shape": (nframes, 9), "dump": not nopbc},
"coords": {"fn": "coord", "shape": (nframes, -1), "dump": True},
"energies": {"fn": "energy", "shape": (nframes,), "dump": True},
"forces": {"fn": "force", "shape": (nframes, -1), "dump": True},
"virials": {"fn": "virial", "shape": (nframes, 9), "dump": True},
}
data_types = {}

labels = "energies" in data
if labels:
Expand All @@ -226,14 +184,9 @@ def dump(
"atom_names",
"atom_types",
"orig",
"cells",
"coords",
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -244,9 +197,9 @@ def dump(
continue

data_types[dtype.name] = {
"fn": dtype.name,
"fn": dtype.deepmd_name,
"shape": (nframes, -1),
"dump": True,
"dump": not (dtype.name == "cells" and nopbc),
}

for dt, prop in data_types.items():
Expand Down
45 changes: 4 additions & 41 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,6 @@ def to_system_data(folder, type_map=None, labels=True):
data["cells"] = np.loadtxt(os.path.join(folder, "box.raw"), ndmin=2)
data["cells"] = np.reshape(data["cells"], [nframes, 3, 3])
data["coords"] = np.reshape(data["coords"], [nframes, -1, 3])
if labels:
if os.path.exists(os.path.join(folder, "energy.raw")):
data["energies"] = np.loadtxt(os.path.join(folder, "energy.raw"))
data["energies"] = np.reshape(data["energies"], [nframes])
if os.path.exists(os.path.join(folder, "force.raw")):
data["forces"] = np.loadtxt(os.path.join(folder, "force.raw"))
data["forces"] = np.reshape(data["forces"], [nframes, -1, 3])
if os.path.exists(os.path.join(folder, "virial.raw")):
data["virials"] = np.loadtxt(os.path.join(folder, "virial.raw"))
data["virials"] = np.reshape(data["virials"], [nframes, 3, 3])
if os.path.isfile(os.path.join(folder, "nopbc")):
data["nopbc"] = True
# allow custom dtypes
Expand All @@ -77,9 +67,6 @@ def to_system_data(folder, type_map=None, labels=True):
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -88,14 +75,14 @@ def to_system_data(folder, type_map=None, labels=True):
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted from deepmd/raw format."
)
continue
natoms = data["coords"].shape[1]
natoms = data["atom_types"].shape[0]
shape = [
natoms if xx == dpdata.system.Axis.NATOMS else xx
for xx in dtype.shape[1:]
]
if os.path.exists(os.path.join(folder, f"{dtype.name}.raw")):
if os.path.exists(os.path.join(folder, f"{dtype.deepmd_name}.raw")):
data[dtype.name] = np.reshape(
np.loadtxt(os.path.join(folder, f"{dtype.name}.raw")),
np.loadtxt(os.path.join(folder, f"{dtype.deepmd_name}.raw")),
[nframes, *shape],
)
return data
Expand All @@ -108,10 +95,6 @@ def dump(folder, data):
nframes = data["cells"].shape[0]
np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d")
np.savetxt(os.path.join(folder, "type_map.raw"), data["atom_names"], fmt="%s")
np.savetxt(os.path.join(folder, "box.raw"), np.reshape(data["cells"], [nframes, 9]))
np.savetxt(
os.path.join(folder, "coord.raw"), np.reshape(data["coords"], [nframes, -1])
)
# BondOrder System
if "bonds" in data:
np.savetxt(
Expand All @@ -121,21 +104,6 @@ def dump(folder, data):
)
if "formal_charges" in data:
np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"])
# Labeled System
if "energies" in data:
np.savetxt(
os.path.join(folder, "energy.raw"),
np.reshape(data["energies"], [nframes, 1]),
)
if "forces" in data:
np.savetxt(
os.path.join(folder, "force.raw"), np.reshape(data["forces"], [nframes, -1])
)
if "virials" in data:
np.savetxt(
os.path.join(folder, "virial.raw"),
np.reshape(data["virials"], [nframes, 9]),
)
try:
os.remove(os.path.join(folder, "nopbc"))
except OSError:
Expand All @@ -155,14 +123,9 @@ def dump(folder, data):
"atom_names",
"atom_types",
"orig",
"cells",
"coords",
"real_atom_types",
"real_atom_names",
"nopbc",
"energies",
"forces",
"virials",
):
# skip as these data contains specific rules
continue
Expand All @@ -174,4 +137,4 @@ def dump(folder, data):
)
continue
ddata = np.reshape(data[dtype.name], [nframes, -1])
np.savetxt(os.path.join(folder, f"{dtype.name}.raw"), ddata)
np.savetxt(os.path.join(folder, f"{dtype.deepmd_name}.raw"), ddata)
20 changes: 15 additions & 5 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ class System:
DataType("atom_names", list, (Axis.NTYPES,)),
DataType("atom_types", np.ndarray, (Axis.NATOMS,)),
DataType("orig", np.ndarray, (3,)),
DataType("cells", np.ndarray, (Axis.NFRAMES, 3, 3)),
DataType("coords", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)),
DataType("cells", np.ndarray, (Axis.NFRAMES, 3, 3), deepmd_name="box"),
DataType(
"coords", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="coord"
),
DataType(
"real_atom_types", np.ndarray, (Axis.NFRAMES, Axis.NATOMS), required=False
),
Expand Down Expand Up @@ -1204,9 +1206,17 @@ class LabeledSystem(System):
"""

DTYPES: tuple[DataType, ...] = System.DTYPES + (
DataType("energies", np.ndarray, (Axis.NFRAMES,)),
DataType("forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)),
DataType("virials", np.ndarray, (Axis.NFRAMES, 3, 3), required=False),
DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"),
DataType(
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
),
DataType(
"virials",
np.ndarray,
(Axis.NFRAMES, 3, 3),
required=False,
deepmd_name="virial",
),
DataType("atom_pref", np.ndarray, (Axis.NFRAMES, Axis.NATOMS), required=False),
)

Expand Down

0 comments on commit a2fbdd8

Please sign in to comment.