Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 154 additions & 4 deletions tests/nexus/test_workflow_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import concurrent.futures
import threading
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
Expand Down Expand Up @@ -38,7 +39,11 @@
)
from temporalio.common import WorkflowIDConflictPolicy
from temporalio.converter import PayloadConverter
from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError
from temporalio.exceptions import (
ApplicationError,
CancelledError,
NexusOperationError,
)
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
from temporalio.runtime import (
BUFFERED_METRIC_KIND_COUNTER,
Expand Down Expand Up @@ -813,9 +818,6 @@ class ServiceClassNameOutput:
name: str


# TODO(nexus-prerelease): async and non-async cancel methods


@nexusrpc.service
class ServiceInterfaceWithoutNameOverride:
op: nexusrpc.Operation[None, ServiceClassNameOutput]
Expand Down Expand Up @@ -1608,3 +1610,151 @@ async def test_workflow_caller_buffered_metrics(
and update.value == 30
for update in updates
)


@workflow.defn()
class CancelTestCallerWorkflow:
def __init__(self) -> None:
self.released = False

@workflow.run
async def run(self, use_async_cancel: bool, task_queue: str) -> str:
nexus_client = workflow.create_nexus_client(
service=TestAsyncAndNonAsyncCancel.CancelTestService,
endpoint=make_nexus_endpoint_name(task_queue),
)

op = (
TestAsyncAndNonAsyncCancel.CancelTestService.async_cancel_op
if use_async_cancel
else TestAsyncAndNonAsyncCancel.CancelTestService.non_async_cancel_op
)

# Start the operation and immediately request cancellation
# Use WAIT_REQUESTED since we just need to verify the cancel handler was called
handle = await nexus_client.start_operation(
op,
None,
cancellation_type=workflow.NexusOperationCancellationType.WAIT_REQUESTED,
)

# Cancel the handle to trigger the cancel method on the handler
handle.cancel()

try:
await handle
except NexusOperationError:
# Wait for release signal before completing
await workflow.wait_condition(lambda: self.released)
return "cancelled_successfully"

return "unexpected_completion"

@workflow.signal
def release(self) -> None:
self.released = True


@pytest.fixture(scope="class")
def cancel_test_events(request: pytest.FixtureRequest):
if request.cls:
request.cls.called_async = asyncio.Event()
request.cls.called_non_async = threading.Event()
yield


@pytest.mark.usefixtures("cancel_test_events")
class TestAsyncAndNonAsyncCancel:
called_async: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable]
called_non_async: threading.Event # pyright: ignore[reportUninitializedInstanceVariable]

class OpWithAsyncCancel(OperationHandler[None, str]):
def __init__(self, evt: asyncio.Event) -> None:
self.evt = evt

async def start(
self, ctx: StartOperationContext, input: None
) -> StartOperationResultAsync:
return StartOperationResultAsync("test-token")

async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
self.evt.set()

class OpWithNonAsyncCancel(OperationHandler[None, str]):
def __init__(self, evt: threading.Event) -> None:
self.evt = evt

def start(
self, ctx: StartOperationContext, input: None
) -> StartOperationResultAsync:
return StartOperationResultAsync("test-token")

def cancel(self, ctx: CancelOperationContext, token: str) -> None:
self.evt.set()

@nexusrpc.service
class CancelTestService:
async_cancel_op: nexusrpc.Operation[None, str]
non_async_cancel_op: nexusrpc.Operation[None, str]

@service_handler(service=CancelTestService)
class CancelTestServiceHandler:
def __init__(
self, async_evt: asyncio.Event, non_async_evt: threading.Event
) -> None:
self.async_evt = async_evt
self.non_async_evt = non_async_evt

@operation_handler
def async_cancel_op(self) -> OperationHandler[None, str]:
return TestAsyncAndNonAsyncCancel.OpWithAsyncCancel(self.async_evt)

@operation_handler
def non_async_cancel_op(self) -> OperationHandler[None, str]:
return TestAsyncAndNonAsyncCancel.OpWithNonAsyncCancel(self.non_async_evt)

@pytest.mark.parametrize("use_async_cancel", [True, False])
async def test_task_executor_operation_cancel_method(
self, client: Client, env: WorkflowEnvironment, use_async_cancel: bool
):
"""Test that both async and non-async cancel methods work for TaskExecutor-based operations."""
if env.supports_time_skipping:
pytest.skip("Nexus tests don't work with time-skipping server")

task_queue = str(uuid.uuid4())
async with Worker(
client,
task_queue=task_queue,
workflows=[CancelTestCallerWorkflow],
nexus_service_handlers=[
TestAsyncAndNonAsyncCancel.CancelTestServiceHandler(
self.called_async, self.called_non_async
)
],
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
):
await create_nexus_endpoint(task_queue, client)

caller_wf_handle = await client.start_workflow(
CancelTestCallerWorkflow.run,
args=[use_async_cancel, task_queue],
id=f"caller-wf-{uuid.uuid4()}",
task_queue=task_queue,
)

# Wait for the cancel method to be called
fut = (
self.called_async.wait()
if use_async_cancel
else asyncio.get_running_loop().run_in_executor(
None, self.called_non_async.wait
)
)
await asyncio.wait_for(fut, timeout=30)

# Release the workflow to complete
await caller_wf_handle.signal(CancelTestCallerWorkflow.release)

# Verify the workflow completed successfully
result = await caller_wf_handle.result()
assert result == "cancelled_successfully"
Loading