Skip to content

Improve array API dispatch: device validation and sklearn fallback#2995

Draft
yuejiaointel wants to merge 3 commits intouxlfoundation:mainfrom
yuejiaointel:fix/array-api-dispatch-device-consistency
Draft

Improve array API dispatch: device validation and sklearn fallback#2995
yuejiaointel wants to merge 3 commits intouxlfoundation:mainfrom
yuejiaointel:fix/array-api-dispatch-device-consistency

Conversation

@yuejiaointel
Copy link
Copy Markdown
Contributor

Description


Checklist:

Completeness and readability

  • I have commented my code, particularly in hard-to-understand areas.
  • I have updated the documentation to reflect the changes or created a separate PR with updates and provided its number in the description, if necessary.
  • Git commit message contains an appropriate signed-off-by string (see CONTRIBUTING.md for details).
  • I have resolved any merge conflicts that might occur with the base branch.

Testing

  • I have run it locally and tested the changes extensively.
  • All CI jobs are green or I have provided justification why they aren't.
  • I have extended testing suite if new functionality was introduced in this PR.

Performance

  • I have measured performance for affected algorithms using scikit-learn_bench and provided at least a summary table with measured data, if performance change is expected.
  • I have provided justification why performance and/or quality metrics have changed or why changes are not expected.
  • I have extended the benchmarking suite and provided a corresponding scikit-learn_bench PR if new measurable functionality was introduced in this PR.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 5, 2026

Codecov Report

❌ Patch coverage is 56.00000% with 11 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
sklearnex/_device_offload.py 56.00% 8 Missing and 3 partials ⚠️
Flag Coverage Δ
azure 79.78% <56.00%> (-0.16%) ⬇️
github ?

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sklearnex/_device_offload.py 77.66% <56.00%> (-6.95%) ⬇️

... and 12 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yuejiaointel yuejiaointel marked this pull request as ready for review March 5, 2026 07:45
Copilot AI review requested due to automatic review settings March 5, 2026 07:45
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts sklearnex’s device offload dispatch logic when array_api_dispatch is enabled, adding device-consistency checks and an explicit fallback to scikit-learn for estimators that support Array API in sklearn but not in oneDAL.

Changes:

  • Add array-API device validation to catch mixed-device inputs and fit/inference device mismatches.
  • Add an early dispatch path that prefers the sklearn implementation when sklearn supports Array API but oneDAL does not, avoiding host transfers.

Comment thread sklearnex/_device_offload.py Outdated
Comment on lines +169 to +173
# If sklearn supports array API for this estimator but oneDAL doesn't,
# fall back to sklearn without transferring to host. This avoids a
# CPU roundtrip for operations that sklearn can already run on GPU
# natively (e.g. pairwise_distances).
if sklearn_array_api and not onedal_array_api:
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early if sklearn_array_api and not onedal_array_api: return branches["sklearn"](...) runs whenever array_api_dispatch is enabled, regardless of the actual input type. This can disable oneDAL acceleration for normal NumPy inputs (where onedal CPU/GPU paths may still be available) and it also bypasses the existing USM handling in this function (see the later has_usm_data/"dpnp fallback is not handled properly yet" logic). Consider gating this fallback on actual array-API non-NumPy inputs and/or on the absence of USM data, otherwise keep the existing host-transfer + backend selection flow.

Suggested change
# If sklearn supports array API for this estimator but oneDAL doesn't,
# fall back to sklearn without transferring to host. This avoids a
# CPU roundtrip for operations that sklearn can already run on GPU
# natively (e.g. pairwise_distances).
if sklearn_array_api and not onedal_array_api:
# Inspect inputs to detect non-NumPy array API data and dpnp/USM data.
has_non_numpy_array_api_input = False
has_dpnp_input = False
for _val in list(args) + list(kwargs.values()):
if is_dpnp_ndarray(_val):
has_dpnp_input = True
# Do not treat dpnp as generic array-API input for the early fallback.
continue
try:
xp = _asarray(_val).__array_namespace__() # type: ignore[assignment]
except Exception:
continue
if not _is_numpy_namespace(xp):
has_non_numpy_array_api_input = True
break
# If sklearn supports array API for this estimator but oneDAL doesn't,
# and we have actual non-NumPy array API inputs (on non-CPU devices),
# fall back to sklearn without transferring to host. This avoids a
# CPU roundtrip for operations that sklearn can already run on GPU
# natively (e.g. pairwise_distances). For dpnp/USM inputs, keep the
# existing host-transfer + backend-selection flow, as dpnp fallback
# is not handled properly yet.
if (
sklearn_array_api
and not onedal_array_api
and has_non_numpy_array_api_input
and not has_dpnp_input
):

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

