Skip to content

Commit

Permalink
add a public API to register data types dynamically (#532)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Sep 2, 2023
1 parent e239d0b commit 6ed3c44
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions dpdata/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,17 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None):
idx = pick_by_amber_mask(parm, maskstr)
return self.pick_atom_idx(idx, nopbc=nopbc)

@classmethod
def register_data_type(cls, *data_type: Tuple[DataType]):
"""Register data type.
Parameters
----------
*data_type : tuple[DataType]
data type to be regiestered
"""
cls.DTYPES = cls.DTYPES + tuple(data_type)


def get_cell_perturb_matrix(cell_pert_fraction):
if cell_pert_fraction < 0:
Expand Down Expand Up @@ -1599,9 +1610,9 @@ def to_format(self, *args, **kwargs):
setattr(MultiSystems, method, get_func(formatcls))

# at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized
System.DTYPES = System.DTYPES + get_data_types(labeled=False)
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False)
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True)
System.register_data_type(*get_data_types(labeled=False))
LabeledSystem.register_data_type(*get_data_types(labeled=False))
LabeledSystem.register_data_type(*get_data_types(labeled=True))


add_format_methods()

0 comments on commit 6ed3c44

Please sign in to comment.