2222 _convert_to_dataframe ,
2323 get_dataframes_and_queues ,
2424)
25+ from sklearnex import config_context
2526from sklearnex .tests .utils .spmd import (
27+ _as_numpy ,
2628 _assert_kmeans_labels_allclose ,
2729 _assert_unordered_allclose ,
2830 _generate_clustering_data ,
@@ -108,9 +110,10 @@ def test_kmeans_spmd_gold(dataframe, queue):
108110 get_dataframes_and_queues (dataframe_filter_ = "dpnp" , device_filter_ = "gpu" ),
109111)
110112@pytest .mark .parametrize ("dtype" , [np .float32 , np .float64 ])
113+ @pytest .mark .parametrize ("array_api_dispatch" , [True , False ])
111114@pytest .mark .mpi
112115def test_kmeans_spmd_synthetic (
113- n_samples , n_features , n_clusters , dataframe , queue , dtype
116+ n_samples , n_features , n_clusters , dataframe , queue , dtype , array_api_dispatch
114117):
115118 # Import spmd and batch algo
116119 from sklearnex .cluster import KMeans as KMeans_Batch
@@ -129,9 +132,11 @@ def test_kmeans_spmd_synthetic(
129132 )
130133
131134 # Validate KMeans init
132- spmd_model_init = KMeans_SPMD (n_clusters = n_clusters , max_iter = 1 , random_state = 0 ).fit (
133- local_dpt_X_train
134- )
135+ # Configure array_api_dispatch for spmd estimator
136+ with config_context (array_api_dispatch = array_api_dispatch ):
137+ spmd_model_init = KMeans_SPMD (
138+ n_clusters = n_clusters , max_iter = 1 , random_state = 0
139+ ).fit (local_dpt_X_train )
135140 batch_model_init = KMeans_Batch (
136141 n_clusters = n_clusters , max_iter = 1 , random_state = 0
137142 ).fit (X_train )
@@ -142,9 +147,13 @@ def test_kmeans_spmd_synthetic(
142147 spmd_model = KMeans_SPMD (
143148 n_clusters = n_clusters , init = spmd_model_init .cluster_centers_ , random_state = 0
144149 )
145- spmd_model .fit (local_dpt_X_train )
150+ # Configure array_api_dispatch for spmd estimator
151+ with config_context (array_api_dispatch = array_api_dispatch ):
152+ spmd_model .fit (local_dpt_X_train )
146153 batch_model = KMeans_Batch (
147- n_clusters = n_clusters , init = spmd_model_init .cluster_centers_ , random_state = 0
154+ n_clusters = n_clusters ,
155+ init = _as_numpy (spmd_model_init .cluster_centers_ ),
156+ random_state = 0 ,
148157 ).fit (X_train )
149158
150159 atol = 1e-5 if dtype == np .float32 else 1e-7
@@ -162,7 +171,9 @@ def test_kmeans_spmd_synthetic(
162171 # assert_allclose(spmd_model.n_iter_, batch_model.n_iter_, atol=1)
163172
164173 # Ensure predictions of batch algo match spmd
165- spmd_result = spmd_model .predict (local_dpt_X_test )
174+ # Configure array_api_dispatch for spmd estimator
175+ with config_context (array_api_dispatch = array_api_dispatch ):
176+ spmd_result = spmd_model .predict (local_dpt_X_test )
166177 batch_result = batch_model .predict (X_test )
167178
168179 _assert_kmeans_labels_allclose (
0 commit comments