Skip to content

Commit 41c6d57

Browse files
committed
Updated type hints
1 parent dbefab1 commit 41c6d57

15 files changed

Lines changed: 77 additions & 72 deletions

File tree

polygraph/datasets/base/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def url_for_split(self, split: str) -> str:
370370
...
371371

372372
@abstractmethod
373-
def hash_for_split(self, split: str) -> str:
373+
def hash_for_split(self, split: str) -> Optional[str]:
374374
"""Gets the expected hash for a specific split's data.
375375
376376
This hash is used to validate downloaded data.

polygraph/datasets/ego.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
import networkx as nx
23

34
from polygraph.datasets.base import OnlineGraphDataset
@@ -42,7 +43,7 @@ def url_for_split(self, split: str):
4243
def is_valid(graph: nx.Graph):
4344
return graph.number_of_nodes() > 0 and graph.number_of_edges() > 0
4445

45-
def hash_for_split(self, split: str) -> str:
46+
def hash_for_split(self, split: str) -> Optional[str]:
4647
return self._HASH_FOR_SPLIT[split]
4748

4849

@@ -78,5 +79,5 @@ def url_for_split(self, split: str):
7879
def is_valid(graph: nx.Graph):
7980
return graph.number_of_nodes() > 0 and graph.number_of_edges() > 0
8081

81-
def hash_for_split(self, split: str) -> str:
82+
def hash_for_split(self, split: str) -> Optional[str]:
8283
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/lobster.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ def is_lobster_graph(graph: nx.Graph) -> bool:
1616
"""Based on https://github.com/lrjconan/GRAN/blob/fc9c04a3f002c55acf892f864c03c6040947bc6b/utils/eval_helper.py#L426C3-L446C17"""
1717
graph = deepcopy(graph)
1818
if nx.is_tree(graph):
19-
leaves = [n for n, d in graph.degree() if d == 1]
19+
leaves = [n for n, d in graph.degree() if d == 1] # pyright: ignore
2020
graph.remove_nodes_from(leaves)
2121

22-
leaves = [n for n, d in graph.degree() if d == 1]
22+
leaves = [n for n, d in graph.degree() if d == 1] # pyright: ignore
2323
graph.remove_nodes_from(leaves)
2424

2525
num_nodes = len(graph.nodes())
26-
num_degree_one = [d for n, d in graph.degree() if d == 1]
27-
num_degree_two = [d for n, d in graph.degree() if d == 2]
26+
num_degree_one = [d for n, d in graph.degree() if d == 1] # pyright: ignore
27+
num_degree_two = [d for n, d in graph.degree() if d == 2] # pyright: ignore
2828

