Skip to content

Commit 9643e46

Browse files
remove wrapper for support_input_format where not needed (#3027)
1 parent 7584d8f commit 9643e46

2 files changed

Lines changed: 3 additions & 2 deletions

File tree

sklearnex/preview/covariance/covariance.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
# This is a temporary workaround for issues with sklearnex._device_offload._get_host_inputs
4242
# passing kwargs with sycl queues with other host data will cause failures
43+
# Note: this wrapper could potentially be removed later if sklearn implements
44+
# array API support for this metric in their pairwise_distances class.
4345
_mahalanobis = support_input_format(partial(pairwise_distances, metric="mahalanobis"))
4446

4547

sklearnex/preview/linear_model/logistic_regression.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from daal4py.sklearn.linear_model.logistic_path import (
2222
LogisticRegressionCV as _daal4py_LogisticRegressionCV,
2323
)
24-
from onedal._device_offload import support_input_format
2524

2625
from ...linear_model.logistic_regression import (
2726
LogisticRegression as _sklearnex_LogisticRegression,
@@ -34,7 +33,7 @@
3433
class LogisticRegressionCV(
3534
_daal4py_LogisticRegressionCV, _sklearnex_LogisticRegression
3635
):
37-
fit = support_input_format(_daal4py_LogisticRegressionCV.fit)
36+
fit = _daal4py_LogisticRegressionCV.fit
3837
predict_proba = _sklearnex_LogisticRegression.predict_proba
3938
predict_log_proba = _sklearnex_LogisticRegression.predict_log_proba
4039
decision_function = _sklearnex_LogisticRegression.decision_function

0 commit comments

Comments
 (0)