@@ -193,29 +193,28 @@ def _onedal_fit(self, X, _, sample_weight, queue=None):
193193
194194 xp , _ = get_namespace (X )
195195
196- if not get_config ()["use_raw_input" ]:
197- if _is_arraylike_not_scalar (self .init ):
198- init = validate_data (
199- self ,
200- self .init ,
201- dtype = [xp .float64 , xp .float32 ],
202- accept_sparse = "csr" ,
203- copy = True ,
204- order = "C" ,
205- reset = False ,
206- )
207- self ._validate_center_shape (X , init )
208- self .init = init
209-
210- X = validate_data (
196+ if _is_arraylike_not_scalar (self .init ):
197+ init = validate_data (
211198 self ,
212- X ,
213- accept_sparse = "csr" ,
199+ self .init ,
214200 dtype = [xp .float64 , xp .float32 ],
201+ accept_sparse = "csr" ,
202+ copy = True ,
215203 order = "C" ,
216- copy = self .copy_x ,
217- accept_large_sparse = False ,
204+ reset = False ,
218205 )
206+ self ._validate_center_shape (X , init )
207+ self .init = init
208+
209+ X = validate_data (
210+ self ,
211+ X ,
212+ accept_sparse = "csr" ,
213+ dtype = [xp .float64 , xp .float32 ],
214+ order = "C" ,
215+ copy = self .copy_x ,
216+ accept_large_sparse = False ,
217+ )
219218
220219 # Validate critical parameters to match sklearn's _check_params
221220 # behavior, which we bypass in the oneDAL path. This is needed
@@ -386,16 +385,13 @@ def predict(
386385 def _onedal_predict (self , X , sample_weight = None , queue = None ):
387386
388387 xp , _ = get_namespace (X )
389-
390- if not get_config ()["use_raw_input" ]:
391- X = validate_data (
392- self ,
393- X ,
394- accept_sparse = "csr" ,
395- reset = False ,
396- dtype = [xp .float64 , xp .float32 ],
397- )
398-
388+ X = validate_data (
389+ self ,
390+ X ,
391+ accept_sparse = "csr" ,
392+ reset = False ,
393+ dtype = [xp .float64 , xp .float32 ],
394+ )
399395 return self ._onedal_estimator .predict (X , queue = queue )
400396
401397 def _onedal_supported (self , method_name , * data ):
@@ -456,17 +452,15 @@ def transform(self, X):
456452 def _onedal_transform (self , X , queue = None ):
457453
458454 xp , is_array_api = get_namespace (X )
459-
460- if not get_config ()["use_raw_input" ]:
461- X = validate_data (
462- self ,
463- X ,
464- accept_sparse = "csr" ,
465- reset = False ,
466- dtype = [xp .float64 , xp .float32 ],
467- order = "C" ,
468- accept_large_sparse = False ,
469- )
455+ X = validate_data (
456+ self ,
457+ X ,
458+ accept_sparse = "csr" ,
459+ reset = False ,
460+ dtype = [xp .float64 , xp .float32 ],
461+ order = "C" ,
462+ accept_large_sparse = False ,
463+ )
470464
471465 if is_array_api :
472466 centers = xp .asarray (self .cluster_centers_ )
@@ -500,15 +494,13 @@ def score(self, X, y=None, sample_weight=None):
500494 def _onedal_score (self , X , y = None , sample_weight = None , queue = None ):
501495
502496 xp , _ = get_namespace (X )
503-
504- if not get_config ()["use_raw_input" ]:
505- X = validate_data (
506- self ,
507- X ,
508- accept_sparse = "csr" ,
509- reset = False ,
510- dtype = [xp .float64 , xp .float32 ],
511- )
497+ X = validate_data (
498+ self ,
499+ X ,
500+ accept_sparse = "csr" ,
501+ reset = False ,
502+ dtype = [xp .float64 , xp .float32 ],
503+ )
512504
513505 if not sklearn_check_version ("1.5" ) and sklearn_check_version ("1.3" ):
514506 if isinstance (sample_weight , str ) and sample_weight == "deprecated" :
0 commit comments