|
12 | 12 | """ |
13 | 13 |
|
14 | 14 | import json |
15 | | -import pickle |
16 | 15 | import sys |
17 | 16 | from importlib.metadata import version as pkg_version |
18 | 17 | from pathlib import Path |
19 | | -from typing import Any, List, Literal, cast |
| 18 | +from typing import Any, List, Literal, Tuple, cast |
20 | 19 |
|
21 | 20 | import hydra |
22 | 21 | import networkx as nx |
|
30 | 29 | ) |
31 | 30 |
|
32 | 31 | sys.path.insert(0, str(here() / "reproducibility")) |
| 32 | +from utils.data import load_graphs as _load |
33 | 33 | from utils.data import make_tabpfn_classifier |
34 | 34 |
|
35 | 35 | REPO_ROOT = here() |
|
40 | 40 |
|
41 | 41 |
|
42 | 42 | def load_graphs(path: Path) -> List[nx.Graph]: |
43 | | - """Load graphs from pickle file and convert to networkx.""" |
44 | | - if not path.exists(): |
45 | | - logger.warning("{} not found", path) |
46 | | - return [] |
47 | | - with open(path, "rb") as f: |
48 | | - data = pickle.load(f) |
49 | | - graphs = [] |
50 | | - for item in data: |
51 | | - if isinstance(item, nx.Graph): |
52 | | - graphs.append(item) |
53 | | - elif isinstance(item, (tuple, list)) and len(item) >= 2: |
54 | | - adj = item[1] |
55 | | - if hasattr(adj, "numpy"): |
56 | | - adj = adj.numpy() |
57 | | - graphs.append(nx.from_numpy_array(adj)) |
58 | | - else: |
59 | | - graphs.append(nx.from_numpy_array(np.array(item))) |
60 | | - return graphs |
| 43 | + """Load graphs from a single pickle file. |
| 44 | +
|
| 45 | + Delegates to ``utils.data.load_graphs`` by extracting the parent |
| 46 | + directory and stem so that the caller-facing ``(path)`` signature |
| 47 | + is preserved. |
| 48 | + """ |
| 49 | + return _load(path.parent, "", path.stem) |
61 | 50 |
|
62 | 51 |
|
63 | 52 | def get_reference_dataset( |
64 | 53 | dataset: str, |
65 | 54 | split: Literal["train", "val", "test"] = "train", |
66 | 55 | num_graphs: int = 2048, |
67 | | -): |
68 | | - """Get reference dataset from polygraph library.""" |
69 | | - if dataset == "planar": |
70 | | - from polygraph.datasets.planar import ProceduralPlanarGraphDataset |
71 | | - |
72 | | - ds = ProceduralPlanarGraphDataset(split=split, num_graphs=num_graphs) |
73 | | - elif dataset == "sbm": |
74 | | - from polygraph.datasets.sbm import ProceduralSBMGraphDataset |
75 | | - |
76 | | - ds = ProceduralSBMGraphDataset(split=split, num_graphs=num_graphs) |
77 | | - elif dataset == "lobster": |
78 | | - from polygraph.datasets.lobster import ProceduralLobsterGraphDataset |
79 | | - |
80 | | - ds = ProceduralLobsterGraphDataset(split=split, num_graphs=num_graphs) |
81 | | - else: |
| 56 | +) -> Tuple[Any, List[nx.Graph]]: |
| 57 | + """Get reference dataset from polygraph library. |
| 58 | +
|
| 59 | + Returns ``(dataset_object, graphs)`` so callers can also call |
| 60 | + ``dataset_object.is_valid()``. |
| 61 | + """ |
| 62 | + from polygraph.datasets.lobster import ProceduralLobsterGraphDataset |
| 63 | + from polygraph.datasets.planar import ProceduralPlanarGraphDataset |
| 64 | + from polygraph.datasets.sbm import ProceduralSBMGraphDataset |
| 65 | + |
| 66 | + procedural = { |
| 67 | + "planar": ProceduralPlanarGraphDataset, |
| 68 | + "lobster": ProceduralLobsterGraphDataset, |
| 69 | + "sbm": ProceduralSBMGraphDataset, |
| 70 | + } |
| 71 | + if dataset not in procedural: |
82 | 72 | raise ValueError(f"Unknown dataset: {dataset}") |
| 73 | + ds = procedural[dataset](split=split, num_graphs=num_graphs) |
83 | 74 | return ds, list(ds.to_nx()) |
84 | 75 |
|
85 | 76 |
|
|
0 commit comments