2929
if sum(num_degree_one) == 2 and sum(num_degree_two) == 2 * (
3030
num_nodes - 2
@@ -166,5 +166,5 @@ def is_valid(graph: nx.Graph) -> bool:
166166
"""Check if a graph is a valid lobster graph."""
167167
return is_lobster_graph(graph)
168168

169-
def hash_for_split(self, split: str) -> str:
169+
def hash_for_split(self, split: str) -> Optional[str]:
170170
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/modelnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
import networkx as nx
23

34
from polygraph.datasets.base import OnlineGraphDataset
@@ -44,5 +45,5 @@ def url_for_split(self, split: str):
4445
def is_valid(graph: nx.Graph):
4546
return graph.number_of_nodes() > 0 and graph.number_of_edges() > 0
4647

47-
def hash_for_split(self, split: str) -> str:
48+
def hash_for_split(self, split: str) -> Optional[str]:
4849
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/planar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal
1+
from typing import Literal, Optional
22

33
import joblib
44
import networkx as nx
@@ -37,7 +37,7 @@ def __init__(
3737
memmap: bool = False,
3838
show_generation_progress: bool = False,
3939
):
40-
config_hash = joblib.hash(
40+
config_hash: str = joblib.hash(
4141
(num_graphs, n_nodes, seed, split), hash_name="md5"
4242
)
4343
self._rng = np.random.default_rng(
@@ -132,5 +132,5 @@ def is_valid(graph: nx.Graph) -> bool:
132132
"""
133133
return is_planar_graph(graph)
134134

135-
def hash_for_split(self, split: str) -> str:
135+
def hash_for_split(self, split: str) -> Optional[str]:
136136
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/point_clouds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
import networkx as nx
23

34
from polygraph.datasets.base import OnlineGraphDataset
@@ -43,5 +44,5 @@ def url_for_split(self, split: str):
4344
def is_valid(graph: nx.Graph):
4445
return graph.number_of_nodes() > 0 and graph.number_of_edges() > 0
4546

46-
def hash_for_split(self, split: str) -> str:
47+
def hash_for_split(self, split: str) -> Optional[str]:
4748
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/proteins.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Optional
12
import networkx as nx
23

34
from polygraph.datasets.base import OnlineGraphDataset
@@ -55,5 +56,5 @@ def url_for_split(self, split: str):
5556
def is_valid(graph: nx.Graph):
5657
return graph.number_of_nodes() > 0 and graph.number_of_edges() > 0
5758

58-
def hash_for_split(self, split: str) -> str:
59+
def hash_for_split(self, split: str) -> Optional[str]:
5960
return self._HASH_FOR_SPLIT[split]

polygraph/datasets/sbm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Tuple
1+
from typing import Literal, Tuple, Optional
22

33
import joblib
44
import networkx as nx
@@ -220,7 +220,7 @@ def __init__(
220220
memmap: bool = False,
221221
show_generation_progress: bool = False,
222222
):
223-
config_hash = joblib.hash(
223+
config_hash: str = joblib.hash(
224224
(
225225
num_graphs,
226226
intra_p,
@@ -358,5 +358,5 @@ def is_valid(graph: nx.Graph) -> bool:
358358
def is_valid_alt(graph: nx.Graph) -> bool:
359359
return is_sbm_graph_alt(graph)
360360

361-
def hash_for_split(self, split: str) -> str:
361+
def hash_for_split(self, split: str) -> Optional[str]:
362362
return self._HASH_FOR_SPLIT[split]

polygraph/metrics/base/metric_interval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def from_samples(
4747

4848
return cls(mean=mean, std=std, low=low, high=high, coverage=coverage)
4949

50-
def __getitem__(self, key: str) -> float:
50+
def __getitem__(self, key: str) -> Optional[float]:
5151
if key == "mean":
5252
return self.mean
5353
elif key == "std":

polygraph/metrics/base/mmd.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@
3939
"""
4040

4141
from abc import ABC, abstractmethod
42-
from collections import namedtuple
43-
from typing import Collection, Dict, Literal, Tuple, Union
42+
from typing import Any, Collection, Literal, Union
4443

4544
import networkx as nx
4645
import numpy as np
@@ -67,6 +66,8 @@ class DescriptorMMD2(GenerationMetric):
6766
variant: Which MMD estimator to use ('biased', 'umve', or 'ustat')
6867
"""
6968

69+
_variant: Literal["biased", "umve", "ustat"]
70+
7071
def __init__(
7172
self,
7273
reference_graphs: Collection[nx.Graph],
@@ -136,13 +137,16 @@ def compute(self, generated_graphs: Collection[nx.Graph]) -> float:
136137
Maximum MMD² value across kernel parameters
137138
"""
138139
multi_kernel_result = super().compute(generated_graphs)
140+
assert isinstance(multi_kernel_result, np.ndarray)
139141
idx = int(np.argmax(multi_kernel_result))
140142
return multi_kernel_result[idx]
141143

142144

143145
class _DescriptorMMD2Interval(ABC):
144146
"""Base class for computing MMD² confidence intervals through subsampling."""
145147

148+
_variant: Literal["biased", "umve", "ustat"]
149+
146150
def __init__(
147151
self,
148152
reference_graphs: Collection[nx.Graph],

0 commit comments

Comments
 (0)