Skip to content

Commit 237a929

Browse files
committed
feat: built-in REST catch-up when gateway join is rejected
Gateway now accepts an optional client param. When a join is rejected for too many missed events and a client is provided, the gateway automatically paginates through missed events via the REST API, dispatches them through the same registered handlers, then reconnects. Without a client, TooManyMissedEventsError is raised as before.
1 parent 142c672 commit 237a929

1 file changed

Lines changed: 59 additions & 23 deletions

File tree

stackcoin/stackcoin/gateway.py

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""StackCoin WebSocket Gateway client."""
22

3+
from __future__ import annotations
4+
35
import asyncio
46
import json
57
from collections.abc import Awaitable, Callable
6-
from typing import Any
8+
from typing import TYPE_CHECKING, Any
79

810
from .client import AnyEvent
911
from .models import Event
1012

13+
if TYPE_CHECKING:
14+
from .client import Client
15+
1116
EventHandler = Callable[[AnyEvent], Awaitable[None]]
1217

1318

@@ -16,25 +21,33 @@ class Gateway:
1621
1722
Usage::
1823
19-
gateway = stackcoin.Gateway(token="...")
24+
async with stackcoin.Client(token="...") as client:
25+
gateway = stackcoin.Gateway(token="...", client=client)
2026
21-
@gateway.on("request.accepted")
22-
async def handle_accepted(event: stackcoin.RequestAcceptedEvent):
23-
print(event.data.request_id)
27+
@gateway.on("request.accepted")
28+
async def handle_accepted(event: stackcoin.RequestAcceptedEvent):
29+
print(event.data.request_id)
2430
25-
await gateway.connect()
31+
await gateway.connect()
32+
33+
If a ``client`` is provided and the bot has been offline too long (>100
34+
missed events), the gateway automatically catches up via the REST API
35+
before reconnecting. Without a ``client``, the error is raised to the
36+
caller.
2637
"""
2738

2839
def __init__(
2940
self,
3041
token: str,
3142
*,
3243
ws_url: str = "wss://stackcoin.world/ws",
44+
client: Client | None = None,
3345
last_event_id: int = 0,
3446
on_event_id: Callable[[int], None] | None = None,
3547
):
3648
self._ws_url = ws_url.rstrip("/")
3749
self._token = token
50+
self._client = client
3851
self._handlers: dict[str, list[EventHandler]] = {}
3952
self._last_event_id = last_event_id
4053
self._on_event_id = on_event_id # callback to persist cursor position
@@ -62,7 +75,13 @@ def register_handler(self, event_type: str, handler: EventHandler) -> None:
6275
self._handlers[event_type].append(handler)
6376

6477
async def connect(self) -> None:
65-
"""Connect and listen for events. Reconnects automatically on failure."""
78+
"""Connect and listen for events. Reconnects automatically on failure.
79+
80+
If the gateway rejects a join because too many events were missed
81+
and a ``client`` was provided, the gateway catches up via the REST
82+
API and reconnects. Without a ``client``, raises
83+
:class:`TooManyMissedEventsError`.
84+
"""
6685
import websockets
6786

6887
from .errors import TooManyMissedEventsError
@@ -86,11 +105,42 @@ async def connect(self) -> None:
86105
heartbeat_task.cancel()
87106

88107
except TooManyMissedEventsError:
89-
raise # Don't retry — caller must catch up via REST
108+
if self._client is None:
109+
raise # No client — caller must handle catch-up
110+
await self._catch_up_via_rest()
111+
# Loop back to reconnect with updated cursor
90112
except Exception:
91113
if self._running:
92114
await asyncio.sleep(5)
93115

116+
async def _catch_up_via_rest(self) -> None:
117+
"""Paginate through missed events via the REST API.
118+
119+
Dispatches each event through the registered handlers, exactly
120+
as if it arrived over the WebSocket.
121+
"""
122+
assert self._client is not None
123+
events = await self._client.get_events(since_id=self._last_event_id)
124+
for event in events:
125+
await self._dispatch_event(event)
126+
127+
async def _dispatch_event(self, typed_event: AnyEvent) -> None:
128+
"""Dispatch a typed event to registered handlers and update the cursor."""
129+
if typed_event.id > self._last_event_id:
130+
self._last_event_id = typed_event.id
131+
132+
for handler in self._handlers.get(typed_event.type, []):
133+
try:
134+
await handler(typed_event)
135+
except Exception:
136+
pass
137+
138+
if typed_event.id > 0 and self._on_event_id:
139+
try:
140+
self._on_event_id(typed_event.id)
141+
except Exception:
142+
pass
143+
94144
async def _join_channel(self, ws: Any) -> None:
95145
"""Join the user:self channel with event replay."""
96146
from .errors import TooManyMissedEventsError
@@ -141,21 +191,7 @@ async def _handle_message(self, msg: list[Any]) -> None:
141191
if event_name == "event":
142192
# Parse via discriminated union RootModel, then unwrap
143193
typed_event = Event.model_validate(payload).root
144-
145-
if typed_event.id > self._last_event_id:
146-
self._last_event_id = typed_event.id
147-
148-
for handler in self._handlers.get(typed_event.type, []):
149-
try:
150-
await handler(typed_event)
151-
except Exception:
152-
pass
153-
154-
if typed_event.id > 0 and self._on_event_id:
155-
try:
156-
self._on_event_id(typed_event.id)
157-
except Exception:
158-
pass
194+
await self._dispatch_event(typed_event)
159195

160196
def stop(self) -> None:
161197
"""Signal the gateway to stop."""

0 commit comments

Comments
 (0)