Skip to content

Commit fed3593

Browse files
authored
Fix parallel worker crash on syntax error (#21202)
Fixes #21195 Fix is straightforward, and also makes tag reading convention similar to regular cache (i.e. caller reads the tag). Note that a cmdline (i.e. full) test is required for this, since this is not a crash from the point of view of the test harness, the workers simply dump tracebacks to stderr when they crash, while coordinator exits normally. Btw, as I mentioned in original PR #20280, at some point I am going to redirect workers' stdout/stderr to a log file, similar to how we do it with the daemon. This would also fix this issue, but I prefer that workers always exit normally when possible, i.e. only genuinely unexpected conditions should crash the workers.
1 parent be51e9c commit fed3593

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

mypy/build.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,9 @@ def wait_for_done_workers(
12221222
done_sccs = []
12231223
results = {}
12241224
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
1225-
data = SccResponseMessage.read(receive(self.workers[idx].conn))
1225+
buf = receive(self.workers[idx].conn)
1226+
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
1227+
data = SccResponseMessage.read(buf)
12261228
self.free_workers.add(idx)
12271229
scc_id = data.scc_id
12281230
if data.blocker is not None:
@@ -4165,7 +4167,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41654167
graph_message.write(buf)
41664168
graph_data = buf.getvalue()
41674169
for worker in manager.workers:
4168-
AckMessage.read(receive(worker.conn))
4170+
buf = receive(worker.conn)
4171+
assert read_tag(buf) == ACK_MESSAGE
41694172
worker.conn.write_bytes(graph_data)
41704173

41714174
sccs = sorted_components(graph)
@@ -4185,10 +4188,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
41854188
sccs_message.write(buf)
41864189
sccs_data = buf.getvalue()
41874190
for worker in manager.workers:
4188-
AckMessage.read(receive(worker.conn))
4191+
buf = receive(worker.conn)
4192+
assert read_tag(buf) == ACK_MESSAGE
41894193
worker.conn.write_bytes(sccs_data)
41904194
for worker in manager.workers:
4191-
AckMessage.read(receive(worker.conn))
4195+
buf = receive(worker.conn)
4196+
assert read_tag(buf) == ACK_MESSAGE
41924197

41934198
manager.free_workers = set(range(manager.options.num_workers))
41944199

@@ -4620,7 +4625,6 @@ class AckMessage(IPCMessage):
46204625

46214626
@classmethod
46224627
def read(cls, buf: ReadBuffer) -> AckMessage:
4623-
assert read_tag(buf) == ACK_MESSAGE
46244628
return AckMessage()
46254629

46264630
def write(self, buf: WriteBuffer) -> None:
@@ -4647,7 +4651,6 @@ def __init__(
46474651

46484652
@classmethod
46494653
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
4650-
assert read_tag(buf) == SCC_REQUEST_MESSAGE
46514654
return SccRequestMessage(
46524655
scc_id=read_int_opt(buf),
46534656
import_errors={
@@ -4708,7 +4711,6 @@ def __init__(
47084711

47094712
@classmethod
47104713
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
4711-
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
47124714
scc_id = read_int(buf)
47134715
tag = read_tag(buf)
47144716
if tag == LITERAL_NONE:
@@ -4753,7 +4755,6 @@ def __init__(self, *, sources: list[BuildSource]) -> None:
47534755

47544756
@classmethod
47554757
def read(cls, buf: ReadBuffer) -> SourcesDataMessage:
4756-
assert read_tag(buf) == SOURCES_DATA_MESSAGE
47574758
sources = [
47584759
BuildSource(
47594760
read_str_opt(buf),
@@ -4785,7 +4786,6 @@ def __init__(self, *, sccs: list[SCC]) -> None:
47854786

47864787
@classmethod
47874788
def read(cls, buf: ReadBuffer) -> SccsDataMessage:
4788-
assert read_tag(buf) == SCCS_DATA_MESSAGE
47894789
sccs = [
47904790
SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf))
47914791
for _ in range(read_int_bare(buf))
@@ -4813,7 +4813,6 @@ def __init__(self, *, graph: Graph, missing_modules: dict[str, int]) -> None:
48134813
@classmethod
48144814
def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMessage:
48154815
assert manager is not None
4816-
assert read_tag(buf) == GRAPH_MESSAGE
48174816
graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))}
48184817
missing_modules = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))}
48194818
message = GraphMessage(graph=graph, missing_modules=missing_modules)

mypy/build_worker/worker.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,15 @@
2424
from typing import NamedTuple
2525

2626
from librt.base64 import b64decode
27+
from librt.internal import ReadBuffer, read_tag
2728

2829
from mypy import util
2930
from mypy.build import (
31+
GRAPH_MESSAGE,
3032
SCC,
33+
SCC_REQUEST_MESSAGE,
34+
SCCS_DATA_MESSAGE,
35+
SOURCES_DATA_MESSAGE,
3136
AckMessage,
3237
BuildManager,
3338
Graph,
@@ -39,6 +44,7 @@
3944
load_plugins,
4045
process_stale_scc,
4146
)
47+
from mypy.cache import Tag, read_int_opt
4248
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
4349
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
4450
from mypy.fscache import FileSystemCache
@@ -113,21 +119,37 @@ def main(argv: list[str]) -> None:
113119
util.hard_exit(0)
114120

115121

122+
def should_shutdown(buf: ReadBuffer, expected_tag: Tag) -> bool:
123+
"""Check if the message is a shutdown request."""
124+
tag = read_tag(buf)
125+
if tag == SCC_REQUEST_MESSAGE:
126+
assert read_int_opt(buf) is None
127+
return True
128+
assert tag == expected_tag, f"Unexpected tag: {tag}"
129+
return False
130+
131+
116132
def serve(server: IPCServer, ctx: ServerContext) -> None:
117133
"""Main server loop of the worker.
118134
119135
Receive initial state from the coordinator, then process each
120136
SCC checking request and reply to client (coordinator). See module
121137
docstring for more details on the protocol.
122138
"""
123-
sources = SourcesDataMessage.read(receive(server)).sources
139+
buf = receive(server)
140+
if should_shutdown(buf, SOURCES_DATA_MESSAGE):
141+
return
142+
sources = SourcesDataMessage.read(buf).sources
124143
manager = setup_worker_manager(sources, ctx)
125144
if manager is None:
126145
return
127146

128147
# Notify coordinator we are done with setup.
129148
send(server, AckMessage())
130-
graph_data = GraphMessage.read(receive(server), manager)
149+
buf = receive(server)
150+
if should_shutdown(buf, GRAPH_MESSAGE):
151+
return
152+
graph_data = GraphMessage.read(buf, manager)
131153
# Update some manager data in-place as it has been passed to semantic analyzer.
132154
manager.missing_modules |= graph_data.missing_modules
133155
graph = graph_data.graph
@@ -138,14 +160,19 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:
138160

139161
# Notify coordinator we are ready to receive computed graph SCC structure.
140162
send(server, AckMessage())
141-
sccs = SccsDataMessage.read(receive(server)).sccs
163+
buf = receive(server)
164+
if should_shutdown(buf, SCCS_DATA_MESSAGE):
165+
return
166+
sccs = SccsDataMessage.read(buf).sccs
142167
manager.scc_by_id = {scc.id: scc for scc in sccs}
143168
manager.top_order = [scc.id for scc in sccs]
144169

145170
# Notify coordinator we are ready to start processing SCCs.
146171
send(server, AckMessage())
147172
while True:
148-
scc_message = SccRequestMessage.read(receive(server))
173+
buf = receive(server)
174+
assert read_tag(buf) == SCC_REQUEST_MESSAGE
175+
scc_message = SccRequestMessage.read(buf)
149176
scc_id = scc_message.scc_id
150177
if scc_id is None:
151178
manager.dump_stats()

test-data/unit/cmdline.test

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,3 +1247,13 @@ class CodecKey(NamedTuple):
12471247
\[mypy-importlib.*]
12481248
follow_imports = skip
12491249
follow_imports_for_stubs = True
1250+
1251+
[case testParallelRunWithSyntaxError]
1252+
# cmd: mypy a.py --num-workers=2 --pretty
1253+
[file a.py]
1254+
1 2
1255+
[out]
1256+
a.py:1: error: Simple statements must be separated by newlines or semicolons
1257+
1 2
1258+
^
1259+
== Return code: 2

0 commit comments

Comments
 (0)