4848from polygraph .utils .kernels import DescriptorKernel , GramBlocks
4949from polygraph .utils .mmd_utils import mmd_from_gram
5050from polygraph .metrics .base .interfaces import GenerationMetric , GenerationMetricInterval
51+ from polygraph .metrics .base .metric_interval import MetricInterval
5152
5253__all__ = [
5354 "DescriptorMMD2" ,
5455 "MaxDescriptorMMD2" ,
55- "MMDInterval" ,
5656 "DescriptorMMD2Interval" ,
5757 "MaxDescriptorMMD2Interval" ,
5858]
5959
6060
61- MMDInterval = namedtuple ("MMDInterval" , ["mean" , "std" , "low" , "high" ])
62-
63-
6461class DescriptorMMD2 (GenerationMetric ):
6562 """Computes squared MMD between reference and generated graphs using a kernel.
6663
@@ -201,9 +198,9 @@ def _generate_mmd_samples(
201198 def compute (
202199 * args , ** kwargs
203200 ) -> Union [
204- MMDInterval ,
201+ MetricInterval ,
205202 Dict [str , float ],
206- Tuple [Union [MMDInterval , Dict [str , float ]], np .ndarray ],
203+ Tuple [Union [MetricInterval , Dict [str , float ]], np .ndarray ],
207204 ]: ...
208205
209206
@@ -225,13 +222,7 @@ def compute(
225222 subsample_size : int ,
226223 num_samples : int = 500 ,
227224 coverage : float = 0.95 ,
228- as_scalar_value_dict : bool = False ,
229- return_samples : bool = False ,
230- ) -> Union [
231- MMDInterval ,
232- Dict [str , float ],
233- Tuple [Union [MMDInterval , Dict [str , float ]], np .ndarray ],
234- ]:
225+ ) -> MetricInterval :
235226 """Computes MMD² confidence intervals through subsampling.
236227
237228 Args:
@@ -248,26 +239,8 @@ def compute(
248239 subsample_size = subsample_size ,
249240 num_samples = num_samples ,
250241 )
251- low , high = (
252- np .quantile (mmd_samples , (1 - coverage ) / 2 , axis = 0 ),
253- np .quantile (mmd_samples , coverage + (1 - coverage ) / 2 , axis = 0 ),
254- )
255- avg = np .mean (mmd_samples , axis = 0 )
256- std = np .std (mmd_samples , axis = 0 )
257- if as_scalar_value_dict :
258- return_result = {
259- "mean" : avg ,
260- "std" : std ,
261- "low" : low ,
262- "high" : high ,
263- }
264- else :
265- return_result = MMDInterval (mean = avg , std = std , low = low , high = high )
266-
267- if return_samples :
268- return return_result , mmd_samples
269- else :
270- return return_result
242+ assert mmd_samples .ndim == 1
243+ return MetricInterval .from_samples (mmd_samples , coverage = coverage )
271244
272245
273246class MaxDescriptorMMD2Interval (_DescriptorMMD2Interval , GenerationMetricInterval ):
@@ -306,13 +279,7 @@ def compute(
306279 subsample_size : int ,
307280 num_samples : int = 500 ,
308281 coverage : float = 0.95 ,
309- as_scalar_value_dict : bool = False ,
310- return_samples : bool = False ,
311- ) -> Union [
312- MMDInterval ,
313- Dict [str , float ],
314- Tuple [Union [MMDInterval , Dict [str , float ]], np .ndarray ],
315- ]:
282+ ) -> MetricInterval :
316283 """Computes confidence intervals for maximum MMD² through subsampling.
317284
318285 Args:
@@ -331,24 +298,6 @@ def compute(
331298 )
332299 assert mmd_samples .ndim == 2
333300 mmd_samples = np .max (mmd_samples , axis = 1 )
334- low , high = (
335- np .quantile (mmd_samples , (1 - coverage ) / 2 , axis = 0 ),
336- np .quantile (mmd_samples , coverage + (1 - coverage ) / 2 , axis = 0 ),
337- )
338- avg = np .mean (mmd_samples , axis = 0 )
339- std = np .std (mmd_samples , axis = 0 )
340-
341- if as_scalar_value_dict :
342- return_result = {
343- "mean" : avg ,
344- "std" : std ,
345- "low" : low ,
346- "high" : high ,
347- }
348- else :
349- return_result = MMDInterval (mean = avg , std = std , low = low , high = high )
350-
351- if return_samples :
352- return return_result , mmd_samples
353- else :
354- return return_result
301+ return MetricInterval .from_samples (mmd_samples , coverage = coverage )
302+
303+
0 commit comments