Skip to content

Commit 5af9e80

Browse files
Remove raw inputs (#3095)
* Remove raw inputs * formatting * isort * remove more instances of use_raw_inputs --------- Co-authored-by: david-cortes-intel <david.cortes@intel.com>
1 parent 5309887 commit 5af9e80

34 files changed

Lines changed: 318 additions & 572 deletions

onedal/_config.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,11 @@
3636
If True, allows to fallback computation to sklearn after onedal
3737
backend in case of runtime error on onedal backend computations.
3838
Global default: True.
39-
use_raw_input:
40-
If True, uses the raw input data in some SPMD onedal backend computations
41-
without any checks on data consistency or validity.
42-
Note: This option is not recommended for general use.
43-
Global default: False.
4439
"""
4540
_default_global_config = {
4641
"target_offload": "auto",
4742
"allow_fallback_to_host": False,
4843
"allow_sklearn_after_onedal": True,
49-
"use_raw_input": False,
5044
}
5145

5246
_threadlocal = threading.local()

onedal/_device_offload.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@
1515
# ==============================================================================
1616

1717
import inspect
18-
import logging
1918
from functools import wraps
2019
from operator import xor
2120

2221
import numpy as np
2322
from sklearn import get_config
2423

25-
from ._config import _get_config
2624
from .datatypes import copy_to_dpnp, dlpack_to_numpy
2725
from .utils import _sycl_queue_manager as QM
2826
from .utils._array_api import _asarray, _get_sycl_namespace, _is_numpy_namespace
2927
from .utils._third_party import is_dpnp_ndarray
3028

31-
logger = logging.getLogger("sklearnex")
32-
3329

3430
def supports_queue(func):
3531
"""Decorator that updates the global queue before function evaluation.
@@ -126,26 +122,7 @@ def wrapper_impl(*args, **kwargs):
126122
else:
127123
self = None
128124

129-
# KNeighbors*.fit can not be used with raw inputs, ignore `use_raw_input=True`
130-
override_raw_input = (
131-
self
132-
and self.__class__.__name__ in ("KNeighborsClassifier", "KNeighborsRegressor")
133-
and func.__name__ == "fit"
134-
and _get_config()["use_raw_input"] is True
135-
)
136-
if override_raw_input:
137-
pretty_name = f"{self.__class__.__name__}.{func.__name__}"
138-
logger.warning(
139-
f"Using raw inputs is not supported for {pretty_name}. Ignoring `use_raw_input=True` setting."
140-
)
141-
if _get_config()["use_raw_input"] is True and not override_raw_input:
142-
if "queue" not in kwargs:
143-
if usm_iface := getattr(args[0], "__sycl_usm_array_interface__", None):
144-
kwargs["queue"] = usm_iface["syclobj"]
145-
else:
146-
kwargs["queue"] = None
147-
return invoke_func(self, *args, **kwargs)
148-
elif len(args) == 0 and len(kwargs) == 0:
125+
if len(args) == 0 and len(kwargs) == 0:
149126
# no arguments, there's nothing we can deduce from them -> just call the function
150127
return invoke_func(self, *args, **kwargs)
151128

sklearnex/_config.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ==============================================================================
1616

1717
import sys
18-
import warnings
1918
from contextlib import contextmanager
2019

2120
from sklearn import get_config as skl_get_config
@@ -54,26 +53,10 @@
5453
{tab}
5554
{tab} Global default: ``True``.
5655
{tab}
57-
{tab}use_raw_input : bool or None
58-
{tab} If ``True``, uses the raw input data in some SPMD onedal backend computations
59-
{tab} without any checks on data consistency or validity. Note that this can be
60-
{tab} better achieved through usage of :ref:`array API classes <array_api>` without
61-
{tab} ``target_offload``. Not recommended for general use.
62-
{tab}
63-
{tab} Global default: ``False``.
64-
{tab}
65-
{tab} .. deprecated:: 2026.0
66-
{tab}
6756
{tab}sklearn_configs : kwargs
6857
{tab} Other settings accepted by scikit-learn. See :obj:`sklearn.set_config` for
6958
{tab} details.
7059
{tab}
71-
{tab}Warnings
72-
{tab}--------
73-
{tab}Using ``use_raw_input=True`` is not recommended for general use as it
74-
{tab}bypasses data consistency checks, which may lead to unexpected behavior. It is
75-
{tab}recommended to use the newer :ref:`array API <array_api>` instead.
76-
{tab}
7760
{tab}Note
7861
{tab}----
7962
{tab}Usage of ``target_offload`` requires additional dependencies - see
@@ -102,7 +85,6 @@ def set_config(
10285
target_offload=None,
10386
allow_fallback_to_host=None,
10487
allow_sklearn_after_onedal=None,
105-
use_raw_input=None,
10688
**sklearn_configs,
10789
): # numpydoc ignore=PR01,PR07
10890
"""Set global configuration.
@@ -125,15 +107,6 @@ def set_config(
125107
local_config["allow_fallback_to_host"] = allow_fallback_to_host
126108
if allow_sklearn_after_onedal is not None:
127109
local_config["allow_sklearn_after_onedal"] = allow_sklearn_after_onedal
128-
if use_raw_input is not None:
129-
if use_raw_input:
130-
warnings.warn(
131-
"The 'use_raw_input' parameter is deprecated and will be removed in version 2026.0. "
132-
"On-device input validation can now be achieved by setting 'array_api_dispatch' to True.",
133-
FutureWarning,
134-
stacklevel=2,
135-
)
136-
local_config["use_raw_input"] = use_raw_input
137110

138111

139112
set_config.__doc__ = set_config.__doc__.replace(

sklearnex/_device_offload.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def dispatch(
8080
8181
Depending on support conditions, oneDAL will be called, otherwise it will
8282
fall back to calling scikit-learn. Dispatching to oneDAL can be influenced
83-
by the 'use_raw_input' or 'allow_fallback_to_host' config parameters.
83+
by the 'allow_fallback_to_host' config parameter.
8484
8585
Parameters
8686
----------
@@ -112,10 +112,6 @@ def dispatch(
112112
object types should match for the sklearn and onedal object methods.
113113
"""
114114

115-
if get_config()["use_raw_input"]:
116-
with QM.manage_global_queue(None, *args) as queue:
117-
return branches["onedal"](obj, *args, **kwargs, queue=queue)
118-
119115
# Determine if array_api dispatching is enabled, and if estimator is capable
120116
onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api
121117
sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support
@@ -215,16 +211,8 @@ def wrapper(self, *args, **kwargs) -> Any:
215211
):
216212
_, (result,) = _transfer_to_host(result)
217213
return result
218-
# Remove check for result __sycl_usm_array_interface__ on deprecation of use_raw_inputs
219-
if (
220-
usm_iface := getattr(data, "__sycl_usm_array_interface__", None)
221-
) and not hasattr(result, "__sycl_usm_array_interface__"):
222-
# Skip if result elements are already SYCL arrays
223-
# (e.g. kneighbors tuple from from_table(like=X))
224-
if isinstance(result, (tuple, list)) and all(
225-
hasattr(r, "__sycl_usm_array_interface__") for r in result
226-
):
227-
return result
214+
215+
if usm_iface := getattr(data, "__sycl_usm_array_interface__", None):
228216
queue = usm_iface["syclobj"]
229217
return copy_to_dpnp(queue, result)
230218

