From 89d1da611ee9d1beca9f870121c97f65f4190d4e Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Wed, 21 Aug 2024 11:18:54 +0200 Subject: [PATCH] feat: enable `GroupedTransformer.set_output(..)` (#697) * feat: GroupedTransformer allow set_output * skip polars for pre 1.4.0 version --- sklego/meta/_grouped_utils.py | 2 +- sklego/meta/grouped_transformer.py | 15 ++++++++++++--- tests/test_meta/test_grouped_transformer.py | 15 +++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sklego/meta/_grouped_utils.py b/sklego/meta/_grouped_utils.py index 47e71df4..32bec84d 100644 --- a/sklego/meta/_grouped_utils.py +++ b/sklego/meta/_grouped_utils.py @@ -25,7 +25,7 @@ def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame: X = nw.from_native(pd.DataFrame(X)) # Check groups and feaures values - if groups is not None: + if groups: _validate_groups_values(X, groups) if check_X: diff --git a/sklego/meta/grouped_transformer.py b/sklego/meta/grouped_transformer.py index fe737ec6..b603bd13 100644 --- a/sklego/meta/grouped_transformer.py +++ b/sklego/meta/grouped_transformer.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import List, Union import narwhals.stable.v1 as nw import numpy as np @@ -108,14 +108,19 @@ def fit(self, X, y=None): """ self.__check_transformer() self.fallback_ = None - self.groups_ = as_list(self.groups) if self.groups is not None else None + self.groups_ = as_list(self.groups) if self.groups is not None else [] X = nw.from_native(X, strict=False, eager_only=True) - if not isinstance(X, nw.DataFrame) and self.groups_ is not None: + + if isinstance(X, nw.DataFrame): + self.feature_names_out_ = [c for c in X.columns if c not in self.groups_] + + else: # Accounts for negative indices if X is an array self.groups_ = [ X.shape[1] + group if isinstance(group, int) and group < 0 else group for group in self.groups_ ] + self.feature_names_out_ = [f"x{i}" for i in range(X.shape[1] - len(self.groups_))] frame = parse_X_y(X, y, self.groups_, check_X=self.check_X, **self._check_kwargs) @@ -203,3 +208,7 @@ def transform(self, X): def _more_tags(self): return {"allow_nan": True} + + def get_feature_names_out(self) -> List[str]: + "Alias for the `feature_names_out_` attribute defined during fit." + return self.feature_names_out_ diff --git a/tests/test_meta/test_grouped_transformer.py b/tests/test_meta/test_grouped_transformer.py index fe1d3e70..875bf604 100644 --- a/tests/test_meta/test_grouped_transformer.py +++ b/tests/test_meta/test_grouped_transformer.py @@ -4,6 +4,7 @@ import pandas as pd import polars as pl import pytest +import sklearn from sklearn import clone from sklearn.linear_model import LinearRegression from sklearn.pipeline import Pipeline @@ -25,6 +26,7 @@ def test_sklearn_compatible_estimator(estimator, check): "check_dtype_object", # custom message "check_estimators_empty_data_messages", # custom message "check_estimators_pickle", # Fails if input contains nan + "check_fit1d", }: pytest.skip() @@ -389,3 +391,16 @@ def test_transform_with_y(transformer): ) assert np.allclose(X_naive.to_numpy(), X_grouped) + + +@pytest.mark.parametrize(("frame_func", "transform_output"), [(pd.DataFrame, "pandas"), (pl.DataFrame, "polars")]) +def test_set_output(penguins_df, frame_func, transform_output): + if transform_output == "polars" and sklearn.__version__ < "1.4.0": + pytest.skip() + + X = frame_func(penguins_df.drop(columns=["sex"])) + y = penguins_df["sex"] + + meta = GroupedTransformer(StandardScaler(), groups="island").set_output(transform=transform_output) + transformed = meta.fit_transform(X, y) + assert isinstance(transformed, frame_func)