11"""StackCoin WebSocket Gateway client."""
22
3+ from __future__ import annotations
4+
35import asyncio
46import json
57from collections .abc import Awaitable , Callable
6- from typing import Any
8+ from typing import TYPE_CHECKING , Any
79
810from .client import AnyEvent
911from .models import Event
1012
13+ if TYPE_CHECKING :
14+ from .client import Client
15+
1116EventHandler = Callable [[AnyEvent ], Awaitable [None ]]
1217
1318
@@ -16,25 +21,33 @@ class Gateway:
1621
1722 Usage::
1823
19- gateway = stackcoin.Gateway(token="...")
24+ async with stackcoin.Client(token="...") as client:
25+ gateway = stackcoin.Gateway(token="...", client=client)
2026
21- @gateway.on("request.accepted")
22- async def handle_accepted(event: stackcoin.RequestAcceptedEvent):
23- print(event.data.request_id)
27+ @gateway.on("request.accepted")
28+ async def handle_accepted(event: stackcoin.RequestAcceptedEvent):
29+ print(event.data.request_id)
2430
25- await gateway.connect()
31+ await gateway.connect()
32+
33+ If a ``client`` is provided and the bot has been offline too long (>100
34+ missed events), the gateway automatically catches up via the REST API
35+ before reconnecting. Without a ``client``, the error is raised to the
36+ caller.
2637 """
2738
2839 def __init__ (
2940 self ,
3041 token : str ,
3142 * ,
3243 ws_url : str = "wss://stackcoin.world/ws" ,
44+ client : Client | None = None ,
3345 last_event_id : int = 0 ,
3446 on_event_id : Callable [[int ], None ] | None = None ,
3547 ):
3648 self ._ws_url = ws_url .rstrip ("/" )
3749 self ._token = token
50+ self ._client = client
3851 self ._handlers : dict [str , list [EventHandler ]] = {}
3952 self ._last_event_id = last_event_id
4053 self ._on_event_id = on_event_id # callback to persist cursor position
@@ -62,7 +75,13 @@ def register_handler(self, event_type: str, handler: EventHandler) -> None:
6275 self ._handlers [event_type ].append (handler )
6376
6477 async def connect (self ) -> None :
65- """Connect and listen for events. Reconnects automatically on failure."""
78+ """Connect and listen for events. Reconnects automatically on failure.
79+
80+ If the gateway rejects a join because too many events were missed
81+ and a ``client`` was provided, the gateway catches up via the REST
82+ API and reconnects. Without a ``client``, raises
83+ :class:`TooManyMissedEventsError`.
84+ """
6685 import websockets
6786
6887 from .errors import TooManyMissedEventsError
@@ -86,11 +105,42 @@ async def connect(self) -> None:
86105 heartbeat_task .cancel ()
87106
88107 except TooManyMissedEventsError :
89- raise # Don't retry — caller must catch up via REST
108+ if self ._client is None :
109+ raise # No client — caller must handle catch-up
110+ await self ._catch_up_via_rest ()
111+ # Loop back to reconnect with updated cursor
90112 except Exception :
91113 if self ._running :
92114 await asyncio .sleep (5 )
93115
116+ async def _catch_up_via_rest (self ) -> None :
117+ """Paginate through missed events via the REST API.
118+
119+ Dispatches each event through the registered handlers, exactly
120+ as if it arrived over the WebSocket.
121+ """
122+ assert self ._client is not None
123+ events = await self ._client .get_events (since_id = self ._last_event_id )
124+ for event in events :
125+ await self ._dispatch_event (event )
126+
127+ async def _dispatch_event (self , typed_event : AnyEvent ) -> None :
128+ """Dispatch a typed event to registered handlers and update the cursor."""
129+ if typed_event .id > self ._last_event_id :
130+ self ._last_event_id = typed_event .id
131+
132+ for handler in self ._handlers .get (typed_event .type , []):
133+ try :
134+ await handler (typed_event )
135+ except Exception :
136+ pass
137+
138+ if typed_event .id > 0 and self ._on_event_id :
139+ try :
140+ self ._on_event_id (typed_event .id )
141+ except Exception :
142+ pass
143+
94144 async def _join_channel (self , ws : Any ) -> None :
95145 """Join the user:self channel with event replay."""
96146 from .errors import TooManyMissedEventsError
@@ -141,21 +191,7 @@ async def _handle_message(self, msg: list[Any]) -> None:
141191 if event_name == "event" :
142192 # Parse via discriminated union RootModel, then unwrap
143193 typed_event = Event .model_validate (payload ).root
144-
145- if typed_event .id > self ._last_event_id :
146- self ._last_event_id = typed_event .id
147-
148- for handler in self ._handlers .get (typed_event .type , []):
149- try :
150- await handler (typed_event )
151- except Exception :
152- pass
153-
154- if typed_event .id > 0 and self ._on_event_id :
155- try :
156- self ._on_event_id (typed_event .id )
157- except Exception :
158- pass
194+ await self ._dispatch_event (typed_event )
159195
160196 def stop (self ) -> None :
161197 """Signal the gateway to stop."""
0 commit comments