Skip to content

Commit 17da595

Browse files
authored
Verify indirect deps reachable on incremental run (#20735)
Fixes #19477 This is another surprising correctness hole in the indirect dependencies logic. We implicitly assume everywhere that indirect dependencies are subset of transitive ones. Although it is true when the indirect dependencies are initially calculated, there is no guarantee this will stay true after an incremental update where some import structure has changed. Currently, we only handle the situations where an indirect dependency is removed from the graph completely, but if it is still in the graph, but now brought in by a _different_ dependent, this may cause out-of-order cache loading, and thus a bug or a crash. It is interesting that although in theory it is kind of a big deal, in practice it is very hard to create such scenario. So far I found only two complicated scenarios (see tests). Implementation note: although this additional check is needed for rare edge cases, it is quite computationally expensive. I tried few options (including only computing import deltas), but the only viable/robust option seems to be an honest recursive check with some caching. I tested this on `torch` (it has ~1000 small SCCs plus single SCC with around 1000 modules), and the performance penalty is negligible there. So hopefully it should be fine even for large code bases.
1 parent 9752e19 commit 17da595

6 files changed

Lines changed: 188 additions & 31 deletions

File tree

mypy/build.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,9 @@ def __init__(
870870
# Mapping from SCC id to corresponding SCC instance. This is populated
871871
# in process_graph().
872872
self.scc_by_id: dict[int, SCC] = {}
873+
# Mapping from module id to the SCC it belongs to. This is populated
874+
# in process_graph().
875+
self.scc_by_mod_id: dict[str, SCC] = {}
873876
# Global topological order for SCCs. This exists to make order of processing
874877
# SCCs more predictable.
875878
self.top_order: list[int] = []
@@ -892,6 +895,8 @@ def __init__(
892895
# raw parsed trees not analyzed with mypy. We use these to find absolute
893896
# location of a symbol used as a location for an error message.
894897
self.extra_trees: dict[str, MypyFile] = {}
898+
# Cache for transitive dependency check (expensive).
899+
self.transitive_deps_cache: dict[tuple[int, int], bool] = {}
895900

896901
def dump_stats(self) -> None:
897902
if self.options.dump_build_stats:
@@ -1203,6 +1208,35 @@ def wait_for_done_workers(self) -> tuple[list[SCC], bool, dict[str, tuple[str, l
12031208
results,
12041209
)
12051210

1211+
def is_transitive_scc_dep(self, from_scc_id: int, to_scc_id: int) -> bool:
1212+
"""Check if one SCC is a (transitive) dependency of another."""
1213+
edge = (from_scc_id, to_scc_id)
1214+
if (cached := self.transitive_deps_cache.get(edge)) is not None:
1215+
return cached
1216+
todo = self.scc_by_id[from_scc_id].deps
1217+
seen = set()
1218+
while todo:
1219+
more = set()
1220+
# Breadth-first search seems to be better here, because all
1221+
# "lower-level" SCCs are processed and some may be cached.
1222+
for dep in todo:
1223+
seen.add(dep)
1224+
if dep == to_scc_id:
1225+
self.transitive_deps_cache[edge] = True
1226+
return True
1227+
if cached := self.transitive_deps_cache.get((dep, to_scc_id)):
1228+
self.transitive_deps_cache[edge] = True
1229+
return True
1230+
elif cached is None:
1231+
more |= self.scc_by_id[dep].deps
1232+
todo = more
1233+
self.transitive_deps_cache[edge] = False
1234+
for dep in seen:
1235+
# We negative-cache all intermediate lookups, thus
1236+
# trading time for space.
1237+
self.transitive_deps_cache[(dep, to_scc_id)] = False
1238+
return False
1239+
12061240

12071241
def deps_to_json(x: dict[str, set[str]]) -> bytes:
12081242
return json_dumps({k: list(v) for k, v in x.items()})
@@ -1841,6 +1875,7 @@ def write_cache(
18411875
dep_prios: list[int],
18421876
dep_lines: list[int],
18431877
old_interface_hash: bytes,
1878+
trans_dep_hash: bytes,
18441879
source_hash: str,
18451880
ignore_all: bool,
18461881
manager: BuildManager,
@@ -1957,6 +1992,7 @@ def write_cache(
19571992
dep_prios=dep_prios,
19581993
dep_lines=dep_lines,
19591994
interface_hash=interface_hash,
1995+
trans_dep_hash=trans_dep_hash,
19601996
version_id=manager.version_id,
19611997
ignore_all=ignore_all,
19621998
plugin_data=plugin_data,
@@ -2175,6 +2211,12 @@ class State:
21752211
# Contains a hash of the public interface in incremental mode
21762212
interface_hash: bytes = b""
21772213

2214+
# Hash of import structure that this module depends on. It is not 1:1 with
2215+
# transitive dependencies set, but if two hashes are equal, transitive
2216+
# dependencies are guaranteed to be identical. Some expensive checks can be
2217+
# skipped if this value is unchanged for a module.
2218+
trans_dep_hash: bytes = b""
2219+
21782220
# Options, specialized for this file
21792221
options: Options
21802222

@@ -2322,15 +2364,15 @@ def new_state(
23222364
if temporary:
23232365
state.load_tree(temporary=True)
23242366
if not manager.use_fine_grained_cache():
2325-
# Special case: if there were a previously missing package imported here
2367+
# Special case: if there were a previously missing package imported here,
23262368
# and it is not present, then we need to re-calculate dependencies.
23272369
# This is to support patterns like this:
23282370
# from missing_package import missing_module # type: ignore
23292371
# At first mypy doesn't know that `missing_module` is a module
23302372
# (it may be a variable, a class, or a function), so it is not added to
23312373
# suppressed dependencies. Therefore, when the package with module is added,
23322374
# we need to re-calculate dependencies.
2333-
# NOTE: see comment below for why we skip this in fine grained mode.
2375+
# NOTE: see comment below for why we skip this in fine-grained mode.
23342376
if exist_added_packages(suppressed, manager, options):
23352377
state.parse_file() # This is safe because the cache is anyway stale.
23362378
state.compute_dependencies()
@@ -2350,6 +2392,7 @@ def new_state(
23502392
# We don't need parsed trees in coordinator process, we parse only to
23512393
# compute dependencies.
23522394
state.tree = None
2395+
del manager.ast_cache[id]
23532396

23542397
return state
23552398

@@ -3012,6 +3055,7 @@ def write_cache(self) -> tuple[CacheMeta, str] | None:
30123055
dep_prios,
30133056
dep_lines,
30143057
self.interface_hash,
3058+
self.trans_dep_hash,
30153059
self.source_hash,
30163060
self.ignore_all,
30173061
self.manager,
@@ -3774,6 +3818,27 @@ def order_ascc_ex(graph: Graph, ascc: SCC) -> list[str]:
37743818
return scc
37753819

37763820

3821+
def verify_transitive_deps(ascc: SCC, graph: Graph, manager: BuildManager) -> str | None:
3822+
"""Verify all indirect dependencies of this SCC are still reachable via direct ones.
3823+
3824+
Return first unreachable dependency id, or None.
3825+
"""
3826+
for id in ascc.mod_ids:
3827+
st = graph[id]
3828+
assert st.meta is not None, "Must be called on fresh SCCs only"
3829+
if st.trans_dep_hash == st.meta.trans_dep_hash:
3830+
# Import graph unchanged, skip this module.
3831+
continue
3832+
for dep in st.dependencies:
3833+
if st.priorities.get(dep) == PRI_INDIRECT:
3834+
dep_scc_id = manager.scc_by_mod_id[dep].id
3835+
if dep_scc_id == ascc.id:
3836+
continue
3837+
if not manager.is_transitive_scc_dep(ascc.id, dep_scc_id):
3838+
return dep
3839+
return None
3840+
3841+
37773842
def find_stale_sccs(
37783843
sccs: list[SCC], graph: Graph, manager: BuildManager
37793844
) -> tuple[list[SCC], list[SCC]]:
@@ -3782,7 +3847,8 @@ def find_stale_sccs(
37823847
Fresh SCCs are those where:
37833848
* We have valid cache files for all modules in the SCC.
37843849
* There are no changes in dependencies (files removed from/added to the build).
3785-
* The interface hashes of direct dependents matches those recorded in the cache.
3850+
* The interface hashes of dependencies matches those recorded in the cache.
3851+
* All indirect dependencies are still reachable via direct ones.
37863852
The first and second conditions are verified by is_fresh().
37873853
"""
37883854
stale_sccs = []
@@ -3799,6 +3865,15 @@ def find_stale_sccs(
37993865
stale_deps.add(dep)
38003866
fresh = fresh and not stale_deps
38013867

3868+
# Verify the invariant that indirect dependencies are a subset of transitive direct
3869+
# dependencies. Note: the case where indirect dependency is removed from the graph
3870+
# completely is already handled above.
3871+
stale_indirect = None
3872+
if fresh:
3873+
stale_indirect = verify_transitive_deps(ascc, graph, manager)
3874+
if stale_indirect is not None:
3875+
fresh = False
3876+
38023877
if fresh:
38033878
fresh_msg = "fresh"
38043879
elif stale_scc:
@@ -3807,8 +3882,11 @@ def find_stale_sccs(
38073882
fresh_msg += f" ({' '.join(sorted(stale_scc))})"
38083883
if stale_deps:
38093884
fresh_msg += f" with stale deps ({' '.join(sorted(stale_deps))})"
3810-
else:
3885+
elif stale_deps:
38113886
fresh_msg = f"stale due to deps ({' '.join(sorted(stale_deps))})"
3887+
else:
3888+
assert stale_indirect is not None
3889+
fresh_msg = f"stale due to stale indirect dep(s): first {stale_indirect}"
38123890

38133891
scc_str = " ".join(ascc.mod_ids)
38143892
if fresh:
@@ -3860,6 +3938,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
38603938
scc_by_id = {scc.id: scc for scc in sccs}
38613939
manager.scc_by_id = scc_by_id
38623940
manager.top_order = [scc.id for scc in sccs]
3941+
for scc in sccs:
3942+
for mod_id in scc.mod_ids:
3943+
manager.scc_by_mod_id[mod_id] = scc
38633944

38643945
# Broadcast SCC structure to the parallel workers, since they don't compute it.
38653946
sccs_message = SccsDataMessage(sccs=sccs)
@@ -3904,8 +3985,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
39043985
# type-checking this is already done and results should be empty here.
39053986
if not manager.workers:
39063987
assert not results
3907-
for id, (interface_cache, errors) in results.items():
3908-
new_hash = bytes.fromhex(interface_cache)
3988+
for id, (interface_hash, errors) in results.items():
3989+
new_hash = bytes.fromhex(interface_hash)
39093990
if new_hash != graph[id].interface_hash:
39103991
graph[id].mark_interface_stale()
39113992
graph[id].interface_hash = new_hash
@@ -3917,6 +3998,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
39173998
if not scc_by_id[dependent].not_ready_deps:
39183999
not_ready.remove(scc_by_id[dependent])
39194000
ready.append(scc_by_id[dependent])
4001+
manager.trace(f"Transitive deps cache size: {sys.getsizeof(manager.transitive_deps_cache)}")
39204002

39214003

39224004
def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_INDIRECT) -> list[str]:
@@ -4168,6 +4250,11 @@ def sorted_components(graph: Graph) -> list[SCC]:
41684250
scc.size_hint = sum(graph[mid].size_hint for mid in scc.mod_ids)
41694251
for dep in scc_dep_map[scc]:
41704252
dep.direct_dependents.append(scc.id)
4253+
# We compute dependencies hash here since we know no direct
4254+
# dependencies will be added or suppressed after this point.
4255+
trans_dep_hash = transitive_dep_hash(scc, graph)
4256+
for id in scc.mod_ids:
4257+
graph[id].trans_dep_hash = trans_dep_hash
41714258
res.extend(sorted_ready)
41724259
return res
41734260

@@ -4201,6 +4288,21 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
42014288
]
42024289

42034290

4291+
def transitive_dep_hash(scc: SCC, graph: Graph) -> bytes:
4292+
"""Compute stable snapshot of transitive import structure for given SCC."""
4293+
all_direct_deps = {
4294+
dep
4295+
for id in scc.mod_ids
4296+
for dep in graph[id].dependencies
4297+
if graph[id].priorities.get(dep) != PRI_INDIRECT
4298+
}
4299+
trans_dep_hash_map = {
4300+
dep_id: "" if dep_id in scc.mod_ids else graph[dep_id].trans_dep_hash.hex()
4301+
for dep_id in all_direct_deps
4302+
}
4303+
return hash_digest_bytes(json_dumps(trans_dep_hash_map))
4304+
4305+
42044306
def missing_stubs_file(cache_dir: str) -> str:
42054307
return os.path.join(cache_dir, "missing_stubs")
42064308

mypy/cache.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
from mypy_extensions import u8
7070

7171
# High-level cache layout format
72-
CACHE_VERSION: Final = 2
72+
CACHE_VERSION: Final = 3
7373

7474
SerializedError: _TypeAlias = tuple[str | None, int | str, int, int, int, str, str, str | None]
7575

@@ -95,6 +95,7 @@ def __init__(
9595
dep_lines: list[int],
9696
dep_hashes: list[bytes],
9797
interface_hash: bytes,
98+
trans_dep_hash: bytes,
9899
error_lines: list[SerializedError],
99100
version_id: str,
100101
ignore_all: bool,
@@ -117,6 +118,7 @@ def __init__(
117118
# dep_hashes list is aligned with dependencies only
118119
self.dep_hashes = dep_hashes # list of interface_hash for dependencies
119120
self.interface_hash = interface_hash # hash representing the public interface
121+
self.trans_dep_hash = trans_dep_hash # hash of import structure (transitive)
120122
self.error_lines = error_lines
121123
self.version_id = version_id # mypy version for cache invalidation
122124
self.ignore_all = ignore_all # if errors were ignored
@@ -138,6 +140,7 @@ def serialize(self) -> dict[str, Any]:
138140
"dep_lines": self.dep_lines,
139141
"dep_hashes": [dep.hex() for dep in self.dep_hashes],
140142
"interface_hash": self.interface_hash.hex(),
143+
"trans_dep_hash": self.trans_dep_hash.hex(),
141144
"error_lines": self.error_lines,
142145
"version_id": self.version_id,
143146
"ignore_all": self.ignore_all,
@@ -165,6 +168,7 @@ def deserialize(cls, meta: dict[str, Any], data_file: str) -> CacheMeta | None:
165168
dep_lines=meta["dep_lines"],
166169
dep_hashes=[bytes.fromhex(dep) for dep in meta["dep_hashes"]],
167170
interface_hash=bytes.fromhex(meta["interface_hash"]),
171+
trans_dep_hash=bytes.fromhex(meta["trans_dep_hash"]),
168172
error_lines=[tuple(err) for err in meta["error_lines"]],
169173
version_id=meta["version_id"],
170174
ignore_all=meta["ignore_all"],
@@ -191,6 +195,7 @@ def write(self, data: WriteBuffer) -> None:
191195
write_int_list(data, self.dep_lines)
192196
write_bytes_list(data, self.dep_hashes)
193197
write_bytes(data, self.interface_hash)
198+
write_bytes(data, self.trans_dep_hash)
194199
write_errors(data, self.error_lines)
195200
write_str(data, self.version_id)
196201
write_bool(data, self.ignore_all)
@@ -219,6 +224,7 @@ def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
219224
dep_lines=read_int_list(data),
220225
dep_hashes=read_bytes_list(data),
221226
interface_hash=read_bytes(data),
227+
trans_dep_hash=read_bytes(data),
222228
error_lines=read_errors(data),
223229
version_id=read_str(data),
224230
ignore_all=read_bool(data),

mypy/semanal_main.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from itertools import groupby
3232
from typing import TYPE_CHECKING, Final, TypeAlias as _TypeAlias
3333

34-
import mypy.build
3534
import mypy.state
3635
from mypy.checker import FineGrainedDeferredNode
3736
from mypy.errors import Errors
@@ -416,11 +415,6 @@ def semantic_analyze_target(
416415
)
417416
if isinstance(node, Decorator):
418417
infer_decorator_signature_if_simple(node, analyzer)
419-
for dep in analyzer.imports:
420-
state.add_dependency(dep)
421-
priority = mypy.build.PRI_LOW
422-
if priority <= state.priorities.get(dep, priority):
423-
state.priorities[dep] = priority
424418

425419
# Clear out some stale data to avoid memory leaks and astmerge
426420
# validity check confusion

mypy/test/testgraph.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ def test_sorted_components(self) -> None:
6464
"d": State.new_state("d", None, "pass", manager),
6565
"b": State.new_state("b", None, "import c", manager),
6666
"c": State.new_state("c", None, "import b, d", manager),
67+
"builtins": State.new_state("builtins", None, "", manager),
6768
}
6869
res = [scc.mod_ids for scc in sorted_components(graph)]
69-
assert_equal(res, [{"d"}, {"c", "b"}, {"a"}])
70+
assert_equal(res, [{"builtins"}, {"d"}, {"c", "b"}, {"a"}])
7071

7172
def test_order_ascc(self) -> None:
7273
manager = self._make_manager()
@@ -75,9 +76,10 @@ def test_order_ascc(self) -> None:
7576
"d": State.new_state("d", None, "def f(): import a", manager),
7677
"b": State.new_state("b", None, "import c", manager),
7778
"c": State.new_state("c", None, "import b, d", manager),
79+
"builtins": State.new_state("builtins", None, "", manager),
7880
}
7981
res = [scc.mod_ids for scc in sorted_components(graph)]
80-
assert_equal(res, [frozenset({"a", "d", "c", "b"})])
81-
ascc = res[0]
82+
assert_equal(res, [{"builtins"}, {"a", "d", "c", "b"}])
83+
ascc = res[1]
8284
scc = order_ascc(graph, ascc)
8385
assert_equal(scc, ["d", "c", "b", "a"])

0 commit comments

Comments
 (0)