Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 945caa4

Browse files
authored
fix(bug): properly handle RST_STREAM responses by opening new stream (#48)
* fix(bug): properly handle RST_STREAM responses by opening new stream Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(format): ruff Signed-off-by: Casper Nielsen <casper@diagrid.io> * fix: correct license header & move file Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(format): f-strings Signed-off-by: Casper Nielsen <casper@diagrid.io> * chore(format): ruff Signed-off-by: Casper Nielsen <casper@diagrid.io> * fix: address pr comments Signed-off-by: Casper Nielsen <casper@diagrid.io> --------- Signed-off-by: Casper Nielsen <casper@diagrid.io>
1 parent c67b696 commit 945caa4

4 files changed

Lines changed: 223 additions & 103 deletions

File tree

durabletask/internal/shared.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ def get_default_host_address() -> str:
5050
return "localhost:4001"
5151

5252

53+
DEFAULT_GRPC_KEEPALIVE_OPTIONS: tuple[tuple[str, int], ...] = (
54+
("grpc.keepalive_time_ms", 30_000),
55+
("grpc.keepalive_timeout_ms", 10_000),
56+
("grpc.http2.max_pings_without_data", 0),
57+
("grpc.keepalive_permit_without_calls", 1),
58+
)
59+
60+
5361
def get_grpc_channel(
5462
host_address: Optional[str],
5563
secure_channel: bool = False,
@@ -81,10 +89,16 @@ def get_grpc_channel(
8189
host_address = host_address[len(protocol) :]
8290
break
8391

92+
merged = dict(DEFAULT_GRPC_KEEPALIVE_OPTIONS)
93+
if options:
94+
merged.update(dict(options))
95+
merged_options = list(merged.items())
8496
if secure_channel:
85-
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
97+
channel = grpc.secure_channel(
98+
host_address, grpc.ssl_channel_credentials(), options=merged_options
99+
)
86100
else:
87-
channel = grpc.insecure_channel(host_address, options=options)
101+
channel = grpc.insecure_channel(host_address, options=merged_options)
88102

89103
# Apply interceptors ONLY if they exist
90104
if interceptors:

durabletask/worker.py

Lines changed: 83 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def __init__(
307307
self._channel_options = channel_options
308308
self._stop_timeout = stop_timeout
309309
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup
310+
self._channel_cleanup_threads: list[threading.Thread] = [] # Deferred channel close threads
310311
self._stream_ready = threading.Event()
311312
# Use provided concurrency options or create default ones
312313
self._concurrency_options = (
@@ -384,15 +385,16 @@ async def _async_run_loop(self):
384385
current_stub = None
385386
current_reader_thread = None
386387
conn_retry_count = 0
387-
conn_max_retry_delay = 60
388+
conn_max_retry_delay = 15
388389

389390
def create_fresh_connection():
390391
nonlocal current_channel, current_stub, conn_retry_count
391-
if current_channel:
392-
try:
393-
current_channel.close()
394-
except Exception:
395-
pass
392+
# Schedule deferred close of old channel to avoid orphaned TCP
393+
# connections. In-flight RPCs on the old stub may still reference
394+
# the channel from another thread, so we wait a grace period
395+
# before closing instead of closing immediately.
396+
if current_channel is not None:
397+
self._schedule_deferred_channel_close(current_channel)
396398
current_channel = None
397399
current_stub = None
398400
try:
@@ -417,31 +419,20 @@ def create_fresh_connection():
417419

418420
def invalidate_connection():
419421
nonlocal current_channel, current_stub, current_reader_thread
420-
# Cancel the response stream first to signal the reader thread to stop
421-
if self._response_stream is not None:
422-
try:
423-
if hasattr(self._response_stream, "call"):
424-
self._response_stream.call.cancel() # type: ignore
425-
else:
426-
self._response_stream.cancel() # type: ignore
427-
except Exception as e:
428-
self._logger.warning(f"Error cancelling response stream: {e}")
429-
self._response_stream = None
430-
431-
# Wait for the reader thread to finish
432-
if current_reader_thread is not None:
433-
current_reader_thread.join(timeout=1)
434-
current_reader_thread = None
435-
436-
# Close the channel
437-
if current_channel:
438-
try:
439-
current_channel.close()
440-
except Exception:
441-
pass
422+
# Schedule deferred close of old channel to avoid orphaned TCP
423+
# connections. In-flight RPCs (e.g. CompleteActivityTask) may still
424+
# be using the stub on another thread, so we defer the close by a
425+
# grace period instead of closing immediately.
426+
if current_channel is not None:
427+
self._schedule_deferred_channel_close(current_channel)
442428
current_channel = None
443429
self._current_channel = None
444430
current_stub = None
431+
self._response_stream = None
432+
433+
if current_reader_thread is not None:
434+
current_reader_thread.join(timeout=5)
435+
current_reader_thread = None
445436

446437
def should_invalidate_connection(rpc_error):
447438
error_code = rpc_error.code() # type: ignore
@@ -451,6 +442,7 @@ def should_invalidate_connection(rpc_error):
451442
grpc.StatusCode.CANCELLED,
452443
grpc.StatusCode.UNAUTHENTICATED,
453444
grpc.StatusCode.ABORTED,
445+
grpc.StatusCode.INTERNAL, # RST_STREAM from proxy means connection is dead
454446
}
455447
return error_code in connection_level_errors
456448

@@ -532,7 +524,11 @@ def stream_reader():
532524
break
533525
# Other RPC errors - put in queue for async loop to handle
534526
self._logger.warning(
535-
f"Stream reader: RPC error (code={rpc_error.code()}): {rpc_error}"
527+
"Stream reader: RPC error (code=%s): %s",
528+
rpc_error.code(),
529+
rpc_error.details()
530+
if hasattr(rpc_error, "details")
531+
else rpc_error,
536532
)
537533
break
538534
except Exception as stream_error:
@@ -654,32 +650,19 @@ def stream_reader():
654650
if should_invalidate:
655651
invalidate_connection()
656652
error_code = rpc_error.code() # type: ignore
657-
error_details = str(rpc_error)
653+
error_detail = (
654+
rpc_error.details() if hasattr(rpc_error, "details") else str(rpc_error)
655+
)
658656

659657
if error_code == grpc.StatusCode.CANCELLED:
660658
self._logger.info(f"Disconnected from {self._host_address}")
661659
break
662-
elif error_code == grpc.StatusCode.UNAVAILABLE:
663-
# Check if this is a connection timeout scenario
664-
if (
665-
"Timeout occurred" in error_details
666-
or "Failed to connect to remote host" in error_details
667-
):
668-
self._logger.warning(
669-
f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection"
670-
)
671-
else:
672-
self._logger.warning(
673-
f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying"
674-
)
675660
elif should_invalidate:
676661
self._logger.warning(
677-
f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection"
662+
f"Connection error ({error_code}): {error_detail} resetting connection"
678663
)
679664
else:
680-
self._logger.warning(
681-
f"Application-level gRPC error ({error_code}): {rpc_error}"
682-
)
665+
self._logger.warning(f"gRPC error ({error_code}): {error_detail}")
683666
except RuntimeError as ex:
684667
# RuntimeError often indicates asyncio loop issues (e.g., "cannot schedule new futures after shutdown")
685668
# Check shutdown state first
@@ -738,22 +721,46 @@ def stream_reader():
738721
except Exception as e:
739722
self._logger.warning(f"Error while waiting for worker task shutdown: {e}")
740723

724+
def _schedule_deferred_channel_close(
725+
self, old_channel: grpc.Channel, grace_timeout: float = 10.0
726+
):
727+
"""Schedule a deferred close of an old gRPC channel.
728+
729+
Waits up to *grace_timeout* seconds for in-flight RPCs to complete
730+
before closing the channel. This prevents orphaned TCP connections
731+
while still allowing in-flight work (e.g. ``CompleteActivityTask``
732+
calls on another thread) to finish gracefully.
733+
734+
During ``stop()``, ``_shutdown`` is already set so the wait returns
735+
immediately and the channel is closed at once.
736+
"""
737+
# Prune already-finished cleanup threads to avoid unbounded growth
738+
self._channel_cleanup_threads = [t for t in self._channel_cleanup_threads if t.is_alive()]
739+
740+
def _deferred_close():
741+
try:
742+
# Normal reconnect: wait grace period for RPCs to drain.
743+
# Shutdown: _shutdown is already set, returns immediately.
744+
self._shutdown.wait(timeout=grace_timeout)
745+
finally:
746+
try:
747+
old_channel.close()
748+
self._logger.debug("Deferred channel close completed")
749+
except Exception as e:
750+
self._logger.debug(f"Error during deferred channel close: {e}")
751+
752+
thread = threading.Thread(target=_deferred_close, daemon=True, name="ChannelCleanup")
753+
thread.start()
754+
self._channel_cleanup_threads.append(thread)
755+
741756
def stop(self):
742757
"""Stops the worker and waits for any pending work items to complete."""
743758
if not self._is_running:
744759
return
745760

746761
self._logger.info("Stopping gRPC worker...")
747-
if self._response_stream is not None:
748-
try:
749-
if hasattr(self._response_stream, "call"):
750-
self._response_stream.call.cancel() # type: ignore
751-
else:
752-
self._response_stream.cancel() # type: ignore
753-
except Exception as e:
754-
self._logger.warning(f"Error cancelling response stream: {e}")
755762
self._shutdown.set()
756-
# Explicitly close the gRPC channel to ensure OTel interceptors and other resources are cleaned up
763+
# Close the channel — propagates cancellation to all streams and cleans up resources
757764
if self._current_channel is not None:
758765
try:
759766
self._current_channel.close()
@@ -772,38 +779,39 @@ def stop(self):
772779
else:
773780
self._logger.debug("Worker thread completed successfully")
774781

782+
# Wait for any deferred channel-cleanup threads to finish
783+
for t in self._channel_cleanup_threads:
784+
t.join(timeout=5)
785+
self._channel_cleanup_threads.clear()
786+
775787
self._async_worker_manager.shutdown()
776788
self._logger.info("Worker shutdown completed")
777789
self._is_running = False
778790

779791
# TODO: This should be removed in the future as we do handle grpc errs
780792
def _handle_grpc_execution_error(self, rpc_error: grpc.RpcError, request_type: str):
781-
"""Handle a gRPC execution error during shutdown or benign condition."""
782-
# During shutdown or if the instance was terminated, the channel may be close
783-
# or the instance may no longer be recognized by the sidecar. Treat these as benign
784-
# to reduce noisy logging when shutting down.
793+
"""Handle a gRPC execution error during shutdown or connection reset."""
785794
details = str(rpc_error).lower()
786-
benign_errors = {
795+
# These errors are transient — the sidecar will re-dispatch the work item.
796+
transient_errors = {
787797
grpc.StatusCode.CANCELLED,
788798
grpc.StatusCode.UNAVAILABLE,
789799
grpc.StatusCode.UNKNOWN,
800+
grpc.StatusCode.INTERNAL,
790801
}
791-
if (
792-
self._shutdown.is_set()
793-
and rpc_error.code() in benign_errors
794-
or (
795-
"unknown instance id/task id combo" in details
796-
or "channel closed" in details
797-
or "locally cancelled by application" in details
798-
)
799-
):
800-
self._logger.debug(
801-
f"Ignoring gRPC {request_type} execution error during shutdown/benign condition: {rpc_error}"
802+
is_transient = rpc_error.code() in transient_errors
803+
is_benign = (
804+
"unknown instance id/task id combo" in details
805+
or "channel closed" in details
806+
or "locally cancelled by application" in details
807+
)
808+
if is_transient or is_benign or self._shutdown.is_set():
809+
self._logger.warning(
810+
f"Could not deliver {request_type} result ({rpc_error.code()}): "
811+
f"{rpc_error.details() if hasattr(rpc_error, 'details') else rpc_error} — sidecar will re-dispatch"
802812
)
803813
else:
804-
self._logger.exception(
805-
f"Failed to execute gRPC {request_type} execution error: {rpc_error}"
806-
)
814+
self._logger.exception(f"Failed to deliver {request_type} result: {rpc_error}")
807815

808816
def _execute_orchestrator(
809817
self,

0 commit comments

Comments
 (0)