Skip to content

Commit d7536f3

Browse files
authored
[bug] fix policy changes for spmd (#2554)
* Update dal.cpp * Update policy.cpp
1 parent b742d86 commit d7536f3

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

onedal/common/policy.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,14 @@ py::object get_policy(py::object obj) {
147147
};
148148

149149
ONEDAL_PY_INIT_MODULE(policy) {
150+
#ifdef ONEDAL_DATA_PARALLEL_SPMD
151+
instantiate_spmd_policy(m);
152+
#else
150153
instantiate_host_policy(m);
151154
instantiate_default_host_policy(m);
152155
#ifdef ONEDAL_DATA_PARALLEL
153156
instantiate_data_parallel_policy(m);
154157
#endif // ONEDAL_DATA_PARALLEL
155-
#ifdef ONEDAL_DATA_PARALLEL_SPMD
156-
instantiate_spmd_policy(m);
157158
#endif // ONEDAL_DATA_PARALLEL_SPMD
158159
m.def("get_policy", &get_policy, py::arg("queue") = py::none());
159160
}

onedal/dal.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace oneapi::dal::python {
2323

2424
/* common */
2525
#ifdef ONEDAL_DATA_PARALLEL_SPMD
26+
ONEDAL_PY_INIT_MODULE(policy);
2627
/* algorithms */
2728
ONEDAL_PY_INIT_MODULE(covariance);
2829
ONEDAL_PY_INIT_MODULE(dbscan);
@@ -83,6 +84,7 @@ ONEDAL_PY_INIT_MODULE(finiteness_checker);
8384

8485
#ifdef ONEDAL_DATA_PARALLEL_SPMD
8586
PYBIND11_MODULE(_onedal_py_spmd_dpc, m) {
87+
init_policy(m);
8688
init_covariance(m);
8789
init_dbscan(m);
8890
init_decomposition(m);

0 commit comments

Comments
 (0)