Skip to content

Commit 4a36fa8

Browse files
committed
Fix remaining review items: performance, style, cleanup
polygraphdiscrepancy.py: - P3: Vectorize _is_constant sparse check (col min/max vs row loop) Reproducibility scripts: - L10: Remove duplicate runtime fields (keep *_perf_seconds only) - L11: Remove pointless _fmt_pgs/_best_two aliases in format scripts - M7: Remove section separator comments across all scripts
1 parent a3b9698 commit 4a36fa8

18 files changed

Lines changed: 25 additions & 191 deletions

File tree

polygraph/metrics/base/polygraphdiscrepancy.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,9 @@ def _is_constant(X) -> bool:
147147
if issparse(X):
148148
if X.shape[0] <= 1:
149149
return True
150-
first = X[0]
151-
for i in range(1, X.shape[0]):
152-
diff = X[i] - first
153-
if diff.nnz > 0:
154-
return False
155-
return True
150+
col_min = X.min(axis=0).toarray().ravel()
151+
col_max = X.max(axis=0).toarray().ravel()
152+
return bool(np.array_equal(col_min, col_max))
156153
return bool(np.all(X == X[0]))
157154

158155

reproducibility/01_subsampling/compute_mmd.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@
4141
)
4242
from polygraph.utils.kernels import AdaptiveRBFKernel
4343

44-
# ---------------------------------------------------------------------------
45-
# Paths (resolved before Hydra touches CWD; we disable chdir in the config)
46-
# ---------------------------------------------------------------------------
4744
REPO_ROOT = here()
4845
DATA_DIR = REPO_ROOT / "data"
4946
EXPERIMENT_RESULTS_DIR = (
@@ -65,9 +62,6 @@
6562
]
6663

6764

68-
# ---------------------------------------------------------------------------
69-
# Graph loading helpers
70-
# ---------------------------------------------------------------------------
7165
def load_graphs(model: str, dataset: str) -> List[nx.Graph]:
7266
"""Load model-generated graphs from ``data/{model}/{dataset}.pkl``."""
7367
pkl_path = DATA_DIR / model / f"{dataset}.pkl"
@@ -115,9 +109,6 @@ def get_reference_dataset(
115109
return list(classes[dataset](split=split, num_graphs=num_graphs).to_nx())
116110

117111

118-
# ---------------------------------------------------------------------------
119-
# Descriptor factory
120-
# ---------------------------------------------------------------------------
121112
def make_descriptor(name: str, reference_graphs: List[nx.Graph]):
122113
"""Instantiate a descriptor by name, matching the original experiment."""
123114
factories = {
@@ -137,9 +128,6 @@ def make_descriptor(name: str, reference_graphs: List[nx.Graph]):
137128
return factories[name]()
138129

139130

140-
# ---------------------------------------------------------------------------
141-
# Main
142-
# ---------------------------------------------------------------------------
143131
@hydra.main(
144132
config_path="../configs",
145133
config_name="01_subsampling_mmd",
@@ -278,7 +266,6 @@ def main(cfg: DictConfig) -> None:
278266
"mmd_std": float(result.std),
279267
"mmd_low": float(result.low) if result.low is not None else None,
280268
"mmd_high": float(result.high) if result.high is not None else None,
281-
"mmd_runtime_seconds": mmd_runtime_perf_seconds,
282269
"mmd_runtime_perf_seconds": mmd_runtime_perf_seconds,
283270
}
284271

@@ -301,7 +288,6 @@ def main(cfg: DictConfig) -> None:
301288
"status": "ok",
302289
"output_path": str(out_path),
303290
"result": output,
304-
"mmd_runtime_seconds": mmd_runtime_perf_seconds,
305291
"mmd_runtime_perf_seconds": mmd_runtime_perf_seconds,
306292
}
307293
)
@@ -329,7 +315,6 @@ def main(cfg: DictConfig) -> None:
329315
"variant": variant,
330316
"status": "error",
331317
"error": str(e),
332-
"mmd_runtime_seconds": metric_runtime_perf_seconds,
333318
"mmd_runtime_perf_seconds": metric_runtime_perf_seconds,
334319
}
335320
)

reproducibility/01_subsampling/compute_pgd.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,13 @@ def _make_tabpfn_classifier(weights_version: str):
5353
)
5454

5555

56-
# ---------------------------------------------------------------------------
57-
# Paths (resolved before Hydra touches CWD; we disable chdir in the config)
58-
# ---------------------------------------------------------------------------
5956
REPO_ROOT = here()
6057
DATA_DIR = REPO_ROOT / "data"
6158
EXPERIMENT_RESULTS_DIR = (
6259
REPO_ROOT / "reproducibility" / "figures" / "01_subsampling" / "results"
6360
)
6461

6562

