Skip to content

Commit b09b6d4

Browse files
committed
Renamed RBF and TV benchmarks
1 parent b218ac3 commit b09b6d4

8 files changed

Lines changed: 45 additions & 45 deletions

File tree

docs/metrics/gaussian_tv_mmd.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
## Summary Benchmark
1313

14-
::: polygraph.metrics.MMD2CollectionGaussianTV
14+
::: polygraph.metrics.GaussianTVMMD2Benchmark
1515
options:
1616
show_root_heading: true
1717
show_full_path: true
1818
show_source: false
1919
heading_level: 3
2020

21-
::: polygraph.metrics.MMD2IntervalCollectionGaussianTV
21+
::: polygraph.metrics.GaussianTVMMD2BenchmarkInterval
2222
options:
2323
show_root_heading: true
2424
show_full_path: true

docs/metrics/rbf_mmd.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
## Summary Benchmark
1313

14-
::: polygraph.metrics.MMD2CollectionRBF
14+
::: polygraph.metrics.RBFMMD2Benchmark
1515
options:
1616
show_root_heading: true
1717
show_full_path: true
1818
show_source: false
1919
heading_level: 3
2020

21-
::: polygraph.metrics.MMD2IntervalCollectionRBF
21+
::: polygraph.metrics.RBFMMD2BenchmarkInterval
2222
options:
2323
show_root_heading: true
2424
show_full_path: true

docs/tutorials/basic_usage.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,22 @@ print(planar_nx[0]) # (Networkx) Graph with 64 nodes and 177 edges
3030
When evaluating graph generative models, we want to compare a set of *generated* graphs to a set of *reference* graphs (typically the test set).
3131
In `polygraph`, we provide various different metrics to quantify how similar these two sets of graphs are.
3232
We usually pass collections of NetworkX graphs to metrics.
33-
Below, we demonstrate how a set of these metrics, combined in the [`MMD2CollectionGaussianTV`][polygraph.metrics.MMD2CollectionGaussianTV] benchmark may be computed:
33+
Below, we demonstrate how a set of these metrics, combined in the [`GaussianTVMMD2Benchmark`][polygraph.metrics.GaussianTVMMD2Benchmark] benchmark may be computed:
3434

3535
```python
36-
from polygraph.metrics import MMD2CollectionGaussianTV
36+
from polygraph.metrics import GaussianTVMMD2Benchmark
3737

3838
reference = planar.to_nx()
3939
generated = sbm.to_nx()
4040

41-
benchmark = MMD2CollectionGaussianTV(reference)
41+
benchmark = GaussianTVMMD2Benchmark(reference)
4242
print(benchmark.compute(generated)) # Dictionary of different metrics
4343
```
4444

4545
We discuss available metrics [in the next tutorial](metrics_overview.md).
4646

4747
All metrics are evaluated in a similar fashion, as defined by the common [interface](../api_reference/metrics/interface.md):
4848

49-
- We first initialize a metric object via `benchmark = MMD2CollectionGaussianTV(reference)`. This fits the metric to the `reference` set, caching data that is required in later computations
49+
- We first initialize a metric object via `benchmark = GaussianTVMMD2Benchmark(reference)`. This fits the metric to the `reference` set, caching data that is required in later computations
5050
- We then compute the metric against the generated set via `benchmark.compute(generated)`
5151
- We may call `benchmark.compute` repeatedly with different generated sets, e.g. over the course of training

docs/tutorials/metrics_overview.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ For convenience, `polygraph` allows metrics that follow this interface to be bun
99

