Skip to content

Commit

Permalink
explainability: adapting to Christians comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Jun 18, 2024
1 parent a9245cd commit 1fe0d02
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 94 deletions.
213 changes: 145 additions & 68 deletions molpipeline/explainability/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,42 @@
import numpy as np
import numpy.typing as npt
import shap
from scipy.sparse import issparse
from scipy.sparse import issparse, spmatrix

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.abstract_pipeline_elements.core import InvalidInstance, OptionalMol
from molpipeline.explainability.explanation import Explanation
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor
from molpipeline.utils.value_checks import get_length


# pylint: disable=C0103,W0613
def _mitigate_feature_incompatibility_with_shap(X: Any) -> Any:
def _to_dense(
feature_matrix: npt.NDArray[Any] | spmatrix,
) -> npt.NDArray[Any]:
"""Mitigate feature incompatibility with SHAP objects.
Parameters
----------
X : Any
feature_matrix : npt.NDArray[Any] | spmatrix
The input features.
Returns
-------
Any
The input features in a compatible format.
"""
if issparse(X):
return X.todense()
return X
if issparse(feature_matrix):
return feature_matrix.todense()
return feature_matrix


# This function might also be put at a more central position in the lib.
def _get_predictions(pipeline: Pipeline, X: Any) -> npt.NDArray[np.float_]:
def _get_predictions(
pipeline: Pipeline, feature_matrix: npt.NDArray[Any] | spmatrix
) -> npt.NDArray[np.float_]:
"""Get the predictions of a model.
Raises if no adequate method is found.
Expand All @@ -47,7 +52,7 @@ def _get_predictions(pipeline: Pipeline, X: Any) -> npt.NDArray[np.float_]:
----------
pipeline : Pipeline
The pipeline containing the model.
X : Any
feature_matrix : Any
The input data.
Returns
Expand All @@ -56,11 +61,11 @@ def _get_predictions(pipeline: Pipeline, X: Any) -> npt.NDArray[np.float_]:
The predictions.
"""
if hasattr(pipeline, "predict_proba"):
return pipeline.predict_proba(X)
return pipeline.predict_proba(feature_matrix)
if hasattr(pipeline, "decision_function"):
return pipeline.decision_function(X)
return pipeline.decision_function(feature_matrix)
if hasattr(pipeline, "predict"):
return pipeline.predict(X)
return pipeline.predict(feature_matrix)
raise ValueError("Could not determine the model output predictions")


Expand Down Expand Up @@ -163,51 +168,112 @@ def __init__(self, pipeline: Pipeline, **kwargs: Any) -> None:
**kwargs,
)

# extract the molecule reader subpipeline
self.molecule_reader_subpipeline = (
pipeline_extractor.get_molecule_reader_subpipeline()
)
if self.molecule_reader_subpipeline is None:
raise ValueError("Could not determine the molecule reader subpipeline.")

# extract the featurization subpipeline
self.featurization_subpipeline = (
pipeline_extractor.get_featurization_subpipeline()
)
if self.featurization_subpipeline is None:
raise ValueError("Could not determine the featurization subpipeline.")

# extract fill values for checking error handling
self.fill_values = pipeline_extractor.get_all_filter_reinserter_fill_values()
self.fill_values_contain_nan = np.isnan(self.fill_values).any()

def _prediction_is_valid(self, prediction: Any) -> bool:
"""Check if the prediction is valid using some heuristics.
Can be used to catch inputs that failed the pipeline for some reason.
Parameters
----------
prediction : Any
The prediction.
Returns
-------
bool
Whether the prediction is valid.
"""
# if no prediction could be obtained (length is 0); the prediction guaranteed failed.
if len(prediction) == 0:
return False
# create subpipelines for extracting intermediate results for explanations
(
self.molecule_reader_subpipeline,
self.featurization_subpipeline,
self.model_subpipeline,
) = self._extract_subpipelines(model, pipeline, pipeline_extractor)

# if a value in the prediction is a fill-value, we - assume - the explanation has failed.
if np.isin(prediction, self.fill_values).any():
return False
if self.fill_values_contain_nan and np.isnan(prediction).any():
# the extra nan check is necessary because np.isin does not work with nan
return False
if len(self.featurization_subpipeline.steps) > 1:
raise AssertionError(
"The featurization subpipeline should only contain one element. Multiple elements are not supported."
)

return True
def _extract_subpipelines(
self, model: Any, pipeline: Pipeline, pipeline_extractor: SubpipelineExtractor
) -> tuple[Pipeline, Pipeline, Pipeline]:

# first extract elements we need the output from
featurization_element = pipeline_extractor.get_featurization_element()
if featurization_element is None:
raise ValueError("Could not determine the featurization element.")

