#!/usr/bin/env python # Copyright (c) ONNX Project Contributors # Copyright (c) ONNX Project Contributors # # SPDX-License-Identifier: Apache-2.0 # NOTE: This is deprecated in favor of protogen's own .pyi generation method. # See: https://github.com/onnx/onnx/pull/6096 # Taken from https://github.com/dropbox/mypy-protobuf/blob/d984389124eae6dbbb517f766b9266bb32171510/python/protoc-gen-mypy # (Apache 2.0 License) # with own fixes to # - appease flake8 # - exit without error when protobuf isn't installed # - fix recognition of whether an identifier is defined locally # (unfortunately, we use a python package name ONNX_NAMESPACE_FOO_BAR_FOR_CI # on CI, which by the original protoc-gen-mypy script was recognized to be # camel case and therefore handled as an entry in the local package) """Protoc Plugin to generate mypy stubs. Loosely based on @zbarsky's go implementation""" from __future__ import annotations import sys from collections import defaultdict from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Callable, cast if TYPE_CHECKING: from collections.abc import Generator try: import google.protobuf.descriptor_pb2 as d_typed from google.protobuf.compiler import plugin_pb2 as plugin except ImportError as e: raise RuntimeError("Failed to generate mypy stubs") from e # Hax to get around fact that google protobuf libraries aren't in typeshed yet d: Any = d_typed # Split the string so phabricator doesn't think this file is generated GENERATED = "@ge" + "nerated" HEADER = ( f"# {GENERATED} by protoc-gen-mypy.py. Do not edit!\n" "# mypy: disable-error-code=override\n" ) class Descriptors: def __init__(self, request: plugin.CodeGeneratorRequest) -> None: files = {f.name: f for f in request.proto_file} to_generate = {n: files[n] for n in request.file_to_generate} self.files: dict[str, d.FileDescriptorProto] = files self.to_generate: dict[str, d.FileDescriptorProto] = to_generate self.messages: dict[str, d.DescriptorProto] = {} self.message_to_fd: dict[str, d.FileDescriptorProto] = {} def _add_enums( enums: d.EnumDescriptorProto, prefix: str, fd: d.FileDescriptorProto ) -> None: for enum in enums: self.message_to_fd[prefix + enum.name] = fd def _add_messages( messages: d.DescriptorProto, prefix: str, fd: d.FileDescriptorProto ) -> None: for message in messages: self.messages[prefix + message.name] = message self.message_to_fd[prefix + message.name] = fd sub_prefix = prefix + message.name + "." _add_messages(message.nested_type, sub_prefix, fd) _add_enums(message.enum_type, sub_prefix, fd) for fd in request.proto_file: start_prefix = "." + fd.package + "." _add_messages(fd.message_type, start_prefix, fd) _add_enums(fd.enum_type, start_prefix, fd) class PkgWriter: """Writes a single pyi file""" def __init__(self, fd: d.FileDescriptorProto, descriptors: Descriptors) -> None: self.fd = fd self.descriptors = descriptors self.lines: list[str] = [] self.indent = "" # dictionary of x->y for `from {x} import {y}` self.imports: dict[str, set[str]] = defaultdict(set) self.locals: set[str] = set() def _import(self, path: str, name: str, import_as: str | None = None) -> str: """Imports a stdlib path and returns a handle to it eg. self._import("typing", "Optional") -> "Optional" """ imp = path.replace("/", ".") if import_as is not None: self.imports[imp].add(f"{name} as {import_as}") return import_as else: self.imports[imp].add(name) return name def _import_message(self, type_name: d.FieldDescriptorProto) -> str: """Import a referenced message and return a handle""" name = cast(str, type_name) if name[0] == "." and name[1].isupper() and name[2].islower(): # Message defined in this file return name[1:] message_fd = self.descriptors.message_to_fd[name] if message_fd.name == self.fd.name: # message defined in this package split = name.split(".") for i, segment in enumerate(split): if segment and segment[0].isupper() and segment[1].islower(): return ".".join(split[i:]) # Not in package. Must import split = name.split(".") for i, segment in enumerate(split): if segment and segment[0].isupper() and segment[1].islower(): assert message_fd.name.endswith(".proto") import_name = self._import( message_fd.name[:-6].replace("-", "_") + "_pb2", segment ) remains = ".".join(split[i + 1 :]) if not remains: return import_name raise AssertionError("Don't support nested imports yet") # return new_nested_import(import_name, remains) raise AssertionError("Could not parse local name " + name) @contextmanager def _indent(self) -> Generator[None, None, None]: self.indent = self.indent + " " yield self.indent = self.indent[:-4] def _write_line(self, line: str, *args: str) -> None: self.lines.append(self.indent + line.format(*args)) def write_enums(self, enums: list[d.EnumDescriptorProto]) -> None: line = self._write_line for enum in enums: line("class {}(int):", enum.name) with self._indent(): line("@classmethod") line("def Name(cls, number: int) -> str: ...") line("@classmethod") line("def Value(cls, name: str) -> int: ...") line("@classmethod") line("def keys(cls) -> {}[str]: ...", self._import("typing", "List")) line("@classmethod") line("def values(cls) -> {}[int]: ...", self._import("typing", "List")) line("@classmethod") line( "def items(cls) -> {}[{}[str, int]]: ...", self._import("typing", "List"), self._import("typing", "Tuple"), ) for val in enum.value: line( "{} = {}({}, {})", val.name, self._import("typing", "cast"), enum.name, val.number, ) line("") def write_messages(self, messages: list[d.DescriptorProto], prefix: str) -> None: line = self._write_line message_class = self._import("google.protobuf.message", "Message") for desc in messages: self.locals.add(desc.name) qualified_name = prefix + desc.name line("class {}({}):", desc.name, message_class) with self._indent(): # Nested enums/messages self.write_enums(desc.enum_type) self.write_messages(desc.nested_type, qualified_name + ".") # Scalar fields for field in [f for f in desc.field if is_scalar(f)]: if field.label == d.FieldDescriptorProto.LABEL_REPEATED: container = self._import( "google.protobuf.internal.containers", "RepeatedScalarFieldContainer", ) line( "{} = ... # type: {}[{}]", field.name, container, self.python_type(field), ) else: line("{} = ... # type: {}", field.name, self.python_type(field)) line("") # Getters for non-scalar fields for field in [f for f in desc.field if not is_scalar(f)]: line("@property") if field.label == d.FieldDescriptorProto.LABEL_REPEATED: msg = self.descriptors.messages[field.type_name] if msg.options.map_entry: # map generates a special Entry wrapper message container = self._import("typing", "MutableMapping") line( "def {}(self) -> {}[{}, {}]: ...", field.name, container, self.python_type(msg.field[0]), self.python_type(msg.field[1]), ) else: container = self._import( "google.protobuf.internal.containers", "RepeatedCompositeFieldContainer", ) line( "def {}(self) -> {}[{}]: ...", field.name, container, self.python_type(field), ) else: line( "def {}(self) -> {}: ...", field.name, self.python_type(field), ) line("") # Constructor line("def __init__(self,") with self._indent(): # Required args for field in [ f for f in desc.field if f.label == d.FieldDescriptorProto.LABEL_REQUIRED ]: line("{} : {},", field.name, self.python_type(field)) for field in [ f for f in desc.field if f.label != d.FieldDescriptorProto.LABEL_REQUIRED ]: if field.label == d.FieldDescriptorProto.LABEL_REPEATED: if ( field.type_name != "" and self.descriptors.messages[ field.type_name ].options.map_entry ): msg = self.descriptors.messages[field.type_name] line( "{} : {}[{}[{}, {}]] = None,", field.name, self._import("typing", "Optional", "OptionalType"), self._import("typing", "Mapping"), self.python_type(msg.field[0]), self.python_type(msg.field[1]), ) else: line( "{} : {}[{}[{}]] = None,", field.name, self._import("typing", "Optional", "OptionalType"), self._import("typing", "Iterable"), self.python_type(field), ) else: line( "{} : {}[{}] = None,", field.name, self._import("typing", "Optional", "OptionalType"), self.python_type(field), ) line(") -> None: ...") # Standard message methods line("@classmethod") line("def FromString(cls, s: bytes) -> {}: ...", qualified_name) line("def MergeFrom(self, other_msg: {}) -> None: ...", message_class) line("def CopyFrom(self, other_msg: {}) -> None: ...", message_class) line("") def write_services(self, services: d.ServiceDescriptorProto) -> None: line = self._write_line for service in services: # The service definition interface line( "class {}({}, metaclass={}):", service.name, self._import("google.protobuf.service", "Service"), self._import("abc", "ABCMeta"), ) with self._indent(): for method in service.method: line("@{}", self._import("abc", "abstractmethod")) line("def {}(self,", method.name) with self._indent(): line( "rpc_controller: {},", self._import("google.protobuf.service", "RpcController"), ) line("request: {},", self._import_message(method.input_type)) line( "done: {}[{}[[{}], None]],", self._import("typing", "Optional"), self._import("typing", "Callable"), self._import_message(method.output_type), ) line( ") -> {}[{}]: ...", self._import("concurrent.futures", "Future"), self._import_message(method.output_type), ) # The stub client line("class {}({}):", service.name + "_Stub", service.name) with self._indent(): line( "def __init__(self, rpc_channel: {}) -> None: ...", self._import("google.protobuf.service", "RpcChannel"), ) def python_type(self, field: d.FieldDescriptorProto) -> str: mapping: dict[int, Callable[[], str]] = { d.FieldDescriptorProto.TYPE_DOUBLE: lambda: "float", d.FieldDescriptorProto.TYPE_FLOAT: lambda: "float", d.FieldDescriptorProto.TYPE_INT64: lambda: "int", d.FieldDescriptorProto.TYPE_UINT64: lambda: "int", d.FieldDescriptorProto.TYPE_FIXED64: lambda: "int", d.FieldDescriptorProto.TYPE_SFIXED64: lambda: "int", d.FieldDescriptorProto.TYPE_SINT64: lambda: "int", d.FieldDescriptorProto.TYPE_INT32: lambda: "int", d.FieldDescriptorProto.TYPE_UINT32: lambda: "int", d.FieldDescriptorProto.TYPE_FIXED32: lambda: "int", d.FieldDescriptorProto.TYPE_SFIXED32: lambda: "int", d.FieldDescriptorProto.TYPE_SINT32: lambda: "int", d.FieldDescriptorProto.TYPE_BOOL: lambda: "bool", d.FieldDescriptorProto.TYPE_STRING: lambda: "str", d.FieldDescriptorProto.TYPE_BYTES: lambda: "bytes", d.FieldDescriptorProto.TYPE_ENUM: lambda: self._import_message( field.type_name ), d.FieldDescriptorProto.TYPE_MESSAGE: lambda: self._import_message( field.type_name ), d.FieldDescriptorProto.TYPE_GROUP: lambda: self._import_message( field.type_name ), } assert field.type in mapping, "Unrecognized type: " + field.type return mapping[field.type]() def write(self) -> str: imports = [] for pkg, items in self.imports.items(): if pkg.startswith("google."): imports.append(f"from {pkg} import ( # type: ignore") else: imports.append(f"from {pkg} import (") for item in sorted(items): imports.append(f" {item},") # noqa: PERF401 imports.append(")\n") return "\n".join(imports + self.lines) def is_scalar(fd: d.FileDescriptorProto) -> bool: return not ( fd.type == d.FieldDescriptorProto.TYPE_MESSAGE # noqa: PLR1714 or fd.type == d.FieldDescriptorProto.TYPE_GROUP ) def generate_mypy_stubs( descriptors: Descriptors, response: plugin.CodeGeneratorResponse ) -> None: for name, fd in descriptors.to_generate.items(): pkg_writer = PkgWriter(fd, descriptors) pkg_writer.write_enums(fd.enum_type) pkg_writer.write_messages(fd.message_type, "") pkg_writer.write_services(fd.service) assert name == fd.name assert fd.name.endswith(".proto") output = response.file.add() output.name = fd.name[:-6].replace("-", "_") + "_pb2.pyi" output.content = HEADER + pkg_writer.write() print("Writing mypy to", output.name, file=sys.stderr) def main() -> None: # Read request message from stdin data = sys.stdin.buffer.read() # Parse request request = plugin.CodeGeneratorRequest() request.ParseFromString(data) # Create response response = plugin.CodeGeneratorResponse() # Generate mypy generate_mypy_stubs(Descriptors(request), response) # Serialise response message output = response.SerializeToString() # Write to stdout sys.stdout.buffer.write(output) if __name__ == "__main__": main()