Skip to content

Commit 91e3076

Browse files
committed
Deduplicate utilities and add kernel size guard
- L4: load_graphs/get_reference_dataset in 01-03 now delegate to utils/data.py instead of local copies - L6: load_results extracted to utils/formatting.py, removed from 4 format scripts - P1: Warn when kernel matrix exceeds 10k samples in kernel_lr.py
1 parent 58d0385 commit 91e3076

9 files changed

Lines changed: 75 additions & 198 deletions

File tree

polygraph/metrics/base/kernel_lr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,14 @@ def fit(
156156
self.X_train_ = X
157157
self.y_train_ = y
158158

159+
n_samples = len(X) if isinstance(X, list) else int(X.shape[0]) # pyright: ignore[reportOptionalSubscript]
160+
if n_samples > 10_000:
161+
warnings.warn(
162+
f"Kernel matrix will require ~{n_samples**2 * 8 / 1e9:.1f} GB "
163+
f"of memory for {n_samples} samples. Consider reducing the "
164+
f"dataset size."
165+
)
166+
159167
K = self._compute_kernel_matrix(X)
160168
if self.normalize_kernel:
161169
self.train_diag_ = np.diag(K).astype(np.float64)

reproducibility/01_subsampling/compute_mmd.py

Lines changed: 6 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@
1414
"""
1515

1616
import json
17-
import pickle
17+
import sys
1818
import time
1919
from pathlib import Path
2020
from typing import List, Literal, cast
2121

2222
import hydra
2323
import networkx as nx
2424
import numpy as np
25-
import torch
2625
from loguru import logger
2726
from omegaconf import DictConfig
2827
from pyprojroot import here
@@ -41,6 +40,10 @@
4140
)
4241
from polygraph.utils.kernels import AdaptiveRBFKernel
4342

43+
sys.path.insert(0, str(here() / "reproducibility"))
44+
from utils.data import get_reference_dataset
45+
from utils.data import load_graphs as _load
46+
4447
REPO_ROOT = here()
4548
DATA_DIR = REPO_ROOT / "data"
4649
EXPERIMENT_RESULTS_DIR = (
@@ -64,49 +67,7 @@
6467

6568
def load_graphs(model: str, dataset: str) -> List[nx.Graph]:
6669
"""Load model-generated graphs from ``data/{model}/{dataset}.pkl``."""
67-
pkl_path = DATA_DIR / model / f"{dataset}.pkl"
68-
if not pkl_path.exists():
69-
raise FileNotFoundError(f"{pkl_path} not found")
70-
with open(pkl_path, "rb") as f:
71-
graphs = pickle.load(f)
72-
73-
cleaned: List[nx.Graph] = []
74-
for g in graphs:
75-
if isinstance(g, nx.Graph):
76-
simple = nx.Graph(g)
77-
elif isinstance(g, (list, tuple)) and len(g) == 2:
78-
try:
79-
_node_feat, adj = g
80-
if isinstance(adj, torch.Tensor):
81-
adj = adj.numpy()
82-
simple = nx.from_numpy_array(adj)
83-
except Exception as e:
84-
logger.warning("Could not convert graph: {}", e)
85-
continue
86-
else:
87-
logger.warning("Unknown graph format: {}", type(g))
88-
continue
89-
simple.remove_edges_from(nx.selfloop_edges(simple))
90-
cleaned.append(simple)
91-
return cleaned
92-
93-
94-
def get_reference_dataset(
95-
dataset: str, split: str = "train", num_graphs: int = 4096
96-
) -> List[nx.Graph]:
97-
"""Get reference dataset from polygraph procedural generators."""
98-
from polygraph.datasets.lobster import ProceduralLobsterGraphDataset
99-
from polygraph.datasets.planar import ProceduralPlanarGraphDataset
100-
from polygraph.datasets.sbm import ProceduralSBMGraphDataset
101-
102-
classes = {
103-
"planar": ProceduralPlanarGraphDataset,
104-
"lobster": ProceduralLobsterGraphDataset,
105-
"sbm": ProceduralSBMGraphDataset,
106-
}
107-
if dataset not in classes:
108-
raise ValueError(f"Unknown dataset: {dataset}")
109-
return list(classes[dataset](split=split, num_graphs=num_graphs).to_nx())
70+
return _load(DATA_DIR, model, dataset)
11071

11172

11273
def make_descriptor(name: str, reference_graphs: List[nx.Graph]):

reproducibility/01_subsampling/compute_pgd.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"""
1515

1616
import json
17-
import pickle
1817
import sys
1918
import time
2019
from importlib.metadata import version as pkg_version
@@ -23,7 +22,6 @@
2322

2423
import hydra
2524
import networkx as nx
26-
import torch
2725
from loguru import logger
2826
from omegaconf import DictConfig
2927
from pyprojroot import here
@@ -34,6 +32,8 @@
3432
)
3533