1010
```python
1111
from polygraph.metrics import MetricCollection
12-
from polygraph.metrics import MMD2CollectionRBF, MMD2CollectionGaussianTV
12+
from polygraph.metrics import RBFMMD2Benchmark, GaussianTVMMD2Benchmark
1313
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
1414

1515
reference_graphs = PlanarGraphDataset("val").to_nx()
1616
generated_graphs = SBMGraphDataset("val").to_nx()
1717

1818
metrics = MetricCollection(
1919
metrics={
20-
"rbf_mmd": MMD2CollectionRBF(reference_graphs),
21-
"tv_mmd": MMD2CollectionGaussianTV(reference_graphs),
20+
"rbf_mmd": RBFMMD2Benchmark(reference_graphs),
21+
"tv_mmd": GaussianTVMMD2Benchmark(reference_graphs),
2222
}
2323
)
2424
print(metrics.compute(generated_graphs)) # Dictionary of metrics
@@ -31,17 +31,17 @@ We now proceed to give a high-level overview over the different types of metrics
3131
[Maximum Mean Discrepancy (MMD)](../api_reference/metrics/mmd.md) is the predominant method for comparing graph distributions.
3232
The two distributions are embedded in a reproducing kernel Hilbert space (RKHS) and their distance is then computed in this space.
3333

34-
In `polygraph`, we bundle the most commonly used MMD metrics in two benchmark classes: [`MMD2CollectionGaussianTV`][polygraph.metrics.MMD2CollectionGaussianTV] and [`MMD2CollectionRBF`][polygraph.metrics.MMD2CollectionRBF]. These benchmarks may be evaluated in the following fashion:
34+
In `polygraph`, we bundle the most commonly used MMD metrics in two benchmark classes: [`GaussianTVMMD2Benchmark`][polygraph.metrics.GaussianTVMMD2Benchmark] and [`RBFMMD2Benchmark`][polygraph.metrics.RBFMMD2Benchmark]. These benchmarks may be evaluated in the following fashion:
3535

3636
```python
3737
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
38-
from polygraph.metrics import MMD2CollectionGaussianTV, MMD2IntervalCollectionGaussianTV
38+
from polygraph.metrics import GaussianTVMMD2Benchmark, GaussianTVMMD2BenchmarkInterval
3939

4040
reference = PlanarGraphDataset("val").to_nx()
4141
generated = SBMGraphDataset("val").to_nx()
4242