66-
# ---------------------------------------------------------------------------
67-
# Graph loading helpers
68-
# ---------------------------------------------------------------------------
6963
def load_graphs(model: str, dataset: str) -> List[nx.Graph]:
7064
"""Load model-generated graphs from ``data/{model}/{dataset}.pkl``."""
7165
pkl_path = DATA_DIR / model / f"{dataset}.pkl"
@@ -113,9 +107,6 @@ def get_reference_dataset(
113107
return list(classes[dataset](split=split, num_graphs=num_graphs).to_nx())
114108

115109

116-
# ---------------------------------------------------------------------------
117-
# Main
118-
# ---------------------------------------------------------------------------
119110
@hydra.main(
120111
config_path="../configs",
121112
config_name="01_subsampling_pgd",
@@ -238,7 +229,6 @@ def main(cfg: DictConfig) -> None:
238229
"num_bootstrap": num_bootstrap,
239230
"pgd_mean": float(result["pgd"].mean),
240231
"pgd_std": float(result["pgd"].std),
241-
"pgd_runtime_seconds": pgd_runtime_perf_seconds,
242232
"pgd_runtime_perf_seconds": pgd_runtime_perf_seconds,
243233
"tabpfn_package_version": pkg_version("tabpfn"),
244234
"tabpfn_weights_version": tabpfn_weights_version,
@@ -271,7 +261,6 @@ def main(cfg: DictConfig) -> None:
271261
"status": "ok",
272262
"output_path": str(out_path),
273263
"result": output,
274-
"pgd_runtime_seconds": pgd_runtime_perf_seconds,
275264
"pgd_runtime_perf_seconds": pgd_runtime_perf_seconds,
276265
}
277266
)
@@ -295,7 +284,6 @@ def main(cfg: DictConfig) -> None:
295284
"subsample_size": subsample_size,
296285
"status": "error",
297286
"error": str(e),
298-
"pgd_runtime_seconds": metric_runtime_perf_seconds,
299287
"pgd_runtime_perf_seconds": metric_runtime_perf_seconds,
300288
}
301289
)

reproducibility/02_perturbation/compute.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@
8282
_RBF_BW = np.array([0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0])
8383

8484

85-
# ---------------------------------------------------------------------------
86-
# Perturbation functions (inline, matching original library implementations)
87-
# ---------------------------------------------------------------------------
88-
89-
9085
def edge_rewiring(graph: nx.Graph, noise_level: float) -> nx.Graph:
9186
"""Rewire edges: each selected with P(noise_level), one endpoint reconnected."""
9287
if not (0 <= noise_level <= 1):
@@ -252,11 +247,6 @@ def edge_addition(graph: nx.Graph, noise_level: float) -> nx.Graph:
252247
}
253248

254249

