Skip to content

Commit

Permalink
add test for full coverage of multiclass chemprop
Browse files Browse the repository at this point in the history
  • Loading branch information
JenniferHem committed Aug 28, 2024
1 parent 1b84120 commit 3dbcff8
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def get_classification_pipeline() -> Pipeline:
return model_pipeline


def get_multiclass_classification_pipeline() -> Pipeline:
def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline:
"""Get the Chemprop model pipeline for classification.
Returns
Expand All @@ -155,7 +155,7 @@ def get_multiclass_classification_pipeline() -> Pipeline:
error_filter, fill_value=np.nan
)
chemprop_model = ChempropMulticlassClassifier(
n_classes=3, lightning_trainer=DEFAULT_TRAINER
n_classes=n_classes, lightning_trainer=DEFAULT_TRAINER
)
model_pipeline = Pipeline(
steps=[
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_prediction(self) -> None:
)
print(test_data_df.head())
print(test_data_df.columns)
classification_model = get_multiclass_classification_pipeline()
classification_model = get_multiclass_classification_pipeline(n_classes=3)
mols = test_data_df["Molecule"].tolist()
classification_model.fit(
mols,
Expand Down Expand Up @@ -376,3 +376,9 @@ def test_prediction(self) -> None:
mols,
test_data_df["Label"].add(1).to_numpy(),
)
with self.assertRaises(ValueError):
classification_model = get_multiclass_classification_pipeline(n_classes=2)
classification_model.fit(
mols,
test_data_df["Label"].to_numpy(),
)

0 comments on commit 3dbcff8

Please sign in to comment.