Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 66 additions & 59 deletions src/a2a/server/agent_execution/active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,36 +104,60 @@ async def run(self) -> None:
)
except Exception as e:
logger.exception('Consumer[%s]: Failed', self.active_task._task_id)
async with self.active_task._lock:
await self.active_task._mark_task_as_failed(e)

updated_task = None
task = await self.active_task._task_manager.get_task()
if task:
handled_event = TaskStatusUpdateEvent(
task_id=task.id,
context_id=task.context_id,
status=TaskStatus(
state=TaskState.TASK_STATE_FAILED,
),
)
updated_task = await self._handle_task_event(handled_event)
Comment thread
bartek-w marked this conversation as resolved.
Comment thread
bartek-w marked this conversation as resolved.

await self._enqueue_to_subscribers(cast('Event', e), updated_task)

async def _process_event(self, event: Event) -> None:
updated_task = None
handled_event: (
Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent
| PushNotificationEvent
| None
) = None

if isinstance(event, _RequestCompleted):
logger.debug(
'Consumer[%s]: Request completed', self.active_task._task_id
)
self.active_task._request_lock.release()
elif isinstance(event, _RequestStarted):
logger.debug(
'Consumer[%s]: Request started', self.active_task._task_id
)
self.message_to_save = event.request_context.message
elif isinstance(event, BaseException):
raise event
elif isinstance(event, Message):
self._handle_message_event(event)
elif isinstance(
event,
TaskStatusUpdateEvent
| TaskArtifactUpdateEvent
| PushNotificationEvent
| Task,
):
updated_task = await self._handle_task_event(event)
handled_event = updated_task if isinstance(event, Task) else event

try:
if isinstance(event, _RequestCompleted):
logger.debug(
'Consumer[%s]: Request completed', self.active_task._task_id
)
self.active_task._request_lock.release()
elif isinstance(event, _RequestStarted):
logger.debug(
'Consumer[%s]: Request started', self.active_task._task_id
)
self.message_to_save = event.request_context.message
elif isinstance(event, Message):
self._handle_message_event(event)
else:
updated_task = await self._handle_task_event(event)
if isinstance(event, Task):
event = updated_task

if updated_task is not None:
await self._update_task_state(updated_task, event)
self.active_task._task_created.set()
if updated_task is not None and handled_event is not None:
await self._update_task_state(updated_task, handled_event)
self.active_task._task_created.set()

finally:
await self._enqueue_to_subscribers(event, updated_task)
await self._enqueue_to_subscribers(event, updated_task)

def _handle_message_event(self, event: Message) -> None:
if self.task_mode is True:
Expand Down Expand Up @@ -286,9 +310,6 @@ class ActiveTask:
- `self._lock` (asyncio.Lock) ensures mutually exclusive access for critical
lifecycle state changes, such as starting the task, subscribing, and
determining if cleanup is safe to trigger.

mutation to the observable result state (like `_exception`,
or `_is_finished`) notifies waiting coroutines (like `wait()`).
- `self._is_finished` (asyncio.Event) provides a thread-safe, non-blocking way
for external observers and internal loops to check if the ActiveTask has
permanently ceased execution and closed its queues.
Expand Down Expand Up @@ -349,10 +370,6 @@ def __init__(
# Protected by `_lock`.
self._reference_count = 0

# Holds any fatal exception that crashed the producer or consumer.
# TODO: Synchronize exception handling (ideally mix it in the queue).
self._exception: Exception | None = None

# Queue for incoming requests
self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = (
_create_async_queue()
Expand Down Expand Up @@ -481,22 +498,17 @@ async def _run_producer(self) -> None:
_RequestStarted(request_id, request_context),
)
)

await self._agent_executor.execute(
request_context, self._event_queue_agent
)
logger.debug(
'Producer[%s]: Execution finished successfully',
self._task_id,
)
finally:
logger.debug(
'Producer[%s]: Enqueuing request completed event',
self._task_id,
)
await self._event_queue_agent.enqueue_event(
cast('Event', _RequestCompleted(request_id))
)
Comment thread
bartek-w marked this conversation as resolved.
finally:
self._request_queue.task_done()
except asyncio.CancelledError:
logger.debug('Producer[%s]: Cancelled', self._task_id)
Expand All @@ -516,8 +528,7 @@ async def _run_producer(self) -> None:
request_context.context_id or '',
)
self._task_created.set()
async with self._lock:
await self._mark_task_as_failed(e)
await self._event_queue_agent.enqueue_event(cast('Event', e))

finally:
self._request_queue.shutdown(immediate=True)
Expand All @@ -537,7 +548,7 @@ async def _run_consumer(self) -> None:
logger.debug('Consumer[%s]: Finishing', self._task_id)
await self._maybe_cleanup()