3634
sys.path.insert(0, str(here() / "reproducibility"))
35+
from utils.data import get_reference_dataset
36+
from utils.data import load_graphs as _load
3737
from utils.data import make_tabpfn_classifier
3838

3939

@@ -46,49 +46,7 @@
4646

4747
def load_graphs(model: str, dataset: str) -> List[nx.Graph]:
4848
"""Load model-generated graphs from ``data/{model}/{dataset}.pkl``."""
49-
pkl_path = DATA_DIR / model / f"{dataset}.pkl"
50-
if not pkl_path.exists():
51-
raise FileNotFoundError(f"{pkl_path} not found")
52-
with open(pkl_path, "rb") as f:
53-
graphs = pickle.load(f)
54-
55-
cleaned: List[nx.Graph] = []
56-
for g in graphs:
57-
if isinstance(g, nx.Graph):
58-
simple = nx.Graph(g)
59-
elif isinstance(g, (list, tuple)) and len(g) == 2:
60-
try:
61-
_node_feat, adj = g
62-
if isinstance(adj, torch.Tensor):
63-
adj = adj.numpy()
64-
simple = nx.from_numpy_array(adj)
65-
except Exception as e:
66-
logger.warning("Could not convert graph: {}", e)
67-
continue
68-
else:
69-
logger.warning("Unknown graph format: {}", type(g))
70-
continue
71-
simple.remove_edges_from(nx.selfloop_edges(simple))
72-
cleaned.append(simple)
73-
return cleaned
74-
75-
76-
def get_reference_dataset(
77-
dataset: str, split: str = "train", num_graphs: int = 4096
78-
) -> List[nx.Graph]:
79-
"""Get reference dataset from polygraph procedural generators."""
80-
from polygraph.datasets.lobster import ProceduralLobsterGraphDataset
81-
from polygraph.datasets.planar import ProceduralPlanarGraphDataset
82-
from polygraph.datasets.sbm import ProceduralSBMGraphDataset
83-
84-
classes = {
85-
"planar": ProceduralPlanarGraphDataset,
86-
"lobster": ProceduralLobsterGraphDataset,
87-
"sbm": ProceduralSBMGraphDataset,
88-
}
89-
if dataset not in classes:
90-
raise ValueError(f"Unknown dataset: {dataset}")
91-
return list(classes[dataset](split=split, num_graphs=num_graphs).to_nx())
49+
return _load(DATA_DIR, model, dataset)
9250

9351

9452
@hydra.main(

reproducibility/03_model_quality/compute.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
"""
1313

1414
import json
15-
import pickle
1615
import sys
1716
from importlib.metadata import version as pkg_version
1817
from pathlib import Path
19-
from typing import Any, List, Literal, cast
18+
from typing import Any, List, Literal, Tuple, cast
2019

2120
import hydra
2221
import networkx as nx
@@ -30,6 +29,7 @@
3029
)
3130

3231
sys.path.insert(0, str(here() / "reproducibility"))
32+
from utils.data import load_graphs as _load
3333
from utils.data import make_tabpfn_classifier
3434

3535
REPO_ROOT = here()
@@ -40,46 +40,37 @@
4040

4141

4242
def load_graphs(path: Path) -> List[nx.Graph]:
43-
"""Load graphs from pickle file and convert to networkx."""
44-
if not path.exists():
45-
logger.warning("{} not found", path)
46-
return []
47-
with open(path, "rb") as f:
48-
data = pickle.load(f)
49-
graphs = []
50-
for item in data:
51-
if isinstance(item, nx.Graph):
52-
graphs.append(item)
53-
elif isinstance(item, (tuple, list)) and len(item) >= 2:
54-
adj = item[1]
55-
if hasattr(adj, "numpy"):
56-
adj = adj.numpy()
57-
graphs.append(nx.from_numpy_array(adj))
58-
else:
59-
graphs.append(nx.from_numpy_array(np.array(item)))
60-
return graphs
43+
"""Load graphs from a single pickle file.
44+
45+
Delegates to ``utils.data.load_graphs`` by extracting the parent
46+
directory and stem so that the caller-facing ``(path)`` signature
47+
is preserved.
48+
"""
49+
return _load(path.parent, "", path.stem)
6150

6251

6352
def get_reference_dataset(
6453
dataset: str,
6554
split: Literal["train", "val", "test"] = "train",
6655
num_graphs: int = 2048,
67-
):
68-
"""Get reference dataset from polygraph library."""
69-
if dataset == "planar":
70-
from polygraph.datasets.planar import ProceduralPlanarGraphDataset
71-
72-
ds = ProceduralPlanarGraphDataset(split=split, num_graphs=num_graphs)
73-
elif dataset == "sbm":
74-
from polygraph.datasets.sbm import ProceduralSBMGraphDataset
75-
76-
ds = ProceduralSBMGraphDataset(split=split, num_graphs=num_graphs)
77-
elif dataset == "lobster":
78-
from polygraph.datasets.lobster import ProceduralLobsterGraphDataset
79-
80-
ds = ProceduralLobsterGraphDataset(split=split, num_graphs=num_graphs)
81-
else:
56+
) -> Tuple[Any, List[nx.Graph]]:
57+
"""Get reference dataset from polygraph library.
58+
59+
Returns ``(dataset_object, graphs)`` so callers can also call
60+
``dataset_object.is_valid()``.
61+
"""
62+
from polygraph.datasets.lobster import ProceduralLobsterGraphDataset
63+
from polygraph.datasets.planar import ProceduralPlanarGraphDataset
64+
from polygraph.datasets.sbm import ProceduralSBMGraphDataset
65+
66+
procedural = {
67+
"planar": ProceduralPlanarGraphDataset,
68+
"lobster": ProceduralLobsterGraphDataset,
69+
"sbm": ProceduralSBMGraphDataset,
70+
}
71+
if dataset not in procedural:
8272
raise ValueError(f"Unknown dataset: {dataset}")
73+
ds = procedural[dataset](split=split, num_graphs=num_graphs)
8374
return ds, list(ds.to_nx())
8475

8576

reproducibility/05_benchmark/format.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
python format.py
88
"""
99

10-
import json
1110
import sys
12-
from pathlib import Path
13-
from typing import Dict, List
11+
from typing import Dict
1412

1513
import pandas as pd
1614
import typer
@@ -25,6 +23,7 @@
2523
MODELS,
2624
best_two,
2725
fmt_pgs,
26+
load_results,
2827
)
2928

