@@ -152,10 +152,48 @@ def validate_data(
152152 return out
153153
154154
155- def _check_sample_weight (
156- sample_weight , X , dtype = None , copy = False , ensure_non_negative = False
157- ):
155+ if sklearn_check_version ("1.9" ):
156+
157+ def _check_sample_weight (
158+ sample_weight ,
159+ X ,
160+ dtype = None ,
161+ copy = False ,
162+ ensure_non_negative = False ,
163+ allow_all_zero_weights = False ,
164+ ):
165+ return _check_sample_weight_internal (
166+ sample_weight ,
167+ X ,
168+ dtype = dtype ,
169+ copy = copy ,
170+ ensure_non_negative = ensure_non_negative ,
171+ allow_all_zero_weights = allow_all_zero_weights ,
172+ )
173+
174+ else :
158175
176+ def _check_sample_weight (
177+ sample_weight , X , dtype = None , copy = False , ensure_non_negative = False
178+ ):
179+ return _check_sample_weight_internal (
180+ sample_weight ,
181+ X ,
182+ dtype = dtype ,
183+ copy = copy ,
184+ ensure_non_negative = ensure_non_negative ,
185+ allow_all_zero_weights = True ,
186+ )
187+
188+
189+ def _check_sample_weight_internal (
190+ sample_weight ,
191+ X ,
192+ dtype = None ,
193+ copy = False ,
194+ ensure_non_negative = False ,
195+ allow_all_zero_weights = False ,
196+ ):
159197 n_samples = _num_samples (X )
160198 xp , _ = get_namespace (X )
161199
@@ -202,6 +240,10 @@ def _check_sample_weight(
202240 )
203241 )
204242
243+ if not allow_all_zero_weights :
244+ if xp .all (sample_weight == 0 ):
245+ raise ValueError ("Sample weights must contain at least one non-zero number." )
246+
205247 if ensure_non_negative :
206248 check_non_negative (sample_weight , "`sample_weight`" )
207249
0 commit comments