Skip to content

Commit 87bdb63

Browse files
committed
Enhance dataset loading: add shuffle and seed parameters for diverse sampling
1 parent bd9e948 commit 87bdb63

3 files changed

Lines changed: 28 additions & 5 deletions

File tree

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ AGENTS.md
7272
.github/agents/
7373
.github/instructions/
7474
.github/hooks/
75+
.github/skills/
76+
.agents/
7577
.specstory/
7678

7779
# Generated version file (hatch-vcs)

scripts/demo_hyperbolic_clip.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def main() -> None:
2222
image_key=HF_IMAGE_KEY,
2323
label_key=HF_LABEL_KEY,
2424
max_samples=NUM_SAMPLES,
25+
shuffle=True,
2526
)
2627
print(f"Loaded {len(dataset)} samples")
2728

src/hyperview/core/dataset.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,8 @@ def add_from_huggingface(
197197
label_key: str | None = "fine_label",
198198
label_names_key: str | None = None,
199199
max_samples: int | None = None,
200+
shuffle: bool = False,
201+
seed: int = 42,
200202
show_progress: bool = True,
201203
skip_existing: bool = True,
202204
image_format: str = "auto",
@@ -214,6 +216,8 @@ def add_from_huggingface(
214216
label_key: Key for the label column (can be None).
215217
label_names_key: Key for label names in dataset info.
216218
max_samples: Maximum number of samples to load.
219+
shuffle: If True, shuffle the dataset before sampling (ensures diverse classes).
220+
seed: Random seed for shuffling (default: 42).
217221
show_progress: Whether to print progress.
218222
skip_existing: If True (default), skip samples that already exist in storage.
219223
image_format: Image format to save: "auto" (detect from source, fallback PNG),
@@ -240,6 +244,22 @@ def add_from_huggingface(
240244
except Exception:
241245
ds = cast(Any, load_dataset(dataset_name, split=split))
242246

247+
source_fingerprint = ds._fingerprint if hasattr(ds, "_fingerprint") else None
248+
249+
dataset_size = len(ds)
250+
total = dataset_size if max_samples is None else min(dataset_size, max_samples)
251+
252+
# Select source row indices explicitly so sampled subsets are clear and
253+
# sample IDs remain stable for the same underlying row.
254+
selected_indices: list[int] | None = None
255+
if shuffle:
256+
rng = np.random.default_rng(seed)
257+
selected_indices = rng.permutation(dataset_size)[:total].tolist()
258+
ds = ds.select(selected_indices)
259+
elif max_samples is not None:
260+
selected_indices = list(range(total))
261+
ds = ds.select(selected_indices)
262+
243263
# Get label names if available
244264
label_names = None
245265
if label_key and label_names_key:
@@ -251,15 +271,14 @@ def add_from_huggingface(
251271

252272
# Extract dataset metadata for robust sample IDs
253273
config_name = getattr(ds.info, "config_name", None) or "default"
254-
fingerprint = ds._fingerprint[:8] if hasattr(ds, "_fingerprint") and ds._fingerprint else "unknown"
274+
fingerprint = source_fingerprint[:8] if source_fingerprint else "unknown"
255275
version = str(ds.info.version) if ds.info.version else None
256276

257277
# Get media directory for this dataset
258278
config = StorageConfig.default()
259279
media_dir = config.get_huggingface_media_dir(dataset_name, split)
260280

261281
samples = []
262-
total = len(ds) if max_samples is None else min(len(ds), max_samples)
263282

264283
if show_progress:
265284
print(f"Loading {total} samples from {dataset_name}...")
@@ -268,6 +287,7 @@ def add_from_huggingface(
268287

269288
for i in iterator:
270289
item = ds[i]
290+
source_index = selected_indices[i] if selected_indices is not None else i
271291
image = item[image_key]
272292

273293
# Handle PIL Image or numpy array
@@ -287,7 +307,7 @@ def add_from_huggingface(
287307

288308
# Generate robust sample ID with config and fingerprint
289309
safe_name = dataset_name.replace("/", "_")
290-
sample_id = f"{safe_name}_{config_name}_{fingerprint}_{split}_{i}"
310+
sample_id = f"{safe_name}_{config_name}_{fingerprint}_{split}_{source_index}"
291311

292312
# Determine image format and extension
293313
if image_format == "auto":
@@ -311,8 +331,8 @@ def add_from_huggingface(
311331
"source": dataset_name,
312332
"config": config_name,
313333
"split": split,
314-
"index": i,
315-
"fingerprint": ds._fingerprint if hasattr(ds, "_fingerprint") else None,
334+
"index": source_index,
335+
"fingerprint": source_fingerprint,
316336
"version": version,
317337
}
318338

0 commit comments

Comments
 (0)