async def subscribe( # noqa: PLR0912, PLR0915
async def subscribe(
self,
*,
request: RequestContext | None = None,
Expand All @@ -554,12 +565,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
logger.debug('Subscribe[%s]: New subscriber', self._task_id)

async with self._lock:
if self._exception:
logger.debug(
'Subscribe[%s]: Failed, exception already set',
self._task_id,
)
raise self._exception
if self._is_finished.is_set():
raise InvalidParamsError(
f'Task {self._task_id} is already completed.'
Expand All @@ -585,17 +590,23 @@ async def subscribe( # noqa: PLR0912, PLR0915

while True:
try:
if self._exception:
raise self._exception

dequeued = await tapped_queue.dequeue_event()
event, updated_task = cast('Any', dequeued)
logger.debug(
'Subscriber[%s]\nDequeued event %s\nUpdated task %s\n',
'Subscriber[%s] Dequeued event [%s]:\n %s\nUpdated task:\n%s\n',
self._task_id,
type(event).__name__,
event,
updated_task,
)
if isinstance(event, BaseException):
logger.debug(
'Subscriber[%s]: Raising exception: %s',
self._task_id,
event,
)
raise event

if replace_status_update_with_task and isinstance(
event, TaskStatusUpdateEvent
):
Expand All @@ -605,8 +616,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
updated_task,
)
event = updated_task
if self._exception:
raise self._exception from None
if isinstance(event, _RequestCompleted):
if (
request_id is not None
Expand All @@ -629,8 +638,6 @@ async def subscribe( # noqa: PLR0912, PLR0915
finally:
tapped_queue.task_done()
except (QueueShutDown, asyncio.CancelledError):
if self._exception:
raise self._exception from None
break
finally:
logger.debug('Subscribe[%s]: Unsubscribing', self._task_id)
Expand Down Expand Up @@ -714,9 +721,9 @@ async def _maybe_cleanup(self) -> None:
logger.debug('Cleanup[%s]: Triggering cleanup', self._task_id)
self._on_cleanup(self)

async def _mark_task_as_failed(self, exception: Exception) -> None:
if self._exception is None:
self._exception = exception
async def _mark_task_as_failed(self, exception: Exception) -> Task | None:
logger.debug('Marking task %s as failed: %s', self._task_id, exception)
task = None
if self._task_created.is_set():
try:
task = await self._task_manager.get_task()
Expand All @@ -732,10 +739,10 @@ async def _mark_task_as_failed(self, exception: Exception) -> None:
)
except QueueShutDown:
pass
return task

async def get_task(self) -> Task:
"""Get task from db."""
# TODO: THERE IS ZERO CONCURRENCY SAFETY HERE (Except inital task creation).
await self._task_created.wait()
task = await self._task_manager.get_task()
if not task:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
SubscribeToTaskRequest,
Task,
TaskPushNotificationConfig,
TaskStatusUpdateEvent,
)
from a2a.utils.errors import (
ExtendedAgentCardNotConfiguredError,
Expand Down Expand Up @@ -252,13 +251,6 @@ async def on_message_send( # noqa: D102
type(event).__name__,
event,
)
if isinstance(event, TaskStatusUpdateEvent):
self._validate_task_id_match(task_id, event.task_id)
event = await active_task.get_task()
logger.debug(
'Replaced TaskStatusUpdateEvent with Task: %s', event
)

if isinstance(event, Task) and (
params.configuration.return_immediately
or event.status.state
Expand Down
9 changes: 9 additions & 0 deletions tests/integration/test_scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,15 @@ async def cancel(
(task,) = (await client.list_tasks(ListTasksRequest())).tasks
assert task.status.state == TaskState.TASK_STATE_FAILED

if streaming:
with pytest.raises(
InvalidParamsError,
match='Task .* is already completed',
):
await client.subscribe(
SubscribeToTaskRequest(id=task.id)
).__anext__()


# Scenario 12/13: Exception after initial event
@pytest.mark.timeout(2.0)
Expand Down
44 changes: 25 additions & 19 deletions tests/server/agent_execution/test_active_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,26 +316,42 @@ async def test_active_task_subscribe_exception_handling(
active_task: ActiveTask,
agent_executor: Mock,
request_context: Mock,
task_manager: Mock,
) -> None:
"""Test exception handling in subscribe."""
agent_executor.execute = AsyncMock(
side_effect=ValueError('Producer failure')
event = asyncio.Event()

task_manager.get_task.return_value = Task(
id='test-task-id',
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
)

async def execute_mock(req, q):
await q.enqueue_event(
Task(
id='test-task-id',
status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED),
)
)
await event.wait()
raise ValueError('Producer failure')

agent_executor.execute = AsyncMock(side_effect=execute_mock)

await active_task.enqueue_request(request_context)
await active_task.start(
call_context=ServerCallContext(), create_task_if_missing=True
)

# Give it a moment to fail
for _ in range(10):
if active_task._exception:
break
await asyncio.sleep(0.05)
subscriber = active_task.subscribe()
task = await anext(subscriber)
assert task.status.state == TaskState.TASK_STATE_SUBMITTED

# Now trigger the exception
event.set()

with pytest.raises(ValueError, match='Producer failure'):
async for _ in active_task.subscribe():
pass
await anext(subscriber)

@pytest.mark.asyncio
async def test_active_task_cancel_not_started(
Expand Down Expand Up @@ -766,16 +782,6 @@ async def test_active_task_maybe_cleanup_not_finished(
await active_task._maybe_cleanup()
on_cleanup.assert_not_called()

@pytest.mark.asyncio
async def test_active_task_subscribe_exception_already_set(
self, active_task: ActiveTask
) -> None:
"""Test subscribe when exception is already set."""
active_task._exception = ValueError('Pre-existing error')
with pytest.raises(ValueError, match='Pre-existing error'):
async for _ in active_task.subscribe():
pass

@pytest.mark.asyncio
async def test_active_task_subscribe_inner_exception(
self,
Expand Down
Loading