From dd9090634d6609febd383948cc45abf86d7292ba Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 16 Jan 2026 10:15:55 -0800 Subject: [PATCH] Add workflow caller tests that confirm both async and non-async cancel methods are invoked correctly --- tests/nexus/test_workflow_caller.py | 158 +++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 4 deletions(-) diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 07c22e688..1473a4d40 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -2,6 +2,7 @@ import asyncio import concurrent.futures +import threading import uuid from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -38,7 +39,11 @@ ) from temporalio.common import WorkflowIDConflictPolicy from temporalio.converter import PayloadConverter -from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError +from temporalio.exceptions import ( + ApplicationError, + CancelledError, + NexusOperationError, +) from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation from temporalio.runtime import ( BUFFERED_METRIC_KIND_COUNTER, @@ -817,9 +822,6 @@ class ServiceClassNameOutput: name: str -# TODO(nexus-prerelease): async and non-async cancel methods - - @nexusrpc.service class ServiceInterfaceWithoutNameOverride: op: nexusrpc.Operation[None, ServiceClassNameOutput] @@ -1612,3 +1614,151 @@ async def test_workflow_caller_buffered_metrics( and update.value == 30 for update in updates ) + + +@workflow.defn() +class CancelTestCallerWorkflow: + def __init__(self) -> None: + self.released = False + + @workflow.run + async def run(self, use_async_cancel: bool, task_queue: str) -> str: + nexus_client = workflow.create_nexus_client( + service=TestAsyncAndNonAsyncCancel.CancelTestService, + endpoint=make_nexus_endpoint_name(task_queue), + ) + + op = ( + TestAsyncAndNonAsyncCancel.CancelTestService.async_cancel_op + if use_async_cancel + else TestAsyncAndNonAsyncCancel.CancelTestService.non_async_cancel_op + ) + + # Start the operation and immediately request cancellation + # Use WAIT_REQUESTED since we just need to verify the cancel handler was called + handle = await nexus_client.start_operation( + op, + None, + cancellation_type=workflow.NexusOperationCancellationType.WAIT_REQUESTED, + ) + + # Cancel the handle to trigger the cancel method on the handler + handle.cancel() + + try: + await handle + except NexusOperationError: + # Wait for release signal before completing + await workflow.wait_condition(lambda: self.released) + return "cancelled_successfully" + + return "unexpected_completion" + + @workflow.signal + def release(self) -> None: + self.released = True + + +@pytest.fixture(scope="class") +def cancel_test_events(request: pytest.FixtureRequest): + if request.cls: + request.cls.called_async = asyncio.Event() + request.cls.called_non_async = threading.Event() + yield + + +@pytest.mark.usefixtures("cancel_test_events") +class TestAsyncAndNonAsyncCancel: + called_async: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable] + called_non_async: threading.Event # pyright: ignore[reportUninitializedInstanceVariable] + + class OpWithAsyncCancel(OperationHandler[None, str]): + def __init__(self, evt: asyncio.Event) -> None: + self.evt = evt + + async def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultAsync: + return StartOperationResultAsync("test-token") + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.evt.set() + + class OpWithNonAsyncCancel(OperationHandler[None, str]): + def __init__(self, evt: threading.Event) -> None: + self.evt = evt + + def start( + self, ctx: StartOperationContext, input: None + ) -> StartOperationResultAsync: + return StartOperationResultAsync("test-token") + + def cancel(self, ctx: CancelOperationContext, token: str) -> None: + self.evt.set() + + @nexusrpc.service + class CancelTestService: + async_cancel_op: nexusrpc.Operation[None, str] + non_async_cancel_op: nexusrpc.Operation[None, str] + + @service_handler(service=CancelTestService) + class CancelTestServiceHandler: + def __init__( + self, async_evt: asyncio.Event, non_async_evt: threading.Event + ) -> None: + self.async_evt = async_evt + self.non_async_evt = non_async_evt + + @operation_handler + def async_cancel_op(self) -> OperationHandler[None, str]: + return TestAsyncAndNonAsyncCancel.OpWithAsyncCancel(self.async_evt) + + @operation_handler + def non_async_cancel_op(self) -> OperationHandler[None, str]: + return TestAsyncAndNonAsyncCancel.OpWithNonAsyncCancel(self.non_async_evt) + + @pytest.mark.parametrize("use_async_cancel", [True, False]) + async def test_task_executor_operation_cancel_method( + self, client: Client, env: WorkflowEnvironment, use_async_cancel: bool + ): + """Test that both async and non-async cancel methods work for TaskExecutor-based operations.""" + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + task_queue = str(uuid.uuid4()) + async with Worker( + client, + task_queue=task_queue, + workflows=[CancelTestCallerWorkflow], + nexus_service_handlers=[ + TestAsyncAndNonAsyncCancel.CancelTestServiceHandler( + self.called_async, self.called_non_async + ) + ], + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), + ): + await create_nexus_endpoint(task_queue, client) + + caller_wf_handle = await client.start_workflow( + CancelTestCallerWorkflow.run, + args=[use_async_cancel, task_queue], + id=f"caller-wf-{uuid.uuid4()}", + task_queue=task_queue, + ) + + # Wait for the cancel method to be called + fut = ( + self.called_async.wait() + if use_async_cancel + else asyncio.get_running_loop().run_in_executor( + None, self.called_non_async.wait + ) + ) + await asyncio.wait_for(fut, timeout=30) + + # Release the workflow to complete + await caller_wf_handle.signal(CancelTestCallerWorkflow.release) + + # Verify the workflow completed successfully + result = await caller_wf_handle.result() + assert result == "cancelled_successfully"