Skip to content

Commit

Permalink
typing, tests, complex filter naming
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-sandfort1 committed Sep 12, 2024
1 parent cfdfd83 commit b843657
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 27 deletions.
15 changes: 8 additions & 7 deletions molpipeline/abstract_pipeline_elements/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from molpipeline.utils.value_conversions import (
FloatCountRange,
IntCountRange,
IntOrIntCountRange,
count_value_to_tuple,
)

Expand Down Expand Up @@ -241,11 +242,11 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC):
- mode = "all" & keep_matches = False: Must not match all filter elements.
"""

_patterns: dict[str, tuple[Optional[int], Optional[int]]]
_patterns: dict[str, IntCountRange]

def __init__(
self,
patterns: Union[list[str], dict[str, IntCountRange]],
patterns: Union[list[str], dict[str, IntOrIntCountRange]],
keep_matches: bool = True,
mode: FilterModeType = "any",
name: Optional[str] = None,
Expand All @@ -256,7 +257,7 @@ def __init__(
Parameters
----------
patterns: Union[list[str], dict[str, CountRange]]
patterns: Union[list[str], dict[str, IntOrIntCountRange]]
List of patterns to allow in molecules.
Alternatively, a dictionary can be passed with patterns as keys
and an int for exact count or a tuple of minimum and maximum.
Expand All @@ -278,20 +279,20 @@ def __init__(
self.patterns = patterns # type: ignore

@property
def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]:
def patterns(self) -> dict[str, IntCountRange]:
"""Get allowed patterns as dict."""
return self._patterns

@patterns.setter
def patterns(
self,
patterns: Union[list[str], dict[str, IntCountRange]],
patterns: Union[list[str], dict[str, IntOrIntCountRange]],
) -> None:
"""Set allowed patterns as dict.
Parameters
----------
patterns: Union[list[str], dict[str, CountRange]]
patterns: Union[list[str], dict[str, IntOrIntCountRange]]
List of patterns.
"""
if isinstance(patterns, (list, set)):
Expand Down Expand Up @@ -334,7 +335,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol:
"""

@property
def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]:
def filter_elements(self) -> Mapping[str, IntCountRange]:
"""Get filter elements as dict."""
return self.patterns

