Skip to content

Commit 520d2bb

Browse files
committed
Clean up dead code, deduplicate constants, and fix style issues
Remove unused utilities (to_list, mol2smiles, BOND_STEREO_TYPES, MetricInterval.__getitem__), extract shared constants (_DEFAULT_RBF_BANDWIDTHS, _molecule_descriptors, _standard_descriptors), fix f-string bug in polygraphdiscrepancy, use NamedTuple classes over namedtuple calls, modernize super() calls, replace assert False with proper exceptions, use sqeuclidean metric directly, and move tqdm to core dependencies.
1 parent d01682b commit 520d2bb

14 files changed

Lines changed: 89 additions & 149 deletions

File tree

pixi.lock

Lines changed: 6 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

polygraph/datasets/base/caching.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import shutil
44
import urllib.request
5-
from typing import Any, Optional, Sequence
5+
from typing import Optional
66

77
import filelock
88
import torch
@@ -87,10 +87,3 @@ def load_from_cache(
8787
logger.debug(f"Loading data from {path}")
8888
data = torch.load(path, weights_only=True, mmap=mmap)
8989
return GraphStorage(**data)
90-
91-
92-
def to_list(value: Any) -> Sequence:
93-
if isinstance(value, Sequence) and not isinstance(value, str):
94-
return value
95-
else:
96-
return [value]

polygraph/datasets/base/molecules.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,6 @@
2828
Chem.rdchem.BondType.ZERO,
2929
]
3030

31-
BOND_STEREO_TYPES = [
32-
Chem.rdchem.BondStereo.STEREONONE,
33-
Chem.rdchem.BondStereo.STEREOZ,
34-
Chem.rdchem.BondStereo.STEREOE,
35-
Chem.rdchem.BondStereo.STEREOCIS,
36-
Chem.rdchem.BondStereo.STEREOTRANS,
37-
Chem.rdchem.BondStereo.STEREOANY,
38-
Chem.rdchem.BondStereo.STEREOATROPCCW,
39-
Chem.rdchem.BondStereo.STEREOATROPCW,
40-
]
4131

4232
# Generalized atom vocabulary for all molecules
4333
N_UNIQUE_ATOMS = 119
@@ -74,14 +64,6 @@ def are_smiles_equivalent(smiles1, smiles2):
7464
return canonical_smiles1 == canonical_smiles2
7565

7666

77-
def mol2smiles(mol, canonical: bool = False):
78-
try:
79-
Chem.SanitizeMol(mol)
80-
except ValueError as e:
81-
print(e, mol)
82-
return None
83-
return Chem.MolToSmiles(mol, canonical=canonical)
84-
8567

8668
def smiles_with_explicit_hydrogens(smiles: str, canonical: bool = True) -> str:
8769
"""Convert a SMILES string to a SMILES string with all hydrogens made explicit.

polygraph/metrics/base/frechet_distance.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from collections import namedtuple
2-
from typing import Callable, Collection, Generic
1+
from typing import Callable, Collection, Generic, NamedTuple
32

43
import numpy as np
54
import scipy
@@ -10,7 +9,10 @@
109

1110
__all__ = ["FittedFrechetDistance", "FrechetDistance"]
1211

13-
GaussianParameters = namedtuple("GaussianParameters", ["mean", "covariance"])
12+
13+
class GaussianParameters(NamedTuple):
14+
mean: np.ndarray
15+
covariance: np.ndarray
1416

1517

1618
def compute_wasserstein_distance(
@@ -112,7 +114,6 @@ def __init__(
112114
):
113115
self._reference_gaussian = fitted_gaussian
114116
self._descriptor_fn = descriptor_fn
115-
self._dim = None
116117

117118
def compute(self, generated_graphs: Collection[GraphType]) -> float:
118119
"""Computes Frechet distance between reference and generated graphs.