def get_index(element):
for idx, step in enumerate(pipeline.steps):
if id(step[1]) == id(element):
return idx
return None

featurization_element_idx = get_index(featurization_element)
if featurization_element_idx is None:
raise ValueError(
"Could not determine the index of the featurization element."
)
model_element_idx = get_index(model)
if model_element_idx is None:
raise ValueError("Could not determine the index of the model element.")

# reader subpipeline is from step 0 to one before the featurization element
reader_subpipeline = self.pipeline[:featurization_element_idx]
featurization_subpipeline = self.pipeline[
featurization_element_idx:model_element_idx
]
model_subpipeline = self.pipeline[model_element_idx:]
return reader_subpipeline, featurization_subpipeline, model_subpipeline

# def _extract_subpipelines(
# self, model: Any, pipeline: Pipeline, pipeline_extractor: SubpipelineExtractor
# ) -> tuple[Pipeline, Pipeline, Pipeline]:
# """Extract the subpipelines from the pipeline extractor.
#
# We extract 3 subpipeline. Each subpipeline is an interval of the original pipeline.
# 1. The first subpipeline is for reading the input to a molecule. The resulting molecules are ready
# for featurization, .e.g. it went through standardization steps.
# 2. The second subpipeline featurizes the molecules to a machine learning ready format.
# 3. The third subpipeline executes the machine learning inference step, including post-processing.
#
#
# Parameters
# ----------
# model : Any
# The model element.
# pipeline : Pipeline
# The pipeline.
# pipeline_extractor : SubpipelineExtractor
# The pipeline extractor.
#
# Returns
# -------
# tuple[Pipeline, Pipeline, Pipeline]
# The molecule reader, featurization, and prediction subpipelines.
# """
#
# # The pipeline in split into subsequent intervals covering the whole pipeline.
# # The intervals are defined as:
# # 1. Molecule reading subpipeline: from the beginning to the position before the featurization element.
# # 2. Featurization subpipeline: from the featurization element to the position before the model element.
# # 3. Model subpipeline: from the position after the featurization element to the end of the pipeline.
# # This heuristic process needs only to find the featurization element and model element to infer
# # the subpipelines.
#
# # first extract elements we need the output from
# featurization_element = pipeline_extractor.get_featurization_element()
# if featurization_element is None:
# raise ValueError("Could not determine the featurization element.")
#
# # reader subpipeline is from step 0 to one before the featurization element
# reader_subpipeline = pipeline_extractor.get_subpipeline(
# pipeline.steps[0][1], featurization_element, second_offset=-1
# )
# if reader_subpipeline is None:
# raise ValueError("Could not determine the molecule reader subpipeline.")
#
# # the featurization subpipeline is from the featurization element one element before the model element.
# featurization_subpipeline = pipeline_extractor.get_subpipeline(
# featurization_element, model, second_offset=-1
# )
# if featurization_subpipeline is None:
# raise ValueError("Could not determine the featurization subpipeline.")
#
# # the model subpipeline is from the first element after the featurization element until the end of the pipeline
# model_subpipeline = pipeline_extractor.get_subpipeline(
# featurization_element, pipeline.steps[-1][1], first_offset=1
# )
# if model_subpipeline is None:
# raise ValueError("Could not determine the model subpipeline.")
#
# return reader_subpipeline, featurization_subpipeline, model_subpipeline

