Commit 6502875
[ArrayAPI] Refactor KMeans estimator to follow oneDAL estimator design pattern (#2654)
* ArrayAPI update.
* Updated imports.
* Updated default parameter for oneAPI.
* Updated format.
* Updated format.
* Fixed to get namespace from function.
* Updated format.
* Fixed for csr_matrix.
* Fixed formatting.
* Updated formatting.
* Fixed formatting.
* Fixed formatting.
* Changed to use getattr.
* Updated imports.
* Formatted.
* Fixes for get_namespace.
* Updated formatting.
* Updated format.
* Updated the format.
* Fixed format.
* Don't need array api stuff for X_tables.
* Removed table import.
* Updated format.
* Fix what is being passed into namespace.
* Updated format.
* Updated init.
* Updated format.
* fixed init.
* Updated name.
* merged other changes.
* Formated.
* Fixed imports.
* fixed _get_onedal_params call error.
* test: logging
* fix: logging
* fix: refactoring
* fix: refacotring
* fix: n_init try to get from skleranex fist due to refacotring
* fix: revert the get namespace
* fix: n_init
* fix: n_init
* fix: init
* fix: validate data
* fix: add enable array api decorator
* Add array API zero-copy support for KMeans on GPU
- Add from_table(like=X) to fit/predict/cluster_centers_ for correct
output array type (dpnp/dpctl) instead of always returning numpy
- Remove np.asarray in _init_centroids_onedal and cluster_centers_
setter to avoid crash on GPU arrays
- Store _input_type for deferred from_table in cluster_centers_ property
- Split _onedal_gpu_supported to reject callable init on GPU (falls
back to sklearn instead of crashing in numpy-only code path)
* Remove dead n_init='warn' handling and simplify defaults
- Remove n_init=='warn' branch from onedal fit() and sklearnex _resolve_n_init()
- Simplify n_init default to 'auto' (sklearn >=1.4)
- Simplify algorithm default to 'lloyd' (sklearn >=1.1)
- CI only tests sklearn 1.7.2/1.8.0, these version-conditional defaults were dead code
* fix: torch
* fix: array API compatibility for KMeans transform/predict/fit_transform
- Replace .ravel() with [:, 0] in onedal KMeans for array API compatibility
(array API strict Array objects lack ravel/reshape methods)
- Add transform dispatch through dispatch() following PCA pattern, fixing
dpnp transform failures (dpnp arrays weren't transferred to host)
- Override fit_transform to use dispatched fit+transform
- Remove support_input_format wrapping (caused namespace mismatch)
- Remove _check_params/_check_params_vs_input (rejects 'lloyd' on sklearn<1.1,
np.var() fails on GPU arrays; onedal computes tolerance internally)
- Update _onedal_supported to handle 'transform' method
Fixes 99 patching test failures (55 array_api + 44 dpnp transform/fit_transform).
* refactor: move n_init resolution to sklearnex, keep onedal fallback
- Pass resolved _n_init integer from sklearnex to onedal estimator
- Keep n_init resolution in onedal as fallback for SPMD and direct usage
- Guard onedal resolution with isinstance check to skip when already resolved
* fix: restore version-conditional defaults for n_init and algorithm
* fix: add parameter validation for sklearn < 1.2 in KMeans _onedal_fit
* fix: address review feedback - memory leak, validation, and robustness
- Store lightweight callable via return_type_constructor(X) instead of
full training array in _input_type to prevent memory retention
- Remove ensure_all_finite=False to match sklearn's KMeans validation
- Use getattr fallback for _n_init in _initialize_onedal_estimator to
handle deserialization/manual-construction edge cases
* fix: restore fit_predict and fit_transform on onedal KMeans
* Add sklearnex SPMD wrapper for KMeans
* Add KMeans to array_api.rst supported classes
* fix: address review - clean up __init__, document array-like init path
- Remove misleading comment and private attr initializations from
onedal _BaseKMeans.__init__ (sklearn convention: only store params)
- Use hasattr checks for runtime attrs (_tol, _cluster_centers_, model_)
- Add comment explaining _is_arraylike_not_scalar check is reachable
(user can explicitly set n_init > 1 with array-like init)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* fix: pass tolerance as local variable instead of storing as self._tol
set_params(tol=...) now correctly affects subsequent operations since
_get_onedal_params falls back to self.tol for predict/score paths.
* Revert "fix: pass tolerance as local variable instead of storing as self._tol"
This reverts commit d3306dc.
* fix: use array API operations for transform with non-numpy arrays
* fix: support transform_output=polars with GPU arrays, align with neighbors branch
* fix: rename euclidean_distances import for clarity
* fix: use xp.clip instead of xp.where for clamping negative distances
* fix: support set_output(transform=polars) with GPU arrays
* test: add transform_output tests for torch CPU with polars/pandas
* fix: address review - move import to top, fix test torch skip and polars/pandas imports
* fix: use xp.asarray instead of _asarray in wrap_output_data for torch compatibility
* fix: remove KMeans cluster_centers_ skip from patching test, fix xp.clip dtype
* fix: skip xp.asarray for scalar results in wrap_output_data
* fix: use xp.asarray instead of _asarray for torch compatibility, skip scalars
* fix: remove neighbor entries from patching test, keep only KMeans change
* fix: move get_config import to top, add use_raw_input guard for init validation
* fix: skip output conversion when transform_output is non-default
* fix: use xp.all instead of np.all in _validate_sample_weight
* fix: handle tuple results in wrap_output_data for kneighbors
* fix: transfer to host when transform_output is non-default
* fix: also check _sklearn_output_config for set_output polars support
* fix: only apply transform_output guard for transform/fit_transform methods
* fix: use queue parameter directly instead of QM.get_global_queue in predict/score
* fix: move unique_values check to sklearnex layer to avoid host copy
* test: add GPU transform_output and dpnp array_api_dispatch tests
* fix: move imports to top of file
* fix: fallback to np.unique when xp.unique_values not available (sklearn < 1.2)
* fix: handle transform_output in array API early return path, filter test inputs
* fix: use xp ops for sample_weight validation, remove _check_sample_weight
* fix: add comments explaining transform_output host transfer
* fix: use getattr for device to handle arrays without device attribute
* fix: use get_dataframes_and_queues filter instead of list comprehension
* fix: add comment explaining test input type selection
* fix: skip array_api_dispatch tests for sklearn < 1.2
* fix: add sklearn source reference to _resolve_n_init docstring
* fix: remove duplicate _validate_center_shape, use inherited sklearn method
* fix: add comment explaining why _check_sample_weight was removed
* fix: remove unnecessary section comments
* fix: revert _init_centroids_sklearn signature to match main
* fix: revert cosmetic changes to match main (variable names, formatting)
* fix: handle non-array sample_weight types in _validate_sample_weight
* fix: black formatting
* fix: validate sample_weight shape matches X in _validate_sample_weight
* fix: also check sample_weight is 1-d in _validate_sample_weight
* fix: use _check_sample_weight for numpy, xp ops for GPU arrays
* fix: array API compatible _validate_sample_weight with proper shape checks
* fix: pass is_csr to _predict_backend for sparse predict support
* fix no error when zero weights
---------
Co-authored-by: Mcgrievy, Kathleen <kathleen.mcgrievy@intel.com>
Co-authored-by: yuejiaointel <yue.jiao@intel.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: david-cortes-intel <david.cortes@intel.com>1 parent 2079f27 commit 6502875
7 files changed
Lines changed: 456 additions & 367 deletions
File tree
- doc/sources
- onedal/cluster
- sklearnex
- cluster
- tests
- spmd/cluster
- tests
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
94 | 94 | | |
95 | 95 | | |
96 | 96 | | |
| 97 | + | |
97 | 98 | | |
98 | 99 | | |
99 | 100 | | |
| |||
0 commit comments