@avolkov-intel avolkov-intel Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a valid concern, solution is probably not the best. We definitely need some separate function where we check that inputs can be handled without array api support. This should include numpy, pandas, dpnp, dpctl (not sure if I missed something)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for LogReg there should be an exception. It would have array api support on GPU but would not have array api support on CPU, so either we add this check here or do it in the estimator, not sure what is the best

Comment on lines +87 to +100
devices = []
for a in args:
if a is not None and hasattr(a, "device"):
devices.append(a.device)

# Check mixed devices across input arguments (e.g. fit(X_gpu, y_cpu))
if len(devices) > 1:
all_cpu = all(_is_cpu_device(d) for d in devices)
if not all_cpu and len(set(str(d) for d in devices)) > 1:
raise ValueError(
f"Input arrays use different devices: "
f"{', '.join(str(d) for d in devices)}. "
f"All inputs must be on the same device."
)
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_validate_array_api_devices only considers arguments that have a .device attribute. Array-API NumPy inputs (and any other CPU arrays without .device) will be ignored and therefore mixed-device cases like X on GPU (has .device) and y as NumPy (no .device) won’t be detected here—especially relevant now that some code paths return to sklearn before QM.manage_global_queue runs. Consider treating arrays without .device as CPU (e.g., based on __array_namespace__ being NumPy) and/or including **kwargs values in the validation.

Copilot uses AI. Check for mistakes.
Comment on lines +160 to +175
if _array_api_offload() and args:
# Check for mixed device inputs (e.g. X on GPU, y on CPU)
# and for device mismatch between fitted model and inference input.
_validate_array_api_devices(obj, method_name, *args)

# Determine if array_api dispatching is enabled, and if estimator is capable
onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api
sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support

# If sklearn supports array API for this estimator but oneDAL doesn't,
# fall back to sklearn without transferring to host. This avoids a
# CPU roundtrip for operations that sklearn can already run on GPU
# natively (e.g. pairwise_distances).
if sklearn_array_api and not onedal_array_api:
return branches["sklearn"](obj, *args, **kwargs)

Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change introduces new user-facing behavior (device-consistency validation + an early sklearn fallback path under array_api_dispatch) but there are no targeted tests covering the new branches/error messages. It would be good to add unit tests around dispatch (e.g., in sklearnex/tests/test_config.py which already exercises dispatch) to assert: (1) mixed-device inputs raise consistently, and (2) the sklearn fallback path doesn’t regress USM/dpnp handling or CPU offload behavior.

Copilot uses AI. Check for mistakes.
_array_api_offload = lambda: False


def _is_cpu_device(device):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort of thing is standardized in the array API specification:

cpu_dlpack_device = (backend.kDLCPU, 0)

if a is not None and hasattr(a, "device"):
devices.append(a.device)

# Check mixed devices across input arguments (e.g. fit(X_gpu, y_cpu))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait until scikit-learn merges their PR with these changes to see how exactly they do it.

if method_name not in ("fit",) and devices:
fit_X = getattr(obj, "_fit_X", None)
if fit_X is not None and hasattr(fit_X, "device"):
both_cpu = _is_cpu_device(fit_X.device) and _is_cpu_device(devices[0])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There can also be cases of different non-CPU devices.

@avolkov-intel
Copy link
Copy Markdown
Contributor

It can't be merged in the current state as it disables default acceleration with numpy if onedal_array_api is not enabled

@yuejiaointel yuejiaointel marked this pull request as draft March 5, 2026 22:23
@yuejiaointel yuejiaointel force-pushed the fix/array-api-dispatch-device-consistency branch from 1c85b69 to a72fcfa Compare March 10, 2026 00:57
Only fall back to sklearn for non-numpy array API inputs (torch,
array_api_strict). Numpy and dpnp/dpctl inputs continue to oneDAL
path, fixing the regression where numpy lost oneDAL acceleration.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants