Skip to content

Commit 6502875

Browse files
KateBlueSkyMcgrievy, Kathleenyuejiaointelclaudedavid-cortes-intel
authored
[ArrayAPI] Refactor KMeans estimator to follow oneDAL estimator design pattern (#2654)
* ArrayAPI update. * Updated imports. * Updated default parameter for oneAPI. * Updated format. * Updated format. * Fixed to get namespace from function. * Updated format. * Fixed for csr_matrix. * Fixed formatting. * Updated formatting. * Fixed formatting. * Fixed formatting. * Changed to use getattr. * Updated imports. * Formatted. * Fixes for get_namespace. * Updated formatting. * Updated format. * Updated the format. * Fixed format. * Don't need array api stuff for X_tables. * Removed table import. * Updated format. * Fix what is being passed into namespace. * Updated format. * Updated init. * Updated format. * fixed init. * Updated name. * merged other changes. * Formated. * Fixed imports. * fixed _get_onedal_params call error. * test: logging * fix: logging * fix: refactoring * fix: refacotring * fix: n_init try to get from skleranex fist due to refacotring * fix: revert the get namespace * fix: n_init * fix: n_init * fix: init * fix: validate data * fix: add enable array api decorator * Add array API zero-copy support for KMeans on GPU - Add from_table(like=X) to fit/predict/cluster_centers_ for correct output array type (dpnp/dpctl) instead of always returning numpy - Remove np.asarray in _init_centroids_onedal and cluster_centers_ setter to avoid crash on GPU arrays - Store _input_type for deferred from_table in cluster_centers_ property - Split _onedal_gpu_supported to reject callable init on GPU (falls back to sklearn instead of crashing in numpy-only code path) * Remove dead n_init='warn' handling and simplify defaults - Remove n_init=='warn' branch from onedal fit() and sklearnex _resolve_n_init() - Simplify n_init default to 'auto' (sklearn >=1.4) - Simplify algorithm default to 'lloyd' (sklearn >=1.1) - CI only tests sklearn 1.7.2/1.8.0, these version-conditional defaults were dead code * fix: torch * fix: array API compatibility for KMeans transform/predict/fit_transform - Replace .ravel() with [:, 0] in onedal KMeans for array API compatibility (array API strict Array objects lack ravel/reshape methods) - Add transform dispatch through dispatch() following PCA pattern, fixing dpnp transform failures (dpnp arrays weren't transferred to host) - Override fit_transform to use dispatched fit+transform - Remove support_input_format wrapping (caused namespace mismatch) - Remove _check_params/_check_params_vs_input (rejects 'lloyd' on sklearn<1.1, np.var() fails on GPU arrays; onedal computes tolerance internally) - Update _onedal_supported to handle 'transform' method Fixes 99 patching test failures (55 array_api + 44 dpnp transform/fit_transform). * refactor: move n_init resolution to sklearnex, keep onedal fallback - Pass resolved _n_init integer from sklearnex to onedal estimator - Keep n_init resolution in onedal as fallback for SPMD and direct usage - Guard onedal resolution with isinstance check to skip when already resolved * fix: restore version-conditional defaults for n_init and algorithm * fix: add parameter validation for sklearn < 1.2 in KMeans _onedal_fit * fix: address review feedback - memory leak, validation, and robustness - Store lightweight callable via return_type_constructor(X) instead of full training array in _input_type to prevent memory retention - Remove ensure_all_finite=False to match sklearn's KMeans validation - Use getattr fallback for _n_init in _initialize_onedal_estimator to handle deserialization/manual-construction edge cases * fix: restore fit_predict and fit_transform on onedal KMeans * Add sklearnex SPMD wrapper for KMeans * Add KMeans to array_api.rst supported classes * fix: address review - clean up __init__, document array-like init path - Remove misleading comment and private attr initializations from onedal _BaseKMeans.__init__ (sklearn convention: only store params) - Use hasattr checks for runtime attrs (_tol, _cluster_centers_, model_) - Add comment explaining _is_arraylike_not_scalar check is reachable (user can explicitly set n_init > 1 with array-like init) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: pass tolerance as local variable instead of storing as self._tol set_params(tol=...) now correctly affects subsequent operations since _get_onedal_params falls back to self.tol for predict/score paths. * Revert "fix: pass tolerance as local variable instead of storing as self._tol" This reverts commit d3306dc. * fix: use array API operations for transform with non-numpy arrays * fix: support transform_output=polars with GPU arrays, align with neighbors branch * fix: rename euclidean_distances import for clarity * fix: use xp.clip instead of xp.where for clamping negative distances * fix: support set_output(transform=polars) with GPU arrays * test: add transform_output tests for torch CPU with polars/pandas * fix: address review - move import to top, fix test torch skip and polars/pandas imports * fix: use xp.asarray instead of _asarray in wrap_output_data for torch compatibility * fix: remove KMeans cluster_centers_ skip from patching test, fix xp.clip dtype * fix: skip xp.asarray for scalar results in wrap_output_data * fix: use xp.asarray instead of _asarray for torch compatibility, skip scalars * fix: remove neighbor entries from patching test, keep only KMeans change * fix: move get_config import to top, add use_raw_input guard for init validation * fix: skip output conversion when transform_output is non-default * fix: use xp.all instead of np.all in _validate_sample_weight * fix: handle tuple results in wrap_output_data for kneighbors * fix: transfer to host when transform_output is non-default * fix: also check _sklearn_output_config for set_output polars support * fix: only apply transform_output guard for transform/fit_transform methods * fix: use queue parameter directly instead of QM.get_global_queue in predict/score * fix: move unique_values check to sklearnex layer to avoid host copy * test: add GPU transform_output and dpnp array_api_dispatch tests * fix: move imports to top of file * fix: fallback to np.unique when xp.unique_values not available (sklearn < 1.2) * fix: handle transform_output in array API early return path, filter test inputs * fix: use xp ops for sample_weight validation, remove _check_sample_weight * fix: add comments explaining transform_output host transfer * fix: use getattr for device to handle arrays without device attribute * fix: use get_dataframes_and_queues filter instead of list comprehension * fix: add comment explaining test input type selection * fix: skip array_api_dispatch tests for sklearn < 1.2 * fix: add sklearn source reference to _resolve_n_init docstring * fix: remove duplicate _validate_center_shape, use inherited sklearn method * fix: add comment explaining why _check_sample_weight was removed * fix: remove unnecessary section comments * fix: revert _init_centroids_sklearn signature to match main * fix: revert cosmetic changes to match main (variable names, formatting) * fix: handle non-array sample_weight types in _validate_sample_weight * fix: black formatting * fix: validate sample_weight shape matches X in _validate_sample_weight * fix: also check sample_weight is 1-d in _validate_sample_weight * fix: use _check_sample_weight for numpy, xp ops for GPU arrays * fix: array API compatible _validate_sample_weight with proper shape checks * fix: pass is_csr to _predict_backend for sparse predict support * fix no error when zero weights --------- Co-authored-by: Mcgrievy, Kathleen <kathleen.mcgrievy@intel.com> Co-authored-by: yuejiaointel <yue.jiao@intel.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: david-cortes-intel <david.cortes@intel.com>
1 parent 2079f27 commit 6502875

7 files changed

Lines changed: 456 additions & 367 deletions

File tree

doc/sources/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ The following patched classes have support for array API inputs:
9494
- :obj:`sklearnex.basic_statistics.BasicStatistics`
9595
- :obj:`sklearnex.basic_statistics.IncrementalBasicStatistics`
9696
- :obj:`sklearn.cluster.DBSCAN`
97+
- :obj:`sklearn.cluster.KMeans`
9798
- :obj:`sklearn.covariance.EmpiricalCovariance`
9899
- :obj:`sklearnex.covariance.IncrementalEmpiricalCovariance`
99100
- :obj:`sklearn.decomposition.PCA`

0 commit comments

Comments
 (0)