Skip to content

Commit e074be5

Browse files
authored
[enhancement] set specific failure for host builds with "target_offload" (#2388)
* Update sycl.cpp * Update _sycl_queue_manager.py * Update test_config.py * Update sycl.cpp * Update sycl.cpp * Update sycl.cpp * Update sycl.cpp * Update dal.cpp * Update dal.cpp * fix logic * formatting * convert into an object * Update test_config.py * Update test_config.py * Update test_config.py * Update test_config.py * Update test_config.py * Update sycl.cpp * Update test_config.py * Update _device_offload.py * Update _sycl_queue_manager.py * Update test_config.py * Update ci.yml * Update ci.yml * Update sycl.cpp * Update sycl.cpp * Update test_config.py * Update sycl.cpp * Update _sycl_queue_manager.py * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update sycl.cpp * Update test_config.py * Update test_config.py * Update test_config.py * Update sycl.cpp * Update _config.py
1 parent afd9399 commit e074be5

7 files changed

Lines changed: 57 additions & 18 deletions

File tree

.ci/pipeline/build-and-test-lnx.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ steps:
6060
# dpep installation is set to pypi to avoid conflict of numpy versions from pip and conda
6161
# py312 is disabled due to segfault on exit of program with usage of dpctl
6262
if [ $(echo $(PYTHON_VERSION) | grep '3.9\|3.10\|3.11') ] && [ $(SKLEARN_VERSION) != "1.0" ] && [ -z ${NO_DPC} ]; then pip install dpctl==0.18.* dpnp==0.16.*; fi
63+
# issues exist with conda-forge dpcpp-cpp-rt=2025.1.1 it is needed to use the dpc build
64+
if [ -z "${NO_DPC}" ]; then pip install dpcpp-cpp-rt==2025.1.*; fi
6365
pip list
6466
env:
6567
NO_DPC: ${{ variables.NO_DPC }}

onedal/_device_offload.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@
3030
if dpctl_available:
3131
from dpctl.memory import MemoryUSMDevice, as_usm_memory
3232
from dpctl.tensor import usm_ndarray
33-
else:
34-
from onedal import _dpc_backend
35-
36-
SyclQueue = getattr(_dpc_backend, "SyclQueue", None)
3733

3834
logger = logging.getLogger("sklearnex")
3935

onedal/common/sycl.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@ namespace py = pybind11;
2121

2222
namespace oneapi::dal::python {
2323

24-
#ifdef ONEDAL_DATA_PARALLEL
25-
2624
void instantiate_sycl_interfaces(py::module& m) {
2725
// These classes mirror a subset of functionality of the dpctl python
2826
// package's `SyclQueue` and `SyclDevice` objects. In the case that dpctl
2927
// is not installed, these classes will enable scikit-learn-intelex to still
3028
// properly offload to other devices when built with the dpc backend.
29+
#ifdef ONEDAL_DATA_PARALLEL
3130
py::class_<sycl::queue> syclqueue(m, "SyclQueue");
3231
syclqueue.def(py::init<const sycl::device&>())
3332
.def(py::init([](const std::string& filter) {
@@ -81,11 +80,31 @@ void instantiate_sycl_interfaces(py::module& m) {
8180
})
8281
.def_property_readonly("is_cpu", &sycl::device::is_cpu)
8382
.def_property_readonly("is_gpu", &sycl::device::is_gpu);
83+
#else
84+
struct syclqueue {};
85+
py::class_<syclqueue> syclqueue(m, "SyclQueue");
86+
// inspired from pybind11 PR#4698 which turns init into a no-op
87+
syclqueue
88+
.def(py::init([]() {
89+
return nullptr;
90+
}))
91+
.def_static("__new__", [](const py::object& cls, const py::object& obj) {
92+
// this object is defined for the host build, where SYCL support is not available.
93+
// This class acts as the failure point to target_offload, which will throw an
94+
// error in all circumstances if any value but the default value ("auto"), or a string
95+
// starting with "cpu". The returned "queue" is a None. Must be a class to work with
96+
// isinstance
97+
if (!py::isinstance<py::str>(obj) || obj.cast<std::string>() != "auto") {
98+
throw std::invalid_argument(
99+
"device use via `target_offload` is only supported with the DPC++ backend");
100+
}
101+
return py::none();
102+
});
103+
#endif
84104
}
85105

86106
ONEDAL_PY_INIT_MODULE(sycl) {
87107
instantiate_sycl_interfaces(m);
88108
}
89-
#endif
90109

91110
} // namespace oneapi::dal::python

onedal/dal.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ ONEDAL_PY_INIT_MODULE(neighbors);
4343
ONEDAL_PY_INIT_MODULE(logistic_regression);
4444
#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20240001
4545
#else // ONEDAL_DATA_PARALLEL_SPMD
46-
#ifdef ONEDAL_DATA_PARALLEL
4746
ONEDAL_PY_INIT_MODULE(sycl);
48-
#endif // ONEDAL_DATA_PARALLEL
4947

5048
ONEDAL_PY_INIT_MODULE(policy);
5149
/* datatypes*/
@@ -106,10 +104,10 @@ PYBIND11_MODULE(_onedal_py_spmd_dpc, m) {
106104
#else
107105
#ifdef ONEDAL_DATA_PARALLEL
108106
PYBIND11_MODULE(_onedal_py_dpc, m) {
109-
init_sycl(m);
110107
#else
111108
PYBIND11_MODULE(_onedal_py_host, m) {
112109
#endif
110+
init_sycl(m);
113111
init_policy(m);
114112
init_table(m);
115113

onedal/utils/_sycl_queue_manager.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
if dpctl_available:
2323
from dpctl import SyclQueue
2424
else:
25-
from onedal import _dpc_backend
25+
from onedal import _default_backend
2626

27-
SyclQueue = getattr(_dpc_backend, "SyclQueue", None)
27+
# Use internally-defined SyclQueue defined in onedal/common/sycl.cpp
28+
# the host backend SyclQueue will only accept "auto" and will return
29+
# a None, it acts as a function via `__new__`. No SyclDevice is defined.
30+
SyclQueue = _default_backend.SyclQueue
2831

2932
# This special object signifies that the queue system should be
3033
# disabled. It will force computation to host. This occurs when the
@@ -36,12 +39,7 @@
3639

3740

3841
def __create_sycl_queue(target):
39-
if SyclQueue is None:
40-
# we don't have SyclQueue support
41-
return None
42-
if target is None:
43-
return None
44-
if isinstance(target, SyclQueue):
42+
if isinstance(target, SyclQueue) or target is None:
4543
return target
4644
if isinstance(target, (str, int)):
4745
return SyclQueue(target)

sklearnex/_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def set_config(
9090
--------
9191
Using ``use_raw_input=True`` is not recommended for general use as it
9292
bypasses data consistency checks, which may lead to unexpected behavior.
93+
94+
Use of ``target_offload`` requires the DPC++ backend. Setting a
95+
non-default value (e.g ``cpu`` or ``gpu``) without this backend active
96+
will raise an error.
9397
"""
9498

9599
array_api_dispatch = sklearn_configs.get("array_api_dispatch", False)

sklearnex/tests/test_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import logging
1818

19+
import numpy as np
1920
import pytest
2021
import sklearn
2122

@@ -127,6 +128,27 @@ def test_config_context_works():
127128
assert onedal_default_config_after_cc[param] == onedal_default_config[param]
128129

129130

131+
@pytest.mark.skipif(
132+
onedal._default_backend.is_dpc, reason="requires host default backend"
133+
)
134+
@pytest.mark.parametrize("target", ["auto", "cpu", "cpu:0", "gpu", 3])
135+
def test_host_backend_target_offload(target):
136+
from sklearnex.neighbors import NearestNeighbors
137+
138+
err_msg = (
139+
r"device use via \`target_offload\` is only supported with the DPC\+\+ backend"
140+
)
141+
142+
est = NearestNeighbors()
143+
if target != "auto":
144+
with pytest.raises(ValueError, match=err_msg):
145+
with sklearnex.config_context(target_offload=target):
146+
est.fit(np.eye(5, 8))
147+
else:
148+
with sklearnex.config_context(target_offload=target):
149+
est.fit(np.eye(5, 8))
150+
151+
130152
@pytest.mark.skipif(
131153
not is_dpctl_device_available(["gpu"]), reason="Requires a gpu for fallback testing"
132154
)

0 commit comments

Comments
 (0)