diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 3c1935b4..11ecde9c 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -41,6 +41,10 @@ def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: deep : bool, optional Not used, only present to match the sklearn API. + Returns + ------- + dict[str, Any] + The parameters of the loss function. """ return {"task_weights": self._original_task_weights}