diff --git a/mypy/build.py b/mypy/build.py index 7eee0f343c45..abb06605b8e8 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -56,18 +56,21 @@ from mypy.cache import ( CACHE_VERSION, DICT_STR_GEN, + LIST_GEN, LITERAL_NONE, CacheMeta, ReadBuffer, SerializedError, Tag, WriteBuffer, + read_bytes, read_int, read_int_list, read_int_opt, read_str, read_str_list, read_str_opt, + write_bytes, write_int, write_int_list, write_int_opt, @@ -2391,6 +2394,93 @@ def __init__( self.add_ancestors() self.size_hint = size_hint + def write(self, buf: WriteBuffer) -> None: + """Serialize State for sending to build worker. + + Note that unlike write() methods for most other classes, this one is + not idempotent. We erase some bulky values that should either be not needed + for processing by the worker, or can be re-created from other data relatively + quickly. These are: + * self.meta: workers will call self.reload_meta() anyway. + * self.options: can be restored with Options.clone_for_module(). + * self.error_lines: fresh errors are handled by the coordinator. + """ + write_int(buf, self.order) + write_str(buf, self.id) + write_str_opt(buf, self.path) + write_str_opt(buf, self.source) # mostly for mypy -c '' + write_bool(buf, self.ignore_all) + write_int(buf, self.caller_line) + write_tag(buf, LIST_GEN) + write_int_bare(buf, len(self.import_context)) + for path, line in self.import_context: + write_str(buf, path) + write_int(buf, line) + write_bytes(buf, self.interface_hash) + write_str_opt(buf, self.meta_source_hash) + write_str_list(buf, self.dependencies) + write_str_list(buf, self.suppressed) + # TODO: we can possibly serialize these dictionaries in a more compact way. + # Most keys in the dictionaries should be the same, so we can write them once. + write_tag(buf, DICT_STR_GEN) + write_int_bare(buf, len(self.priorities)) + for mod_id, prio in self.priorities.items(): + write_str_bare(buf, mod_id) + write_int(buf, prio) + write_tag(buf, DICT_STR_GEN) + write_int_bare(buf, len(self.dep_line_map)) + for mod_id, line in self.dep_line_map.items(): + write_str_bare(buf, mod_id) + write_int(buf, line) + write_tag(buf, DICT_STR_GEN) + write_int_bare(buf, len(self.dep_hashes)) + for mod_id, dep_hash in self.dep_hashes.items(): + write_str_bare(buf, mod_id) + write_bytes(buf, dep_hash) + write_int(buf, self.size_hint) + + @classmethod + def read(cls, buf: ReadBuffer, manager: BuildManager) -> State: + order = read_int(buf) + id = read_str(buf) + path = read_str_opt(buf) + source = read_str_opt(buf) + ignore_all = read_bool(buf) + caller_line = read_int(buf) + assert read_tag(buf) == LIST_GEN + import_context = [(read_str(buf), read_int(buf)) for _ in range(read_int_bare(buf))] + interface_hash = read_bytes(buf) + meta_source_hash = read_str_opt(buf) + dependencies = read_str_list(buf) + suppressed = read_str_list(buf) + assert read_tag(buf) == DICT_STR_GEN + priorities = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))} + assert read_tag(buf) == DICT_STR_GEN + dep_line_map = {read_str_bare(buf): read_int(buf) for _ in range(read_int_bare(buf))} + assert read_tag(buf) == DICT_STR_GEN + dep_hashes = {read_str_bare(buf): read_bytes(buf) for _ in range(read_int_bare(buf))} + return cls( + manager=manager, + order=order, + id=id, + path=path, + source=source, + options=manager.options.clone_for_module(id), + ignore_all=ignore_all, + caller_line=caller_line, + import_context=import_context, + meta=None, + interface_hash=interface_hash, + meta_source_hash=meta_source_hash, + dependencies=dependencies, + suppressed=suppressed, + priorities=priorities, + dep_line_map=dep_line_map, + dep_hashes=dep_hashes, + error_lines=[], + size_hint=read_int(buf), + ) + def reload_meta(self) -> None: """Force reload of cache meta. @@ -3727,11 +3817,19 @@ def find_stale_sccs( def process_graph(graph: Graph, manager: BuildManager) -> None: """Process everything in dependency order.""" + # Broadcast graph to workers before computing SCCs to save a bit of time. + graph_message = GraphMessage(graph=graph) + buf = WriteBuffer() + graph_message.write(buf) + graph_data = buf.getvalue() + for worker in manager.workers: + AckMessage.read(receive(worker.conn)) + worker.conn.write_bytes(graph_data) + sccs = sorted_components(graph) manager.log( "Found %d SCCs; largest has %d nodes" % (len(sccs), max(len(scc.mod_ids) for scc in sccs)) ) - scc_by_id = {scc.id: scc for scc in sccs} manager.scc_by_id = scc_by_id manager.top_order = [scc.id for scc in sccs] @@ -4186,6 +4284,7 @@ def deserialize_codes(errs: list[SerializedError]) -> list[ErrorTupleRaw]: SCC_RESPONSE_MESSAGE: Final[Tag] = 103 SOURCES_DATA_MESSAGE: Final[Tag] = 104 SCCS_DATA_MESSAGE: Final[Tag] = 105 +GRAPH_MESSAGE: Final[Tag] = 106 class AckMessage(IPCMessage): @@ -4336,3 +4435,24 @@ def write(self, buf: WriteBuffer) -> None: write_str_list(buf, sorted(scc.mod_ids)) write_int(buf, scc.id) write_int_list(buf, sorted(scc.deps)) + + +class GraphMessage(IPCMessage): + """A message wrapping the build graph computed by the coordinator.""" + + def __init__(self, *, graph: Graph) -> None: + self.graph = graph + + @classmethod + def read(cls, buf: ReadBuffer, manager: BuildManager | None = None) -> GraphMessage: + assert manager is not None + assert read_tag(buf) == GRAPH_MESSAGE + graph = {read_str_bare(buf): State.read(buf, manager) for _ in range(read_int_bare(buf))} + return GraphMessage(graph=graph) + + def write(self, buf: WriteBuffer) -> None: + write_tag(buf, GRAPH_MESSAGE) + write_int_bare(buf, len(self.graph)) + for mod_id, state in self.graph.items(): + write_str_bare(buf, mod_id) + state.write(buf) diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index 86a0da3b6c7f..9cc7b25a8157 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -28,6 +28,7 @@ from mypy.build import ( AckMessage, BuildManager, + GraphMessage, SccRequestMessage, SccResponseMessage, SccsDataMessage, @@ -128,6 +129,15 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: # Notify worker we are done loading graph. send(server, AckMessage()) + + # Compare worker graph and coordinator, with parallel parser we will only use the latter. + coordinator_graph = GraphMessage.read(receive(server), manager).graph + assert coordinator_graph.keys() == graph.keys() + for id in graph: + assert graph[id].dependencies_set == coordinator_graph[id].dependencies_set + assert graph[id].suppressed_set == coordinator_graph[id].suppressed_set + send(server, AckMessage()) + sccs = SccsDataMessage.read(receive(server)).sccs manager.scc_by_id = {scc.id: scc for scc in sccs} manager.top_order = [scc.id for scc in sccs]