Skip to content

Commit

Permalink
Merge pull request #20 from basf/fix-scorer-sign-error
Browse files Browse the repository at this point in the history
Fix scorer sign error
  • Loading branch information
JenniferHem committed Jun 13, 2024
2 parents e920f4e + 5f26c0b commit eed5362
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 1 deletion.
2 changes: 2 additions & 0 deletions molpipeline/metrics/ignore_error_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def ignored_value_scorer(
score_func = scorer._score_func # pylint: disable=protected-access
response_method = scorer._response_method # pylint: disable=protected-access
scorer_kwargs = scorer._kwargs # pylint: disable=protected-access
if scorer._sign < 0: # pylint: disable=protected-access
scorer_kwargs["greater_is_better"] = False

def newscore(
y_true: npt.NDArray[np.float_ | np.int_],
Expand Down
1 change: 0 additions & 1 deletion molpipeline/pipeline/_skl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union


try:
from typing import Self # type: ignore[attr-defined]
except ImportError:
Expand Down
46 changes: 46 additions & 0 deletions tests/test_metrics/test_ignore_error_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import unittest

import numpy as np
from sklearn import linear_model
from sklearn.metrics import get_scorer

from molpipeline.metrics import ignored_value_scorer

Expand Down Expand Up @@ -45,3 +47,47 @@ def test_filter_none_with_nan(self) -> None:
ba_score._score_func(y_true, y_pred), # pylint: disable=protected-access
1.0,
)

def test_correct_init_mse(self) -> None:
"""Test that initialization is correct as we access via protected vars."""
x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape(
-1, 1
)
y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
cix_scorer = ignored_value_scorer("neg_mean_squared_error", None)
scikit_scorer = get_scorer("neg_mean_squared_error")
self.assertEqual(
cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train)
)

def test_correct_init_rmse(self) -> None:
"""Test that initialization is correct as we access via protected vars."""
x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape(
-1, 1
)
y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
cix_scorer = ignored_value_scorer("neg_root_mean_squared_error", None)
scikit_scorer = get_scorer("neg_root_mean_squared_error")
self.assertEqual(
cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train)
)

def test_correct_init_inheritance(self) -> None:
"""Test that initialization is correct if we pass an initialized scorer."""
x_train = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]).reshape(
-1, 1
)
y_train = np.array([0.1, 0.3, 0.3, 0.4, 0.5, 0.5, 0.7, 0.88, 0.9, 1])
regr = linear_model.LinearRegression()
regr.fit(x_train, y_train)
scikit_scorer = get_scorer("neg_root_mean_squared_error")
cix_scorer = ignored_value_scorer(
get_scorer("neg_root_mean_squared_error"), None
)
self.assertEqual(
cix_scorer(regr, x_train, y_train), scikit_scorer(regr, x_train, y_train)
)

0 comments on commit eed5362

Please sign in to comment.