From b10de0fe31383f8f7139b21c8e622b2ce953cbfe Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 16 Jan 2026 20:46:03 -0800 Subject: [PATCH 01/16] Payload limit configuration and validation --- temporalio/bridge/worker.py | 4 +- temporalio/client.py | 4 +- temporalio/converter.py | 111 ++++++++- temporalio/exceptions.py | 24 ++ temporalio/worker/_activity.py | 26 +- temporalio/worker/_workflow.py | 55 +++-- tests/test_converter.py | 4 +- tests/worker/test_workflow.py | 439 +++++++++++++++++++++++++++++++++ 8 files changed, 633 insertions(+), 34 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index f37c76f19..b7c888e36 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -25,7 +25,9 @@ from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) -from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore +from temporalio.bridge.temporal_sdk_bridge import ( + PollShutdownError, # noqa +) from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor diff --git a/temporalio/client.py b/temporalio/client.py index b4d5af0fa..5038b2715 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -4249,7 +4249,7 @@ async def _to_proto( client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC and not self._from_raw, - client.data_converter.payload_codec, + client.data_converter._payload_codec_chain, ) return action @@ -6870,7 +6870,7 @@ async def _apply_headers( dest, self._client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC, - self._client.data_converter.payload_codec, + self._client.data_converter._payload_codec_chain, ) diff --git a/temporalio/converter.py b/temporalio/converter.py index 3849a47f4..fd5325b11 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -28,6 +28,7 @@ TypeVar, get_type_hints, overload, + override, ) import google.protobuf.json_format @@ -1238,6 +1239,85 @@ def __init__(self) -> None: super().__init__(encode_common_attributes=True) +@dataclass(frozen=True) +class PayloadLimitsConfig: + """Configuration for when payload sizes exceed limits.""" + + payload_upload_error_limit: int | Literal["disabled"] | None = None + """The limit at which a payloads size error is created.""" + payload_upload_warning_limit: int | Literal["disabled"] | None = None + """The limit at which a payloads size warning is created.""" + + +@dataclass(kw_only=True) +class _PayloadLimitsPayloadCodec(PayloadCodec, WithSerializationContext): + config: PayloadLimitsConfig + inner_codec: PayloadCodec | None + + @override + def with_context(self, context: SerializationContext) -> _PayloadLimitsPayloadCodec: + inner_codec = self.inner_codec + if isinstance(inner_codec, WithSerializationContext): + inner_codec = inner_codec.with_context(context) + if inner_codec == self.inner_codec: + return self + return _PayloadLimitsPayloadCodec(config=self.config, inner_codec=inner_codec) + + @override + async def encode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + result_payloads = list(payloads) + if self.inner_codec: + result_payloads = await self.inner_codec.encode(payloads=result_payloads) + + total_size = sum(payload.ByteSize() for payload in payloads) + + exceeded_limit_value = self._check_over_limit( + self.config.payload_upload_error_limit, + total_size, + ) + + if exceeded_limit_value: + raise temporalio.exceptions.PayloadsSizeError( + size=total_size, + limit=exceeded_limit_value, + ) + + exceeded_limit_value = self._check_over_limit( + self.config.payload_upload_warning_limit, + total_size, + ) + + if exceeded_limit_value: + # TODO: Use a context aware logger to log extra information about workflow/activity/etc + logger.warning( + "Payloads size has exceeded the warning limit. Size: %d bytes, Limit: %s bytes", + total_size, + exceeded_limit_value, + ) + + return list(result_payloads) + + @override + async def decode( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + result_payloads = list(payloads) + if self.inner_codec: + result_payloads = await self.inner_codec.decode(result_payloads) + return result_payloads + + def _check_over_limit( + self, + limit: int | Literal["disabled"] | None, + size: int, + ) -> int | None: + if limit and limit != "disabled" and limit > 0 and size > limit: + return limit + return None + + @dataclass(frozen=True) class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. @@ -1261,12 +1341,25 @@ class DataConverter(WithSerializationContext): failure_converter: FailureConverter = dataclasses.field(init=False) """Failure converter created from the :py:attr:`failure_converter_class`.""" + payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() + """Settings for payload size limits.""" + + _payload_codec_chain: PayloadCodec = dataclasses.field(init=False) + default: ClassVar[DataConverter] """Singleton default data converter.""" def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "payload_converter", self.payload_converter_class()) object.__setattr__(self, "failure_converter", self.failure_converter_class()) + object.__setattr__( + self, + "_payload_codec_chain", + _PayloadLimitsPayloadCodec( + config=self.payload_limits, + inner_codec=self.payload_codec, + ), + ) async def encode( self, values: Sequence[Any] @@ -1284,8 +1377,8 @@ async def encode( more than was given. """ payloads = self.payload_converter.to_payloads(values) - if self.payload_codec: - payloads = await self.payload_codec.encode(payloads) + payloads = await self._payload_codec_chain.encode(payloads) + return payloads async def decode( @@ -1303,8 +1396,7 @@ async def decode( Returns: Decoded and converted values. """ - if self.payload_codec: - payloads = await self.payload_codec.decode(payloads) + payloads = await self._payload_codec_chain.decode(payloads) return self.payload_converter.from_payloads(payloads, type_hints) async def encode_wrapper( @@ -1332,15 +1424,13 @@ async def encode_failure( ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) - if self.payload_codec: - await self.payload_codec.encode_failure(failure) + await self._payload_codec_chain.encode_failure(failure) async def decode_failure( self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" - if self.payload_codec: - await self.payload_codec.decode_failure(failure) + await self._payload_codec_chain.decode_failure(failure) return self.failure_converter.from_failure(failure, self.payload_converter) def with_context(self, context: SerializationContext) -> Self: @@ -1348,18 +1438,22 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter + codec_chain = self._payload_codec_chain if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) + if isinstance(codec_chain, WithSerializationContext): + codec_chain = codec_chain.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), + (codec_chain, self._payload_codec_chain), ] ): return self @@ -1367,6 +1461,7 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) + object.__setattr__(cloned, "_payload_codec_chain", codec_chain) return cloned diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f8f8ca20c..f526f968c 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -446,3 +446,27 @@ def is_cancelled_exception(exception: BaseException) -> bool: and isinstance(exception.cause, CancelledError) ) ) + + +class PayloadsSizeError(TemporalError): + """Error raised when payloads size exceeds payload size limits.""" + + def __init__(self, size: int, limit: int): + """Initialize a payloads limit error. + + Args: + size: Actual payloads size in bytes. + limit: Payloads size limit in bytes. + """ + self._size = size + self._limit = limit + + @property + def payloads_size(self) -> int: + """Actual payloads size in bytes.""" + return self._size + + @property + def payloads_limit(self) -> int: + """Payloads size limit in bytes.""" + return self._limit diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 23f2ed5cc..8d41b6ca4 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -380,6 +380,26 @@ async def _handle_start_activity_task( temporalio.exceptions.CancelledError("Cancelled"), completion.result.cancelled.failure, ) + elif isinstance( + err, + temporalio.exceptions.PayloadsSizeError, + ): + temporalio.activity.logger.warning( + "Completing as failure due to payloads size exceeding the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + extra={"__temporal_error_identifier": "ActivityFailure"}, + ) + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="PayloadsTooLarge", + message="Payloads size has exceeded the error limit.", + ), + completion.result.failed.failure, + ) + # TODO: Add force_cause to activity Failure bridge proto? + # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API + # completion.result.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE else: if ( isinstance( @@ -577,9 +597,11 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter.payload_codec is not None: + if self._encode_headers: for payload in start.header_fields.values(): - new_payload = (await data_converter.payload_codec.decode([payload]))[0] + new_payload = ( + await data_converter._payload_codec_chain.decode([payload]) + )[0] payload.CopyFrom(new_payload) running_activity.info = info diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 16e0de5e8..5a084bc4c 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -268,22 +268,20 @@ async def _handle_activation( workflow_id=workflow_id, ) data_converter = self._data_converter.with_context(workflow_context) - if self._data_converter.payload_codec: - assert data_converter.payload_codec - if not workflow: - payload_codec = data_converter.payload_codec - else: - payload_codec = _CommandAwarePayloadCodec( - workflow.instance, - context_free_payload_codec=self._data_converter.payload_codec, - workflow_context_payload_codec=data_converter.payload_codec, - workflow_context=workflow_context, - ) - await temporalio.bridge.worker.decode_activation( - act, - payload_codec, - decode_headers=self._encode_headers, + if not workflow: + payload_codec = data_converter._payload_codec_chain + else: + payload_codec = _CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter._payload_codec_chain, + workflow_context_payload_codec=data_converter._payload_codec_chain, + workflow_context=workflow_context, ) + await temporalio.bridge.worker.decode_activation( + act, + payload_codec, + decode_headers=self._encode_headers, + ) if not workflow: assert init_job workflow = _RunningWorkflow( @@ -349,12 +347,11 @@ async def _handle_activation( completion.run_id = act.run_id # Encode completion - if self._data_converter.payload_codec and workflow: - assert data_converter.payload_codec + if workflow: payload_codec = _CommandAwarePayloadCodec( workflow.instance, - context_free_payload_codec=self._data_converter.payload_codec, - workflow_context_payload_codec=data_converter.payload_codec, + context_free_payload_codec=self._data_converter._payload_codec_chain, + workflow_context_payload_codec=data_converter._payload_codec_chain, workflow_context=temporalio.converter.WorkflowSerializationContext( namespace=self._namespace, workflow_id=workflow.workflow_id, @@ -366,6 +363,26 @@ async def _handle_activation( payload_codec, encode_headers=self._encode_headers, ) + except temporalio.exceptions.PayloadsSizeError as err: + # TODO: Would like to use temporalio.workflow.logger here, but + # that requires being in the workflow event loop. Possibly refactor + # the logger core functionality into shareable class and update + # LoggerAdapter to be a decorator. + logger.warning( + "Completing as failure due to payloads size exceeding the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + ) + completion.failed.Clear() + await data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="PayloadsTooLarge", + message="Payloads size has exceeded the error limit.", + ), + completion.failed.failure, + ) + # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API + # completion.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE except Exception as err: logger.exception( "Failed encoding completion on workflow with run ID %s", act.run_id diff --git a/tests/test_converter.py b/tests/test_converter.py index bb5b3c8bc..64bb51dc1 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -574,7 +574,7 @@ async def test_failure_encoded_attributes(): # Encode it and check encoded orig_failure = Failure() orig_failure.CopyFrom(failure) - await conv.payload_codec.encode_failure(failure) + await conv._payload_codec_chain.encode_failure(failure) assert "encoding" not in failure.encoded_attributes.metadata assert "simple-codec" in failure.encoded_attributes.metadata assert ( @@ -585,7 +585,7 @@ async def test_failure_encoded_attributes(): ) # Decode and check - await conv.payload_codec.decode_failure(failure) + await conv._payload_codec_chain.decode_failure(failure) assert "encoding" in failure.encoded_attributes.metadata assert "simple-codec" not in failure.encoded_attributes.metadata assert "encoding" in failure.application_failure_info.details.payloads[0].metadata diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 719b567e7..739da0589 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -88,6 +88,7 @@ DefaultPayloadConverter, PayloadCodec, PayloadConverter, + PayloadLimitsConfig, ) from temporalio.exceptions import ( ActivityError, @@ -95,18 +96,23 @@ ApplicationErrorCategory, CancelledError, ChildWorkflowError, + PayloadsSizeError, TemporalError, TimeoutError, + TimeoutType, WorkflowAlreadyStartedError, ) from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, BUFFERED_METRIC_KIND_HISTOGRAM, + LogForwardingConfig, + LoggingConfig, MetricBuffer, MetricBufferDurationFormat, PrometheusConfig, Runtime, TelemetryConfig, + TelemetryFilter, ) from temporalio.service import RPCError, RPCStatusCode, __version__ from temporalio.testing import WorkflowEnvironment @@ -8434,3 +8440,436 @@ async def test_activity_failure_with_encoded_payload_is_decoded_in_workflow( run_timeout=timedelta(seconds=5), ) assert result == "Handled encrypted failure successfully" + + +@dataclass +class LargePayloadWorkflowInput: + activity_input_data_size: int + activity_output_data_size: int + workflow_output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadWorkflowOutput: + data: list[int] + + +@dataclass +class LargePayloadActivityInput: + output_data_size: int + data: list[int] + + +@dataclass +class LargePayloadActivityOutput: + data: list[int] + + +@activity.defn +async def large_payload_activity( + input: LargePayloadActivityInput, +) -> LargePayloadActivityOutput: + return LargePayloadActivityOutput(data=[0] * input.output_data_size) + + +@workflow.defn +class LargePayloadWorkflow: + @workflow.run + async def run(self, input: LargePayloadWorkflowInput) -> LargePayloadWorkflowOutput: + await workflow.execute_activity( + large_payload_activity, + LargePayloadActivityInput( + output_data_size=input.activity_output_data_size, + data=[0] * input.activity_input_data_size, + ), + schedule_to_close_timeout=timedelta(seconds=5), + ) + return LargePayloadWorkflowOutput(data=[0] * input.workflow_output_data_size) + + +async def test_large_payload_error_workflow_input(client: Client): + config = client.config() + error_limit = 5 * 1024 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with pytest.raises(PayloadsSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 6 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + ) + + assert error_limit == err.value.payloads_limit + + +async def test_large_payload_warning_workflow_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + client = Client(**config) + + with ( + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[0] * 2 * 1024, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the warning limit." in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_error_workflow_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=6 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + assert err.value.cause.type == TimeoutType.START_TO_CLOSE + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the error limit." in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Completing as failure due to payloads size exceeding the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_workflow_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with ( + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=2 * 1024, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the warning limit." in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_error_activity_input(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=6 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=3), + ) + + assert isinstance(err.value.cause, TimeoutError) + + def worker_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the error limit." in record.msg + ) + + assert worker_logger_capturer.find(worker_logger_predicate) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Completing as failure due to payloads size exceeding the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_warning_activity_input(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with ( + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=2 * 1024, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the warning limit." in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) + + +async def test_large_payload_error_activity_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + + # Create client for worker with custom payload limits + error_limit = 5 * 1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), + ), + ) + + with ( + LogCapturer().logs_captured( + activity.logger.base_logger + ) as activity_logger_capturer, + # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + + def activity_logger_predicate(record: logging.LogRecord) -> bool: + return ( + hasattr(record, "__temporal_error_identifier") + and getattr(record, "__temporal_error_identifier") == "ActivityFailure" + and record.levelname == "WARNING" + and "Completing as failure due to payloads size exceeding the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg + ) + + assert activity_logger_capturer.find(activity_logger_predicate) + + # Worker logger is not emitting this follow message. Maybe activity completion failures + # are not routed through the log forwarder whereas workflow completion failures are? + # def worker_logger_predicate(record: logging.LogRecord) -> bool: + # return "Payloads size has exceeded the error limit." in record.msg + + # assert worker_logger_capturer.find(worker_logger_predicate) + + +async def test_large_payload_warning_activity_result(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=5 * 1024, payload_upload_warning_limit=1024 + ), + ) + worker_client = Client(**config) + + with ( + LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=2 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + def root_logger_predicate(record: logging.LogRecord) -> bool: + return ( + record.levelname == "WARNING" + and "Payloads size has exceeded the warning limit." in record.msg + ) + + assert root_logger_capturer.find(root_logger_predicate) From b80b725e3a3491565d4c6a0e8291aace32dc79ef Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:07:36 -0800 Subject: [PATCH 02/16] Remove override --- temporalio/converter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index fd5325b11..b716f3111 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -28,7 +28,6 @@ TypeVar, get_type_hints, overload, - override, ) import google.protobuf.json_format @@ -1254,7 +1253,6 @@ class _PayloadLimitsPayloadCodec(PayloadCodec, WithSerializationContext): config: PayloadLimitsConfig inner_codec: PayloadCodec | None - @override def with_context(self, context: SerializationContext) -> _PayloadLimitsPayloadCodec: inner_codec = self.inner_codec if isinstance(inner_codec, WithSerializationContext): @@ -1263,7 +1261,6 @@ def with_context(self, context: SerializationContext) -> _PayloadLimitsPayloadCo return self return _PayloadLimitsPayloadCodec(config=self.config, inner_codec=inner_codec) - @override async def encode( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: @@ -1299,7 +1296,6 @@ async def encode( return list(result_payloads) - @override async def decode( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: From f8eca98e47acb5855e08bb840d49ce3026b00e38 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:21:14 -0800 Subject: [PATCH 03/16] Call super.__init__ --- temporalio/exceptions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index f526f968c..a31e40466 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -458,6 +458,7 @@ def __init__(self, size: int, limit: int): size: Actual payloads size in bytes. limit: Payloads size limit in bytes. """ + super().__init__() self._size = size self._limit = limit From 79fe1755c02beffaa0b44b85abc21401f8a2fe3e Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 14:34:20 -0800 Subject: [PATCH 04/16] Undo change to worker.py --- temporalio/bridge/worker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index b7c888e36..f37c76f19 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -25,9 +25,7 @@ from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) -from temporalio.bridge.temporal_sdk_bridge import ( - PollShutdownError, # noqa -) +from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore from temporalio.worker._command_aware_visitor import CommandAwarePayloadVisitor From 95a66d72397da636df04cda56d8ed29feea49610 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 15:10:28 -0800 Subject: [PATCH 05/16] Remove activity.logger capture due to test run interference --- tests/worker/test_workflow.py | 85 ++++++++++------------------------- 1 file changed, 23 insertions(+), 62 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 739da0589..8fe77819d 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8767,75 +8767,36 @@ def root_logger_predicate(record: logging.LogRecord) -> bool: async def test_large_payload_error_activity_result(client: Client): - # Create worker runtime with forwarded logger - worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") - worker_runtime = Runtime( - telemetry=TelemetryConfig( - logging=LoggingConfig( - filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), - forwarding=LogForwardingConfig(logger=worker_logger), - ) - ) - ) - # Create client for worker with custom payload limits + config = client.config() error_limit = 5 * 1024 - worker_client = await Client.connect( - client.service_client.config.target_host, - namespace=client.namespace, - runtime=worker_runtime, - data_converter=dataclasses.replace( - temporalio.converter.default(), - payload_limits=PayloadLimitsConfig( - payload_upload_error_limit=error_limit, - payload_upload_warning_limit=1024, - ), + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, payload_upload_warning_limit=1024 ), ) + worker_client = Client(**config) - with ( - LogCapturer().logs_captured( - activity.logger.base_logger - ) as activity_logger_capturer, - # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, - ): - async with new_worker( - worker_client, LargePayloadWorkflow, activities=[large_payload_activity] - ) as worker: - with pytest.raises(WorkflowFailureError) as err: - await client.execute_workflow( - LargePayloadWorkflow.run, - LargePayloadWorkflowInput( - activity_input_data_size=0, - activity_output_data_size=6 * 1024, - workflow_output_data_size=0, - data=[], - ), - id=f"workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - ) - - assert isinstance(err.value.cause, ActivityError) - assert isinstance(err.value.cause.cause, ApplicationError) - - def activity_logger_predicate(record: logging.LogRecord) -> bool: - return ( - hasattr(record, "__temporal_error_identifier") - and getattr(record, "__temporal_error_identifier") == "ActivityFailure" - and record.levelname == "WARNING" - and "Completing as failure due to payloads size exceeding the error limit." - in record.msg - and f"Limit: {error_limit} bytes" in record.msg + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, ) - assert activity_logger_capturer.find(activity_logger_predicate) - - # Worker logger is not emitting this follow message. Maybe activity completion failures - # are not routed through the log forwarder whereas workflow completion failures are? - # def worker_logger_predicate(record: logging.LogRecord) -> bool: - # return "Payloads size has exceeded the error limit." in record.msg - - # assert worker_logger_capturer.find(worker_logger_predicate) + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + assert "PayloadsTooLarge" == err.value.cause.cause.type async def test_large_payload_warning_activity_result(client: Client): From ebca76c2667bb4cbb89ebc8a6cf2eca529b0e339 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 15:57:31 -0800 Subject: [PATCH 06/16] Revert "Remove activity.logger capture due to test run interference" This reverts commit 95a66d72397da636df04cda56d8ed29feea49610. --- tests/worker/test_workflow.py | 85 +++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 23 deletions(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 8fe77819d..739da0589 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8767,36 +8767,75 @@ def root_logger_predicate(record: logging.LogRecord) -> bool: async def test_large_payload_error_activity_result(client: Client): + # Create worker runtime with forwarded logger + worker_logger = logging.getLogger(f"log-{uuid.uuid4()}") + worker_runtime = Runtime( + telemetry=TelemetryConfig( + logging=LoggingConfig( + filter=TelemetryFilter(core_level="WARN", other_level="ERROR"), + forwarding=LogForwardingConfig(logger=worker_logger), + ) + ) + ) + # Create client for worker with custom payload limits - config = client.config() error_limit = 5 * 1024 - config["data_converter"] = dataclasses.replace( - temporalio.converter.default(), - payload_limits=PayloadLimitsConfig( - payload_upload_error_limit=error_limit, payload_upload_warning_limit=1024 + worker_client = await Client.connect( + client.service_client.config.target_host, + namespace=client.namespace, + runtime=worker_runtime, + data_converter=dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig( + payload_upload_error_limit=error_limit, + payload_upload_warning_limit=1024, + ), ), ) - worker_client = Client(**config) - async with new_worker( - worker_client, LargePayloadWorkflow, activities=[large_payload_activity] - ) as worker: - with pytest.raises(WorkflowFailureError) as err: - await client.execute_workflow( - LargePayloadWorkflow.run, - LargePayloadWorkflowInput( - activity_input_data_size=0, - activity_output_data_size=6 * 1024, - workflow_output_data_size=0, - data=[], - ), - id=f"workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, + with ( + LogCapturer().logs_captured( + activity.logger.base_logger + ) as activity_logger_capturer, + # LogCapturer().logs_captured(worker_logger) as worker_logger_capturer, + ): + async with new_worker( + worker_client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=6 * 1024, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + + def activity_logger_predicate(record: logging.LogRecord) -> bool: + return ( + hasattr(record, "__temporal_error_identifier") + and getattr(record, "__temporal_error_identifier") == "ActivityFailure" + and record.levelname == "WARNING" + and "Completing as failure due to payloads size exceeding the error limit." + in record.msg + and f"Limit: {error_limit} bytes" in record.msg ) - assert isinstance(err.value.cause, ActivityError) - assert isinstance(err.value.cause.cause, ApplicationError) - assert "PayloadsTooLarge" == err.value.cause.cause.type + assert activity_logger_capturer.find(activity_logger_predicate) + + # Worker logger is not emitting this follow message. Maybe activity completion failures + # are not routed through the log forwarder whereas workflow completion failures are? + # def worker_logger_predicate(record: logging.LogRecord) -> bool: + # return "Payloads size has exceeded the error limit." in record.msg + + # assert worker_logger_capturer.find(worker_logger_predicate) async def test_large_payload_warning_activity_result(client: Client): From 9515b4e7667e12ca231c52201499d80288f68f6e Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 15:58:11 -0800 Subject: [PATCH 07/16] Fix test_activity_failure_trace_identifier test --- tests/worker/test_activity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index e66a42dc0..b97f08a8a 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -1687,7 +1687,7 @@ async def raise_error(): assert handler._trace_identifiers == 1 finally: - activity.logger.base_logger.removeHandler(CustomLogHandler()) + activity.logger.base_logger.removeHandler(handler) async def test_activity_heartbeat_context( From 3687b2f106aee3fc2c05206f3fd64b19a4eb28a7 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Tue, 20 Jan 2026 16:05:27 -0800 Subject: [PATCH 08/16] Update error messages --- temporalio/worker/_activity.py | 2 +- temporalio/worker/_workflow.py | 2 +- tests/worker/test_workflow.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 8d41b6ca4..c30304acf 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -385,7 +385,7 @@ async def _handle_start_activity_task( temporalio.exceptions.PayloadsSizeError, ): temporalio.activity.logger.warning( - "Completing as failure due to payloads size exceeding the error limit. Size: %d bytes, Limit: %d bytes", + "Activity task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", err.payloads_size, err.payloads_limit, extra={"__temporal_error_identifier": "ActivityFailure"}, diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 5a084bc4c..17f648aae 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -369,7 +369,7 @@ async def _handle_activation( # the logger core functionality into shareable class and update # LoggerAdapter to be a decorator. logger.warning( - "Completing as failure due to payloads size exceeding the error limit. Size: %d bytes, Limit: %d bytes", + "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", err.payloads_size, err.payloads_limit, ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 739da0589..8b5eb782f 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8614,7 +8614,7 @@ def worker_logger_predicate(record: logging.LogRecord) -> bool: def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Completing as failure due to payloads size exceeding the error limit." + and "Workflow task failed: payloads size exceeded the error limit." in record.msg and f"Limit: {error_limit} bytes" in record.msg ) @@ -8721,7 +8721,7 @@ def worker_logger_predicate(record: logging.LogRecord) -> bool: def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Completing as failure due to payloads size exceeding the error limit." + and "Workflow task failed: payloads size exceeded the error limit." in record.msg and f"Limit: {error_limit} bytes" in record.msg ) @@ -8823,7 +8823,7 @@ def activity_logger_predicate(record: logging.LogRecord) -> bool: hasattr(record, "__temporal_error_identifier") and getattr(record, "__temporal_error_identifier") == "ActivityFailure" and record.levelname == "WARNING" - and "Completing as failure due to payloads size exceeding the error limit." + and "Activity task failed: payloads size exceeded the error limit." in record.msg and f"Limit: {error_limit} bytes" in record.msg ) From 1221a27198bb42cb62a80ba7d0375332cb297bdd Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 08:24:29 -0800 Subject: [PATCH 09/16] Remove disabled option for limits --- temporalio/converter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index b716f3111..645d39538 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1242,9 +1242,9 @@ def __init__(self) -> None: class PayloadLimitsConfig: """Configuration for when payload sizes exceed limits.""" - payload_upload_error_limit: int | Literal["disabled"] | None = None + payload_upload_error_limit: int | None = None """The limit at which a payloads size error is created.""" - payload_upload_warning_limit: int | Literal["disabled"] | None = None + payload_upload_warning_limit: int | None = None """The limit at which a payloads size warning is created.""" @@ -1306,10 +1306,10 @@ async def decode( def _check_over_limit( self, - limit: int | Literal["disabled"] | None, + limit: int | None, size: int, ) -> int | None: - if limit and limit != "disabled" and limit > 0 and size > limit: + if limit and limit > 0 and size > limit: return limit return None From 103317e8a888ee1e4014ee68aab3465b7185bcf8 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:09:26 -0800 Subject: [PATCH 10/16] Rename error type and extend from FailureError --- temporalio/converter.py | 8 ++++++-- temporalio/exceptions.py | 4 ++-- temporalio/worker/_activity.py | 11 ++--------- temporalio/worker/_workflow.py | 10 ++-------- tests/worker/test_workflow.py | 19 ++++++++++--------- 5 files changed, 22 insertions(+), 30 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 645d39538..9dcb7f2b4 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1066,6 +1066,10 @@ def _error_to_failure( failure.nexus_operation_execution_failure_info.operation_token = ( error.operation_token ) + elif isinstance(error, temporalio.exceptions.PayloadSizeError): + failure.application_failure_info.SetInParent() + failure.application_failure_info.type = "PayloadSizeError" + failure.application_failure_info.non_retryable = False def _nexus_handler_error_to_failure( self, @@ -1276,7 +1280,7 @@ async def encode( ) if exceeded_limit_value: - raise temporalio.exceptions.PayloadsSizeError( + raise temporalio.exceptions.PayloadSizeError( size=total_size, limit=exceeded_limit_value, ) @@ -1289,7 +1293,7 @@ async def encode( if exceeded_limit_value: # TODO: Use a context aware logger to log extra information about workflow/activity/etc logger.warning( - "Payloads size has exceeded the warning limit. Size: %d bytes, Limit: %s bytes", + "Payloads size exceeded the warning limit. Size: %d bytes, Limit: %s bytes", total_size, exceeded_limit_value, ) diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index a31e40466..7c2f78b71 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -448,7 +448,7 @@ def is_cancelled_exception(exception: BaseException) -> bool: ) -class PayloadsSizeError(TemporalError): +class PayloadSizeError(FailureError): """Error raised when payloads size exceeds payload size limits.""" def __init__(self, size: int, limit: int): @@ -458,7 +458,7 @@ def __init__(self, size: int, limit: int): size: Actual payloads size in bytes. limit: Payloads size limit in bytes. """ - super().__init__() + super().__init__("Payloads size exceeded the error limit") self._size = size self._limit = limit diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c30304acf..a25b6ac36 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -382,7 +382,7 @@ async def _handle_start_activity_task( ) elif isinstance( err, - temporalio.exceptions.PayloadsSizeError, + temporalio.exceptions.PayloadSizeError, ): temporalio.activity.logger.warning( "Activity task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", @@ -391,15 +391,8 @@ async def _handle_start_activity_task( extra={"__temporal_error_identifier": "ActivityFailure"}, ) await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="PayloadsTooLarge", - message="Payloads size has exceeded the error limit.", - ), - completion.result.failed.failure, + err, completion.result.failed.failure ) - # TODO: Add force_cause to activity Failure bridge proto? - # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API - # completion.result.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE else: if ( isinstance( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 17f648aae..a07526000 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -363,7 +363,7 @@ async def _handle_activation( payload_codec, encode_headers=self._encode_headers, ) - except temporalio.exceptions.PayloadsSizeError as err: + except temporalio.exceptions.PayloadSizeError as err: # TODO: Would like to use temporalio.workflow.logger here, but # that requires being in the workflow event loop. Possibly refactor # the logger core functionality into shareable class and update @@ -374,13 +374,7 @@ async def _handle_activation( err.payloads_limit, ) completion.failed.Clear() - await data_converter.encode_failure( - temporalio.exceptions.ApplicationError( - type="PayloadsTooLarge", - message="Payloads size has exceeded the error limit.", - ), - completion.failed.failure, - ) + await data_converter.encode_failure(err, completion.failed.failure) # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API # completion.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE except Exception as err: diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 8b5eb782f..df7b0ea79 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -96,7 +96,7 @@ ApplicationErrorCategory, CancelledError, ChildWorkflowError, - PayloadsSizeError, + PayloadSizeError, TemporalError, TimeoutError, TimeoutType, @@ -8499,7 +8499,7 @@ async def test_large_payload_error_workflow_input(client: Client): ) client = Client(**config) - with pytest.raises(PayloadsSizeError) as err: + with pytest.raises(PayloadSizeError) as err: await client.execute_workflow( LargePayloadWorkflow.run, LargePayloadWorkflowInput( @@ -8546,7 +8546,7 @@ async def test_large_payload_warning_workflow_input(client: Client): def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Payloads size has exceeded the warning limit." in record.msg + and "Payloads size exceeded the warning limit" in record.msg ) assert root_logger_capturer.find(root_logger_predicate) @@ -8604,9 +8604,10 @@ async def test_large_payload_error_workflow_result(client: Client): assert err.value.cause.type == TimeoutType.START_TO_CLOSE def worker_logger_predicate(record: logging.LogRecord) -> bool: + print(f"Justin Record: {record}") return ( record.levelname == "WARNING" - and "Payloads size has exceeded the error limit." in record.msg + and "Payloads size exceeded the error limit" in record.msg ) assert worker_logger_capturer.find(worker_logger_predicate) @@ -8654,7 +8655,7 @@ async def test_large_payload_warning_workflow_result(client: Client): def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Payloads size has exceeded the warning limit." in record.msg + and "Payloads size exceeded the warning limit" in record.msg ) assert root_logger_capturer.find(root_logger_predicate) @@ -8713,7 +8714,7 @@ async def test_large_payload_error_activity_input(client: Client): def worker_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Payloads size has exceeded the error limit." in record.msg + and "Payloads size exceeded the error limit" in record.msg ) assert worker_logger_capturer.find(worker_logger_predicate) @@ -8760,7 +8761,7 @@ async def test_large_payload_warning_activity_input(client: Client): def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Payloads size has exceeded the warning limit." in record.msg + and "Payloads size exceeded the warning limit" in record.msg ) assert root_logger_capturer.find(root_logger_predicate) @@ -8833,7 +8834,7 @@ def activity_logger_predicate(record: logging.LogRecord) -> bool: # Worker logger is not emitting this follow message. Maybe activity completion failures # are not routed through the log forwarder whereas workflow completion failures are? # def worker_logger_predicate(record: logging.LogRecord) -> bool: - # return "Payloads size has exceeded the error limit." in record.msg + # return "Payloads size exceeded the error limit" in record.msg # assert worker_logger_capturer.find(worker_logger_predicate) @@ -8869,7 +8870,7 @@ async def test_large_payload_warning_activity_result(client: Client): def root_logger_predicate(record: logging.LogRecord) -> bool: return ( record.levelname == "WARNING" - and "Payloads size has exceeded the warning limit." in record.msg + and "Payloads size exceeded the warning limit" in record.msg ) assert root_logger_capturer.find(root_logger_predicate) From 7ecb7662f836932bdd08699fc54eda439687884d Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 09:23:20 -0800 Subject: [PATCH 11/16] Use warnings modules --- temporalio/converter.py | 6 ++-- tests/worker/test_workflow.py | 57 +++++++++++------------------------ 2 files changed, 19 insertions(+), 44 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 9dcb7f2b4..486d9271d 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1292,10 +1292,8 @@ async def encode( if exceeded_limit_value: # TODO: Use a context aware logger to log extra information about workflow/activity/etc - logger.warning( - "Payloads size exceeded the warning limit. Size: %d bytes, Limit: %s bytes", - total_size, - exceeded_limit_value, + warnings.warn( + f"Payloads size exceeded the warning limit. Size: {total_size} bytes, Limit: {exceeded_limit_value} bytes" ) return list(result_payloads) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index df7b0ea79..f6d276471 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -15,6 +15,7 @@ import time import typing import uuid +import warnings from abc import ABC, abstractmethod from collections.abc import Awaitable, Mapping, Sequence from dataclasses import dataclass @@ -8525,9 +8526,7 @@ async def test_large_payload_warning_workflow_input(client: Client): ) client = Client(**config) - with ( - LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, - ): + with warnings.catch_warnings(record=True) as w: async with new_worker( client, LargePayloadWorkflow, activities=[large_payload_activity] ) as worker: @@ -8543,13 +8542,9 @@ async def test_large_payload_warning_workflow_input(client: Client): task_queue=worker.task_queue, ) - def root_logger_predicate(record: logging.LogRecord) -> bool: - return ( - record.levelname == "WARNING" - and "Payloads size exceeded the warning limit" in record.msg - ) - - assert root_logger_capturer.find(root_logger_predicate) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) async def test_large_payload_error_workflow_result(client: Client): @@ -8633,9 +8628,7 @@ async def test_large_payload_warning_workflow_result(client: Client): ) worker_client = Client(**config) - with ( - LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, - ): + with warnings.catch_warnings(record=True) as w: async with new_worker( worker_client, LargePayloadWorkflow, activities=[large_payload_activity] ) as worker: @@ -8652,13 +8645,9 @@ async def test_large_payload_warning_workflow_result(client: Client): execution_timeout=timedelta(seconds=3), ) - def root_logger_predicate(record: logging.LogRecord) -> bool: - return ( - record.levelname == "WARNING" - and "Payloads size exceeded the warning limit" in record.msg - ) - - assert root_logger_capturer.find(root_logger_predicate) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) async def test_large_payload_error_activity_input(client: Client): @@ -8740,9 +8729,7 @@ async def test_large_payload_warning_activity_input(client: Client): ) worker_client = Client(**config) - with ( - LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, - ): + with warnings.catch_warnings(record=True) as w: async with new_worker( worker_client, LargePayloadWorkflow, activities=[large_payload_activity] ) as worker: @@ -8758,13 +8745,9 @@ async def test_large_payload_warning_activity_input(client: Client): task_queue=worker.task_queue, ) - def root_logger_predicate(record: logging.LogRecord) -> bool: - return ( - record.levelname == "WARNING" - and "Payloads size exceeded the warning limit" in record.msg - ) - - assert root_logger_capturer.find(root_logger_predicate) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) async def test_large_payload_error_activity_result(client: Client): @@ -8849,9 +8832,7 @@ async def test_large_payload_warning_activity_result(client: Client): ) worker_client = Client(**config) - with ( - LogCapturer().logs_captured(logging.getLogger()) as root_logger_capturer, - ): + with warnings.catch_warnings(record=True) as w: async with new_worker( worker_client, LargePayloadWorkflow, activities=[large_payload_activity] ) as worker: @@ -8867,10 +8848,6 @@ async def test_large_payload_warning_activity_result(client: Client): task_queue=worker.task_queue, ) - def root_logger_predicate(record: logging.LogRecord) -> bool: - return ( - record.levelname == "WARNING" - and "Payloads size exceeded the warning limit" in record.msg - ) - - assert root_logger_capturer.find(root_logger_predicate) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) From 5b4e79f9d9da3c6f35cb335c938d21cd9096f108 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 14:12:15 -0800 Subject: [PATCH 12/16] Replace _PayloadLimitsPayloadCodec with DataConverter methods --- temporalio/bridge/worker.py | 8 +- temporalio/client.py | 11 +- temporalio/converter.py | 229 ++++++++++++++++----------------- temporalio/worker/_activity.py | 5 +- temporalio/worker/_workflow.py | 100 +++++++------- tests/test_converter.py | 4 +- tests/worker/test_visitor.py | 9 +- 7 files changed, 188 insertions(+), 178 deletions(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index f37c76f19..84c3f215c 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -293,21 +293,21 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: async def decode_activation( activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation, - codec: temporalio.converter.PayloadCodec, + data_converter: temporalio.converter.DataConverter, decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(codec.decode), activation) + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, - codec: temporalio.converter.PayloadCodec, + data_converter: temporalio.converter.DataConverter, encode_headers: bool, ) -> None: """Encode all payloads in the completion.""" await CommandAwarePayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(_Visitor(codec.encode), completion) + ).visit(_Visitor(data_converter._encode_payload_sequence), completion) diff --git a/temporalio/client.py b/temporalio/client.py index 5038b2715..53f81ad22 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -4249,7 +4249,7 @@ async def _to_proto( client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC and not self._from_raw, - client.data_converter._payload_codec_chain, + client.data_converter, ) return action @@ -6870,7 +6870,7 @@ async def _apply_headers( dest, self._client.config(active_config=True)["header_codec_behavior"] == HeaderCodecBehavior.CODEC, - self._client.data_converter._payload_codec_chain, + self._client.data_converter, ) @@ -6878,14 +6878,13 @@ async def _apply_headers( source: Mapping[str, temporalio.api.common.v1.Payload] | None, dest: MessageMap[str, temporalio.api.common.v1.Payload], encode_headers: bool, - codec: temporalio.converter.PayloadCodec | None, + data_converter: DataConverter, ) -> None: if source is None: return - if encode_headers and codec is not None: + if encode_headers: for payload in source.values(): - new_payload = (await codec.encode([payload]))[0] - payload.CopyFrom(new_payload) + payload.CopyFrom(await data_converter._encode_payload(payload)) temporalio.common._apply_headers(source, dest) diff --git a/temporalio/converter.py b/temporalio/converter.py index 486d9271d..9cdcfd428 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -824,45 +824,14 @@ async def encode_failure(self, failure: temporalio.api.failure.v1.Failure) -> No It is not guaranteed that all failures will be encoded with this method rather than encoding the underlying payloads. """ - await self._apply_to_failure_payloads(failure, self.encode_wrapper) + await DataConverter._apply_to_failure_payloads(failure, self.encode_wrapper) async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: """Decode payloads of a failure. Intended as a helper method, not for overriding. It is not guaranteed that all failures will be decoded with this method rather than decoding the underlying payloads. """ - await self._apply_to_failure_payloads(failure, self.decode_wrapper) - - async def _apply_to_failure_payloads( - self, - failure: temporalio.api.failure.v1.Failure, - cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], - ) -> None: - if failure.HasField("encoded_attributes"): - # Wrap in payloads and merge back - payloads = temporalio.api.common.v1.Payloads( - payloads=[failure.encoded_attributes] - ) - await cb(payloads) - failure.encoded_attributes.CopyFrom(payloads.payloads[0]) - if failure.HasField( - "application_failure_info" - ) and failure.application_failure_info.HasField("details"): - await cb(failure.application_failure_info.details) - elif failure.HasField( - "timeout_failure_info" - ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): - await cb(failure.timeout_failure_info.last_heartbeat_details) - elif failure.HasField( - "canceled_failure_info" - ) and failure.canceled_failure_info.HasField("details"): - await cb(failure.canceled_failure_info.details) - elif failure.HasField( - "reset_workflow_failure_info" - ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): - await cb(failure.reset_workflow_failure_info.last_heartbeat_details) - if failure.HasField("cause"): - await self._apply_to_failure_payloads(failure.cause, cb) + await DataConverter._apply_to_failure_payloads(failure, self.decode_wrapper) class FailureConverter(ABC): @@ -1252,70 +1221,6 @@ class PayloadLimitsConfig: """The limit at which a payloads size warning is created.""" -@dataclass(kw_only=True) -class _PayloadLimitsPayloadCodec(PayloadCodec, WithSerializationContext): - config: PayloadLimitsConfig - inner_codec: PayloadCodec | None - - def with_context(self, context: SerializationContext) -> _PayloadLimitsPayloadCodec: - inner_codec = self.inner_codec - if isinstance(inner_codec, WithSerializationContext): - inner_codec = inner_codec.with_context(context) - if inner_codec == self.inner_codec: - return self - return _PayloadLimitsPayloadCodec(config=self.config, inner_codec=inner_codec) - - async def encode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] - ) -> list[temporalio.api.common.v1.Payload]: - result_payloads = list(payloads) - if self.inner_codec: - result_payloads = await self.inner_codec.encode(payloads=result_payloads) - - total_size = sum(payload.ByteSize() for payload in payloads) - - exceeded_limit_value = self._check_over_limit( - self.config.payload_upload_error_limit, - total_size, - ) - - if exceeded_limit_value: - raise temporalio.exceptions.PayloadSizeError( - size=total_size, - limit=exceeded_limit_value, - ) - - exceeded_limit_value = self._check_over_limit( - self.config.payload_upload_warning_limit, - total_size, - ) - - if exceeded_limit_value: - # TODO: Use a context aware logger to log extra information about workflow/activity/etc - warnings.warn( - f"Payloads size exceeded the warning limit. Size: {total_size} bytes, Limit: {exceeded_limit_value} bytes" - ) - - return list(result_payloads) - - async def decode( - self, payloads: Sequence[temporalio.api.common.v1.Payload] - ) -> list[temporalio.api.common.v1.Payload]: - result_payloads = list(payloads) - if self.inner_codec: - result_payloads = await self.inner_codec.decode(result_payloads) - return result_payloads - - def _check_over_limit( - self, - limit: int | None, - size: int, - ) -> int | None: - if limit and limit > 0 and size > limit: - return limit - return None - - @dataclass(frozen=True) class DataConverter(WithSerializationContext): """Data converter for converting and encoding payloads to/from Python values. @@ -1342,22 +1247,12 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" - _payload_codec_chain: PayloadCodec = dataclasses.field(init=False) - default: ClassVar[DataConverter] """Singleton default data converter.""" def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "payload_converter", self.payload_converter_class()) object.__setattr__(self, "failure_converter", self.failure_converter_class()) - object.__setattr__( - self, - "_payload_codec_chain", - _PayloadLimitsPayloadCodec( - config=self.payload_limits, - inner_codec=self.payload_codec, - ), - ) async def encode( self, values: Sequence[Any] @@ -1375,7 +1270,7 @@ async def encode( more than was given. """ payloads = self.payload_converter.to_payloads(values) - payloads = await self._payload_codec_chain.encode(payloads) + payloads = await self._encode_payload_sequence(payloads) return payloads @@ -1394,7 +1289,7 @@ async def decode( Returns: Decoded and converted values. """ - payloads = await self._payload_codec_chain.decode(payloads) + payloads = await self._decode_payload_sequence(payloads) return self.payload_converter.from_payloads(payloads, type_hints) async def encode_wrapper( @@ -1422,13 +1317,13 @@ async def encode_failure( ) -> None: """Convert and encode failure.""" self.failure_converter.to_failure(exception, self.payload_converter, failure) - await self._payload_codec_chain.encode_failure(failure) + await DataConverter._apply_to_failure_payloads(failure, self._encode_payloads) async def decode_failure( self, failure: temporalio.api.failure.v1.Failure ) -> BaseException: """Decode and convert failure.""" - await self._payload_codec_chain.decode_failure(failure) + await DataConverter._apply_to_failure_payloads(failure, self._decode_payloads) return self.failure_converter.from_failure(failure, self.payload_converter) def with_context(self, context: SerializationContext) -> Self: @@ -1436,22 +1331,18 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter - codec_chain = self._payload_codec_chain if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) - if isinstance(codec_chain, WithSerializationContext): - codec_chain = codec_chain.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), - (codec_chain, self._payload_codec_chain), ] ): return self @@ -1459,9 +1350,115 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) - object.__setattr__(cloned, "_payload_codec_chain", codec_chain) return cloned + async def _encode_payload( + self, payload: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + if self.payload_codec: + payload = (await self.payload_codec.encode([payload]))[0] + self._validate_payload_limits([payload]) + return payload + + async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.payload_codec: + await self.payload_codec.encode_wrapper(payloads) + self._validate_payload_limits(payloads.payloads) + + async def _encode_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + encoded_payloads = list(payloads) + if self.payload_codec: + encoded_payloads = await self.payload_codec.encode(encoded_payloads) + self._validate_payload_limits(encoded_payloads) + return encoded_payloads + + async def _decode_payload( + self, payload: temporalio.api.common.v1.Payload + ) -> temporalio.api.common.v1.Payload: + if self.payload_codec: + payload = (await self.payload_codec.decode([payload]))[0] + return payload + + async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.payload_codec: + await self.payload_codec.decode_wrapper(payloads) + + async def _decode_payload_sequence( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ) -> list[temporalio.api.common.v1.Payload]: + decoded_payloads = list(payloads) + if self.payload_codec: + decoded_payloads = await self.payload_codec.decode(decoded_payloads) + return decoded_payloads + + @staticmethod + async def _apply_to_failure_payloads( + failure: temporalio.api.failure.v1.Failure, + cb: Callable[[temporalio.api.common.v1.Payloads], Awaitable[None]], + ) -> None: + if failure.HasField("encoded_attributes"): + # Wrap in payloads and merge back + payloads = temporalio.api.common.v1.Payloads( + payloads=[failure.encoded_attributes] + ) + await cb(payloads) + failure.encoded_attributes.CopyFrom(payloads.payloads[0]) + if failure.HasField( + "application_failure_info" + ) and failure.application_failure_info.HasField("details"): + await cb(failure.application_failure_info.details) + elif failure.HasField( + "timeout_failure_info" + ) and failure.timeout_failure_info.HasField("last_heartbeat_details"): + await cb(failure.timeout_failure_info.last_heartbeat_details) + elif failure.HasField( + "canceled_failure_info" + ) and failure.canceled_failure_info.HasField("details"): + await cb(failure.canceled_failure_info.details) + elif failure.HasField( + "reset_workflow_failure_info" + ) and failure.reset_workflow_failure_info.HasField("last_heartbeat_details"): + await cb(failure.reset_workflow_failure_info.last_heartbeat_details) + if failure.HasField("cause"): + await DataConverter._apply_to_failure_payloads(failure.cause, cb) + + def _validate_payload_limits( + self, payloads: Sequence[temporalio.api.common.v1.Payload] + ): + total_size = sum(payload.ByteSize() for payload in payloads) + + def _check_over_limit( + limit: int | None, + size: int, + ) -> int | None: + if limit and limit > 0 and size > limit: + return limit + return None + + exceeded_limit_value = _check_over_limit( + self.payload_limits.payload_upload_error_limit, + total_size, + ) + + if exceeded_limit_value: + raise temporalio.exceptions.PayloadSizeError( + size=total_size, + limit=exceeded_limit_value, + ) + + exceeded_limit_value = _check_over_limit( + self.payload_limits.payload_upload_warning_limit, + total_size, + ) + + if exceeded_limit_value: + # TODO: Use a context aware logger to log extra information about workflow/activity/etc + warnings.warn( + f"Payloads size exceeded the warning limit. Size: {total_size} bytes, Limit: {exceeded_limit_value} bytes" + ) + DefaultPayloadConverter.default_encoding_payload_converters = ( BinaryNullPayloadConverter(), diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index a25b6ac36..6d6126ace 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -592,10 +592,7 @@ async def _execute_activity( if self._encode_headers: for payload in start.header_fields.values(): - new_payload = ( - await data_converter._payload_codec_chain.decode([payload]) - )[0] - payload.CopyFrom(new_payload) + payload.CopyFrom(await data_converter._decode_payload(payload)) running_activity.info = info input = ExecuteActivityInput( diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index a07526000..48891f299 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -4,6 +4,7 @@ import asyncio import concurrent.futures +import dataclasses import logging import os import sys @@ -268,18 +269,21 @@ async def _handle_activation( workflow_id=workflow_id, ) data_converter = self._data_converter.with_context(workflow_context) - if not workflow: - payload_codec = data_converter._payload_codec_chain - else: - payload_codec = _CommandAwarePayloadCodec( - workflow.instance, - context_free_payload_codec=self._data_converter._payload_codec_chain, - workflow_context_payload_codec=data_converter._payload_codec_chain, - workflow_context=workflow_context, - ) + if self._data_converter.payload_codec: + assert data_converter.payload_codec + if workflow: + data_converter = dataclasses.replace( + data_converter, + payload_codec=_CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=workflow_context, + ), + ) await temporalio.bridge.worker.decode_activation( act, - payload_codec, + data_converter, decode_headers=self._encode_headers, ) if not workflow: @@ -347,42 +351,48 @@ async def _handle_activation( completion.run_id = act.run_id # Encode completion - if workflow: - payload_codec = _CommandAwarePayloadCodec( - workflow.instance, - context_free_payload_codec=self._data_converter._payload_codec_chain, - workflow_context_payload_codec=data_converter._payload_codec_chain, - workflow_context=temporalio.converter.WorkflowSerializationContext( - namespace=self._namespace, - workflow_id=workflow.workflow_id, - ), - ) - try: - await temporalio.bridge.worker.encode_completion( - completion, - payload_codec, - encode_headers=self._encode_headers, - ) - except temporalio.exceptions.PayloadSizeError as err: - # TODO: Would like to use temporalio.workflow.logger here, but - # that requires being in the workflow event loop. Possibly refactor - # the logger core functionality into shareable class and update - # LoggerAdapter to be a decorator. - logger.warning( - "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", - err.payloads_size, - err.payloads_limit, + if self._data_converter.payload_codec and workflow: + assert data_converter.payload_codec + if workflow: + data_converter = dataclasses.replace( + data_converter, + payload_codec=_CommandAwarePayloadCodec( + workflow.instance, + context_free_payload_codec=self._data_converter.payload_codec, + workflow_context_payload_codec=data_converter.payload_codec, + workflow_context=temporalio.converter.WorkflowSerializationContext( + namespace=self._namespace, + workflow_id=workflow.workflow_id, + ), + ), ) - completion.failed.Clear() - await data_converter.encode_failure(err, completion.failed.failure) - # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API - # completion.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE - except Exception as err: - logger.exception( - "Failed encoding completion on workflow with run ID %s", act.run_id - ) - completion.failed.Clear() - completion.failed.failure.message = f"Failed encoding completion: {err}" + + try: + await temporalio.bridge.worker.encode_completion( + completion, + data_converter, + encode_headers=self._encode_headers, + ) + except temporalio.exceptions.PayloadSizeError as err: + # TODO: Would like to use temporalio.workflow.logger here, but + # that requires being in the workflow event loop. Possibly refactor + # the logger core functionality into shareable class and update + # LoggerAdapter to be a decorator. + logger.warning( + "Workflow task failed: payloads size exceeded the error limit. Size: %d bytes, Limit: %d bytes", + err.payloads_size, + err.payloads_limit, + ) + completion.failed.Clear() + await data_converter.encode_failure(err, completion.failed.failure) + # TODO: Add WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE to API + # completion.failed.force_cause = WorkflowTaskFailedCause.WORKFLOW_TASK_FAILED_CAUSE_PAYLOADS_TOO_LARGE + except Exception as err: + logger.exception( + "Failed encoding completion on workflow with run ID %s", act.run_id + ) + completion.failed.Clear() + completion.failed.failure.message = f"Failed encoding completion: {err}" # Send off completion if LOG_PROTOS: diff --git a/tests/test_converter.py b/tests/test_converter.py index 64bb51dc1..bb5b3c8bc 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -574,7 +574,7 @@ async def test_failure_encoded_attributes(): # Encode it and check encoded orig_failure = Failure() orig_failure.CopyFrom(failure) - await conv._payload_codec_chain.encode_failure(failure) + await conv.payload_codec.encode_failure(failure) assert "encoding" not in failure.encoded_attributes.metadata assert "simple-codec" in failure.encoded_attributes.metadata assert ( @@ -585,7 +585,7 @@ async def test_failure_encoded_attributes(): ) # Decode and check - await conv._payload_codec_chain.decode_failure(failure) + await conv.payload_codec.decode_failure(failure) assert "encoding" in failure.encoded_attributes.metadata assert "simple-codec" not in failure.encoded_attributes.metadata assert "encoding" in failure.application_failure_info.details.payloads[0].metadata diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 41e6ccad9..5604b8542 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,8 +1,10 @@ +import dataclasses from collections.abc import MutableSequence from google.protobuf.duration_pb2 import Duration import temporalio.bridge.worker +import temporalio.converter from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -228,7 +230,12 @@ async def test_bridge_encoding(): ), ) - await temporalio.bridge.worker.encode_completion(comp, SimpleCodec(), True) + data_converter = dataclasses.replace( + temporalio.converter.default(), + payload_codec=SimpleCodec(), + ) + + await temporalio.bridge.worker.encode_completion(comp, data_converter, True) cmd = comp.successful.commands[0] sa = cmd.schedule_activity From 3ffe0d38492c9a0c8f5d1568cf6836c755292d6c Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 15:25:54 -0800 Subject: [PATCH 13/16] Change super type of PayloadSizeError to TemporalError --- temporalio/converter.py | 4 ---- temporalio/exceptions.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 9cdcfd428..2db46a173 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1035,10 +1035,6 @@ def _error_to_failure( failure.nexus_operation_execution_failure_info.operation_token = ( error.operation_token ) - elif isinstance(error, temporalio.exceptions.PayloadSizeError): - failure.application_failure_info.SetInParent() - failure.application_failure_info.type = "PayloadSizeError" - failure.application_failure_info.non_retryable = False def _nexus_handler_error_to_failure( self, diff --git a/temporalio/exceptions.py b/temporalio/exceptions.py index 7c2f78b71..96a644a61 100644 --- a/temporalio/exceptions.py +++ b/temporalio/exceptions.py @@ -448,7 +448,7 @@ def is_cancelled_exception(exception: BaseException) -> bool: ) -class PayloadSizeError(FailureError): +class PayloadSizeError(TemporalError): """Error raised when payloads size exceeds payload size limits.""" def __init__(self, size: int, limit: int): From 5dd575ada534da8345ce5bf00942003e401188f6 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:58:53 -0800 Subject: [PATCH 14/16] Change warning level to have default of 512 KiB --- temporalio/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 2db46a173..cc042aaef 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1213,7 +1213,7 @@ class PayloadLimitsConfig: payload_upload_error_limit: int | None = None """The limit at which a payloads size error is created.""" - payload_upload_warning_limit: int | None = None + payload_upload_warning_limit: int = 512 * 1024 """The limit at which a payloads size warning is created.""" From c61c35f527a68f0496d23035dd4068f96bafbe00 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 20:28:22 -0800 Subject: [PATCH 15/16] Preserve unused import --- tests/worker/test_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index db1f535ed..704f9d33e 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8452,7 +8452,7 @@ async def run(self): class CustomLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: - pass # type: ignore[reportUnusedImport] + import httpx # type: ignore[reportUnusedImport] # noqa async def test_disable_logger_sandbox( From 8dbc49125b314822d8512e8feed160dcadf16c5b Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 21 Jan 2026 22:21:32 -0800 Subject: [PATCH 16/16] Add special handling for memos --- temporalio/client.py | 83 +++++++------------------------ temporalio/converter.py | 93 ++++++++++++++++++++++++++--------- tests/worker/test_workflow.py | 56 +++++++++++++++++++++ 3 files changed, 146 insertions(+), 86 deletions(-) diff --git a/temporalio/client.py b/temporalio/client.py index 53f81ad22..849fb953a 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -2977,10 +2977,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_info.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_info.memo) @overload async def memo_value( @@ -3019,16 +3016,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_info.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_info.memo, key, default, type_hint + ) @dataclass @@ -4209,18 +4199,9 @@ async def _to_proto( workflow_run_timeout=run_timeout, workflow_task_timeout=task_timeout, retry_policy=retry_policy, - memo=( - temporalio.api.common.v1.Memo( - fields={ - k: v - if isinstance(v, temporalio.api.common.v1.Payload) - else (await data_converter.encode([v]))[0] - for k, v in self.memo.items() - }, - ) - if self.memo - else None - ), + memo=await data_converter._convert_to_memo(self.memo) + if self.memo + else None, user_metadata=await _encode_user_metadata( data_converter, self.static_summary, self.static_details ), @@ -4521,10 +4502,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_description.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_description.memo) @overload async def memo_value( @@ -4563,16 +4541,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_description.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_description.memo, key, default, type_hint + ) @dataclass @@ -4770,10 +4741,7 @@ async def memo(self) -> Mapping[str, Any]: Returns: Mapping of all memo keys and they values without type hints. """ - return { - k: (await self.data_converter.decode([v]))[0] - for k, v in self.raw_entry.memo.fields.items() - } + return await self.data_converter._convert_from_memo(self.raw_entry.memo) @overload async def memo_value( @@ -4812,16 +4780,9 @@ async def memo_value( Raises: KeyError: Key not present and default not set. """ - payload = self.raw_entry.memo.fields.get(key) - if not payload: - if default is temporalio.common._arg_unset: - raise KeyError(f"Memo does not have a value for key {key}") - return default - return ( - await self.data_converter.decode( - [payload], [type_hint] if type_hint else None - ) - )[0] + return await self.data_converter._convert_from_memo_field( + self.raw_entry.memo, key, default, type_hint + ) @dataclass @@ -6014,8 +5975,7 @@ async def _populate_start_workflow_execution_request( input.retry_policy.apply_to_proto(req.retry_policy) req.cron_schedule = input.cron_schedule if input.memo is not None: - for k, v in input.memo.items(): - req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0]) + await data_converter._convert_into_memo(input.memo, req.memo) if input.search_attributes is not None: temporalio.converter.encode_search_attributes( input.search_attributes, req.search_attributes @@ -6641,14 +6601,9 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle: initial_patch=initial_patch, identity=self._client.identity, request_id=str(uuid.uuid4()), - memo=None - if not input.memo - else temporalio.api.common.v1.Memo( - fields={ - k: (await self._client.data_converter.encode([v]))[0] - for k, v in input.memo.items() - }, - ), + memo=await self._client.data_converter._convert_to_memo(input.memo) + if input.memo + else None, ) if input.search_attributes: temporalio.converter.encode_search_attributes( diff --git a/temporalio/converter.py b/temporalio/converter.py index cc042aaef..7de095c81 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1211,6 +1211,10 @@ def __init__(self) -> None: class PayloadLimitsConfig: """Configuration for when payload sizes exceed limits.""" + memo_upload_error_limit: int | None = None + """The limit at which a memo size error is created.""" + memo_upload_warning_limit: int = 2 * 1024 + """The limit at which a memo size warning is created.""" payload_upload_error_limit: int | None = None """The limit at which a payloads size error is created.""" payload_upload_warning_limit: int = 512 * 1024 @@ -1348,6 +1352,54 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "failure_converter", failure_converter) return cloned + async def _convert_from_memo( + self, + source: temporalio.api.common.v1.Memo, + ) -> Mapping[str, Any]: + mapping: dict[str, Any] = {} + for k, v in source.fields.items(): + mapping[k] = (await self.decode([v]))[0] + return mapping + + async def _convert_from_memo_field( + self, + source: temporalio.api.common.v1.Memo, + key: str, + default: Any, + type_hint: type | None, + ) -> dict[str, Any]: + payload = source.fields.get(key) + if not payload: + if default is temporalio.common._arg_unset: + raise KeyError(f"Memo does not have a value for key {key}") + return default + return (await self.decode([payload], [type_hint] if type_hint else None))[0] + + async def _convert_into_memo( + self, source: Mapping[str, Any], memo: temporalio.api.common.v1.Memo + ): + payloads: list[temporalio.api.common.v1.Payload] = [] + for k, v in source.items(): + payload = v + if not isinstance(v, temporalio.api.common.v1.Payload): + payload = (await self.encode([v]))[0] + memo.fields[k].CopyFrom(payload) + payloads.append(payload) + # Memos have their field payloads validated all together in one unit + self._validate_limits( + payloads, + self.payload_limits.memo_upload_error_limit, + self.payload_limits.memo_upload_warning_limit, + "Memo size exceeded the warning limit.", + ) + + async def _convert_to_memo( + self, source: Mapping[str, Any] + ) -> temporalio.api.common.v1.Memo: + memo = temporalio.api.common.v1.Memo() + await self._convert_into_memo(source, memo) + return memo + async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: @@ -1421,38 +1473,35 @@ async def _apply_to_failure_payloads( await DataConverter._apply_to_failure_payloads(failure.cause, cb) def _validate_payload_limits( - self, payloads: Sequence[temporalio.api.common.v1.Payload] + self, + payloads: Sequence[temporalio.api.common.v1.Payload], ): - total_size = sum(payload.ByteSize() for payload in payloads) - - def _check_over_limit( - limit: int | None, - size: int, - ) -> int | None: - if limit and limit > 0 and size > limit: - return limit - return None - - exceeded_limit_value = _check_over_limit( + self._validate_limits( + payloads, self.payload_limits.payload_upload_error_limit, - total_size, + self.payload_limits.payload_upload_warning_limit, + "Payloads size exceeded the warning limit.", ) - if exceeded_limit_value: + def _validate_limits( + self, + payloads: Sequence[temporalio.api.common.v1.Payload], + error_limit: int | None, + warning_limit: int, + warning_message: str, + ): + total_size = sum(payload.ByteSize() for payload in payloads) + + if error_limit and error_limit > 0 and total_size > error_limit: raise temporalio.exceptions.PayloadSizeError( size=total_size, - limit=exceeded_limit_value, + limit=error_limit, ) - exceeded_limit_value = _check_over_limit( - self.payload_limits.payload_upload_warning_limit, - total_size, - ) - - if exceeded_limit_value: + if warning_limit and warning_limit > 0 and total_size > warning_limit: # TODO: Use a context aware logger to log extra information about workflow/activity/etc warnings.warn( - f"Payloads size exceeded the warning limit. Size: {total_size} bytes, Limit: {exceeded_limit_value} bytes" + f"{warning_message} Size: {total_size} bytes, Limit: {warning_limit} bytes" ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 704f9d33e..d8401d684 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8565,6 +8565,32 @@ async def test_large_payload_error_workflow_input(client: Client): assert error_limit == err.value.payloads_limit +async def test_large_payload_error_workflow_memo(client: Client): + config = client.config() + error_limit = 128 + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(memo_upload_error_limit=error_limit), + ) + client = Client(**config) + + with pytest.raises(PayloadSizeError) as err: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue="test-queue", + memo={"key1": [0] * 256}, + ) + + assert error_limit == err.value.payloads_limit + + async def test_large_payload_warning_workflow_input(client: Client): config = client.config() config["data_converter"] = dataclasses.replace( @@ -8596,6 +8622,36 @@ async def test_large_payload_warning_workflow_input(client: Client): assert "Payloads size exceeded the warning limit" in str(w[-1].message) +async def test_large_payload_warning_workflow_memo(client: Client): + config = client.config() + config["data_converter"] = dataclasses.replace( + temporalio.converter.default(), + payload_limits=PayloadLimitsConfig(payload_upload_warning_limit=128), + ) + client = Client(**config) + + with warnings.catch_warnings(record=True) as w: + async with new_worker( + client, LargePayloadWorkflow, activities=[large_payload_activity] + ) as worker: + await client.execute_workflow( + LargePayloadWorkflow.run, + LargePayloadWorkflowInput( + activity_input_data_size=0, + activity_output_data_size=0, + workflow_output_data_size=0, + data=[], + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + memo={"key1": [0] * 256}, + ) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Payloads size exceeded the warning limit" in str(w[-1].message) + + async def test_large_payload_error_workflow_result(client: Client): # Create worker runtime with forwarded logger worker_logger = logging.getLogger(f"log-{uuid.uuid4()}")