3939"""
4040
4141from abc import ABC , abstractmethod
42- from collections import namedtuple
43- from typing import Collection , Dict , Literal , Tuple , Union
42+ from typing import Any , Collection , Literal , Union
4443
4544import networkx as nx
4645import numpy as np
@@ -67,6 +66,8 @@ class DescriptorMMD2(GenerationMetric):
6766 variant: Which MMD estimator to use ('biased', 'umve', or 'ustat')
6867 """
6968
69+ _variant : Literal ["biased" , "umve" , "ustat" ]
70+
7071 def __init__ (
7172 self ,
7273 reference_graphs : Collection [nx .Graph ],
@@ -136,13 +137,16 @@ def compute(self, generated_graphs: Collection[nx.Graph]) -> float:
136137 Maximum MMD² value across kernel parameters
137138 """
138139 multi_kernel_result = super ().compute (generated_graphs )
140+ assert isinstance (multi_kernel_result , np .ndarray )
139141 idx = int (np .argmax (multi_kernel_result ))
140142 return multi_kernel_result [idx ]
141143
142144
143145class _DescriptorMMD2Interval (ABC ):
144146 """Base class for computing MMD² confidence intervals through subsampling."""
145147
148+ _variant : Literal ["biased" , "umve" , "ustat" ]
149+
146150 def __init__ (
147151 self ,
148152 reference_graphs : Collection [nx .Graph ],
0 commit comments