4343
# Evaluate the benchmark with point estimates
44-
benchmark = MMD2CollectionGaussianTV(reference)
44+
benchmark = GaussianTVMMD2Benchmark(reference)
4545
print(benchmark.compute(generated)) # {'orbit': 1.067700488335175, 'clustering': 0.32549637224264394, 'degree': 0.3375409762261701, 'spectral': 0.0830197437100697}
4646
```
4747

polygraph/metrics/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from .base import MetricCollection
22
from .polygraphscore import PGS5, PGS5Interval
33
from .gaussian_tv_mmd import (
4-
MMD2CollectionGaussianTV,
5-
MMD2IntervalCollectionGaussianTV,
4+
GaussianTVMMD2Benchmark,
5+
GaussianTVMMD2BenchmarkInterval,
66
)
7-
from .rbf_mmd import MMD2CollectionRBF, MMD2IntervalCollectionRBF
7+
from .rbf_mmd import RBFMMD2Benchmark, RBFMMD2BenchmarkInterval
88
from .vun import VUN
99

1010
__all__ = [
1111
"VUN",
1212
"MetricCollection",
1313
"PGS5",
1414
"PGS5Interval",
15-
"MMD2CollectionGaussianTV",
16-
"MMD2IntervalCollectionGaussianTV",
17-
"MMD2CollectionRBF",
18-
"MMD2IntervalCollectionRBF",
15+
"GaussianTVMMD2Benchmark",
16+
"GaussianTVMMD2BenchmarkInterval",
17+
"RBFMMD2Benchmark",
18+
"RBFMMD2BenchmarkInterval",
1919
]

polygraph/metrics/gaussian_tv_mmd.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,23 @@
1717
1818
1919
Below, we demonstrate how to evaluate all metrics in the benchmark with point estimates and with uncertainty quantification.
20-
Note that the parameter `subsample_size` in [`MMD2IntervalCollectionGaussianTV`][polygraph.metrics.MMD2IntervalCollectionGaussianTV]
21-
should match the number of generated and reference graphs in [`MMD2CollectionGaussianTV`][polygraph.metrics.MMD2CollectionGaussianTV]
20+
Note that the parameter `subsample_size` in [`GaussianTVMMD2BenchmarkInterval`][polygraph.metrics.GaussianTVMMD2BenchmarkInterval]
21+
should match the number of generated and reference graphs in [`GaussianTVMMD2Benchmark`][polygraph.metrics.GaussianTVMMD2Benchmark]
2222
to obtain comparable results:
2323
2424
```python
2525
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
26-
from polygraph.metrics import MMD2CollectionGaussianTV, MMD2IntervalCollectionGaussianTV
26+
from polygraph.metrics import GaussianTVMMD2Benchmark, GaussianTVMMD2BenchmarkInterval
2727
2828
reference = list(PlanarGraphDataset("val").to_nx())
2929
generated = list(SBMGraphDataset("val").to_nx())
3030
3131
# Evaluate the benchmark with point estimates
32-
benchmark = MMD2CollectionGaussianTV(reference[:20])
32+
benchmark = GaussianTVMMD2Benchmark(reference[:20])
3333
print(benchmark.compute(generated[:20]))
3434
3535
# Evaluate the benchmark with uncertainty quantification
36-
benchmark_with_uncertainty = MMD2IntervalCollectionGaussianTV(
36+
benchmark_with_uncertainty = GaussianTVMMD2BenchmarkInterval(
3737
reference,
3838
subsample_size=20,
3939
num_samples=100,
@@ -71,8 +71,8 @@
7171

7272

7373
__all__ = [
74-
"MMD2CollectionGaussianTV",
75-
"MMD2IntervalCollectionGaussianTV",
74+
"GaussianTVMMD2Benchmark",
75+
"GaussianTVMMD2BenchmarkInterval",
7676
"GaussianTVOrbitMMD2",
7777
"GaussianTVOrbitMMD2Interval",
7878
"GaussianTVClusteringMMD2",
@@ -84,7 +84,7 @@
8484
]
8585

8686

87-
class MMD2CollectionGaussianTV(MetricCollection):
87+
class GaussianTVMMD2Benchmark(MetricCollection):
8888
"""Collection of MMD2 metrics using the Gaussian TV kernel.
8989
9090
This graphs combines the following graph descriptors into one benchmark:
@@ -109,10 +109,10 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
109109
)
110110

111111

112-
class MMD2IntervalCollectionGaussianTV(MetricCollection):
112+
class GaussianTVMMD2BenchmarkInterval(MetricCollection):
113113
"""Collection of MMD2 metrics using the Gaussian TV kernel with uncertainty quantification.
114114
115-
This class provides the same metrics as [`MMD2CollectionGaussianTV`][polygraph.metrics.MMD2CollectionGaussianTV] but with uncertainty quantification.
115+
This class provides the same metrics as [`GaussianTVMMD2Benchmark`][polygraph.metrics.GaussianTVMMD2Benchmark] but with uncertainty quantification.
116116
117117
Args:
118118
reference_graphs: Collection of reference graphs to fit the metric to.

polygraph/metrics/rbf_mmd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
1414
```python
1515
from polygraph.datasets import PlanarGraphDataset, SBMGraphDataset
16-
from polygraph.metrics import MMD2CollectionRBF, MMD2IntervalCollectionRBF
16+
from polygraph.metrics import RBFMMD2Benchmark, RBFMMD2BenchmarkInterval
1717
1818
reference = list(PlanarGraphDataset("val").to_nx())
1919
generated = list(SBMGraphDataset("val").to_nx())
2020
2121
# Evaluate the benchmark with point estimates
22-
benchmark = MMD2CollectionRBF(reference[:20])
22+
benchmark = RBFMMD2Benchmark(reference[:20])
2323
print(benchmark.compute(generated[:20]))
2424
2525
# Evaluate the benchmark with uncertainty quantification
26-
benchmark_with_uncertainty = MMD2IntervalCollectionRBF(
26+
benchmark_with_uncertainty = RBFMMD2BenchmarkInterval(
2727
reference,
2828
subsample_size=20,
2929
num_samples=100,
@@ -58,8 +58,8 @@
5858
from polygraph.metrics.base import MetricCollection
5959

6060
__all__ = [
61-
"MMD2CollectionRBF",
62-
"MMD2IntervalCollectionRBF",
61+
"RBFMMD2Benchmark",
62+
"RBFMMD2BenchmarkInterval",
6363
"RBFOrbitMMD2",
6464
"RBFOrbitMMD2Interval",
6565
"RBFClusteringMMD2",
@@ -73,7 +73,7 @@
7373
]
7474

7575

76-
class MMD2CollectionRBF(MetricCollection):
76+
class RBFMMD2Benchmark(MetricCollection):
7777
"""Collection of MMD2 metrics using RBF kernels with dynamic bandwidths."""
7878

7979
def __init__(self, reference_graphs: Collection[nx.Graph]):
@@ -88,7 +88,7 @@ def __init__(self, reference_graphs: Collection[nx.Graph]):
8888
)
8989

9090

91-
class MMD2IntervalCollectionRBF(MetricCollection):
91+
class RBFMMD2BenchmarkInterval(MetricCollection):
9292
"""Collection of MMD2 metrics using RBF kernels with dynamic bandwidths and uncertainty quantification."""
9393

9494
def __init__(

tests/test_mmd.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
RBFGraphNeuralNetworkMMD2,
4040
)
4141
from polygraph.metrics import (
42-
MMD2CollectionGaussianTV,
43-
MMD2IntervalCollectionGaussianTV,
42+
GaussianTVMMD2Benchmark,
43+
GaussianTVMMD2BenchmarkInterval,
4444
)
45-
from polygraph.metrics import MMD2CollectionRBF, MMD2IntervalCollectionRBF
45+
from polygraph.metrics import RBFMMD2Benchmark, RBFMMD2BenchmarkInterval
4646
from polygraph.utils.kernels import LinearKernel
4747
from polygraph.utils.graph_descriptors import WeisfeilerLehmanDescriptor
4848
from polygraph.utils.mmd_utils import mmd_from_gram
@@ -332,15 +332,15 @@ def test_mmd_collections(datasets, variant):
332332
"spectral": RBFSpectralMMD2(planar),
333333
"gin": RBFGraphNeuralNetworkMMD2(planar),
334334
}
335-
benchmark = MMD2CollectionRBF(planar)
335+
benchmark = RBFMMD2Benchmark(planar)
336336
elif variant == "gaussian_tv":
337337
separate_metrics = {
338338
"orbit": GaussianTVOrbitMMD2(planar),
339339
"clustering": GaussianTVClusteringMMD2(planar),
340340
"degree": GaussianTVDegreeMMD2(planar),
341341
"spectral": GaussianTVSpectralMMD2(planar),
342342
}
343-
benchmark = MMD2CollectionGaussianTV(planar)
343+
benchmark = GaussianTVMMD2Benchmark(planar)
344344
else:
345345
raise ValueError(f"Invalid variant: {variant}")
346346

@@ -357,9 +357,9 @@ def test_mmd_collections(datasets, variant):
357357
)
358358

359359
if variant == "rbf":
360-
metric = MMD2IntervalCollectionRBF(planar, subsample_size=16)
360+
metric = RBFMMD2BenchmarkInterval(planar, subsample_size=16)
361361
elif variant == "gaussian_tv":
362-
metric = MMD2IntervalCollectionGaussianTV(planar, subsample_size=16)
362+
metric = GaussianTVMMD2BenchmarkInterval(planar, subsample_size=16)
363363
else:
364364
raise ValueError(f"Invalid variant: {variant}")
365365

0 commit comments

Comments
 (0)