sklearnex/basic_statistics/basic_statistics.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from onedal.basic_statistics import BasicStatistics as onedal_BasicStatistics
2222
from onedal.utils.validation import _is_csr
2323

24-
from .._config import get_config
2524
from .._device_offload import dispatch
2625
from .._utils import PatchingConditionsChain
2726
from ..base import oneDALEstimator
@@ -157,20 +156,19 @@ def _onedal_gpu_supported(self, method_name, *data):
157156
return patching_status
158157

159158
def _onedal_fit(self, X, sample_weight=None, queue=None):
160-
if not get_config()["use_raw_input"]:
161-
xp, _ = get_namespace(X, sample_weight)
162-
X = validate_data(
163-
self,
164-
X,
165-
dtype=[xp.float64, xp.float32],
166-
ensure_2d=False,
167-
accept_sparse="csr",
168-
)
159+
xp, _ = get_namespace(X, sample_weight)
160+
X = validate_data(
161+
self,
162+
X,
163+
dtype=[xp.float64, xp.float32],
164+
ensure_2d=False,
165+
accept_sparse="csr",
166+
)
169167

170-
if sample_weight is not None:
171-
sample_weight = _check_sample_weight(
172-
sample_weight, X, dtype=[xp.float64, xp.float32]
173-
)
168+
if sample_weight is not None:
169+
sample_weight = _check_sample_weight(
170+
sample_weight, X, dtype=[xp.float64, xp.float32]
171+
)
174172

