diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index d651658..f59439f 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -89,6 +89,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( service=service, @@ -156,6 +157,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property @@ -354,6 +356,7 @@ def __init__( service: ConformanceServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( endpoints={ @@ -420,6 +423,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 02e3750..a58afbc 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -56,6 +56,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( service=service, @@ -93,6 +94,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property @@ -190,6 +192,7 @@ def __init__( service: ElizaServiceSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( endpoints={ @@ -226,6 +229,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 12a6b38..07122f0 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -68,7 +68,7 @@ class {{.Name}}(Protocol):{{- range .Methods }} {{ end }} class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): - def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None: + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None: super().__init__( service=service, endpoints=lambda svc: { {{- range .Methods }} @@ -85,6 +85,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property @@ -128,7 +129,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }} class {{.Name}}WSGIApplication(ConnectWSGIApplication): - def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None) -> None: + def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None: super().__init__( endpoints={ {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}( @@ -144,6 +145,7 @@ class {{.Name}}WSGIApplication(ConnectWSGIApplication): }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property diff --git a/src/connectrpc/_compression.py b/src/connectrpc/_compression.py index 7a5fe6d..ce60fc2 100644 --- a/src/connectrpc/_compression.py +++ b/src/connectrpc/_compression.py @@ -116,9 +116,12 @@ def get_accept_encoding() -> str: ) -def negotiate_compression(accept_encoding: str) -> Compression: +def negotiate_compression( + accept_encoding: str, compressions: dict[str, Compression] | None +) -> Compression: + compressions = compressions if compressions is not None else _compressions for accept in accept_encoding.split(","): - compression = _compressions.get(accept.strip()) + compression = compressions.get(accept.strip()) if compression: return compression return _identity diff --git a/src/connectrpc/_protocol.py b/src/connectrpc/_protocol.py index cdb19dd..3995ef0 100644 --- a/src/connectrpc/_protocol.py +++ b/src/connectrpc/_protocol.py @@ -207,7 +207,7 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st ... def negotiate_stream_compression( - self, headers: Headers + self, headers: Headers, compressions: dict[str, Compression] | None ) -> tuple[Compression | None, Compression]: """Negotiates request and response compression based on headers.""" ... diff --git a/src/connectrpc/_protocol_connect.py b/src/connectrpc/_protocol_connect.py index 6f89a40..0922ff2 100644 --- a/src/connectrpc/_protocol_connect.py +++ b/src/connectrpc/_protocol_connect.py @@ -123,7 +123,7 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st return codec_name_from_content_type(content_type, stream=stream) def negotiate_stream_compression( - self, headers: Headers + self, headers: Headers, compressions: dict[str, Compression] | None ) -> tuple[Compression, Compression]: req_compression_name = headers.get( CONNECT_STREAMING_HEADER_COMPRESSION, "identity" @@ -132,7 +132,7 @@ def negotiate_stream_compression( accept_compression = headers.get( CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, "" ) - resp_compression = negotiate_compression(accept_compression) + resp_compression = negotiate_compression(accept_compression, compressions) return req_compression, resp_compression diff --git a/src/connectrpc/_protocol_grpc.py b/src/connectrpc/_protocol_grpc.py index 2ae8f7f..8c0216f 100644 --- a/src/connectrpc/_protocol_grpc.py +++ b/src/connectrpc/_protocol_grpc.py @@ -82,12 +82,12 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st return "proto" def negotiate_stream_compression( - self, headers: Headers + self, headers: Headers, compressions: dict[str, Compression] | None ) -> tuple[Compression | None, Compression]: req_compression_name = headers.get(GRPC_HEADER_COMPRESSION, "identity") req_compression = get_compression(req_compression_name) accept_compression = headers.get(GRPC_HEADER_ACCEPT_COMPRESSION, "") - resp_compression = negotiate_compression(accept_compression) + resp_compression = negotiate_compression(accept_compression, compressions) return req_compression, resp_compression diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index 490cf48..52a171f 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -86,14 +86,36 @@ def __init__( endpoints: Callable[[_SVC], Mapping[str, Endpoint]], interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: - """Initialize the ASGI application.""" + """Initialize the ASGI application. + + Args: + service: The service instance or async generator that yields the service during lifespan. + endpoints: A mapping of URL paths to endpoints resolved from service. + interceptors: A sequence of interceptors to apply to the endpoints. + read_max_bytes: Maximum size of request messages. + compressions: Supported compression algorithms. If unset, + defaults to gzip along with zstd and br if available. + If set to empty, disables compression. + """ super().__init__() self._service = service self._endpoints = endpoints self._interceptors = interceptors self._resolved_endpoints = None self._read_max_bytes = read_max_bytes + if compressions is not None: + compressions_dict: dict[str, _compression.Compression] = {} + for name in compressions: + comp = _compression.get_compression(name) + if not comp: + msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}" + raise ValueError(msg) + compressions_dict[name] = comp + self._compressions = compressions_dict + else: + self._compressions = None async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @@ -225,7 +247,9 @@ async def _handle_unary_connect( ctx: RequestContext, ) -> None: accept_encoding = headers.get("accept-encoding", "") - compression = _compression.negotiate_compression(accept_encoding) + compression = _compression.negotiate_compression( + accept_encoding, self._compressions + ) if http_method == "GET": request = await self._read_get_request(endpoint, codec, query_params) @@ -347,7 +371,7 @@ async def _handle_stream( ctx: _server_shared.RequestContext, ) -> None: req_compression, resp_compression = protocol.negotiate_stream_compression( - headers + headers, self._compressions ) writer = protocol.create_envelope_writer(codec, resp_compression) diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index d45e95c..33184fb 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -165,8 +165,18 @@ def __init__( endpoints: Mapping[str, EndpointSync], interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: - """Initialize the WSGI application.""" + """Initialize the WSGI application. + + Args: + endpoints: A mapping of URL paths to service endpoints. + interceptors: A sequence of interceptors to apply to the endpoints. + read_max_bytes: Maximum size of request messages. + compressions: Supported compression algorithms. If unset, + defaults to gzip along with zstd and br if available. + If set to empty, disables compression. + """ super().__init__() if interceptors: interceptors = [ @@ -181,6 +191,17 @@ def __init__( } self._endpoints = endpoints self._read_max_bytes = read_max_bytes + if compressions is not None: + compressions_dict: dict[str, _compression.Compression] = {} + for name in compressions: + comp = _compression.get_compression(name) + if not comp: + msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}" + raise ValueError(msg) + compressions_dict[name] = comp + self._compressions = compressions_dict + else: + self._compressions = None def __call__( self, environ: WSGIEnvironment, start_response: StartResponse @@ -253,7 +274,9 @@ def _handle_unary( # Handle compression if accepted accept_encoding = headers.get("accept-encoding", "identity") - compression = _compression.negotiate_compression(accept_encoding) + compression = _compression.negotiate_compression( + accept_encoding, self._compressions + ) res_bytes = compression.compress(res_bytes) response_headers = prepare_response_headers(base_headers, compression.name()) @@ -403,7 +426,7 @@ def _handle_stream( ctx: RequestContext[_REQ, _RES], ) -> Iterable[bytes]: req_compression, resp_compression = protocol.negotiate_stream_compression( - headers + headers, self._compressions ) codec_name = protocol.codec_name_from_content_type( diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index d83c143..38e8718 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -71,6 +71,7 @@ def __init__( *, interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( service=service, @@ -138,6 +139,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property @@ -308,6 +310,7 @@ def __init__( service: HaberdasherSync, interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, + compressions: Iterable[str] | None = None, ) -> None: super().__init__( endpoints={ @@ -374,6 +377,7 @@ def __init__( }, interceptors=interceptors, read_max_bytes=read_max_bytes, + compressions=compressions, ) @property diff --git a/test/test_compression.py b/test/test_compression.py new file mode 100644 index 0000000..3cb3a95 --- /dev/null +++ b/test/test_compression.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import pytest +from pyqwest import Client, SyncClient +from pyqwest.testing import ASGITransport, WSGITransport + +from connectrpc.client import ResponseMetadata + +from .haberdasher_connect import ( + Haberdasher, + HaberdasherASGIApplication, + HaberdasherClient, + HaberdasherClientSync, + HaberdasherSync, + HaberdasherWSGIApplication, +) +from .haberdasher_pb2 import Hat, Size + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("compressions", "encoding"), + [ + pytest.param((), "identity", id="none"), + pytest.param(("gzip",), "gzip", id="gzip"), + pytest.param(("zstd",), "zstd", id="zstd"), + pytest.param(("br",), "br", id="br"), + pytest.param(("gzip", "br", "zstd"), "zstd", id="all"), + ], +) +async def test_server_compressions_async( + compressions: tuple[str], encoding: str +) -> None: + class SimpleHaberdasher(Haberdasher): + async def make_hat(self, request, ctx): + return Hat(size=10, color="blue") + + app = HaberdasherASGIApplication(SimpleHaberdasher(), compressions=compressions) + with ResponseMetadata() as meta: + client = HaberdasherClient( + "http://localhost", + http_client=Client(ASGITransport(app)), + accept_compression=["zstd", "gzip", "br"], + ) + res = await client.make_hat(Size(inches=10)) + assert res.size == 10 + assert res.color == "blue" + assert meta.headers().get("content-encoding") == encoding + + +@pytest.mark.parametrize( + ("compressions", "encoding"), + [ + pytest.param((), "identity", id="none"), + pytest.param(("gzip",), "gzip", id="gzip"), + pytest.param(("zstd",), "zstd", id="zstd"), + pytest.param(("br",), "br", id="br"), + pytest.param(("gzip", "br", "zstd"), "zstd", id="all"), + ], +) +def test_server_compressions_sync(compressions: tuple[str], encoding: str) -> None: + class SimpleHaberdasher(HaberdasherSync): + def make_hat(self, request, ctx): + return Hat(size=10, color="blue") + + app = HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=compressions) + client = HaberdasherClientSync( + "http://localhost", + http_client=SyncClient(WSGITransport(app)), + accept_compression=["zstd", "gzip", "br"], + ) + with ResponseMetadata() as meta: + res = client.make_hat(Size(inches=10)) + assert res.size == 10 + assert res.color == "blue" + assert meta.headers().get("content-encoding") == encoding + + +def test_server_unsupported_compression_async() -> None: + class SimpleHaberdasher(HaberdasherSync): + def make_hat(self, request, ctx): + return Hat(size=10, color="blue") + + with pytest.raises(ValueError, match="unknown compression"): + HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=("unknown",)) + + +def test_server_unsupported_compression_sync() -> None: + class SimpleHaberdasher(Haberdasher): + async def make_hat(self, request, ctx): + return Hat(size=10, color="blue") + + with pytest.raises(ValueError, match="unknown compression"): + HaberdasherASGIApplication(SimpleHaberdasher(), compressions=("unknown",))