@@ -197,6 +197,8 @@ def add_from_huggingface(
197197 label_key : str | None = "fine_label" ,
198198 label_names_key : str | None = None ,
199199 max_samples : int | None = None ,
200+ shuffle : bool = False ,
201+ seed : int = 42 ,
200202 show_progress : bool = True ,
201203 skip_existing : bool = True ,
202204 image_format : str = "auto" ,
@@ -214,6 +216,8 @@ def add_from_huggingface(
214216 label_key: Key for the label column (can be None).
215217 label_names_key: Key for label names in dataset info.
216218 max_samples: Maximum number of samples to load.
219+ shuffle: If True, shuffle the dataset before sampling (ensures diverse classes).
220+ seed: Random seed for shuffling (default: 42).
217221 show_progress: Whether to print progress.
218222 skip_existing: If True (default), skip samples that already exist in storage.
219223 image_format: Image format to save: "auto" (detect from source, fallback PNG),
@@ -240,6 +244,22 @@ def add_from_huggingface(
240244 except Exception :
241245 ds = cast (Any , load_dataset (dataset_name , split = split ))
242246
247+ source_fingerprint = ds ._fingerprint if hasattr (ds , "_fingerprint" ) else None
248+
249+ dataset_size = len (ds )
250+ total = dataset_size if max_samples is None else min (dataset_size , max_samples )
251+
252+ # Select source row indices explicitly so sampled subsets are clear and
253+ # sample IDs remain stable for the same underlying row.
254+ selected_indices : list [int ] | None = None
255+ if shuffle :
256+ rng = np .random .default_rng (seed )
257+ selected_indices = rng .permutation (dataset_size )[:total ].tolist ()
258+ ds = ds .select (selected_indices )
259+ elif max_samples is not None :
260+ selected_indices = list (range (total ))
261+ ds = ds .select (selected_indices )
262+
243263 # Get label names if available
244264 label_names = None
245265 if label_key and label_names_key :
@@ -251,15 +271,14 @@ def add_from_huggingface(
251271
252272 # Extract dataset metadata for robust sample IDs
253273 config_name = getattr (ds .info , "config_name" , None ) or "default"
254- fingerprint = ds . _fingerprint [:8 ] if hasattr ( ds , "_fingerprint" ) and ds . _fingerprint else "unknown"
274+ fingerprint = source_fingerprint [:8 ] if source_fingerprint else "unknown"
255275 version = str (ds .info .version ) if ds .info .version else None
256276
257277 # Get media directory for this dataset
258278 config = StorageConfig .default ()
259279 media_dir = config .get_huggingface_media_dir (dataset_name , split )
260280
261281 samples = []
262- total = len (ds ) if max_samples is None else min (len (ds ), max_samples )
263282
264283 if show_progress :
265284 print (f"Loading { total } samples from { dataset_name } ..." )
@@ -268,6 +287,7 @@ def add_from_huggingface(
268287
269288 for i in iterator :
270289 item = ds [i ]
290+ source_index = selected_indices [i ] if selected_indices is not None else i
271291 image = item [image_key ]
272292
273293 # Handle PIL Image or numpy array
@@ -287,7 +307,7 @@ def add_from_huggingface(
287307
288308 # Generate robust sample ID with config and fingerprint
289309 safe_name = dataset_name .replace ("/" , "_" )
290- sample_id = f"{ safe_name } _{ config_name } _{ fingerprint } _{ split } _{ i } "
310+ sample_id = f"{ safe_name } _{ config_name } _{ fingerprint } _{ split } _{ source_index } "
291311
292312 # Determine image format and extension
293313 if image_format == "auto" :
@@ -311,8 +331,8 @@ def add_from_huggingface(
311331 "source" : dataset_name ,
312332 "config" : config_name ,
313333 "split" : split ,
314- "index" : i ,
315- "fingerprint" : ds . _fingerprint if hasattr ( ds , "_fingerprint" ) else None ,
334+ "index" : source_index ,
335+ "fingerprint" : source_fingerprint ,
316336 "version" : version ,
317337 }
318338
0 commit comments