# pylint: disable=C0103,W0613
def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
Expand All @@ -221,35 +287,46 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
X : Any
The input data to explain.
kwargs : Any
Additional keyword arguments for SHAP's TreeExplainer.shap_values .
Additional keyword arguments for SHAP's TreeExplainer.shap_values.
Returns
-------
list[Explanation]
List of explanations corresponding to the input data.
"""
featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore
featurization_element = self.featurization_subpipeline.steps[0][1] # type: ignore

explanation_results = []
for input_sample in X:

input_sample = [input_sample]

# get the molecule
molecule_list = self.molecule_reader_subpipeline.transform(input_sample) # type: ignore
if len(molecule_list) == 0 or isinstance(molecule_list[0], InvalidInstance):
explanation_results.append(Explanation())
continue

feature_vector = self.featurization_subpipeline.transform(molecule_list) # type: ignore
if get_length(feature_vector) == 0 or isinstance(
feature_vector[0], InvalidInstance
):
explanation_results.append(Explanation())
continue

# get predictions
prediction = _get_predictions(self.pipeline, input_sample)
if not self._prediction_is_valid(prediction):
# we use the prediction to check if the input is valid. If not, we cannot explain it.
prediction = _get_predictions(self.model_subpipeline, feature_vector)
if len(prediction) == 0 or isinstance(prediction[0], InvalidInstance):
explanation_results.append(Explanation())
continue

# todo fill values?

if prediction.ndim > 1:
prediction = prediction.squeeze()

# get the molecule
molecule = self.molecule_reader_subpipeline.transform(input_sample)[0] # type: ignore

# get feature vectors
feature_vector = self.featurization_subpipeline.transform(input_sample) # type: ignore
feature_vector = _mitigate_feature_incompatibility_with_shap(feature_vector)
# reshape feature vector for SHAP and output
feature_vector = _to_dense(feature_vector)
feature_vector = np.asarray(feature_vector).squeeze()

# Feature names should also be extracted from the Pipeline.
Expand All @@ -268,7 +345,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
# for Morgan fingerprint, we can map the shap values to atom weights
atom_weights = _convert_shap_feature_weights_to_atom_weights(
feature_weights,
molecule,
molecule_list[0],
featurization_element,
feature_vector,
)
Expand All @@ -277,7 +354,7 @@ def explain(self, X: Any, **kwargs: Any) -> list[Explanation]:
Explanation(
feature_vector=feature_vector,
feature_names=feature_names,
molecule=molecule,
molecule=molecule_list[0],
prediction=prediction,
feature_weights=feature_weights,
atom_weights=atom_weights,
Expand Down
9 changes: 0 additions & 9 deletions molpipeline/explainability/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,6 @@ def rdkit_gaussplot(
contour_params=cps,
sigma_f=0.4,
)
# from rdkit.Chem.Draw import SimilarityMaps
# drawer = SimilarityMaps.GetSimilarityMapFromWeights(
# mol,
# weights,
# contour_lines=n_contour_lines,
# draw2d=drawer,
# contour_params=cps,
# sigma_f=0.4,
# )
drawer.FinishDrawing()
return drawer

Expand Down
32 changes: 16 additions & 16 deletions tests/test_explainability/test_shap_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ def test_explanations_fingerprint_pipeline(self) -> None:
"""Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints."""

estimators = [
RandomForestClassifier(random_state=_RANDOM_STATE),
RandomForestRegressor(random_state=_RANDOM_STATE),
GradientBoostingClassifier(random_state=_RANDOM_STATE),
GradientBoostingRegressor(random_state=_RANDOM_STATE),
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE),
]
n_bits = 64

Expand Down Expand Up @@ -172,10 +172,10 @@ def test_explanations_pipeline_with_invalid_inputs(self) -> None:
"""Test SHAP's TreeExplainer wrapper with invalid inputs."""

estimators = [
RandomForestClassifier(random_state=_RANDOM_STATE),
RandomForestRegressor(random_state=_RANDOM_STATE),
GradientBoostingClassifier(random_state=_RANDOM_STATE),
GradientBoostingRegressor(random_state=_RANDOM_STATE),
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE),
]

n_bits = 64
Expand Down Expand Up @@ -244,10 +244,10 @@ def test_explanations_pipeline_with_physchem(self) -> None:
"""Test SHAP's TreeExplainer wrapper on physchem feature vector."""

estimators = [
RandomForestClassifier(random_state=_RANDOM_STATE),
RandomForestRegressor(random_state=_RANDOM_STATE),
GradientBoostingClassifier(random_state=_RANDOM_STATE),
GradientBoostingRegressor(random_state=_RANDOM_STATE),
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE),
]

# test explanations with different estimators
Expand Down Expand Up @@ -286,10 +286,10 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None:
"""Test SHAP's TreeExplainer wrapper on concatenated feature vector."""

estimators = [
RandomForestClassifier(random_state=_RANDOM_STATE),
RandomForestRegressor(random_state=_RANDOM_STATE),
GradientBoostingClassifier(random_state=_RANDOM_STATE),
GradientBoostingRegressor(random_state=_RANDOM_STATE),
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingClassifier(n_estimators=2, random_state=_RANDOM_STATE),
GradientBoostingRegressor(n_estimators=2, random_state=_RANDOM_STATE),
]

n_bits = 64
Expand Down
5 changes: 4 additions & 1 deletion tests/test_explainability/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def test_test_fingerprint_based_atom_coloring(self) -> None:
[
("smi2mol", SmilesToMol()),
("morgan", MolToMorganFP(radius=1, n_bits=1024)),
("model", RandomForestClassifier(random_state=_RANDOM_STATE)),
(
"model",
RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE),
),
]
)
pipeline.fit(TEST_SMILES, CONTAINS_OX)
Expand Down

0 comments on commit 1fe0d02

Please sign in to comment.