Skip to content

Commit 7384361

Browse files
update signature for check_sample_weight (#2967)
1 parent 7574bac commit 7384361

1 file changed

Lines changed: 45 additions & 3 deletions

File tree

sklearnex/utils/validation.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,48 @@ def validate_data(
152152
return out
153153

154154

155-
def _check_sample_weight(
156-
sample_weight, X, dtype=None, copy=False, ensure_non_negative=False
157-
):
155+
if sklearn_check_version("1.9"):
156+
157+
def _check_sample_weight(
158+
sample_weight,
159+
X,
160+
dtype=None,
161+
copy=False,
162+
ensure_non_negative=False,
163+
allow_all_zero_weights=False,
164+
):
165+
return _check_sample_weight_internal(
166+
sample_weight,
167+
X,
168+
dtype=dtype,
169+
copy=copy,
170+
ensure_non_negative=ensure_non_negative,
171+
allow_all_zero_weights=allow_all_zero_weights,
172+
)
173+
174+
else:
158175

176+
def _check_sample_weight(
177+
sample_weight, X, dtype=None, copy=False, ensure_non_negative=False
178+
):
179+
return _check_sample_weight_internal(
180+
sample_weight,
181+
X,
182+
dtype=dtype,
183+
copy=copy,
184+
ensure_non_negative=ensure_non_negative,
185+
allow_all_zero_weights=True,
186+
)
187+
188+
189+
def _check_sample_weight_internal(
190+
sample_weight,
191+
X,
192+
dtype=None,
193+
copy=False,
194+
ensure_non_negative=False,
195+
allow_all_zero_weights=False,
196+
):
159197
n_samples = _num_samples(X)
160198
xp, _ = get_namespace(X)
161199

@@ -202,6 +240,10 @@ def _check_sample_weight(
202240
)
203241
)
204242

243+
if not allow_all_zero_weights:
244+
if xp.all(sample_weight == 0):
245+
raise ValueError("Sample weights must contain at least one non-zero number.")
246+
205247
if ensure_non_negative:
206248
check_non_negative(sample_weight, "`sample_weight`")
207249

0 commit comments

Comments
 (0)