Skip to content

Commit 58d0385

Browse files
committed
Deduplicate TabPFN factory and VUN logic, add sparse eigenvalue path
- H6: 05_benchmark/compute_vun.py now imports compute_vun_parallel from utils.vun instead of duplicating ~120 lines of VUN logic - L5: Extracted make_tabpfn_classifier to utils/data.py, replaced 6 local copies across compute scripts - P4: EigenvalueHistogram uses scipy.sparse.linalg.eigsh for graphs with >500 nodes, avoiding dense conversion of large Laplacians
1 parent 4a36fa8 commit 58d0385

9 files changed

Lines changed: 58 additions & 284 deletions

File tree

polygraph/utils/descriptors/generic_descriptors.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import orbit_count
1919
import torch
2020
from scipy.sparse import csgraph, csr_array
21+
from scipy.sparse.linalg import eigsh
2122
from sklearn.decomposition import TruncatedSVD
2223
from sklearn.preprocessing import StandardScaler
2324
from torch_geometric.data import Batch
@@ -225,14 +226,24 @@ def __init__(self, n_bins: int = 200, sparse: bool = False):
225226
else:
226227
self._bins = None
227228

229+
_SPARSE_THRESHOLD = 500
230+
228231
def __call__(
229232
self, graphs: Iterable[nx.Graph]
230233
) -> Union[np.ndarray, csr_array]:
231234
all_eigs = []
232235
for g in graphs:
233-
eigs = np.linalg.eigvalsh(
234-
nx.normalized_laplacian_matrix(g).todense()
235-
)
236+
n = g.number_of_nodes()
237+
laplacian = nx.normalized_laplacian_matrix(g)
238+
if n > self._SPARSE_THRESHOLD:
239+
k = min(n - 2, self._n_bins)
240+
eigs = eigsh(
241+
laplacian.astype(np.float64),
242+
k=k,
243+
return_eigenvectors=False,
244+
)
245+
else:
246+
eigs = np.linalg.eigvalsh(laplacian.todense())
236247
all_eigs.append(eigs)
237248

238249
if self._sparse:

reproducibility/01_subsampling/compute_pgd.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import json
1717
import pickle
18+
import sys
1819
import time
1920
from importlib.metadata import version as pkg_version
2021
from pathlib import Path
@@ -32,25 +33,8 @@
3233
maybe_append_jsonl,
3334
)
3435

35-
36-
def _make_tabpfn_classifier(weights_version: str):
37-
"""Create a TabPFN classifier for the given weights version."""
38-
from tabpfn import TabPFNClassifier
39-
from tabpfn.classifier import ModelVersion
40-
41-
version_map = {
42-
"v2": ModelVersion.V2,
43-
"v2.5": ModelVersion.V2_5,
44-
}
45-
if weights_version not in version_map:
46-
raise ValueError(
47-
f"Unknown weights_version: {weights_version!r}. Must be one of {list(version_map)}"
48-
)
49-
return TabPFNClassifier.create_default_for_version(
50-
version_map[weights_version],
51-
device="auto",
52-
n_estimators=4,
53-
)
36+
sys.path.insert(0, str(here() / "reproducibility"))
37+
from utils.data import make_tabpfn_classifier
5438

5539

5640
REPO_ROOT = here()
@@ -122,7 +106,7 @@ def main(cfg: DictConfig) -> None:
122106
model: str = cfg.model
123107
subsample_size: int = cfg.subsample_size
124108
num_bootstrap: int = 3 if cfg.subset else cfg.num_bootstrap
125-
classifier = _make_tabpfn_classifier(tabpfn_weights_version)
109+
classifier = make_tabpfn_classifier(tabpfn_weights_version)
126110

127111
logger.info(
128112
"PGD subsampling: dataset={}, model={}, n={}, bootstraps={}",

reproducibility/02_perturbation/compute.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import gc
2727
import json
2828
import random
29+
import sys
2930
from importlib.metadata import version as pkg_version
3031
from itertools import product
3132
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, cast
@@ -43,6 +44,9 @@
4344
maybe_append_jsonl,
4445
)
4546

