Skip to content

Commit b742d86

Browse files
authored
fix: minor basic stats quality fixes (#2521)
* fix: minor basic stats quality fixes * blacked * vary seed by rank instead of size
1 parent c15f011 commit b742d86

2 files changed

Lines changed: 2 additions & 4 deletions

File tree

examples/sklearnex/basic_statistics_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def generate_data(par, size, seed=777):
4848

4949
params_spmd = {"ns": 19, "nf": 31}
5050

51-
data, weights = generate_data(params_spmd, size)
51+
data, weights = generate_data(params_spmd, size, seed=rank)
5252
weighted_data = np.diag(weights) @ data
5353

5454
dpt_data = dpt.asarray(data, usm_type="device", sycl_queue=q)

onedal/basic_statistics/basic_statistics.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ def fit(self, data, sample_weight=None, queue=None):
157157
data_table, weights_table = to_table(data, sample_weight, queue=queue)
158158

159159
dtype = data_table.dtype
160-
raw_result = raw_result = self._compute_raw(
161-
data_table, weights_table, dtype, is_csr
162-
)
160+
raw_result = self._compute_raw(data_table, weights_table, dtype, is_csr)
163161
for opt, raw_value in raw_result.items():
164162
value = from_table(raw_value).ravel()
165163
if is_single_dim:

0 commit comments

Comments
 (0)