Skip to content

Commit

Permalink
check_input default to False
Browse files Browse the repository at this point in the history
  • Loading branch information
mtrbean committed Oct 16, 2023
1 parent 512b2b1 commit 935d50a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
7 changes: 5 additions & 2 deletions sklego/meta/estimator_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sklearn.utils.validation import (
check_is_fitted,
check_X_y,
FLOAT_DTYPES
)


Expand All @@ -19,14 +20,16 @@ class EstimatorTransformer(TransformerMixin, MetaEstimatorMixin, BaseEstimator):
:param check_input: Whether to check the input data for NaNs, Infs and non-numeric values
"""

def __init__(self, estimator, predict_func="predict", check_input=True):
def __init__(self, estimator, predict_func="predict", check_input=False):
self.estimator = estimator
self.predict_func = predict_func
self.check_input = check_input

def fit(self, X, y, **kwargs):
"""Fits the estimator"""
X, y = check_X_y(X, y, estimator=self, force_all_finite=self.check_input, dtype=None, multi_output=True)

if self.check_input:
X, y = check_X_y(X, y, estimator=self, dtype=FLOAT_DTYPES, multi_output=True)

self.multi_output_ = len(y.shape) > 1
self.estimator_ = clone(self.estimator)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_meta/test_estimatortransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"test_fn", flatten([transformer_checks, general_checks])
)
def test_estimator_checks(test_fn):
trf = EstimatorTransformer(LinearRegression())
trf = EstimatorTransformer(LinearRegression(), check_input=True)
test_fn(EstimatorTransformer.__name__, trf)


Expand Down Expand Up @@ -53,7 +53,7 @@ def test_get_params():
"estimator": clf,
"estimator__strategy": "most_frequent",
"predict_func": "predict",
"check_input": True,
"check_input": False,
}


Expand Down

0 comments on commit 935d50a

Please sign in to comment.