175173
onedal_params = {
176174
"result_options": self.result_options,

sklearnex/basic_statistics/incremental_basic_statistics.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
IncrementalBasicStatistics as onedal_IncrementalBasicStatistics,
2424
)
2525

26-
from .._config import get_config
2726
from .._device_offload import dispatch
2827
from .._utils import PatchingConditionsChain, _add_inc_serialization_note
2928
from ..base import oneDALEstimator
@@ -174,8 +173,7 @@ def _onedal_finalize_fit(self, queue=None):
174173
def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=True):
175174
first_pass = not hasattr(self, "n_samples_seen_") or self.n_samples_seen_ == 0
176175

177-
# never check input when using raw input
178-
if check_input and not get_config()["use_raw_input"]:
176+
if check_input:
179177
xp, _ = get_namespace(X)
180178
X = validate_data(
181179
self,
@@ -204,14 +202,13 @@ def _onedal_partial_fit(self, X, sample_weight=None, queue=None, check_input=Tru
204202
self._need_to_finalize = True
205203

206204
def _onedal_fit(self, X, sample_weight=None, queue=None):
207-
if not get_config()["use_raw_input"]:
208-
xp, _ = get_namespace(X, sample_weight)
209-
X = validate_data(self, X, dtype=[xp.float64, xp.float32])
205+
xp, _ = get_namespace(X, sample_weight)
206+
X = validate_data(self, X, dtype=[xp.float64, xp.float32])
210207

211-
if sample_weight is not None:
212-
sample_weight = _check_sample_weight(
213-
sample_weight, X, dtype=[xp.float64, xp.float32]
214-
)
208+
if sample_weight is not None:
209+
sample_weight = _check_sample_weight(
210+
sample_weight, X, dtype=[xp.float64, xp.float32]
211+
)
215212

216213
_, n_features = X.shape
217214
if self.batch_size is None:

sklearnex/cluster/dbscan.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from onedal.cluster import DBSCAN as onedal_DBSCAN
2222
from onedal.utils._array_api import _is_numpy_namespace
2323

24-
from .._config import get_config
2524
from .._device_offload import dispatch
2625
from .._utils import PatchingConditionsChain
2726
from ..base import oneDALEstimator
@@ -77,14 +76,11 @@ def __init__(
7776

7877
def _onedal_fit(self, X, y, sample_weight=None, queue=None):
7978
xp, _ = get_namespace(X, y, sample_weight)
80-
if not get_config()["use_raw_input"]:
81-
X = validate_data(
82-
self, X, accept_sparse="csr", dtype=[xp.float64, xp.float32]
79+
X = validate_data(self, X, accept_sparse="csr", dtype=[xp.float64, xp.float32])
80+
if sample_weight is not None:
81+
sample_weight = _check_sample_weight(
82+
sample_weight, X, dtype=[xp.float64, xp.float32]
8383
)
84-
if sample_weight is not None:
85-
sample_weight = _check_sample_weight(
86-
sample_weight, X, dtype=[xp.float64, xp.float32]
87-
)
8884

8985
onedal_params = {
9086
"eps": self.eps,

sklearnex/cluster/k_means.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -193,29 +193,28 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
193193

194194
xp, _ = get_namespace(X)
195195

196-
if not get_config()["use_raw_input"]:
197-
if _is_arraylike_not_scalar(self.init):
198-
init = validate_data(
199-
self,
200-
self.init,
201-
dtype=[xp.float64, xp.float32],
202-
accept_sparse="csr",
203-
copy=True,
204-
order="C",
205-
reset=False,
206-
)
207-
self._validate_center_shape(X, init)
208-
self.init = init
209-
210-
X = validate_data(
196+
if _is_arraylike_not_scalar(self.init):
197+
init = validate_data(
211198
self,
212-
X,
213-
accept_sparse="csr",
199+
self.init,
214200
dtype=[xp.float64, xp.float32],
201+
accept_sparse="csr",
202+
copy=True,
215203
order="C",
216-
copy=self.copy_x,
217-
accept_large_sparse=False,
204+
reset=False,
218205
)
206+
self._validate_center_shape(X, init)
207+
self.init = init
208+
209+
X = validate_data(
210+
self,
211+
X,
212+
accept_sparse="csr",
213+
dtype=[xp.float64, xp.float32],
214+
order="C",
215+
copy=self.copy_x,
216+
accept_large_sparse=False,
217+
)
219218

220219
# Validate critical parameters to match sklearn's _check_params
221220
# behavior, which we bypass in the oneDAL path. This is needed
@@ -386,16 +385,13 @@ def predict(
386385
def _onedal_predict(self, X, sample_weight=None, queue=None):
387386

388387
xp, _ = get_namespace(X)
389-
390-
if not get_config()["use_raw_input"]:
391-
X = validate_data(
392-
self,
393-
X,
394-
accept_sparse="csr",
395-
reset=False,
396-
dtype=[xp.float64, xp.float32],
397-
)
398-
388+
X = validate_data(
389+
self,
390+
X,
391+
accept_sparse="csr",
392+
reset=False,
393+
dtype=[xp.float64, xp.float32],
394+
)
399395
return self._onedal_estimator.predict(X, queue=queue)
400396

401397
def _onedal_supported(self, method_name, *data):
@@ -456,17 +452,15 @@ def transform(self, X):
456452
def _onedal_transform(self, X, queue=None):
457453

458454
xp, is_array_api = get_namespace(X)
459-
460-
if not get_config()["use_raw_input"]:
461-
X = validate_data(
462-
self,
463-
X,
464-
accept_sparse="csr",
465-
reset=False,
466-
dtype=[xp.float64, xp.float32],
467-
order="C",
468-
accept_large_sparse=False,
469-
)
455+
X = validate_data(
456+
self,
457+
X,
458+
accept_sparse="csr",
459+
reset=False,
460+
dtype=[xp.float64, xp.float32],
461+
order="C",
462+
accept_large_sparse=False,
463+
)
470464

471465
if is_array_api:
472466
centers = xp.asarray(self.cluster_centers_)
@@ -500,15 +494,13 @@ def score(self, X, y=None, sample_weight=None):
500494
def _onedal_score(self, X, y=None, sample_weight=None, queue=None):
501495

502496
xp, _ = get_namespace(X)
503-
504-
if not get_config()["use_raw_input"]:
505-
X = validate_data(
506-
self,
507-
X,
508-
accept_sparse="csr",
509-
reset=False,
510-
dtype=[xp.float64, xp.float32],
511-
)
497+
X = validate_data(
498+
self,
499+
X,
500+
accept_sparse="csr",
501+
reset=False,
502+
dtype=[xp.float64, xp.float32],
503+
)
512504

513505
if not sklearn_check_version("1.5") and sklearn_check_version("1.3"):
514506
if isinstance(sample_weight, str) and sample_weight == "deprecated":

0 commit comments

Comments
 (0)