3029
app = typer.Typer()
@@ -43,26 +42,6 @@
4342
]
4443

4544

46-
def load_results(results_dir: Path) -> List[Dict]:
47-
results = []
48-
for f in sorted(results_dir.glob("*.json")):
49-
with open(f) as fh:
50-
results.append(json.load(fh))
51-
return results
52-
53-
54-
def _reshape(result_list: List[Dict]) -> Dict[str, Dict]:
55-
all_results: Dict[str, Dict] = {}
56-
for r in result_list:
57-
r = r.copy()
58-
ds = r.pop("dataset", None)
59-
model = r.pop("model", None)
60-
r.pop("error", None)
61-
if ds and model:
62-
all_results.setdefault(ds, {})[model] = r
63-
return all_results
64-
65-
6645
def generate_benchmark_table(all_results: Dict) -> str:
6746
lines = []
6847
lines.append(
@@ -167,14 +146,12 @@ def generate_benchmark_table(all_results: Dict) -> str:
167146
def main():
168147
"""Generate LaTeX tables from pre-computed JSON results."""
169148
results_dir = OUTPUT_DIR / "results" / "benchmark"
170-
result_list = load_results(results_dir)
171-
if not result_list:
149+
all_results = load_results(results_dir)
150+
if not all_results:
172151
logger.error(
173152
"No results found in {}. Run compute.py first.", results_dir
174153
)
175154
return
176-
177-
all_results = _reshape(result_list)
178155
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
179156

180157
table = generate_benchmark_table(all_results)

reproducibility/06_mmd/format.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
python format.py
1111
"""
1212

13-
import json
1413
import sys
15-
from pathlib import Path
1614
from typing import Dict, List
1715

1816
import pandas as pd
@@ -28,6 +26,7 @@
2826
MODELS,
2927
best_two,
3028
fmt_sci,
29+
load_results,
3130
)
3231

3332
app = typer.Typer()
@@ -38,19 +37,6 @@
3837
BENCHMARK_RESULTS_DIR = OUTPUT_DIR / "results" / "benchmark"
3938

4039

41-
def load_results(results_dir: Path) -> Dict[str, Dict]:
42-
all_r: Dict[str, Dict] = {}
43-
if not results_dir.exists():
44-
return all_r
45-
for f in sorted(results_dir.glob("*.json")):
46-
with open(f) as fh:
47-
r = json.load(fh)
48-
ds, model = r.get("dataset"), r.get("model")
49-
if ds and model:
50-
all_r.setdefault(ds, {})[model] = r
51-
return all_r
52-
53-
5440
def _fmt_pgs(mean: float, std: float, is_best=False, is_second=False) -> str:
5541
"""MMD tables use 3 decimal places for PGD scores (not 1 like benchmark)."""
5642
if pd.isna(mean):

0 commit comments

Comments
 (0)