Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
*,
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
service=service,
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down Expand Up @@ -354,6 +356,7 @@ def __init__(
service: ConformanceServiceSync,
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
endpoints={
Expand Down Expand Up @@ -420,6 +423,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down
4 changes: 4 additions & 0 deletions example/example/eliza_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
*,
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
service=service,
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__(
service: ElizaServiceSync,
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
endpoints={
Expand Down Expand Up @@ -226,6 +229,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down
6 changes: 4 additions & 2 deletions protoc-gen-connect-python/generator/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class {{.Name}}(Protocol):{{- range .Methods }}
{{ end }}

class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None) -> None:
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None:
super().__init__(
service=service,
endpoints=lambda svc: { {{- range .Methods }}
Expand All @@ -85,6 +85,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down Expand Up @@ -128,7 +129,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }}


class {{.Name}}WSGIApplication(ConnectWSGIApplication):
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None) -> None:
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[str] | None = None) -> None:
super().__init__(
endpoints={ {{- range .Methods }}
"/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}(
Expand All @@ -144,6 +145,7 @@ class {{.Name}}WSGIApplication(ConnectWSGIApplication):
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down
7 changes: 5 additions & 2 deletions src/connectrpc/_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ def get_accept_encoding() -> str:
)


def negotiate_compression(accept_encoding: str) -> Compression:
def negotiate_compression(
accept_encoding: str, compressions: dict[str, Compression] | None
) -> Compression:
compressions = compressions if compressions is not None else _compressions
for accept in accept_encoding.split(","):
compression = _compressions.get(accept.strip())
compression = compressions.get(accept.strip())
if compression:
return compression
return _identity
2 changes: 1 addition & 1 deletion src/connectrpc/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st
...

def negotiate_stream_compression(
self, headers: Headers
self, headers: Headers, compressions: dict[str, Compression] | None
) -> tuple[Compression | None, Compression]:
"""Negotiates request and response compression based on headers."""
...
Expand Down
4 changes: 2 additions & 2 deletions src/connectrpc/_protocol_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st
return codec_name_from_content_type(content_type, stream=stream)

def negotiate_stream_compression(
self, headers: Headers
self, headers: Headers, compressions: dict[str, Compression] | None
) -> tuple[Compression, Compression]:
req_compression_name = headers.get(
CONNECT_STREAMING_HEADER_COMPRESSION, "identity"
Expand All @@ -132,7 +132,7 @@ def negotiate_stream_compression(
accept_compression = headers.get(
CONNECT_STREAMING_HEADER_ACCEPT_COMPRESSION, ""
)
resp_compression = negotiate_compression(accept_compression)
resp_compression = negotiate_compression(accept_compression, compressions)
return req_compression, resp_compression


Expand Down
4 changes: 2 additions & 2 deletions src/connectrpc/_protocol_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def codec_name_from_content_type(self, content_type: str, *, stream: bool) -> st
return "proto"

def negotiate_stream_compression(
self, headers: Headers
self, headers: Headers, compressions: dict[str, Compression] | None
) -> tuple[Compression | None, Compression]:
req_compression_name = headers.get(GRPC_HEADER_COMPRESSION, "identity")
req_compression = get_compression(req_compression_name)
accept_compression = headers.get(GRPC_HEADER_ACCEPT_COMPRESSION, "")
resp_compression = negotiate_compression(accept_compression)
resp_compression = negotiate_compression(accept_compression, compressions)
return req_compression, resp_compression


Expand Down
30 changes: 27 additions & 3 deletions src/connectrpc/_server_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,36 @@ def __init__(
endpoints: Callable[[_SVC], Mapping[str, Endpoint]],
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
"""Initialize the ASGI application."""
"""Initialize the ASGI application.

Args:
service: The service instance or async generator that yields the service during lifespan.
endpoints: A mapping of URL paths to endpoints resolved from service.
interceptors: A sequence of interceptors to apply to the endpoints.
read_max_bytes: Maximum size of request messages.
compressions: Supported compression algorithms. If unset,
defaults to gzip along with zstd and br if available.
If set to empty, disables compression.
"""
super().__init__()
self._service = service
self._endpoints = endpoints
self._interceptors = interceptors
self._resolved_endpoints = None
self._read_max_bytes = read_max_bytes
if compressions is not None:
compressions_dict: dict[str, _compression.Compression] = {}
for name in compressions:
comp = _compression.get_compression(name)
if not comp:
msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}"
raise ValueError(msg)
compressions_dict[name] = comp
self._compressions = compressions_dict
else:
self._compressions = None

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
Expand Down Expand Up @@ -225,7 +247,9 @@ async def _handle_unary_connect(
ctx: RequestContext,
) -> None:
accept_encoding = headers.get("accept-encoding", "")
compression = _compression.negotiate_compression(accept_encoding)
compression = _compression.negotiate_compression(
accept_encoding, self._compressions
)

if http_method == "GET":
request = await self._read_get_request(endpoint, codec, query_params)
Expand Down Expand Up @@ -347,7 +371,7 @@ async def _handle_stream(
ctx: _server_shared.RequestContext,
) -> None:
req_compression, resp_compression = protocol.negotiate_stream_compression(
headers
headers, self._compressions
)

writer = protocol.create_envelope_writer(codec, resp_compression)
Expand Down
29 changes: 26 additions & 3 deletions src/connectrpc/_server_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,18 @@ def __init__(
endpoints: Mapping[str, EndpointSync],
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
"""Initialize the WSGI application."""
"""Initialize the WSGI application.

Args:
endpoints: A mapping of URL paths to service endpoints.
interceptors: A sequence of interceptors to apply to the endpoints.
read_max_bytes: Maximum size of request messages.
compressions: Supported compression algorithms. If unset,
defaults to gzip along with zstd and br if available.
If set to empty, disables compression.
"""
super().__init__()
if interceptors:
interceptors = [
Expand All @@ -181,6 +191,17 @@ def __init__(
}
self._endpoints = endpoints
self._read_max_bytes = read_max_bytes
if compressions is not None:
compressions_dict: dict[str, _compression.Compression] = {}
for name in compressions:
comp = _compression.get_compression(name)
if not comp:
msg = f"unknown compression: '{name}': supported encodings are {', '.join(_compression.get_available_compressions())}"
raise ValueError(msg)
compressions_dict[name] = comp
self._compressions = compressions_dict
else:
self._compressions = None

def __call__(
self, environ: WSGIEnvironment, start_response: StartResponse
Expand Down Expand Up @@ -253,7 +274,9 @@ def _handle_unary(

# Handle compression if accepted
accept_encoding = headers.get("accept-encoding", "identity")
compression = _compression.negotiate_compression(accept_encoding)
compression = _compression.negotiate_compression(
accept_encoding, self._compressions
)
res_bytes = compression.compress(res_bytes)
response_headers = prepare_response_headers(base_headers, compression.name())

Expand Down Expand Up @@ -403,7 +426,7 @@ def _handle_stream(
ctx: RequestContext[_REQ, _RES],
) -> Iterable[bytes]:
req_compression, resp_compression = protocol.negotiate_stream_compression(
headers
headers, self._compressions
)

codec_name = protocol.codec_name_from_content_type(
Expand Down
4 changes: 4 additions & 0 deletions test/haberdasher_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
*,
interceptors: Iterable[Interceptor] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
service=service,
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down Expand Up @@ -308,6 +310,7 @@ def __init__(
service: HaberdasherSync,
interceptors: Iterable[InterceptorSync] = (),
read_max_bytes: int | None = None,
compressions: Iterable[str] | None = None,
) -> None:
super().__init__(
endpoints={
Expand Down Expand Up @@ -374,6 +377,7 @@ def __init__(
},
interceptors=interceptors,
read_max_bytes=read_max_bytes,
compressions=compressions,
)

@property
Expand Down
94 changes: 94 additions & 0 deletions test/test_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

import pytest
from pyqwest import Client, SyncClient
from pyqwest.testing import ASGITransport, WSGITransport

from connectrpc.client import ResponseMetadata

from .haberdasher_connect import (
Haberdasher,
HaberdasherASGIApplication,
HaberdasherClient,
HaberdasherClientSync,
HaberdasherSync,
HaberdasherWSGIApplication,
)
from .haberdasher_pb2 import Hat, Size


@pytest.mark.asyncio
@pytest.mark.parametrize(
("compressions", "encoding"),
[
pytest.param((), "identity", id="none"),
pytest.param(("gzip",), "gzip", id="gzip"),
pytest.param(("zstd",), "zstd", id="zstd"),
pytest.param(("br",), "br", id="br"),
pytest.param(("gzip", "br", "zstd"), "zstd", id="all"),
],
)
async def test_server_compressions_async(
compressions: tuple[str], encoding: str
) -> None:
class SimpleHaberdasher(Haberdasher):
async def make_hat(self, request, ctx):
return Hat(size=10, color="blue")

app = HaberdasherASGIApplication(SimpleHaberdasher(), compressions=compressions)
with ResponseMetadata() as meta:
client = HaberdasherClient(
"http://localhost",
http_client=Client(ASGITransport(app)),
accept_compression=["zstd", "gzip", "br"],
)
res = await client.make_hat(Size(inches=10))
assert res.size == 10
assert res.color == "blue"
assert meta.headers().get("content-encoding") == encoding


@pytest.mark.parametrize(
("compressions", "encoding"),
[
pytest.param((), "identity", id="none"),
pytest.param(("gzip",), "gzip", id="gzip"),
pytest.param(("zstd",), "zstd", id="zstd"),
pytest.param(("br",), "br", id="br"),
pytest.param(("gzip", "br", "zstd"), "zstd", id="all"),
],
)
def test_server_compressions_sync(compressions: tuple[str], encoding: str) -> None:
class SimpleHaberdasher(HaberdasherSync):
def make_hat(self, request, ctx):
return Hat(size=10, color="blue")

app = HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=compressions)
client = HaberdasherClientSync(
"http://localhost",
http_client=SyncClient(WSGITransport(app)),
accept_compression=["zstd", "gzip", "br"],
)
with ResponseMetadata() as meta:
res = client.make_hat(Size(inches=10))
assert res.size == 10
assert res.color == "blue"
assert meta.headers().get("content-encoding") == encoding


def test_server_unsupported_compression_async() -> None:
class SimpleHaberdasher(HaberdasherSync):
def make_hat(self, request, ctx):
return Hat(size=10, color="blue")

with pytest.raises(ValueError, match="unknown compression"):
HaberdasherWSGIApplication(SimpleHaberdasher(), compressions=("unknown",))


def test_server_unsupported_compression_sync() -> None:
class SimpleHaberdasher(Haberdasher):
async def make_hat(self, request, ctx):
return Hat(size=10, color="blue")

with pytest.raises(ValueError, match="unknown compression"):
HaberdasherASGIApplication(SimpleHaberdasher(), compressions=("unknown",))