From 3d7cc5f1f7f65564d13b22cb2449277c53f6afcd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 10 Apr 2026 20:48:11 +0100 Subject: [PATCH] Fix parallel worker crash on syntax error --- mypy/build.py | 19 +++++++++---------- mypy/build_worker/worker.py | 35 +++++++++++++++++++++++++++++++---- test-data/unit/cmdline.test | 10 ++++++++++ 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/mypy/build.py b/mypy/build.py index 2a9e7f5bf8f16..96ba59dd10956 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -1222,7 +1222,9 @@ def wait_for_done_workers( done_sccs = [] results = {} for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT): - data = SccResponseMessage.read(receive(self.workers[idx].conn)) + buf = receive(self.workers[idx].conn) + assert read_tag(buf) == SCC_RESPONSE_MESSAGE + data = SccResponseMessage.read(buf) self.free_workers.add(idx) scc_id = data.scc_id if data.blocker is not None: @@ -4165,7 +4167,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: graph_message.write(buf) graph_data = buf.getvalue() for worker in manager.workers: - AckMessage.read(receive(worker.conn)) + buf = receive(worker.conn) + assert read_tag(buf) == ACK_MESSAGE worker.conn.write_bytes(graph_data) sccs = sorted_components(graph) @@ -4185,10 +4188,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: sccs_message.write(buf) sccs_data = buf.getvalue() for worker in manager.workers: - AckMessage.read(receive(worker.conn)) + buf = receive(worker.conn) + assert read_tag(buf) == ACK_MESSAGE worker.conn.write_bytes(sccs_data) for worker in manager.workers: - AckMessage.read(receive(worker.conn)) + buf = receive(worker.conn) + assert read_tag(buf) == ACK_MESSAGE manager.free_workers = set(range(manager.options.num_workers)) @@ -4620,7 +4625,6 @@ class AckMessage(IPCMessage): @classmethod def read(cls, buf: ReadBuffer) -> AckMessage: - assert read_tag(buf) == ACK_MESSAGE return AckMessage() def write(self, buf: WriteBuffer) -> None: @@ -4647,7 +4651,6 @@ def __init__( @classmethod def read(cls, buf: ReadBuffer) -> SccRequestMessage: - assert read_tag(buf) == SCC_REQUEST_MESSAGE return SccRequestMessage( scc_id=read_int_opt(buf), import_errors={ @@ -4708,7 +4711,6 @@ def __init__( @classmethod def read(cls, buf: ReadBuffer) -> SccResponseMessage: - assert read_tag(buf) == SCC_RESPONSE_MESSAGE scc_id = read_int(buf) tag = read_tag(buf) if tag == LITERAL_NONE: @@ -4753,7 +4755,6 @@ def __init__(self, *, sources: list[BuildSource]) -> None: @classmethod def read(cls, buf: ReadBuffer) -> SourcesDataMessage: - assert read_tag(buf) == SOURCES_DATA_MESSAGE sources = [ BuildSource( read_str_opt(buf), @@ -4785,7 +4786,6 @@ def __init__(self, *, sccs: list[SCC]) -> None: @classmethod def read(cls, buf: ReadBuffer) -> SccsDataMessage: - assert read_tag(buf) == SCCS_DATA_MESSAGE sccs = [ SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf)) for _ in range(read_int_bare(buf)) @@ -4813,7 +4813,6 @@ def __init__(self, *, graph: Graph, missing_modules: dict[str, int]) -> None: @classmethod def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMessage: assert manager is not None - assert read_tag(buf) == GRAPH_MESSAGE graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))} missing_modules = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))} message = GraphMessage(graph=graph, missing_modules=missing_modules) diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index b35da8c412c73..66cfec6f6a36a 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -24,10 +24,15 @@ from typing import NamedTuple from librt.base64 import b64decode +from librt.internal import ReadBuffer, read_tag from mypy import util from mypy.build import ( + GRAPH_MESSAGE, SCC, + SCC_REQUEST_MESSAGE, + SCCS_DATA_MESSAGE, + SOURCES_DATA_MESSAGE, AckMessage, BuildManager, Graph, @@ -39,6 +44,7 @@ load_plugins, process_stale_scc, ) +from mypy.cache import Tag, read_int_opt from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error from mypy.fscache import FileSystemCache @@ -113,6 +119,16 @@ def main(argv: list[str]) -> None: util.hard_exit(0) +def should_shutdown(buf: ReadBuffer, expected_tag: Tag) -> bool: + """Check if the message is a shutdown request.""" + tag = read_tag(buf) + if tag == SCC_REQUEST_MESSAGE: + assert read_int_opt(buf) is None + return True + assert tag == expected_tag, f"Unexpected tag: {tag}" + return False + + def serve(server: IPCServer, ctx: ServerContext) -> None: """Main server loop of the worker. @@ -120,14 +136,20 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: SCC checking request and reply to client (coordinator). See module docstring for more details on the protocol. """ - sources = SourcesDataMessage.read(receive(server)).sources + buf = receive(server) + if should_shutdown(buf, SOURCES_DATA_MESSAGE): + return + sources = SourcesDataMessage.read(buf).sources manager = setup_worker_manager(sources, ctx) if manager is None: return # Notify coordinator we are done with setup. send(server, AckMessage()) - graph_data = GraphMessage.read(receive(server), manager) + buf = receive(server) + if should_shutdown(buf, GRAPH_MESSAGE): + return + graph_data = GraphMessage.read(buf, manager) # Update some manager data in-place as it has been passed to semantic analyzer. manager.missing_modules |= graph_data.missing_modules graph = graph_data.graph @@ -138,14 +160,19 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: # Notify coordinator we are ready to receive computed graph SCC structure. send(server, AckMessage()) - sccs = SccsDataMessage.read(receive(server)).sccs + buf = receive(server) + if should_shutdown(buf, SCCS_DATA_MESSAGE): + return + sccs = SccsDataMessage.read(buf).sccs manager.scc_by_id = {scc.id: scc for scc in sccs} manager.top_order = [scc.id for scc in sccs] # Notify coordinator we are ready to start processing SCCs. send(server, AckMessage()) while True: - scc_message = SccRequestMessage.read(receive(server)) + buf = receive(server) + assert read_tag(buf) == SCC_REQUEST_MESSAGE + scc_message = SccRequestMessage.read(buf) scc_id = scc_message.scc_id if scc_id is None: manager.dump_stats() diff --git a/test-data/unit/cmdline.test b/test-data/unit/cmdline.test index d8ae2dc4afb1b..8549ae277295f 100644 --- a/test-data/unit/cmdline.test +++ b/test-data/unit/cmdline.test @@ -1247,3 +1247,13 @@ class CodecKey(NamedTuple): \[mypy-importlib.*] follow_imports = skip follow_imports_for_stubs = True + +[case testParallelRunWithSyntaxError] +# cmd: mypy a.py --num-workers=2 --pretty +[file a.py] +1 2 +[out] +a.py:1: error: Simple statements must be separated by newlines or semicolons + 1 2 + ^ +== Return code: 2