Skip to content

Commit

Permalink
explainability: more linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Aug 14, 2024
1 parent d7f2af3 commit f4f075f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
2 changes: 2 additions & 0 deletions molpipeline/explainability/visualization/heatmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def grid_field_lim(
) -> tuple[tuple[float, float], tuple[float, float]]:
"""Get x and y coordinates for the upper left and lower right position of specified pixel.
Parameters
----------
x_idx: int
cell-index along x-axis.
y_idx: int
Expand Down
72 changes: 70 additions & 2 deletions molpipeline/explainability/visualization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,25 @@ def _make_grid(
grid_resolution: Sequence[int],
padding: Sequence[float],
) -> ValueGrid:
# Setting up the grid
"""Create a grid for the molecule.
Parameters
----------
mol: Chem.Mol
RDKit molecule object.
canvas: rdMolDraw2D.MolDraw2D
RDKit canvas.
grid_resolution: Sequence[int]
Resolution of the grid.
padding: Sequence[float]
Padding of the grid.
Returns
-------
ValueGrid
ValueGrid object.
"""

xl, yl = [list(lim) for lim in get_mol_lims(mol)] # Limit of molecule

# Extent of the canvas is approximated by size of molecule scaled by ratio of canvas height and width.
Expand Down Expand Up @@ -152,7 +170,26 @@ def _add_gaussians_for_atoms(
atom_weights: npt.NDArray[np.float64],
atom_width: float,
) -> ValueGrid:
# Adding Gauss-functions centered at atoms
"""Add Gauss-functions centered at atoms to the grid.
Parameters
----------
mol: Chem.Mol
RDKit molecule object.
conf: Chem.Conformer
Conformation of the molecule.
v_map: ValueGrid
ValueGrid object to which the functions are added.
atom_weights: npt.NDArray[np.float64]
Array of weights for atoms.
atom_width: float
Width of the displayed atom weights.
Returns
-------
ValueGrid
ValueGrid object with added functions.
"""
for i, _ in enumerate(mol.GetAtoms()):
if atom_weights[i] == 0:
continue
Expand All @@ -169,6 +206,7 @@ def _add_gaussians_for_atoms(
return v_map


# pylint: disable=too-many-locals
def _add_gaussians_for_bonds(
mol: Chem.Mol,
conf: Chem.Conformer,
Expand All @@ -177,6 +215,29 @@ def _add_gaussians_for_bonds(
bond_width: float,
bond_length: float,
) -> ValueGrid:
"""Add Gauss-functions centered at bonds to the grid.
Parameters
----------
mol: Chem.Mol
RDKit molecule object.
conf: Chem.Conformer
Conformation of the molecule.
v_map: ValueGrid
ValueGrid object to which the functions are added.
bond_weights: npt.NDArray[np.float64]
Array of weights for bonds.
bond_width: float
Width of the displayed bond weights (perpendicular to bond-axis).
bond_length: float
Length of the displayed bond weights (along the bond-axis).
Returns
-------
ValueGrid
ValueGrid object with added functions.
"""

# Adding Gauss-functions centered at bonds (position between the two bonded-atoms)
for i, b in enumerate(mol.GetBonds()): # type: Chem.Bond
if bond_weights[i] == 0:
Expand Down Expand Up @@ -258,6 +319,13 @@ def mapvalues2mol(
rdMolDraw2D.MolDraw2D
Drawing of molecule and corresponding heatmap.
"""

if not isinstance(atom_weights, np.ndarray):
atom_weights = np.array(atom_weights)

if not isinstance(bond_weights, np.ndarray):
atom_weights = np.array(atom_weights)

# assign default values
if atom_weights is None:
atom_weights = np.zeros(len(mol.GetAtoms()))
Expand Down

0 comments on commit f4f075f

Please sign in to comment.