47+
sys.path.insert(0, str(here() / "reproducibility"))
48+
from utils.data import make_tabpfn_classifier
49+
4650
from polygraph.datasets.ego import EgoGraphDataset
4751
from polygraph.datasets.lobster import ProceduralLobsterGraphDataset
4852
from polygraph.datasets.planar import ProceduralPlanarGraphDataset
@@ -328,22 +332,7 @@ def load_dataset(
328332
def _make_classifier(name: str, tabpfn_weights_version: str = "v2.5"):
329333
"""Build a classifier by name. For TabPFN, respects weights version."""
330334
if name == "tabpfn":
331-
from tabpfn import TabPFNClassifier
332-
from tabpfn.classifier import ModelVersion
333-
334-
version_map = {
335-
"v2": ModelVersion.V2,
336-
"v2.5": ModelVersion.V2_5,
337-
}
338-
if tabpfn_weights_version not in version_map:
339-
raise ValueError(
340-
f"Unknown tabpfn_weights_version: {tabpfn_weights_version!r}. Must be one of {list(version_map)}"
341-
)
342-
return TabPFNClassifier.create_default_for_version(
343-
version_map[tabpfn_weights_version],
344-
device="auto",
345-
n_estimators=4,
346-
)
335+
return make_tabpfn_classifier(tabpfn_weights_version)
347336
elif name == "lr":
348337
return LogisticRegression(max_iter=1000)
349338
else:

reproducibility/03_model_quality/compute.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import json
1515
import pickle
16+
import sys
1617
from importlib.metadata import version as pkg_version
1718
from pathlib import Path
1819
from typing import Any, List, Literal, cast
@@ -28,6 +29,9 @@
2829
maybe_append_jsonl,
2930
)
3031

32+
sys.path.insert(0, str(here() / "reproducibility"))
33+
from utils.data import make_tabpfn_classifier
34+
3135
REPO_ROOT = here()
3236
DATA_DIR = REPO_ROOT / "data"
3337
_RESULTS_DIR_BASE = (
@@ -192,22 +196,7 @@ def _parse_steps(p: Path) -> int:
192196
if subset:
193197
ref = ref[:30]
194198

195-
from tabpfn import TabPFNClassifier
196-
from tabpfn.classifier import ModelVersion
197-
198-
version_map = {
199-
"v2": ModelVersion.V2,
200-
"v2.5": ModelVersion.V2_5,
201-
}
202-
if tabpfn_weights_version not in version_map:
203-
raise ValueError(
204-
f"Unknown tabpfn_weights_version: {tabpfn_weights_version!r}. Must be one of {list(version_map)}"
205-
)
206-
classifier = TabPFNClassifier.create_default_for_version(
207-
version_map[tabpfn_weights_version],
208-
device="auto",
209-
n_estimators=4,
210-
)
199+
classifier = make_tabpfn_classifier(tabpfn_weights_version)
211200

212201
pgd_metric = StandardPGD(
213202
reference_graphs=ref,

reproducibility/05_benchmark/compute.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
sys.path.insert(0, str(here() / "reproducibility"))
2626
from utils.data import get_reference_dataset as _get_ref
2727
from utils.data import load_graphs as _load
28+
from utils.data import make_tabpfn_classifier
2829

2930
REPO_ROOT = here()
3031
DATA_DIR = REPO_ROOT / "data"
@@ -129,22 +130,7 @@ def main(cfg: DictConfig) -> None:
129130
subset = cfg.subset
130131
skip_vun = cfg.get("skip_vun", False)
131132

132-
from tabpfn import TabPFNClassifier
133-
from tabpfn.classifier import ModelVersion
134-
135-
version_map = {
136-
"v2": ModelVersion.V2,
137-
"v2.5": ModelVersion.V2_5,
138-
}
139-
if tabpfn_weights_version not in version_map:
140-
raise ValueError(
141-
f"Unknown tabpfn_weights_version: {tabpfn_weights_version!r}. Must be one of {list(version_map)}"
142-
)
143-
classifier = TabPFNClassifier.create_default_for_version(
144-
version_map[tabpfn_weights_version],
145-
device="auto",
146-
n_estimators=4,
147-
)
133+
classifier = make_tabpfn_classifier(tabpfn_weights_version)
148134

149135
logger.info("Computing benchmark for {}/{}", model, dataset)
150136

0 commit comments

Comments
 (0)