Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,9 @@ def wait_for_done_workers(
done_sccs = []
results = {}
for idx in ready_to_read([w.conn for w in self.workers], WORKER_DONE_TIMEOUT):
data = SccResponseMessage.read(receive(self.workers[idx].conn))
buf = receive(self.workers[idx].conn)
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
data = SccResponseMessage.read(buf)
self.free_workers.add(idx)
scc_id = data.scc_id
if data.blocker is not None:
Expand Down Expand Up @@ -4165,7 +4167,8 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
graph_message.write(buf)
graph_data = buf.getvalue()
for worker in manager.workers:
AckMessage.read(receive(worker.conn))
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE
worker.conn.write_bytes(graph_data)

sccs = sorted_components(graph)
Expand All @@ -4185,10 +4188,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
sccs_message.write(buf)
sccs_data = buf.getvalue()
for worker in manager.workers:
AckMessage.read(receive(worker.conn))
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE
worker.conn.write_bytes(sccs_data)
for worker in manager.workers:
AckMessage.read(receive(worker.conn))
buf = receive(worker.conn)
assert read_tag(buf) == ACK_MESSAGE

manager.free_workers = set(range(manager.options.num_workers))

Expand Down Expand Up @@ -4620,7 +4625,6 @@ class AckMessage(IPCMessage):

@classmethod
def read(cls, buf: ReadBuffer) -> AckMessage:
assert read_tag(buf) == ACK_MESSAGE
return AckMessage()

def write(self, buf: WriteBuffer) -> None:
Expand All @@ -4647,7 +4651,6 @@ def __init__(

@classmethod
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
assert read_tag(buf) == SCC_REQUEST_MESSAGE
return SccRequestMessage(
scc_id=read_int_opt(buf),
import_errors={
Expand Down Expand Up @@ -4708,7 +4711,6 @@ def __init__(

@classmethod
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
assert read_tag(buf) == SCC_RESPONSE_MESSAGE
scc_id = read_int(buf)
tag = read_tag(buf)
if tag == LITERAL_NONE:
Expand Down Expand Up @@ -4753,7 +4755,6 @@ def __init__(self, *, sources: list[BuildSource]) -> None:

@classmethod
def read(cls, buf: ReadBuffer) -> SourcesDataMessage:
assert read_tag(buf) == SOURCES_DATA_MESSAGE
sources = [
BuildSource(
read_str_opt(buf),
Expand Down Expand Up @@ -4785,7 +4786,6 @@ def __init__(self, *, sccs: list[SCC]) -> None:

@classmethod
def read(cls, buf: ReadBuffer) -> SccsDataMessage:
assert read_tag(buf) == SCCS_DATA_MESSAGE
sccs = [
SCC(set(read_str_list(buf)), read_int(buf), read_int_list(buf))
for _ in range(read_int_bare(buf))
Expand Down Expand Up @@ -4813,7 +4813,6 @@ def __init__(self, *, graph: Graph, missing_modules: dict[str, int]) -> None:
@classmethod
def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMessage:
assert manager is not None
assert read_tag(buf) == GRAPH_MESSAGE
graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))}
missing_modules = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))}
message = GraphMessage(graph=graph, missing_modules=missing_modules)
Expand Down
35 changes: 31 additions & 4 deletions mypy/build_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@
from typing import NamedTuple

from librt.base64 import b64decode
from librt.internal import ReadBuffer, read_tag

from mypy import util
from mypy.build import (
GRAPH_MESSAGE,
SCC,
SCC_REQUEST_MESSAGE,
SCCS_DATA_MESSAGE,
SOURCES_DATA_MESSAGE,
AckMessage,
BuildManager,
Graph,
Expand All @@ -39,6 +44,7 @@
load_plugins,
process_stale_scc,
)
from mypy.cache import Tag, read_int_opt
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.fscache import FileSystemCache
Expand Down Expand Up @@ -113,21 +119,37 @@ def main(argv: list[str]) -> None:
util.hard_exit(0)


def should_shutdown(buf: ReadBuffer, expected_tag: Tag) -> bool:
"""Check if the message is a shutdown request."""
tag = read_tag(buf)
if tag == SCC_REQUEST_MESSAGE:
assert read_int_opt(buf) is None
return True
assert tag == expected_tag, f"Unexpected tag: {tag}"
return False


def serve(server: IPCServer, ctx: ServerContext) -> None:
"""Main server loop of the worker.

Receive initial state from the coordinator, then process each
SCC checking request and reply to client (coordinator). See module
docstring for more details on the protocol.
"""
sources = SourcesDataMessage.read(receive(server)).sources
buf = receive(server)
if should_shutdown(buf, SOURCES_DATA_MESSAGE):
return
sources = SourcesDataMessage.read(buf).sources
manager = setup_worker_manager(sources, ctx)
if manager is None:
return

# Notify coordinator we are done with setup.
send(server, AckMessage())
graph_data = GraphMessage.read(receive(server), manager)
buf = receive(server)
if should_shutdown(buf, GRAPH_MESSAGE):
return
graph_data = GraphMessage.read(buf, manager)
# Update some manager data in-place as it has been passed to semantic analyzer.
manager.missing_modules |= graph_data.missing_modules
graph = graph_data.graph
Expand All @@ -138,14 +160,19 @@ def serve(server: IPCServer, ctx: ServerContext) -> None:

# Notify coordinator we are ready to receive computed graph SCC structure.
send(server, AckMessage())
sccs = SccsDataMessage.read(receive(server)).sccs
buf = receive(server)
if should_shutdown(buf, SCCS_DATA_MESSAGE):
return
sccs = SccsDataMessage.read(buf).sccs
manager.scc_by_id = {scc.id: scc for scc in sccs}
manager.top_order = [scc.id for scc in sccs]

# Notify coordinator we are ready to start processing SCCs.
send(server, AckMessage())
while True:
scc_message = SccRequestMessage.read(receive(server))
buf = receive(server)
assert read_tag(buf) == SCC_REQUEST_MESSAGE
scc_message = SccRequestMessage.read(buf)
scc_id = scc_message.scc_id
if scc_id is None:
manager.dump_stats()
Expand Down
10 changes: 10 additions & 0 deletions test-data/unit/cmdline.test
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,13 @@ class CodecKey(NamedTuple):
\[mypy-importlib.*]
follow_imports = skip
follow_imports_for_stubs = True

[case testParallelRunWithSyntaxError]
# cmd: mypy a.py --num-workers=2 --pretty
[file a.py]
1 2
[out]
a.py:1: error: Simple statements must be separated by newlines or semicolons
1 2
^
== Return code: 2
Loading