diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 3f4efa5d..52b64a56 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -17,6 +17,7 @@ from molpipeline.utils.value_conversions import ( FloatCountRange, IntCountRange, + IntOrIntCountRange, count_value_to_tuple, ) @@ -241,11 +242,11 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - _patterns: dict[str, tuple[Optional[int], Optional[int]]] + _patterns: dict[str, IntCountRange] def __init__( self, - patterns: Union[list[str], dict[str, IntCountRange]], + patterns: Union[list[str], dict[str, IntOrIntCountRange]], keep_matches: bool = True, mode: FilterModeType = "any", name: Optional[str] = None, @@ -256,7 +257,7 @@ def __init__( Parameters ---------- - patterns: Union[list[str], dict[str, CountRange]] + patterns: Union[list[str], dict[str, IntOrIntCountRange]] List of patterns to allow in molecules. Alternatively, a dictionary can be passed with patterns as keys and an int for exact count or a tuple of minimum and maximum. @@ -278,20 +279,20 @@ def __init__( self.patterns = patterns # type: ignore @property - def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: + def patterns(self) -> dict[str, IntCountRange]: """Get allowed patterns as dict.""" return self._patterns @patterns.setter def patterns( self, - patterns: Union[list[str], dict[str, IntCountRange]], + patterns: Union[list[str], dict[str, IntOrIntCountRange]], ) -> None: """Set allowed patterns as dict. Parameters ---------- - patterns: Union[list[str], dict[str, CountRange]] + patterns: Union[list[str], dict[str, IntOrIntCountRange]] List of patterns. """ if isinstance(patterns, (list, set)): @@ -334,7 +335,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: """ @property - def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]: + def filter_elements(self) -> Mapping[str, IntCountRange]: """Get filter elements as dict.""" return self.patterns diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 4fa3bd95..7f6ed1ae 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,11 +1,11 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( + ComplexFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, - ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 5b8b9b37..5e46f7e1 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -27,6 +27,7 @@ from molpipeline.utils.value_conversions import ( FloatCountRange, IntCountRange, + IntOrIntCountRange, count_value_to_tuple, ) @@ -57,7 +58,7 @@ class ElementFilter(_MolToMolPipelineElement): def __init__( self, allowed_element_numbers: Optional[ - Union[list[int], dict[int, IntCountRange]] + Union[list[int], dict[int, IntOrIntCountRange]] ] = None, name: str = "ElementFilter", n_jobs: int = 1, @@ -67,7 +68,7 @@ def __init__( Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]]] + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]]] List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum @@ -82,23 +83,25 @@ def __init__( self.allowed_element_numbers = allowed_element_numbers # type: ignore @property - def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int]]]: + def allowed_element_numbers(self) -> dict[int, IntCountRange]: """Get allowed element numbers as dict.""" return self._allowed_element_numbers @allowed_element_numbers.setter def allowed_element_numbers( self, - allowed_element_numbers: Optional[Union[list[int], dict[int, IntCountRange]]], + allowed_element_numbers: Optional[ + Union[list[int], dict[int, IntOrIntCountRange]] + ], ) -> None: """Set allowed element numbers as dict. Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]] + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]] List of atomic numbers of elements to allowed in molecules. """ - self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] + self._allowed_element_numbers: dict[int, IntCountRange] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS if isinstance(allowed_element_numbers, (list, set)): diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index fb508276..df595a84 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -2,26 +2,28 @@ from typing import Optional, Sequence, TypeAlias, Union -# IntCountRange for Typing of count ranges +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] + +IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] + +# IntOrIntCountRange for Typing of count ranges # - a single int for an exact value match # - a range given as a tuple with a lower and upper bound # - both limits are optional -IntCountRange: TypeAlias = Union[int, tuple[Optional[int], Optional[int]]] - -FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] +IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] -def count_value_to_tuple(count: IntCountRange) -> tuple[Optional[int], Optional[int]]: +def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: """Convert a count value to a tuple. Parameters ---------- - count: Union[int, tuple[Optional[int], Optional[int]]] + count: Union[int, IntCountRange] Count value. Can be a single int or a tuple of two values. Returns ------- - tuple[Optional[int], Optional[int]] + IntCountRange Tuple of count values. """ if isinstance(count, int): diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 7d274ed3..f4bc8df4 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -6,15 +6,15 @@ from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles from molpipeline.mol2mol import ( + ComplexFilter, ElementFilter, InorganicsFilter, MixtureFilter, - ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, ) -from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange +from molpipeline.utils.value_conversions import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -73,7 +73,7 @@ def test_element_filter(self) -> None: filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) - def test_multi_element_filter(self) -> None: + def test_complex_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) @@ -98,13 +98,13 @@ def test_multi_element_filter(self) -> None: def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" - smarts_pats: dict[str, IntCountRange] = { + smarts_pats: dict[str, IntOrIntCountRange] = { "c": (4, None), "Cl": 1, } smarts_filter = SmartsFilter(smarts_pats) - smiles_pats: dict[str, IntCountRange] = { + smiles_pats: dict[str, IntOrIntCountRange] = { "c1ccccc1": (1, None), "Cl": 1, } @@ -151,7 +151,7 @@ def test_smarts_smiles_filter(self) -> None: def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" - smarts_pats: dict[str, IntCountRange] = { + smarts_pats: dict[str, IntOrIntCountRange] = { "c": (4, None), "Cl": 1, "cc": (1, None), @@ -213,6 +213,14 @@ def test_descriptor_filter(self) -> None: DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True ) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (2.00, 4), + } + ) + result_lower_exact = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_lower_exact, [SMILES_CL_BR]) + pipeline.set_params( DescriptorsFilter__descriptors={ "NumHAcceptors": (1.99, 4), @@ -229,6 +237,14 @@ def test_descriptor_filter(self) -> None: result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_lower_out_bound, []) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (1, 2.00), + } + ) + result_upper_exact = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_upper_exact, [SMILES_CL_BR]) + pipeline.set_params( DescriptorsFilter__descriptors={ "NumHAcceptors": (1, 2.01),