Skip to content

Commit

Permalink
use input check to prevent confusing message by torch of the class la…
Browse files Browse the repository at this point in the history
…bels do not match requirements
  • Loading branch information
JenniferHem committed Aug 26, 2024
1 parent 737db7f commit 8692bc2
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,4 +413,48 @@ def get_params(self, deep: bool = False) -> dict[str, Any]:
params = super().get_params(deep)
return params

def fit(
self,
X: MoleculeDataset,
y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64],
) -> Self:
"""Fit the model to the data.
Parameters
----------
X : MoleculeDataset
The input data.
y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
The target data.
Returns
-------
Self
The fitted model.
"""
self._check_correct_input(y)
return super().fit(X, y)

def _check_correct_input(self,y) -> None:
"""Checks if the input for the multi-class classifier is correct.
Parameters
----------
y : _type_
Indended classes for the dataset
Raises
------
ValueError
if the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1
"""
unique_y = np.unique(y)
log = []
if self.n_classes != len(unique_y):
log.append(f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data.")
if sorted(unique_y) != list(range(self.n_classes)):
err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly."
print(err)
log.append(err)
if log:
raise ValueError("\n".join(log))

0 comments on commit 8692bc2

Please sign in to comment.