Skip to content

Commit a9b382a

Browse files
committed
Updated MMDInterval to use MetricInterval objects
1 parent 8f2de71 commit a9b382a

4 files changed

Lines changed: 28 additions & 66 deletions

File tree

polygraph/metrics/base/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
FrechetDistance,
44
)
55
from polygraph.metrics.base.mmd import (
6-
MMDInterval,
76
DescriptorMMD2,
87
DescriptorMMD2Interval,
98
MaxDescriptorMMD2,
@@ -27,7 +26,6 @@
2726
"MetricInterval",
2827
"FittedFrechetDistance",
2928
"FrechetDistance",
30-
"MMDInterval",
3129
"DescriptorMMD2",
3230
"DescriptorMMD2Interval",
3331
"MaxDescriptorMMD2",

polygraph/metrics/base/metric_interval.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
class MetricInterval:
6+
"""Class for representing uncertainty quantifications of a metric."""
67
def __init__(
78
self,
89
mean: float,
@@ -46,6 +47,18 @@ def from_samples(
4647

4748
return cls(mean=mean, std=std, low=low, high=high, coverage=coverage)
4849

50+
def __getitem__(self, key: str) -> float:
51+
if key == "mean":
52+
return self.mean
53+
elif key == "std":
54+
return self.std
55+
elif key == "low":
56+
return self.low
57+
elif key == "high":
58+
return self.high
59+
else:
60+
raise ValueError(f"Invalid key: {key}")
61+
4962
def __repr__(self):
5063
if self.coverage is not None:
5164
return f"MetricInterval(mean={self.mean}, std={self.std}, low={self.low}, high={self.high}, coverage={self.coverage})"

polygraph/metrics/base/mmd.py

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,16 @@
4848
from polygraph.utils.kernels import DescriptorKernel, GramBlocks
4949
from polygraph.utils.mmd_utils import mmd_from_gram
5050
from polygraph.metrics.base.interfaces import GenerationMetric, GenerationMetricInterval
51+
from polygraph.metrics.base.metric_interval import MetricInterval
5152

5253
__all__ = [
5354
"DescriptorMMD2",
5455
"MaxDescriptorMMD2",
55-
"MMDInterval",
5656
"DescriptorMMD2Interval",
5757
"MaxDescriptorMMD2Interval",
5858
]
5959

6060

61-
MMDInterval = namedtuple("MMDInterval", ["mean", "std", "low", "high"])
62-
63-
6461
class DescriptorMMD2(GenerationMetric):
6562
"""Computes squared MMD between reference and generated graphs using a kernel.
6663
@@ -201,9 +198,9 @@ def _generate_mmd_samples(
201198
def compute(
202199
*args, **kwargs
203200
) -> Union[
204-
MMDInterval,
201+
MetricInterval,
205202
Dict[str, float],
206-
Tuple[Union[MMDInterval, Dict[str, float]], np.ndarray],
203+
Tuple[Union[MetricInterval, Dict[str, float]], np.ndarray],
207204
]: ...
208205

209206

@@ -225,13 +222,7 @@ def compute(
225222
subsample_size: int,
226223
num_samples: int = 500,
227224
coverage: float = 0.95,
228-
as_scalar_value_dict: bool = False,
229-
return_samples: bool = False,
230-
) -> Union[
231-
MMDInterval,
232-
Dict[str, float],
233-
Tuple[Union[MMDInterval, Dict[str, float]], np.ndarray],
234-
]:
225+
) -> MetricInterval:
235226
"""Computes MMD² confidence intervals through subsampling.
236227
237228
Args:
@@ -248,26 +239,8 @@ def compute(
248239
subsample_size=subsample_size,
249240
num_samples=num_samples,
250241
)
251-
low, high = (
252-
np.quantile(mmd_samples, (1 - coverage) / 2, axis=0),
253-
np.quantile(mmd_samples, coverage + (1 - coverage) / 2, axis=0),
254-
)
255-
avg = np.mean(mmd_samples, axis=0)
256-
std = np.std(mmd_samples, axis=0)
257-
if as_scalar_value_dict:
258-
return_result = {
259-
"mean": avg,
260-
"std": std,
261-
"low": low,
262-
"high": high,
263-
}
264-
else:
265-
return_result = MMDInterval(mean=avg, std=std, low=low, high=high)
266-
267-
if return_samples:
268-
return return_result, mmd_samples
269-
else:
270-
return return_result
242+
assert mmd_samples.ndim == 1
243+
return MetricInterval.from_samples(mmd_samples, coverage=coverage)
271244

272245

273246
class MaxDescriptorMMD2Interval(_DescriptorMMD2Interval, GenerationMetricInterval):
@@ -306,13 +279,7 @@ def compute(
306279
subsample_size: int,
307280
num_samples: int = 500,
308281
coverage: float = 0.95,
309-
as_scalar_value_dict: bool = False,
310-
return_samples: bool = False,
311-
) -> Union[
312-
MMDInterval,
313-
Dict[str, float],
314-
Tuple[Union[MMDInterval, Dict[str, float]], np.ndarray],
315-
]:
282+
) -> MetricInterval:
316283
"""Computes confidence intervals for maximum MMD² through subsampling.
317284
318285
Args:
@@ -331,24 +298,6 @@ def compute(
331298
)
332299
assert mmd_samples.ndim == 2
333300
mmd_samples = np.max(mmd_samples, axis=1)
334-
low, high = (
335-
np.quantile(mmd_samples, (1 - coverage) / 2, axis=0),
336-
np.quantile(mmd_samples, coverage + (1 - coverage) / 2, axis=0),
337-
)
338-
avg = np.mean(mmd_samples, axis=0)
339-
std = np.std(mmd_samples, axis=0)
340-
341-
if as_scalar_value_dict:
342-
return_result = {
343-
"mean": avg,
344-
"std": std,
345-
"low": low,
346-
"high": high,
347-
}
348-
else:
349-
return_result = MMDInterval(mean=avg, std=std, low=low, high=high)
350-
351-
if return_samples:
352-
return return_result, mmd_samples
353-
else:
354-
return return_result
301+
return MetricInterval.from_samples(mmd_samples, coverage=coverage)
302+
303+

tests/test_mmd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
DescriptorMMD2Interval,
1717
MaxDescriptorMMD2,
1818
MaxDescriptorMMD2Interval,
19-
MMDInterval,
2019
)
2120
from polygraph.metrics.gran import (
2221
GRANClusteringMMD2,
@@ -49,6 +48,8 @@
4948
from polygraph.utils.kernels import LinearKernel
5049
from polygraph.utils.graph_descriptors import WeisfeilerLehmanDescriptor
5150
from polygraph.utils.mmd_utils import mmd_from_gram
51+
from polygraph.metrics.base.metric_interval import MetricInterval
52+
5253
import grakel
5354

5455

@@ -191,7 +192,7 @@ def test_mmd_uncertainty(request, datasets, kernel, subsample_size, variant):
191192
kernel = request.getfixturevalue(kernel)
192193
mmd = DescriptorMMD2Interval(sbm, kernel, variant=variant)
193194
result = mmd.compute(planar, subsample_size=subsample_size)
194-
assert isinstance(result, MMDInterval)
195+
assert isinstance(result, MetricInterval)
195196
assert result.std > 0
196197

197198
rng = np.random.default_rng(42)
@@ -243,7 +244,7 @@ def test_concrete_uncertainty(
243244

244245
interval_mmd = interval_cls(planar)
245246
interval = interval_mmd.compute(sbm, subsample_size=subsample_size)
246-
assert isinstance(interval, MMDInterval)
247+
assert isinstance(interval, MetricInterval)
247248

248249
num_in_bounds = 0
249250
num_total = 10
@@ -258,6 +259,7 @@ def test_concrete_uncertainty(
258259

259260
single_mmd = single_cls(planar_samples)
260261
single_estimate = single_mmd.compute(sbm_samples)
262+
assert interval.low <= interval.high
261263
if interval.low <= single_estimate <= interval.high:
262264
num_in_bounds += 1
263265

0 commit comments

Comments
 (0)