polygraph/metrics/base/metric_interval.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,6 @@ def from_samples(
6464

6565
return cls(mean=mean, std=std, low=low, high=high, coverage=coverage)
6666

67-
def __getitem__(self, key: str) -> Optional[float]:
68-
if key == "mean":
69-
return self.mean
70-
elif key == "std":
71-
return self.std
72-
elif key == "low":
73-
return self.low
74-
elif key == "high":
75-
return self.high
76-
else:
77-
raise ValueError(f"Invalid key: {key}")
78-
7967
def __repr__(self):
8068
if self.coverage is not None:
8169
return f"MetricInterval(mean={self.mean}, std={self.std}, low={self.low}, high={self.high}, coverage={self.coverage})"

polygraph/metrics/base/polygraphdiscrepancy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _scores_to_informedness_and_threshold(
171171
)
172172
if ref_scores.ndim != 1:
173173
raise RuntimeError(
174-
"ref_scores must be 1-dimensional, got shape {ref_scores.shape}. This should not happen, please file a bug report."
174+
f"ref_scores must be 1-dimensional, got shape {ref_scores.shape}. This should not happen, please file a bug report."
175175
)
176176

177177
assert ref_scores.ndim == 1 and gen_scores.ndim == 1
@@ -587,9 +587,9 @@ def compute(
587587
588588
Returns:
589589
Typed dictionary of scores.
590-
The key `"polygraphscore"` specifies the PolyGraphDiscrepancy, giving the estimated tightest lower-bound on the probability metric.
591-
The key `"polygraphscore_descriptor"` specifies the descriptor that achieves this bound.
592-
All descritor-wise scores are returned in the key `"subscores"`.
590+
The key `"pgd"` specifies the PolyGraphDiscrepancy, giving the estimated tightest lower-bound on the probability metric.
591+
The key `"pgd_descriptor"` specifies the descriptor that achieves this bound.
592+
All descriptor-wise scores are returned in the key `"subscores"`.
593593
"""
594594
all_metrics = {
595595
name: metric.compute(generated_graphs)
@@ -665,7 +665,7 @@ def compute(
665665
Typed dictionary of scores.
666666
The key `"pgd"` specifies the PolyGraphDiscrepancy, giving mean and standard deviation as [`MetricInterval`][polygraph.metrics.base.metric_interval.MetricInterval] objects.
667667
The key `"pgd_descriptor"` describes which descriptors achieve this score. This is a dictionary mapping descriptor names to the ratio of samples in which the descriptor was chosen.
668-
All descritor-wise scores are returned in the key `"subscores"`. These are [`MetricInterval`][polygraph.metrics.base.metric_interval.MetricInterval] objects.
668+
All descriptor-wise scores are returned in the key `"subscores"`. These are [`MetricInterval`][polygraph.metrics.base.metric_interval.MetricInterval] objects.
669669
"""
670670
if len(generated_graphs) < 2 * self._subsample_size:
671671
raise ValueError(

polygraph/metrics/molecule_pgd.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@
6363
]
6464

6565

66+
def _molecule_descriptors():
67+
return {
68+
"topochemical": TopoChemicalDescriptor(),
69+
"morgan_fingerprint": FingerprintDescriptor(
70+
algorithm="morgan", dim=128
71+
),
72+
"chemnet": ChemNetDescriptor(dim=128),
73+
"molclr": MolCLRDescriptor(dim=128),
74+
"lipinski": LipinskiDescriptor(),
75+
}
76+
77+
6678
class MoleculePGD(PolyGraphDiscrepancy[rdkit.Chem.Mol]):
6779
"""MoleculePGD to compare molecule distributions, combining different molecule descriptors.
6880
@@ -73,15 +85,7 @@ class MoleculePGD(PolyGraphDiscrepancy[rdkit.Chem.Mol]):
7385
def __init__(self, reference_molecules: Collection[rdkit.Chem.Mol]):
7486
super().__init__(
7587
reference_graphs=reference_molecules,
76-
descriptors={
77-
"topochemical": TopoChemicalDescriptor(),
78-
"morgan_fingerprint": FingerprintDescriptor(
79-
algorithm="morgan", dim=128
80-
),
81-
"chemnet": ChemNetDescriptor(dim=128),
82-
"molclr": MolCLRDescriptor(dim=128),
83-
"lipinski": LipinskiDescriptor(),
84-
},
88+
descriptors=_molecule_descriptors(),
8589
variant="jsd",
8690
classifier=None,
8791
)
@@ -106,15 +110,7 @@ def __init__(
106110
):
107111
super().__init__(
108112
reference_graphs=reference_molecules,
109-
descriptors={
110-
"topochemical": TopoChemicalDescriptor(),
111-
"morgan_fingerprint": FingerprintDescriptor(
112-
algorithm="morgan", dim=128
113-
),
114-
"chemnet": ChemNetDescriptor(dim=128),
115-
"molclr": MolCLRDescriptor(dim=128),
116-
"lipinski": LipinskiDescriptor(),
117-
},
113+
descriptors=_molecule_descriptors(),
118114
subsample_size=subsample_size,
119115
num_samples=num_samples,
120116
variant="jsd",

polygraph/metrics/rbf_mmd.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
from polygraph.utils.kernels import AdaptiveRBFKernel
5858
from polygraph.metrics.base import MetricCollection
5959

60+
_DEFAULT_RBF_BANDWIDTHS = np.array(
61+
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
62+
)
63+
6064
__all__ = [
6165
"RBFMMD2Benchmark",
6266
"RBFMMD2BenchmarkInterval",
@@ -139,9 +143,7 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
139143
reference_graphs=reference_graphs,
140144
kernel=AdaptiveRBFKernel(
141145
descriptor_fn=OrbitCounts(),
142-
bw=np.array(
143-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
144-
),
146+
bw=_DEFAULT_RBF_BANDWIDTHS,
145147
),
146148
variant="biased",
147149
)
@@ -159,9 +161,7 @@ def __init__(
159161
reference_graphs=reference_graphs,
160162
kernel=AdaptiveRBFKernel(
161163
descriptor_fn=OrbitCounts(),
162-
bw=np.array(
163-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
164-
),
164+
bw=_DEFAULT_RBF_BANDWIDTHS,
165165
),
166166
subsample_size=subsample_size,
167167
num_samples=num_samples,
@@ -176,9 +176,7 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
176176
reference_graphs=reference_graphs,
177177
kernel=AdaptiveRBFKernel(
178178
descriptor_fn=ClusteringHistogram(bins=100),
179-
bw=np.array(
180-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
181-
),
179+
bw=_DEFAULT_RBF_BANDWIDTHS,
182180
),
183181
variant="biased",
184182
)
@@ -196,9 +194,7 @@ def __init__(
196194
reference_graphs=reference_graphs,
197195
kernel=AdaptiveRBFKernel(
198196
descriptor_fn=ClusteringHistogram(bins=100),
199-
bw=np.array(
200-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
201-
),
197+
bw=_DEFAULT_RBF_BANDWIDTHS,
202198
),
203199
subsample_size=subsample_size,
204200
num_samples=num_samples,
@@ -213,9 +209,7 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
213209
reference_graphs=reference_graphs,
214210
kernel=AdaptiveRBFKernel(
215211
descriptor_fn=SparseDegreeHistogram(),
216-
bw=np.array(
217-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
218-
),
212+
bw=_DEFAULT_RBF_BANDWIDTHS,
219213
),
220214
variant="biased",
221215
)
@@ -233,9 +227,7 @@ def __init__(
233227
reference_graphs=reference_graphs,
234228
kernel=AdaptiveRBFKernel(
235229
descriptor_fn=SparseDegreeHistogram(),
236-
bw=np.array(
237-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
238-
),
230+
bw=_DEFAULT_RBF_BANDWIDTHS,
239231
),
240232
subsample_size=subsample_size,
241233
num_samples=num_samples,
@@ -250,9 +242,7 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
250242
reference_graphs=reference_graphs,
251243
kernel=AdaptiveRBFKernel(
252244
descriptor_fn=EigenvalueHistogram(),
253-
bw=np.array(
254-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
255-
),
245+
bw=_DEFAULT_RBF_BANDWIDTHS,
256246
),
257247
variant="biased",
258248
)
@@ -270,9 +260,7 @@ def __init__(
270260
reference_graphs=reference_graphs,
271261
kernel=AdaptiveRBFKernel(
272262
descriptor_fn=EigenvalueHistogram(),
273-
bw=np.array(
274-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
275-
),
263+
bw=_DEFAULT_RBF_BANDWIDTHS,
276264
),
277265
subsample_size=subsample_size,
278266
num_samples=num_samples,
@@ -304,9 +292,7 @@ def __init__(
304292
),
305293
reference_graphs,
306294
),
307-
bw=np.array(
308-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
309-
),
295+
bw=_DEFAULT_RBF_BANDWIDTHS,
310296
),
311297
variant="biased",
312298
)
@@ -338,9 +324,7 @@ def __init__(
338324
),
339325
reference_graphs,
340326
),
341-
bw=np.array(
342-
[0.01, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
343-
),
327+
bw=_DEFAULT_RBF_BANDWIDTHS,
344328
),
345329
subsample_size=subsample_size,
346330
num_samples=num_samples,

0 commit comments

Comments
 (0)