255-
# ---------------------------------------------------------------------------
256-
# Dataset loading
257-
# ---------------------------------------------------------------------------
258-
259-
260250
def load_dataset(
261251
dataset: str, num_graphs: int, seed: int
262252
) -> Tuple[List[nx.Graph], List[nx.Graph]]:
@@ -335,11 +325,6 @@ def load_dataset(
335325
return reference_graphs, perturbed_graphs
336326

337327

338-
# ---------------------------------------------------------------------------
339-
# Metric initialization
340-
# ---------------------------------------------------------------------------
341-
342-
343328
def _make_classifier(name: str, tabpfn_weights_version: str = "v2.5"):
344329
"""Build a classifier by name. For TabPFN, respects weights version."""
345330
if name == "tabpfn":
@@ -457,11 +442,6 @@ def build_metrics(
457442
return metrics
458443

459444

460-
# ---------------------------------------------------------------------------
461-
# Evaluation
462-
# ---------------------------------------------------------------------------
463-
464-
465445
def evaluate_metrics(
466446
perturbed_graphs: List[nx.Graph],
467447
metrics: Dict[str, Any],
@@ -491,11 +471,6 @@ def evaluate_metrics(
491471
return result
492472

493473

494-
# ---------------------------------------------------------------------------
495-
# Main
496-
# ---------------------------------------------------------------------------
497-
498-
499474
@hydra.main(
500475
config_path="../configs", config_name="02_perturbation", version_base=None
501476
)

reproducibility/02_perturbation/plot.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,6 @@ def _compute_spearman(series: pd.Series, noise: pd.Series) -> float:
226226
return float(rho) # type: ignore[arg-type]
227227

228228

229-
# ---------------------------------------------------------------------------
230-
# Figure 1 & 2: Correlation bar plots
231-
# ---------------------------------------------------------------------------
232-
233-
234229
def plot_correlation_bars(
235230
all_data: Dict,
236231
classifier: str,
@@ -329,11 +324,6 @@ def plot_correlation_bars(
329324
logger.success("Saved: {}", out)
330325

331326

332-
# ---------------------------------------------------------------------------
333-
# Figures 3-6: Metrics vs noise level (faceted grid)
334-
# ---------------------------------------------------------------------------
335-
336-
337327
def plot_metrics_vs_noise(
338328
all_data: Dict,
339329
classifier: str,
@@ -485,11 +475,6 @@ def plot_metrics_vs_noise(
485475
logger.success("Saved: {}", out)
486476

487477

488-
# ---------------------------------------------------------------------------
489-
# Figure 7: LR vs TabPFN comparison
490-
# ---------------------------------------------------------------------------
491-
492-
493478
def plot_lr_vs_tabpfn(
494479
all_data: Dict,
495480
variant: str,
@@ -586,11 +571,6 @@ def plot_lr_vs_tabpfn(
586571
logger.success("Saved: {}", out)
587572

588573

589-
# ---------------------------------------------------------------------------
590-
# Single-dataset perturbation figures (e.g. SBM-only)
591-
# ---------------------------------------------------------------------------
592-
593-
594574
def plot_single_dataset_perturbation(
595575
all_data: Dict,
596576
classifier: str,
@@ -743,11 +723,6 @@ def plot_single_dataset_perturbation(
743723
logger.success("Saved: {}", out)
744724

745725

746-
# ---------------------------------------------------------------------------
747-
# CLI
748-
# ---------------------------------------------------------------------------
749-
750-
751726
def _load_results_dir(results_dir: Path) -> Dict[Tuple[str, str], dict]:
752727
"""Load all perturbation JSON results from a directory."""
753728
data = {}

reproducibility/03_model_quality/compute_vun.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@
4040
from utils.vun import compute_vun_parallel # noqa: E402
4141

4242

43-
# ---------------------------------------------------------------------------
44-
# Graph loading (mirrors compute.py)
45-
# ---------------------------------------------------------------------------
46-
47-
4843
def load_graphs(path: Path) -> List[nx.Graph]:
4944
"""Load graphs from pickle file and convert to networkx."""
5045
if not path.exists():
@@ -89,11 +84,6 @@ def get_reference_dataset(
8984
return ds, list(ds.to_nx())
9085

9186

92-
# ---------------------------------------------------------------------------
93-
# Main
94-
# ---------------------------------------------------------------------------
95-
96-
9787
@hydra.main(
9888
config_path="../configs",
9989
config_name="03_model_quality_vun",

reproducibility/03_model_quality/format.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,6 @@ def _neg_pearson(x, y) -> float:
9696
return -float(r) # type: ignore[arg-type]
9797

9898

99-
# ---------------------------------------------------------------------------
100-
# Table 1 & 3: Pearson correlation of validity with other metrics
101-
# ---------------------------------------------------------------------------
102-
103-
10499
def _format_row_with_ranking(
105100
values: list[float], fmt: str = "{:.2f}"
106101
) -> list[str]:
@@ -186,11 +181,6 @@ def generate_pearson_correlation_table(variant: str) -> str:
186181
return "\n".join(lines)
187182

188183

189-
# ---------------------------------------------------------------------------
190-
# Table 2 & 4: Spearman correlation with training iterations
191-
# ---------------------------------------------------------------------------
192-
193-
194184
def generate_spearman_training_table(variant: str) -> str:
195185
"""Generate Spearman correlation table of metrics with training steps.
196186
@@ -249,11 +239,6 @@ def generate_spearman_training_table(variant: str) -> str:
249239
return "\n".join(lines)
250240

251241

252-
# ---------------------------------------------------------------------------
253-
# Table 5: Denoising iterations MMD values
254-
# ---------------------------------------------------------------------------
255-
256-
257242
def generate_denoising_mmd_table() -> str:
258243
"""Generate table of MMD values per denoising step."""
259244
df = load_results("denoising", "planar", "jsd")
@@ -289,11 +274,6 @@ def generate_denoising_mmd_table() -> str:
289274
return "\n".join(lines)
290275

291276

292-
# ---------------------------------------------------------------------------
293-
# Table 6: Denoising iterations PGS values
294-
# ---------------------------------------------------------------------------
295-
296-
297277
def generate_denoising_pgs_table(variant: str = "jsd") -> str:
298278
"""Generate table of PGS values per denoising step, with optional VUN column."""
299279
df = load_results("denoising", "planar", variant)

reproducibility/05_benchmark/compute.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
from utils.data import get_reference_dataset as _get_ref
2727
from utils.data import load_graphs as _load
2828

29-
# ---------------------------------------------------------------------------
30-
# Paths (resolved before Hydra touches CWD)
31-
# ---------------------------------------------------------------------------
3229
REPO_ROOT = here()
3330
DATA_DIR = REPO_ROOT / "data"
3431
_RESULTS_DIR_BASE = REPO_ROOT / "reproducibility" / "tables" / "results"

reproducibility/05_benchmark/compute_vun.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
app = typer.Typer()
4040

4141

42-
# ---------------------------------------------------------------------------
43-
# Per-pair isomorphism with SIGALRM timeout
44-
# ---------------------------------------------------------------------------
45-
46-
4742
class _TimeoutError(Exception):
4843
pass
4944

@@ -97,11 +92,6 @@ def __contains__(self, g: nx.Graph) -> bool:
9792
return False
9893

9994

100-
# ---------------------------------------------------------------------------
101-
# Parallel novelty worker
102-
# ---------------------------------------------------------------------------
103-
104-
10595
def _check_novel_worker(
10696
gen_graph_json: str,
10797
train_graphs_json: List[str],

0 commit comments

Comments
 (0)