Skip to content

Commit 5d3e592

Browse files
yuejiaointelclaude
andauthored
fix: fallback for take_along_axis in kd_tree sorting for array API < … (#3083)
* fix: fallback for take_along_axis in kd_tree sorting for array API < 2024.12 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * use np.from_dlpack instead of np.asarray for array conversion Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * use try/except and from_dlpack for take_along_axis fallback * add device=cpu to np.from_dlpack calls * revert deselection of test_label_propagation (#3082) to confirm fix * restore deselection of test_label_propagation * revert deselection of test_label_propagation to confirm fix * fix kneighbors_graph: use csr_matrix for sklearn < 1.9 --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9393460 commit 5d3e592

2 files changed

Lines changed: 24 additions & 6 deletions

File tree

deselected_tests.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,6 @@ deselected_tests:
458458
# CI jobs in sklearnex compile scikit-learn from source, not necessarily with the same toolkits as sklearn's CIs
459459
- preprocessing/tests/test_polynomial.py::test_sizeof_LARGEST_INT_t
460460

461-
# ValueError: operands could not be broadcast together, possibly from neighbors array api update
462-
- semi_supervised/tests/test_label_propagation.py
463-
464461
# --------------------------------------------------------
465462
# No need to test daal4py patching
466463
reduced_tests:

sklearnex/neighbors/common.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,26 @@ def _kneighbors_postprocess(
350350
# on read-only arrays (e.g. array-api-strict).
351351
if self._fit_method == "kd_tree":
352352
seq = xp.argsort(distances, axis=1)
353-
indices = xp.take_along_axis(indices, seq, axis=1)
354-
distances = xp.take_along_axis(distances, seq, axis=1)
353+
try:
354+
indices = xp.take_along_axis(indices, seq, axis=1)
355+
distances = xp.take_along_axis(distances, seq, axis=1)
356+
except RuntimeError:
357+
# Fallback for array API < 2024.12 (e.g. array-api-strict
358+
# with api_version='2023.12' has the function but raises)
359+
indices = xp.from_dlpack(
360+
np.take_along_axis(
361+
np.from_dlpack(indices, device="cpu"),
362+
np.from_dlpack(seq, device="cpu"),
363+
axis=1,
364+
)
365+
)
366+
distances = xp.from_dlpack(
367+
np.take_along_axis(
368+
np.from_dlpack(distances, device="cpu"),
369+
np.from_dlpack(seq, device="cpu"),
370+
axis=1,
371+
)
372+
)
355373

356374
if not query_is_train:
357375
if return_distance:
@@ -675,7 +693,10 @@ def kneighbors_graph(self, X=None, n_neighbors=None, mode="connectivity"):
675693
n_nonzero = n_queries * n_neighbors
676694
A_indptr = xp.arange(0, n_nonzero + 1, n_neighbors)
677695

678-
_csr_container = sp.csr_array if hasattr(sp, "csr_array") else sp.csr_matrix
696+
if sklearn_check_version("1.9"):
697+
_csr_container = sp.csr_array if hasattr(sp, "csr_array") else sp.csr_matrix
698+
else:
699+
_csr_container = sp.csr_matrix
679700
kneighbors_graph = _csr_container(
680701
(A_data, xp.reshape(A_ind, (-1,)), A_indptr), shape=(n_queries, n_samples_fit)
681702
)

0 commit comments

Comments
 (0)