diff --git a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py index be5f7199..61700f1b 100644 --- a/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py +++ b/packages/pynumaflow/pynumaflow/mapper/_servicer/_sync_servicer.py @@ -2,11 +2,12 @@ from concurrent.futures import ThreadPoolExecutor from collections.abc import Iterator +import grpc from google.protobuf import empty_pb2 as _empty_pb2 -from pynumaflow.shared.server import exit_on_error +from pynumaflow.shared.server import update_context_err from pynumaflow._metadata import _user_and_system_metadata_from_proto -from pynumaflow._constants import NUM_THREADS_DEFAULT, STREAM_EOF, _LOGGER, ERR_UDF_EXCEPTION_STRING +from pynumaflow._constants import NUM_THREADS_DEFAULT, _LOGGER, ERR_UDF_EXCEPTION_STRING from pynumaflow.mapper._dtypes import MapSyncCallable, Datum, MapError from pynumaflow.proto.mapper import map_pb2, map_pb2_grpc from pynumaflow.shared.synciter import SyncIterator @@ -26,6 +27,8 @@ def __init__(self, handler: MapSyncCallable, multiproc: bool = False): self.multiproc = multiproc # create a thread pool for executing UDF code self.executor = ThreadPoolExecutor(max_workers=NUM_THREADS_DEFAULT) + self.shutdown_event: threading.Event = threading.Event() + self.error: BaseException | None = None def MapFn( self, @@ -36,6 +39,7 @@ def MapFn( Applies a function to each datum element. The pascal case function name comes from the proto map_pb2_grpc.py file. """ + result_queue = None try: # The first message to be received should be a valid handshake req = next(request_iterator) @@ -57,10 +61,13 @@ def MapFn( for res in result_queue.read_iterator(): # if error handler accordingly if isinstance(res, BaseException): - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}", parent=self.multiproc - ) + err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(res)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, res, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + result_queue.close() + self.error = res + self.shutdown_event.set() return # return the result yield res @@ -69,12 +76,22 @@ def MapFn( reader_thread.join() self.executor.shutdown(cancel_futures=True) + except grpc.RpcError: + _LOGGER.warning("gRPC stream closed, shutting down the server.") + if result_queue is not None: + result_queue.close() + self.shutdown_event.set() + return + except BaseException as err: - _LOGGER.critical("UDFError, re-raising the error", exc_info=True) - # Terminate the current server process due to exception - exit_on_error( - context, f"{ERR_UDF_EXCEPTION_STRING}: {repr(err)}", parent=self.multiproc - ) + err_msg = f"UDFError, {ERR_UDF_EXCEPTION_STRING}: {repr(err)}" + _LOGGER.critical(err_msg, exc_info=True) + update_context_err(context, err, err_msg) + # Unblock the reader thread if it is waiting on queue.put() + if result_queue is not None: + result_queue.close() + self.error = err + self.shutdown_event.set() return def _process_requests( @@ -91,9 +108,20 @@ def _process_requests( # wait for all tasks to finish after all requests exhausted self.executor.shutdown(wait=True) # Indicate to the result queue that no more messages left to process - result_queue.put(STREAM_EOF) + result_queue.close() + except grpc.RpcError: + # The only error that can occur here is the gRPC stream closing + # (e.g. client disconnected). UDF exceptions are caught inside _invoke_map + # and never propagate here. + _LOGGER.warning("gRPC stream closed in reader thread, shutting down the server.") + # Let already-submitted UDF tasks finish within the graceful shutdown period + self.executor.shutdown(wait=True) + # Signal MapFn's read_iterator() loop to exit cleanly + result_queue.close() + # Trigger server shutdown (not a UDF error, so self.error is not set) + self.shutdown_event.set() except BaseException as e: - _LOGGER.critical("MapFn Error, re-raising the error", exc_info=True) + _LOGGER.critical("MapFn Error while reading requests from gRPC stream", exc_info=True) # Surface the error to the consumer; MapFn will handle and exit result_queue.put(e) diff --git a/packages/pynumaflow/pynumaflow/mapper/sync_server.py b/packages/pynumaflow/pynumaflow/mapper/sync_server.py index 9c2431b6..a96ceb7f 100644 --- a/packages/pynumaflow/pynumaflow/mapper/sync_server.py +++ b/packages/pynumaflow/pynumaflow/mapper/sync_server.py @@ -1,3 +1,5 @@ +import sys + from pynumaflow.info.types import ( ServerInfo, MAP_MODE_KEY, @@ -112,4 +114,9 @@ def start(self) -> None: server_options=self._server_options, udf_type=UDFType.Map, server_info=serv_info, + shutdown_event=self.servicer.shutdown_event, ) + + if self.servicer.error: + _LOGGER.critical("Server exiting due to UDF error: %s", self.servicer.error) + sys.exit(1)