Expand Down
2 changes: 1 addition & 1 deletion molpipeline/mol2mol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Init the module for mol2mol pipeline elements."""

from molpipeline.mol2mol.filter import (
ComplexFilter,
ElementFilter,
EmptyMoleculeFilter,
InorganicsFilter,
MixtureFilter,
ComplexFilter,
RDKitDescriptorsFilter,
SmartsFilter,
SmilesFilter,
Expand Down
15 changes: 9 additions & 6 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from molpipeline.utils.value_conversions import (
FloatCountRange,
IntCountRange,
IntOrIntCountRange,
count_value_to_tuple,
)

Expand Down Expand Up @@ -57,7 +58,7 @@ class ElementFilter(_MolToMolPipelineElement):
def __init__(
self,
allowed_element_numbers: Optional[
Union[list[int], dict[int, IntCountRange]]
Union[list[int], dict[int, IntOrIntCountRange]]
] = None,
name: str = "ElementFilter",
n_jobs: int = 1,
Expand All @@ -67,7 +68,7 @@ def __init__(
Parameters
----------
allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]]]
allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]]]
List of atomic numbers of elements to allowed in molecules. Per default allowed elements are:
H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I.
Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum
Expand All @@ -82,23 +83,25 @@ def __init__(
self.allowed_element_numbers = allowed_element_numbers # type: ignore

@property
def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int]]]:
def allowed_element_numbers(self) -> dict[int, IntCountRange]:
"""Get allowed element numbers as dict."""
return self._allowed_element_numbers

@allowed_element_numbers.setter
def allowed_element_numbers(
self,
allowed_element_numbers: Optional[Union[list[int], dict[int, IntCountRange]]],
allowed_element_numbers: Optional[
Union[list[int], dict[int, IntOrIntCountRange]]
],
) -> None:
"""Set allowed element numbers as dict.
Parameters
----------
allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]]
allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]]
List of atomic numbers of elements to allowed in molecules.
"""
self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]]
self._allowed_element_numbers: dict[int, IntCountRange]
if allowed_element_numbers is None:
allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS
if isinstance(allowed_element_numbers, (list, set)):
Expand Down
16 changes: 9 additions & 7 deletions molpipeline/utils/value_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,28 @@

from typing import Optional, Sequence, TypeAlias, Union

# IntCountRange for Typing of count ranges
FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]]

IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]]

# IntOrIntCountRange for Typing of count ranges
# - a single int for an exact value match
# - a range given as a tuple with a lower and upper bound
# - both limits are optional
IntCountRange: TypeAlias = Union[int, tuple[Optional[int], Optional[int]]]

FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]]
IntOrIntCountRange: TypeAlias = Union[int, IntCountRange]


def count_value_to_tuple(count: IntCountRange) -> tuple[Optional[int], Optional[int]]:
def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange:
"""Convert a count value to a tuple.
Parameters
----------
count: Union[int, tuple[Optional[int], Optional[int]]]
count: Union[int, IntCountRange]
Count value. Can be a single int or a tuple of two values.
Returns
-------
tuple[Optional[int], Optional[int]]
IntCountRange
Tuple of count values.
"""
if isinstance(count, int):
Expand Down
28 changes: 22 additions & 6 deletions tests/test_elements/test_mol2mol/test_mol2mol_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from molpipeline.any2mol import SmilesToMol
from molpipeline.mol2any import MolToSmiles
from molpipeline.mol2mol import (
ComplexFilter,
ElementFilter,
InorganicsFilter,
MixtureFilter,
ComplexFilter,
RDKitDescriptorsFilter,
SmartsFilter,
SmilesFilter,
)
from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange
from molpipeline.utils.value_conversions import FloatCountRange, IntOrIntCountRange

# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
SMILES_ANTIMONY = "[SbH6+3]"
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_element_filter(self) -> None:
filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE])

def test_multi_element_filter(self) -> None:
def test_complex_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
element_filter_1 = ElementFilter({6: 6, 1: 6})
element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1})
Expand All @@ -98,13 +98,13 @@ def test_multi_element_filter(self) -> None:

def test_smarts_smiles_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns."""
smarts_pats: dict[str, IntCountRange] = {
smarts_pats: dict[str, IntOrIntCountRange] = {
"c": (4, None),
"Cl": 1,
}
smarts_filter = SmartsFilter(smarts_pats)

smiles_pats: dict[str, IntCountRange] = {
smiles_pats: dict[str, IntOrIntCountRange] = {
"c1ccccc1": (1, None),
"Cl": 1,
}
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_smarts_smiles_filter(self) -> None:

def test_smarts_filter_parallel(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
smarts_pats: dict[str, IntCountRange] = {
smarts_pats: dict[str, IntOrIntCountRange] = {
"c": (4, None),
"Cl": 1,
"cc": (1, None),
Expand Down Expand Up @@ -213,6 +213,14 @@ def test_descriptor_filter(self) -> None:
DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True
)

pipeline.set_params(
DescriptorsFilter__descriptors={
"NumHAcceptors": (2.00, 4),
}
)
result_lower_exact = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(result_lower_exact, [SMILES_CL_BR])

pipeline.set_params(
DescriptorsFilter__descriptors={
"NumHAcceptors": (1.99, 4),
Expand All @@ -229,6 +237,14 @@ def test_descriptor_filter(self) -> None:
result_lower_out_bound = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(result_lower_out_bound, [])

pipeline.set_params(
DescriptorsFilter__descriptors={
"NumHAcceptors": (1, 2.00),
}
)
result_upper_exact = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(result_upper_exact, [SMILES_CL_BR])

pipeline.set_params(
DescriptorsFilter__descriptors={
"NumHAcceptors": (1, 2.01),
Expand Down

0 comments on commit b843657

Please sign in to comment.