Improve array API dispatch: device validation and sklearn fallback#2995
Improve array API dispatch: device validation and sklearn fallback#2995yuejiaointel wants to merge 3 commits intouxlfoundation:mainfrom
Conversation
Codecov Report❌ Patch coverage is
Flags with carried forward coverage won't be shown. Click here to find out more.
... and 12 files with indirect coverage changes 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This PR adjusts sklearnex’s device offload dispatch logic when array_api_dispatch is enabled, adding device-consistency checks and an explicit fallback to scikit-learn for estimators that support Array API in sklearn but not in oneDAL.
Changes:
- Add array-API device validation to catch mixed-device inputs and fit/inference device mismatches.
- Add an early dispatch path that prefers the sklearn implementation when sklearn supports Array API but oneDAL does not, avoiding host transfers.
| # If sklearn supports array API for this estimator but oneDAL doesn't, | ||
| # fall back to sklearn without transferring to host. This avoids a | ||
| # CPU roundtrip for operations that sklearn can already run on GPU | ||
| # natively (e.g. pairwise_distances). | ||
| if sklearn_array_api and not onedal_array_api: |
There was a problem hiding this comment.
The early if sklearn_array_api and not onedal_array_api: return branches["sklearn"](...) runs whenever array_api_dispatch is enabled, regardless of the actual input type. This can disable oneDAL acceleration for normal NumPy inputs (where onedal CPU/GPU paths may still be available) and it also bypasses the existing USM handling in this function (see the later has_usm_data/"dpnp fallback is not handled properly yet" logic). Consider gating this fallback on actual array-API non-NumPy inputs and/or on the absence of USM data, otherwise keep the existing host-transfer + backend selection flow.
| # If sklearn supports array API for this estimator but oneDAL doesn't, | |
| # fall back to sklearn without transferring to host. This avoids a | |
| # CPU roundtrip for operations that sklearn can already run on GPU | |
| # natively (e.g. pairwise_distances). | |
| if sklearn_array_api and not onedal_array_api: | |
| # Inspect inputs to detect non-NumPy array API data and dpnp/USM data. | |
| has_non_numpy_array_api_input = False | |
| has_dpnp_input = False | |
| for _val in list(args) + list(kwargs.values()): | |
| if is_dpnp_ndarray(_val): | |
| has_dpnp_input = True | |
| # Do not treat dpnp as generic array-API input for the early fallback. | |
| continue | |
| try: | |
| xp = _asarray(_val).__array_namespace__() # type: ignore[assignment] | |
| except Exception: | |
| continue | |
| if not _is_numpy_namespace(xp): | |
| has_non_numpy_array_api_input = True | |
| break | |
| # If sklearn supports array API for this estimator but oneDAL doesn't, | |
| # and we have actual non-NumPy array API inputs (on non-CPU devices), | |
| # fall back to sklearn without transferring to host. This avoids a | |
| # CPU roundtrip for operations that sklearn can already run on GPU | |
| # natively (e.g. pairwise_distances). For dpnp/USM inputs, keep the | |
| # existing host-transfer + backend-selection flow, as dpnp fallback | |
| # is not handled properly yet. | |
| if ( | |
| sklearn_array_api | |
| and not onedal_array_api | |
| and has_non_numpy_array_api_input | |
| and not has_dpnp_input | |
| ): |
There was a problem hiding this comment.
That is a valid concern, solution is probably not the best. We definitely need some separate function where we check that inputs can be handled without array api support. This should include numpy, pandas, dpnp, dpctl (not sure if I missed something)
There was a problem hiding this comment.
Also for LogReg there should be an exception. It would have array api support on GPU but would not have array api support on CPU, so either we add this check here or do it in the estimator, not sure what is the best
| devices = [] | ||
| for a in args: | ||
| if a is not None and hasattr(a, "device"): | ||
| devices.append(a.device) | ||
|
|
||
| # Check mixed devices across input arguments (e.g. fit(X_gpu, y_cpu)) | ||
| if len(devices) > 1: | ||
| all_cpu = all(_is_cpu_device(d) for d in devices) | ||
| if not all_cpu and len(set(str(d) for d in devices)) > 1: | ||
| raise ValueError( | ||
| f"Input arrays use different devices: " | ||
| f"{', '.join(str(d) for d in devices)}. " | ||
| f"All inputs must be on the same device." | ||
| ) |
There was a problem hiding this comment.
_validate_array_api_devices only considers arguments that have a .device attribute. Array-API NumPy inputs (and any other CPU arrays without .device) will be ignored and therefore mixed-device cases like X on GPU (has .device) and y as NumPy (no .device) won’t be detected here—especially relevant now that some code paths return to sklearn before QM.manage_global_queue runs. Consider treating arrays without .device as CPU (e.g., based on __array_namespace__ being NumPy) and/or including **kwargs values in the validation.
| if _array_api_offload() and args: | ||
| # Check for mixed device inputs (e.g. X on GPU, y on CPU) | ||
| # and for device mismatch between fitted model and inference input. | ||
| _validate_array_api_devices(obj, method_name, *args) | ||
|
|
||
| # Determine if array_api dispatching is enabled, and if estimator is capable | ||
| onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api | ||
| sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support | ||
|
|
||
| # If sklearn supports array API for this estimator but oneDAL doesn't, | ||
| # fall back to sklearn without transferring to host. This avoids a | ||
| # CPU roundtrip for operations that sklearn can already run on GPU | ||
| # natively (e.g. pairwise_distances). | ||
| if sklearn_array_api and not onedal_array_api: | ||
| return branches["sklearn"](obj, *args, **kwargs) | ||
|
|
There was a problem hiding this comment.
This change introduces new user-facing behavior (device-consistency validation + an early sklearn fallback path under array_api_dispatch) but there are no targeted tests covering the new branches/error messages. It would be good to add unit tests around dispatch (e.g., in sklearnex/tests/test_config.py which already exercises dispatch) to assert: (1) mixed-device inputs raise consistently, and (2) the sklearn fallback path doesn’t regress USM/dpnp handling or CPU offload behavior.
| _array_api_offload = lambda: False | ||
|
|
||
|
|
||
| def _is_cpu_device(device): |
There was a problem hiding this comment.
This sort of thing is standardized in the array API specification:
| if a is not None and hasattr(a, "device"): | ||
| devices.append(a.device) | ||
|
|
||
| # Check mixed devices across input arguments (e.g. fit(X_gpu, y_cpu)) |
There was a problem hiding this comment.
Let's wait until scikit-learn merges their PR with these changes to see how exactly they do it.
| if method_name not in ("fit",) and devices: | ||
| fit_X = getattr(obj, "_fit_X", None) | ||
| if fit_X is not None and hasattr(fit_X, "device"): | ||
| both_cpu = _is_cpu_device(fit_X.device) and _is_cpu_device(devices[0]) |
There was a problem hiding this comment.
There can also be cases of different non-CPU devices.
|
It can't be merged in the current state as it disables default acceleration with numpy if onedal_array_api is not enabled |
1c85b69 to
a72fcfa
Compare
Only fall back to sklearn for non-numpy array API inputs (torch, array_api_strict). Numpy and dpnp/dpctl inputs continue to oneDAL path, fixing the regression where numpy lost oneDAL acceleration.
Description
Checklist:
Completeness and readability
Testing
Performance