Skip to content

Commit

Permalink
feat: enable GroupedTransformer.set_output(..) (#697)
Browse files Browse the repository at this point in the history
* feat: GroupedTransformer allow set_output

* skip polars for pre 1.4.0 version
  • Loading branch information
FBruzzesi committed Aug 21, 2024
1 parent a4c8e25 commit 89d1da6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sklego/meta/_grouped_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def parse_X_y(X, y, groups, check_X=True, **kwargs) -> nw.DataFrame:
X = nw.from_native(pd.DataFrame(X))

# Check groups and feaures values
if groups is not None:
if groups:
_validate_groups_values(X, groups)

if check_X:
Expand Down
15 changes: 12 additions & 3 deletions sklego/meta/grouped_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import List, Union

import narwhals.stable.v1 as nw
import numpy as np
Expand Down Expand Up @@ -108,14 +108,19 @@ def fit(self, X, y=None):
"""
self.__check_transformer()
self.fallback_ = None
self.groups_ = as_list(self.groups) if self.groups is not None else None
self.groups_ = as_list(self.groups) if self.groups is not None else []

X = nw.from_native(X, strict=False, eager_only=True)
if not isinstance(X, nw.DataFrame) and self.groups_ is not None:

if isinstance(X, nw.DataFrame):
self.feature_names_out_ = [c for c in X.columns if c not in self.groups_]

else:
# Accounts for negative indices if X is an array
self.groups_ = [
X.shape[1] + group if isinstance(group, int) and group < 0 else group for group in self.groups_
]
self.feature_names_out_ = [f"x{i}" for i in range(X.shape[1] - len(self.groups_))]

frame = parse_X_y(X, y, self.groups_, check_X=self.check_X, **self._check_kwargs)

Expand Down Expand Up @@ -203,3 +208,7 @@ def transform(self, X):

def _more_tags(self):
return {"allow_nan": True}

def get_feature_names_out(self) -> List[str]:
"Alias for the `feature_names_out_` attribute defined during fit."
return self.feature_names_out_
15 changes: 15 additions & 0 deletions tests/test_meta/test_grouped_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import polars as pl
import pytest
import sklearn
from sklearn import clone
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
Expand All @@ -25,6 +26,7 @@ def test_sklearn_compatible_estimator(estimator, check):
"check_dtype_object", # custom message
"check_estimators_empty_data_messages", # custom message
"check_estimators_pickle", # Fails if input contains nan
"check_fit1d",
}:
pytest.skip()

Expand Down Expand Up @@ -389,3 +391,16 @@ def test_transform_with_y(transformer):
)

assert np.allclose(X_naive.to_numpy(), X_grouped)


@pytest.mark.parametrize(("frame_func", "transform_output"), [(pd.DataFrame, "pandas"), (pl.DataFrame, "polars")])
def test_set_output(penguins_df, frame_func, transform_output):
if transform_output == "polars" and sklearn.__version__ < "1.4.0":
pytest.skip()

X = frame_func(penguins_df.drop(columns=["sex"]))
y = penguins_df["sex"]

meta = GroupedTransformer(StandardScaler(), groups="island").set_output(transform=transform_output)
transformed = meta.fit_transform(X, y)
assert isinstance(transformed, frame_func)

0 comments on commit 89d1da6

Please sign in to comment.