From c66cfda68e92eefac4fd4bc59e11c812c819f8b6 Mon Sep 17 00:00:00 2001 From: Prassanna Ravishankar Date: Tue, 7 Apr 2026 17:59:42 +0100 Subject: [PATCH 1/2] feat: decouple ACP protocol with AgentProtocolGateway abstraction Introduce protocol abstraction layer so future protocols (A2A, etc.) can be added without touching business logic. AgentACPService now implements AgentProtocolGateway ABC, and all consumers depend on the interface rather than the concrete class. --- ...8_add_agent_protocol_field_a064de6df78e.py | 28 ++++++ agentex/src/adapters/orm.py | 1 + agentex/src/api/routes/agents.py | 1 + agentex/src/api/routes/checkpoints.py | 16 ++-- agentex/src/api/routes/deployments.py | 4 +- agentex/src/api/schemas/agents.py | 12 +++ agentex/src/domain/entities/agents.py | 8 ++ agentex/src/domain/entities/agents_rpc.py | 2 + .../repositories/checkpoint_repository.py | 4 +- .../src/domain/services/agent_acp_service.py | 75 ++++++++++++--- .../domain/services/agent_protocol_gateway.py | 75 +++++++++++++++ .../src/domain/services/schedule_service.py | 9 +- agentex/src/domain/services/task_service.py | 62 ++++++------ .../domain/use_cases/agents_acp_use_case.py | 72 +++++++------- .../src/domain/use_cases/agents_use_case.py | 12 ++- .../activities/healthcheck_activities.py | 45 ++------- agentex/src/temporal/run_worker.py | 24 ++++- agentex/tests/fixtures/services.py | 2 +- .../checkpoints/test_checkpoint_repository.py | 7 +- .../fixtures/integration_client.py | 8 +- agentex/tests/integration/test_task_stream.py | 4 +- .../unit/services/test_agent_acp_service.py | 22 ++--- .../tests/unit/services/test_task_service.py | 96 +++++++++---------- ...p_type_backwards_compatibility_use_case.py | 43 ++++----- .../use_cases/test_agents_acp_use_case.py | 4 +- 25 files changed, 409 insertions(+), 227 deletions(-) create mode 100644 agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py create mode 100644 agentex/src/domain/services/agent_protocol_gateway.py diff --git a/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py b/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py new file mode 100644 index 00000000..76b1ca7c --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py @@ -0,0 +1,28 @@ +"""add_agent_protocol_field + +Revision ID: a064de6df78e +Revises: 4a9b7787ccd7 +Create Date: 2026-04-07 15:38:00.000000 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a064de6df78e" +down_revision: Union[str, None] = "4a9b7787ccd7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "agents", + sa.Column("protocol", sa.String(), nullable=False, server_default="acp"), + ) + + +def downgrade() -> None: + op.drop_column("agents", "protocol") diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index 8fbd6d08..d362476f 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -39,6 +39,7 @@ class AgentORM(BaseORM): acp_url = Column(String, nullable=True) # URL of the agent's ACP server # TODO: make this a SQLAlchemyEnum rather than a string acp_type = Column(String, nullable=False, server_default="async") + protocol = Column(String, nullable=False, server_default="acp") created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() diff --git a/agentex/src/api/routes/agents.py b/agentex/src/api/routes/agents.py index ef33bcea..1449ccdc 100644 --- a/agentex/src/api/routes/agents.py +++ b/agentex/src/api/routes/agents.py @@ -189,6 +189,7 @@ async def register_agent( acp_type=request.acp_type, registration_metadata=request.registration_metadata, agent_input_type=request.agent_input_type, + protocol=request.protocol, ) await authorization_service.grant( AgentexResource.agent(agent_entity.id), diff --git a/agentex/src/api/routes/checkpoints.py b/agentex/src/api/routes/checkpoints.py index 81f3a831..48428e9c 100644 --- a/agentex/src/api/routes/checkpoints.py +++ b/agentex/src/api/routes/checkpoints.py @@ -2,6 +2,10 @@ from fastapi import APIRouter, Response +from src.api.schemas.authorization_types import ( + AgentexResourceType, + AuthorizedOperationType, +) from src.api.schemas.checkpoints import ( BlobResponse, CheckpointListItem, @@ -14,10 +18,6 @@ PutWritesRequest, WriteResponse, ) -from src.api.schemas.authorization_types import ( - AgentexResourceType, - AuthorizedOperationType, -) from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase from src.utils.authorization_shortcuts import DAuthorizedBodyId from src.utils.logging import make_logger @@ -95,7 +95,9 @@ async def put_checkpoint( request: PutCheckpointRequest, checkpoints_use_case: DCheckpointsUseCase, _authorized_task_id: DAuthorizedBodyId( - AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + AgentexResourceType.task, + AuthorizedOperationType.execute, + field_name="thread_id", ), ) -> PutCheckpointResponse: blobs = [ @@ -133,7 +135,9 @@ async def put_writes( request: PutWritesRequest, checkpoints_use_case: DCheckpointsUseCase, _authorized_task_id: DAuthorizedBodyId( - AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + AgentexResourceType.task, + AuthorizedOperationType.execute, + field_name="thread_id", ), ) -> Response: writes = [ diff --git a/agentex/src/api/routes/deployments.py b/agentex/src/api/routes/deployments.py index 0718794e..7f83ad7c 100644 --- a/agentex/src/api/routes/deployments.py +++ b/agentex/src/api/routes/deployments.py @@ -194,7 +194,7 @@ async def _handle_deployment_sync_rpc( method=request.method, params=request.params, request_headers=request_headers, - acp_url_override=acp_url, + service_url_override=acp_url, ) if isinstance(result_entity, AsyncIterator): @@ -231,7 +231,7 @@ async def rpc_response_generator(): method=request.method, params=request.params, request_headers=request_headers, - acp_url_override=acp_url, + service_url_override=acp_url, ) if not isinstance(result_entity_async_iterator, AsyncIterator): diff --git a/agentex/src/api/schemas/agents.py b/agentex/src/api/schemas/agents.py index 79724171..04f36fc5 100644 --- a/agentex/src/api/schemas/agents.py +++ b/agentex/src/api/schemas/agents.py @@ -34,6 +34,10 @@ class AgentInputType(str, Enum): JSON = "json" +class AgentProtocol(str, Enum): + ACP = "acp" + + class Agent(BaseModel): id: str = Field(..., description="The unique identifier of the agent.") name: str = Field(..., description="The unique name of the agent.") @@ -66,6 +70,10 @@ class Agent(BaseModel): agent_input_type: AgentInputType | None = Field( default=None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) production_deployment_id: str | None = Field( default=None, description="ID of the current production deployment." ) @@ -95,6 +103,10 @@ class RegisterAgentRequest(BaseModel): agent_input_type: AgentInputType | None = Field( default=None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) class RegisterAgentResponse(Agent): diff --git a/agentex/src/domain/entities/agents.py b/agentex/src/domain/entities/agents.py index 0d721de7..d465ef9d 100644 --- a/agentex/src/domain/entities/agents.py +++ b/agentex/src/domain/entities/agents.py @@ -27,6 +27,10 @@ class AgentInputType(str, Enum): JSON = "json" +class AgentProtocol(str, Enum): + ACP = "acp" + + class AgentEntity(BaseModel): id: str = Field(..., description="The unique identifier of the agent.") docker_image: str | None = Field( @@ -65,6 +69,10 @@ class AgentEntity(BaseModel): agent_input_type: AgentInputType | None = Field( None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) production_deployment_id: str | None = Field( None, description="ID of the current production deployment." ) diff --git a/agentex/src/domain/entities/agents_rpc.py b/agentex/src/domain/entities/agents_rpc.py index 571ba104..a2ab69c6 100644 --- a/agentex/src/domain/entities/agents_rpc.py +++ b/agentex/src/domain/entities/agents_rpc.py @@ -85,6 +85,8 @@ class CancelTaskParams(BaseModel): task: TaskEntity = Field(..., description="The task that was cancelled") +# Deprecated: canonical source is AgentACPService.get_allowed_methods(). +# Kept for backward compatibility with existing tests. ACP_TYPE_TO_ALLOWED_RPC_METHODS = { ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE], ACPType.AGENTIC: [ diff --git a/agentex/src/domain/repositories/checkpoint_repository.py b/agentex/src/domain/repositories/checkpoint_repository.py index b320517d..567fd15d 100644 --- a/agentex/src/domain/repositories/checkpoint_repository.py +++ b/agentex/src/domain/repositories/checkpoint_repository.py @@ -248,9 +248,7 @@ async def list_checkpoints( self.async_ro_session_maker() as session, async_sql_exception_handler(), ): - query = select(CheckpointORM).where( - CheckpointORM.thread_id == thread_id - ) + query = select(CheckpointORM).where(CheckpointORM.thread_id == thread_id) if checkpoint_ns is not None: query = query.where(CheckpointORM.checkpoint_ns == checkpoint_ns) diff --git a/agentex/src/domain/services/agent_acp_service.py b/agentex/src/domain/services/agent_acp_service.py index ce214b4b..860e4b62 100644 --- a/agentex/src/domain/services/agent_acp_service.py +++ b/agentex/src/domain/services/agent_acp_service.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from src.adapters.http.adapter_httpx import DHttpxGateway -from src.domain.entities.agents import AgentEntity +from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.agents_rpc import ( AgentRPCMethod, CancelTaskParams, @@ -39,6 +39,7 @@ from src.domain.mixins.task_messages.task_message_mixin import TaskMessageMixin from src.domain.repositories.agent_api_key_repository import DAgentAPIKeyRepository from src.domain.repositories.agent_repository import DAgentRepository +from src.domain.services.agent_protocol_gateway import AgentProtocolGateway from src.utils.logging import ctx_var_request_id, make_logger logger = make_logger(__name__) @@ -107,7 +108,7 @@ def filter_request_headers(headers: dict[str, str] | None) -> dict[str, str]: } -class AgentACPService(TaskMessageMixin): +class AgentACPService(AgentProtocolGateway, TaskMessageMixin): """ Client service for communicating with downstream ACP servers. Handles JSON-RPC 2.0 communication with agent ACP servers. @@ -279,7 +280,7 @@ async def create_task( self, agent: AgentEntity, task: TaskEntity, - acp_url: str, + service_url: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: """Create a new task""" @@ -290,7 +291,7 @@ async def create_task( ) headers = await self.get_headers(agent) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.TASK_CREATE, params=params, request_id=f"{AgentRPCMethod.TASK_CREATE}-{task.id}", # Use create-specific request ID @@ -302,7 +303,7 @@ async def send_message( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> TaskMessageContentEntity: """Send a message to a running task""" params = SendMessageParams( @@ -319,7 +320,7 @@ async def send_message( f"Agent {agent.id} already processing message send for task {task.id}" ) result = await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", # Use message-specific request ID @@ -333,7 +334,7 @@ async def send_message_stream( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> AsyncIterator[TaskMessageUpdateEntity]: """Send a message to a running task and stream the response""" params = SendMessageParams( @@ -357,7 +358,7 @@ async def send_message_stream( f"Agent {agent.id} already processing message send for task {task.id}" ) async for chunk in self._call_jsonrpc_stream( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", @@ -366,7 +367,7 @@ async def send_message_stream( yield self._parse_task_message_update(chunk) else: async for chunk in self._call_jsonrpc_stream( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", @@ -375,13 +376,13 @@ async def send_message_stream( yield self._parse_task_message_update(chunk) async def cancel_task( - self, agent: AgentEntity, task: TaskEntity, acp_url: str + self, agent: AgentEntity, task: TaskEntity, service_url: str ) -> dict[str, Any]: """Cancel a running task""" params = CancelTaskParams(agent=agent, task=task) headers = await self.get_headers(agent) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.TASK_CANCEL, params=params, request_id=f"{AgentRPCMethod.TASK_CANCEL}-{task.id}", # Use cancel-specific request ID @@ -393,7 +394,7 @@ async def send_event( agent: AgentEntity, event: EventEntity, task: TaskEntity, - acp_url: str, + service_url: str, request_headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Send an event to a running task""" @@ -418,12 +419,60 @@ async def send_event( headers.update(auth_headers) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.EVENT_SEND, params=params, request_id=f"{AgentRPCMethod.EVENT_SEND}-{task.id}", # Use event-specific request ID default_headers=headers, ) + async def check_health( + self, + agent_id: str, + service_url: str, + ) -> bool: + """Check if the agent server is healthy via its /healthz endpoint.""" + try: + response = await self._http_gateway.async_call( + method="GET", + url=f"{service_url}/healthz", + timeout=5, + ) + if response.get("status") != "healthy": + logger.error( + f"Agent {agent_id} returned non-healthy status: {response.get('status')}" + ) + return False + response_agent_id = response.get("agent_id") + if response_agent_id and response_agent_id != agent_id: + logger.error( + f"Agent {agent_id} returned unexpected agent ID: {response_agent_id}" + ) + return False + return True + except Exception as e: + logger.error(f"Failed to check health of agent {agent_id}: {e}") + return False + + # ACP-specific: maps agent type to allowed RPC methods + ACP_ALLOWED_METHODS: dict[ACPType, list[AgentRPCMethod]] = { + ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE], + ACPType.AGENTIC: [ + AgentRPCMethod.TASK_CREATE, + AgentRPCMethod.TASK_CANCEL, + AgentRPCMethod.EVENT_SEND, + ], + ACPType.ASYNC: [ + AgentRPCMethod.TASK_CREATE, + AgentRPCMethod.TASK_CANCEL, + AgentRPCMethod.EVENT_SEND, + ], + } + + def get_allowed_methods(self, acp_type: ACPType) -> list[AgentRPCMethod]: + """Return the list of RPC methods allowed for the given ACP type.""" + return self.ACP_ALLOWED_METHODS.get(acp_type, []) + DAgentACPService = Annotated[AgentACPService, Depends(AgentACPService)] +DAgentProtocolGateway = Annotated[AgentProtocolGateway, Depends(AgentACPService)] diff --git a/agentex/src/domain/services/agent_protocol_gateway.py b/agentex/src/domain/services/agent_protocol_gateway.py new file mode 100644 index 00000000..65751aa0 --- /dev/null +++ b/agentex/src/domain/services/agent_protocol_gateway.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any + +from src.domain.entities.agents import AgentEntity +from src.domain.entities.events import EventEntity +from src.domain.entities.task_message_updates import TaskMessageUpdateEntity +from src.domain.entities.task_messages import TaskMessageContentEntity +from src.domain.entities.tasks import TaskEntity + + +class AgentProtocolGateway(ABC): + """Protocol-neutral interface for communicating with downstream agent servers.""" + + @abstractmethod + async def create_task( + self, + agent: AgentEntity, + task: TaskEntity, + service_url: str, + params: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a new task on the agent server.""" + ... + + @abstractmethod + async def send_message( + self, + agent: AgentEntity, + task: TaskEntity, + content: TaskMessageContentEntity, + service_url: str, + ) -> TaskMessageContentEntity: + """Send a message to a running task.""" + ... + + @abstractmethod + async def send_message_stream( + self, + agent: AgentEntity, + task: TaskEntity, + content: TaskMessageContentEntity, + service_url: str, + ) -> AsyncIterator[TaskMessageUpdateEntity]: + """Send a message to a running task and stream the response.""" + ... + + @abstractmethod + async def cancel_task( + self, + agent: AgentEntity, + task: TaskEntity, + service_url: str, + ) -> dict[str, Any]: + """Cancel a running task.""" + ... + + async def send_event( + self, + agent: AgentEntity, + event: EventEntity, + task: TaskEntity, + service_url: str, + request_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Send an event to a running task. Not all protocols support events.""" + raise NotImplementedError("This protocol does not support events") + + async def check_health( + self, + agent_id: str, + service_url: str, + ) -> bool: + """Check if the agent server is healthy.""" + raise NotImplementedError("This protocol does not support health checks") diff --git a/agentex/src/domain/services/schedule_service.py b/agentex/src/domain/services/schedule_service.py index 07f08d5b..295c7208 100644 --- a/agentex/src/domain/services/schedule_service.py +++ b/agentex/src/domain/services/schedule_service.py @@ -276,7 +276,10 @@ def _description_to_response( # Decode bytes to string if possible try: import json - workflow_params.append(json.loads(arg.data.decode("utf-8"))) + + workflow_params.append( + json.loads(arg.data.decode("utf-8")) + ) except (json.JSONDecodeError, UnicodeDecodeError): workflow_params.append(str(arg.data)) else: @@ -314,7 +317,9 @@ def _description_to_response( if hasattr(info, "recent_actions") and info.recent_actions: # ScheduleActionResult has started_at (when action started) and scheduled_at (when it was scheduled) last_action = info.recent_actions[-1] - last_action_time = getattr(last_action, "started_at", None) or getattr(last_action, "scheduled_at", None) + last_action_time = getattr(last_action, "started_at", None) or getattr( + last_action, "scheduled_at", None + ) created_at: datetime | None = ( cast(datetime, info.create_time) if hasattr(info, "create_time") and info.create_time diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 9afa3e4b..8102af8b 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -13,7 +13,7 @@ from src.domain.repositories.event_repository import DEventRepository from src.domain.repositories.task_repository import DTaskRepository from src.domain.repositories.task_state_repository import DTaskStateRepository -from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.agent_acp_service import DAgentProtocolGateway from src.utils.ids import orm_id from src.utils.logging import make_logger from src.utils.stream_topics import get_task_event_stream_topic @@ -23,18 +23,18 @@ class AgentTaskService: """ - Service for managing agent tasks and forwarding operations to ACP servers. + Service for managing agent tasks and forwarding operations to agent servers. """ def __init__( self, - acp_client: DAgentACPService, + protocol_gateway: DAgentProtocolGateway, task_state_repository: DTaskStateRepository, task_repository: DTaskRepository, event_repository: DEventRepository, stream_repository: DRedisStreamRepository, ): - self.acp_client = acp_client + self.protocol_gateway = protocol_gateway self.task_state_repository = task_state_repository self.task_repository = task_repository self.event_repository = event_repository @@ -63,13 +63,13 @@ async def create_task( id=orm_id(), name=task_name, status=TaskStatus.RUNNING, - status_reason="Task created, forwarding to ACP server", + status_reason="Task created, forwarding to agent server", params=task_params, ), ) return task_entity - async def create_task_and_forward_to_acp( + async def create_task_and_forward( self, agent: AgentEntity, task_name: str | None = None, @@ -77,11 +77,11 @@ async def create_task_and_forward_to_acp( ) -> TaskEntity: """ Create a new task record in the repository with single agent (maintains existing interface). - Then, forward the task to the ACP server. + Then, forward the task to the agent server. Args: agent: The agent to create the task for - task_params: The parameters for the task to be sent to the ACP server + task_params: The parameters for the task to be sent to the agent server Returns: Task containing the created task info @@ -92,38 +92,38 @@ async def create_task_and_forward_to_acp( if agent.acp_type == ACPType.SYNC: logger.info( - "For sync agents, there are no initialization handlers, skipping ACP call" + "For sync agents, there are no initialization handlers, skipping forwarding" ) return task_entity try: - await self.acp_client.create_task( + await self.protocol_gateway.create_task( agent=agent, task=task_entity, - acp_url=agent.acp_url, + service_url=agent.acp_url, params=task_params, ) return task_entity except Exception as e: - logger.error(f"Error creating task in ACP: {e}") + logger.error(f"Error creating task: {e}") await self.fail_task(task_entity, str(e)) raise e from e - async def forward_task_to_acp( + async def forward_task( self, agent: AgentEntity, task: TaskEntity, task_params: dict[str, Any] | None = None, - acp_url: str | None = None, + service_url: str | None = None, ) -> None: try: - await self.acp_client.create_task( + await self.protocol_gateway.create_task( agent=agent, task=task, - acp_url=acp_url or agent.acp_url, + service_url=service_url or agent.acp_url, params=task_params, ) except Exception as e: - logger.error(f"Error creating task in ACP: {e}") + logger.error(f"Error creating task: {e}") await self.fail_task(task, str(e)) raise e from e @@ -244,14 +244,14 @@ async def send_message( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> TaskMessageContentEntity: """Send a message to a running task""" - return await self.acp_client.send_message( + return await self.protocol_gateway.send_message( agent=agent, task=task, content=content, - acp_url=acp_url, + service_url=service_url, ) async def send_message_stream( @@ -259,49 +259,51 @@ async def send_message_stream( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> AsyncIterator[TaskMessageUpdateEntity]: """Send a message to a running task and stream the response""" logger.info(f"TaskService: Sending message stream for task {task.id}") - async for chunk in self.acp_client.send_message_stream( + async for chunk in self.protocol_gateway.send_message_stream( agent=agent, task=task, content=content, - acp_url=acp_url, + service_url=service_url, ): yield chunk async def cancel_task( - self, agent: AgentEntity, task: TaskEntity, acp_url: str + self, agent: AgentEntity, task: TaskEntity, service_url: str ) -> TaskEntity: """Cancel a running task""" - await self.acp_client.cancel_task(agent=agent, task=task, acp_url=acp_url) + await self.protocol_gateway.cancel_task( + agent=agent, task=task, service_url=service_url + ) task = await self.task_repository.get(id=task.id) task.status = TaskStatus.CANCELED task.status_reason = "Task canceled by user" return await self.task_repository.update(task) - async def create_event_and_forward_to_acp( + async def create_event_and_forward( self, agent: AgentEntity, task: TaskEntity, - acp_url: str, + service_url: str, content: TaskMessageContentEntity | None = None, request_headers: dict[str, str] | None = None, ) -> EventEntity: - """Create an event and forward it to the ACP server""" + """Create an event and forward it to the agent server""" event = await self.event_repository.create( id=orm_id(), task_id=task.id, agent_id=agent.id, content=content, ) - await self.acp_client.send_event( + await self.protocol_gateway.send_event( agent=agent, event=event, task=task, - acp_url=acp_url, + service_url=service_url, request_headers=request_headers, ) return event diff --git a/agentex/src/domain/use_cases/agents_acp_use_case.py b/agentex/src/domain/use_cases/agents_acp_use_case.py index 62abb5ad..f038b001 100644 --- a/agentex/src/domain/use_cases/agents_acp_use_case.py +++ b/agentex/src/domain/use_cases/agents_acp_use_case.py @@ -48,7 +48,7 @@ from src.domain.mixins.task_messages.task_message_mixin import TaskMessageMixin from src.domain.repositories.agent_repository import DAgentRepository from src.domain.repositories.deployment_repository import DDeploymentRepository -from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.agent_acp_service import DAgentProtocolGateway from src.domain.services.authorization_service import DAuthorizationService from src.domain.services.task_message_service import DTaskMessageService from src.domain.services.task_service import DAgentTaskService @@ -189,14 +189,14 @@ def __init__( self, agent_repository: DAgentRepository, deployment_repository: DDeploymentRepository, - acp_client: DAgentACPService, + protocol_gateway: DAgentProtocolGateway, task_service: DAgentTaskService, task_message_service: DTaskMessageService, authorization_service: DAuthorizationService, ): self.agent_repository = agent_repository self.deployment_repo = deployment_repository - self.acp_client = acp_client + self.protocol_gateway = protocol_gateway self.task_service = task_service self.task_message_service = task_message_service self.authorization_service = authorization_service @@ -309,14 +309,14 @@ async def _get_or_create_task( await self.grant_with_retry(task) return task - async def _resolve_acp_url( + async def _resolve_service_url( self, agent: AgentEntity, - acp_url_override: str | None = None, + service_url_override: str | None = None, ) -> str: - """Resolve the ACP URL for an agent, optionally overriding with a specific URL.""" - if acp_url_override: - return acp_url_override + """Resolve the service URL for an agent, optionally overriding with a specific URL.""" + if service_url_override: + return service_url_override # Resolve through production deployment if available if agent.production_deployment_id: @@ -330,7 +330,7 @@ async def _resolve_acp_url( if agent.acp_url: return agent.acp_url - raise ClientError(f"Agent {agent.id} does not have an ACP URL configured") + raise ClientError(f"Agent {agent.id} does not have a service URL configured") async def handle_rpc_request( self, @@ -342,7 +342,7 @@ async def handle_rpc_request( agent_id: str | None = None, agent_name: str | None = None, request_headers: dict[str, str] | None = None, - acp_url_override: str | None = None, + service_url_override: str | None = None, ) -> ( list[TaskMessageEntity] | AsyncIterator[TaskMessageUpdateEntity] @@ -358,7 +358,7 @@ async def handle_rpc_request( method: JSON-RPC method name params: JSON-RPC parameters request_headers: HTTP headers from the incoming request - acp_url_override: Override ACP URL (for preview deployment routing) + service_url_override: Override service URL (for preview deployment routing) Returns: - list[TaskMessageEntity] for synchronous MESSAGE_SEND @@ -373,7 +373,7 @@ async def handle_rpc_request( if agent.status == AgentStatus.DELETED: raise ClientError(f"Agent {agent_id} is deleted") - acp_url = await self._resolve_acp_url(agent, acp_url_override) + service_url = await self._resolve_service_url(agent, service_url_override) logger.info( f"[handle_rpc_request] Validating RPC method for ACP type: {agent.acp_type} - {method}" @@ -382,20 +382,20 @@ async def handle_rpc_request( # Handle different methods if method == AgentRPCMethod.MESSAGE_SEND: - return await self._handle_message_send(agent, params, acp_url) + return await self._handle_message_send(agent, params, service_url) elif method == AgentRPCMethod.TASK_CREATE: - return await self._handle_task_create(agent, params, acp_url) + return await self._handle_task_create(agent, params, service_url) elif method == AgentRPCMethod.TASK_CANCEL: - return await self._handle_task_cancel(agent, params, acp_url) + return await self._handle_task_cancel(agent, params, service_url) elif method == AgentRPCMethod.EVENT_SEND: return await self._handle_event_send( - agent, params, request_headers, acp_url + agent, params, request_headers, service_url ) else: raise ValueError(f"Unsupported method: {method}") async def _handle_task_create( - self, agent: AgentEntity, params: CreateTaskRequestEntity, acp_url: str + self, agent: AgentEntity, params: CreateTaskRequestEntity, service_url: str ) -> TaskEntity: """ Handle task/create method. @@ -403,27 +403,27 @@ async def _handle_task_create( Args: agent: The agent to create the task for params: Parameters containing task and initial message - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: Task containing the created task info """ - # This creates the task record then forwards the message to the ACP server + # This creates the task record then forwards the message to the agent server task = await self._get_or_create_task( agent=agent, task_name=params.name, task_params=params.params ) if agent.acp_type in [ACPType.AGENTIC, ACPType.ASYNC]: - await self.task_service.forward_task_to_acp( + await self.task_service.forward_task( agent=agent, task=task, task_params=params.params, - acp_url=acp_url, + service_url=service_url, ) return task async def _handle_message_send( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> list[TaskMessageEntity] | AsyncIterator[TaskMessageUpdateEntity]: """ Handle message/send method. @@ -431,18 +431,18 @@ async def _handle_message_send( Args: agent: The agent to send the message to params: Parameters containing task_id and message - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: TaskMessageEntry for synchronous requests or AsyncIterator[TaskMessage] for streaming """ if params.stream: - return self._handle_message_send_stream(agent, params, acp_url) + return self._handle_message_send_stream(agent, params, service_url) else: - return await self._handle_message_send_sync(agent, params, acp_url) + return await self._handle_message_send_sync(agent, params, service_url) async def _handle_message_send_sync( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> list[TaskMessageEntity]: task = await self._get_or_create_task( agent=agent, @@ -494,7 +494,7 @@ async def flush_aggregated_deltas( agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, ): logger.debug( f"[message_send_stream] Received message chunk: {task_message_update}" @@ -569,7 +569,7 @@ async def flush_aggregated_deltas( return new_task_message_entities async def _handle_message_send_stream( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> AsyncIterator[TaskMessageUpdateEntity]: """Handle streaming message send - yields raw TaskMessage objects""" @@ -647,7 +647,7 @@ async def flush_aggregated_deltas(task_message_index: int) -> TaskMessageEntity: agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, ): logger.debug( f"[message_send_stream] Received message chunk type: {type(task_message_update).__name__}" @@ -769,7 +769,7 @@ async def flush_aggregated_deltas(task_message_index: int) -> TaskMessageEntity: return async def _handle_task_cancel( - self, agent: AgentEntity, params: CancelTaskRequestEntity, acp_url: str + self, agent: AgentEntity, params: CancelTaskRequestEntity, service_url: str ) -> TaskEntity: """ Handle task/cancel method. @@ -777,7 +777,7 @@ async def _handle_task_cancel( Args: agent: The agent to cancel the task for params: Parameters containing task_id - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: Dict containing the cancellation result @@ -790,7 +790,7 @@ async def _handle_task_cancel( return await self.task_service.cancel_task( agent=agent, task=task, - acp_url=acp_url, + service_url=service_url, ) async def _handle_event_send( @@ -798,7 +798,7 @@ async def _handle_event_send( agent: AgentEntity, params: SendEventRequestEntity, request_headers: dict[str, str] | None = None, - acp_url: str = "", + service_url: str = "", ) -> EventEntity: """ Handle event/send method @@ -807,7 +807,7 @@ async def _handle_event_send( agent: The agent to send the event to params: Parameters containing task_id and event data request_headers: HTTP headers from the incoming request - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: EventEntity for the created and forwarded event @@ -820,11 +820,11 @@ async def _handle_event_send( id=params.task_id, name=params.task_name ) # Create the event in the DB - event_entity = await self.task_service.create_event_and_forward_to_acp( + event_entity = await self.task_service.create_event_and_forward( agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, request_headers=request_headers, ) return event_entity diff --git a/agentex/src/domain/use_cases/agents_use_case.py b/agentex/src/domain/use_cases/agents_use_case.py index fcb1bc9e..a0a535a7 100644 --- a/agentex/src/domain/use_cases/agents_use_case.py +++ b/agentex/src/domain/use_cases/agents_use_case.py @@ -9,7 +9,13 @@ TemporalWorkflowAlreadyExistsError, ) from src.config.environment_variables import EnvironmentVariables -from src.domain.entities.agents import ACPType, AgentEntity, AgentInputType, AgentStatus +from src.domain.entities.agents import ( + ACPType, + AgentEntity, + AgentInputType, + AgentProtocol, + AgentStatus, +) from src.domain.entities.deployments import DeploymentEntity, DeploymentStatus from src.domain.repositories.agent_repository import DAgentRepository from src.domain.repositories.deployment_history_repository import ( @@ -45,6 +51,7 @@ async def register_agent( acp_type: ACPType = ACPType.ASYNC, registration_metadata: dict[str, Any] | None = None, agent_input_type: AgentInputType | None = None, + protocol: AgentProtocol = AgentProtocol.ACP, ) -> AgentEntity: deployment_id = (registration_metadata or {}).get("deployment_id") @@ -67,6 +74,7 @@ async def register_agent( agent.status = AgentStatus.READY agent.status_reason = "Agent registered successfully." agent.acp_type = acp_type + agent.protocol = protocol if agent_input_type: agent.agent_input_type = agent_input_type if registration_metadata: @@ -97,6 +105,7 @@ async def register_agent( agent.status = AgentStatus.READY agent.status_reason = "Agent registered successfully." agent.acp_type = acp_type + agent.protocol = protocol if registration_metadata: existing_metadata = agent.registration_metadata or {} existing_metadata.update(registration_metadata) @@ -132,6 +141,7 @@ async def register_agent( status_reason="Agent registered successfully.", acp_url=acp_url, acp_type=acp_type, + protocol=protocol, registration_metadata=registration_metadata, registered_at=datetime.now(UTC), agent_input_type=agent_input_type, diff --git a/agentex/src/temporal/activities/healthcheck_activities.py b/agentex/src/temporal/activities/healthcheck_activities.py index 4c2a22c2..16f836d6 100644 --- a/agentex/src/temporal/activities/healthcheck_activities.py +++ b/agentex/src/temporal/activities/healthcheck_activities.py @@ -6,11 +6,9 @@ the status checks and the database updates. """ -import json - -import httpx from src.domain.entities.agents import AgentStatus from src.domain.repositories.agent_repository import AgentRepository +from src.domain.services.agent_protocol_gateway import AgentProtocolGateway from src.utils.logging import make_logger from temporalio import activity @@ -37,10 +35,12 @@ class HealthCheckActivities: - Updating agent status in the database """ - def __init__(self, agent_repo: AgentRepository, http_client: httpx.AsyncClient): - """Initialize with session maker and http client.""" + def __init__( + self, agent_repo: AgentRepository, protocol_gateway: AgentProtocolGateway + ): + """Initialize with agent repository and protocol gateway.""" self.agent_repo = agent_repo - self.http_client = http_client + self.protocol_gateway = protocol_gateway @activity.defn(name=CHECK_STATUS_ACTIVITY) async def check_status_activity(self, agent_id: str, acp_url: str) -> bool: @@ -55,36 +55,9 @@ async def check_status_activity(self, agent_id: str, acp_url: str) -> bool: bool: True if the agent is healthy, False otherwise """ logger.info(f"Checking status of agent {agent_id} via {acp_url}") - try: - response = await self.http_client.get(f"{acp_url}/healthz", timeout=5) - if response.status_code != 200: - logger.error( - f"Agent {agent_id} returned non-200 status: {response.status_code}" - ) - return False - try: - parsed_response = response.json() - status = parsed_response.get("status") - if status != "healthy": - logger.error( - f"Agent {agent_id} returned non-healthy status: {status}" - ) - return False - response_agent_id = parsed_response.get("agent_id") - if response_agent_id and response_agent_id != agent_id: - logger.error( - f"Agent {agent_id} returned unexpected agent ID: {response_agent_id}" - ) - return False - except json.JSONDecodeError: - logger.error( - f"Agent {agent_id} returned non-JSON response: {response.text}" - ) - return False - return True - except Exception as e: - logger.error(f"Failed to check status of agent {agent_id}: {e}") - return False + return await self.protocol_gateway.check_health( + agent_id=agent_id, service_url=acp_url + ) @activity.defn(name=UPDATE_AGENT_STATUS_ACTIVITY) async def update_agent_status_activity(self, agent_id: str, status: str) -> None: diff --git a/agentex/src/temporal/run_worker.py b/agentex/src/temporal/run_worker.py index d3665094..9020d4ef 100644 --- a/agentex/src/temporal/run_worker.py +++ b/agentex/src/temporal/run_worker.py @@ -9,19 +9,20 @@ import uuid from concurrent.futures import ThreadPoolExecutor -import httpx from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from src.adapters.http.adapter_httpx import HttpxGateway from src.adapters.temporal.client_factory import TemporalClientFactory from src.config.dependencies import ( database_async_read_only_session_maker, database_async_read_write_engine, database_async_read_write_session_maker, - httpx_client, startup_global_dependencies, ) from src.config.environment_variables import EnvironmentVariables +from src.domain.repositories.agent_api_key_repository import AgentAPIKeyRepository from src.domain.repositories.agent_repository import AgentRepository +from src.domain.services.agent_acp_service import AgentACPService from src.temporal.activities.healthcheck_activities import HealthCheckActivities from src.temporal.workflows.healthcheck_workflow import HealthCheckWorkflow from src.utils.logging import make_logger @@ -122,7 +123,8 @@ async def run_worker( def create_health_check_worker( - agent_repo: AgentRepository, http_client: httpx.AsyncClient + agent_repo: AgentRepository, + agent_api_key_repo: AgentAPIKeyRepository, ) -> asyncio.Task: """ Create a Health Check worker. @@ -133,10 +135,19 @@ def create_health_check_worker( logger.info("Starting Temporal Health Check Worker") logger.info(f"Task queue: {task_queue}") + # Construct protocol gateway + environment_variables = EnvironmentVariables.refresh() + http_gateway = HttpxGateway(environment_variables=environment_variables) + protocol_gateway = AgentACPService( + agent_repository=agent_repo, + agent_api_key_repository=agent_api_key_repo, + http_gateway=http_gateway, + ) + # Create activities instance with dependencies health_check_activities = HealthCheckActivities( agent_repo=agent_repo, - http_client=httpx_client(), + protocol_gateway=protocol_gateway, ) # Extract activity methods @@ -169,9 +180,12 @@ async def main() -> None: session_maker = database_async_read_write_session_maker(engine) read_only_session_maker = database_async_read_only_session_maker(engine) agent_repo = AgentRepository(session_maker, read_only_session_maker) + agent_api_key_repo = AgentAPIKeyRepository( + session_maker, read_only_session_maker + ) health_check_worker_task = create_health_check_worker( agent_repo=agent_repo, - http_client=httpx_client(), + agent_api_key_repo=agent_api_key_repo, ) # Wait for the worker to complete await health_check_worker_task diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index 30b16ca2..9f583c74 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -44,7 +44,7 @@ def create_task_service( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, stream_repository=redis_stream_repository, ) diff --git a/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py index 062e8c86..4a4fdddd 100644 --- a/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py +++ b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py @@ -640,7 +640,12 @@ async def test_null_blob_stored_correctly(self, isolated_repositories): repo = isolated_repositories["checkpoint_repository"] blobs = [ - {"channel": "empty_channel", "version": "v1", "type": "empty", "blob": None}, + { + "channel": "empty_channel", + "version": "v1", + "type": "empty", + "blob": None, + }, ] checkpoint = { "id": "cp-1", diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index ade4bdc4..7b338a4d 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -259,6 +259,7 @@ async def __aenter__(self): from src.domain.repositories.agent_task_tracker_repository import ( AgentTaskTrackerRepository, ) + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.deployment_history_repository import ( DeploymentHistoryRepository, ) @@ -266,7 +267,6 @@ async def __aenter__(self): from src.domain.repositories.span_repository import SpanRepository from src.domain.repositories.task_message_repository import TaskMessageRepository from src.domain.repositories.task_repository import TaskRepository - from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.task_state_repository import TaskStateRepository # Create Redis repository with mock environment variables @@ -370,6 +370,7 @@ async def isolated_integration_app( from src.domain.use_cases.agent_api_keys_use_case import AgentAPIKeysUseCase from src.domain.use_cases.agent_task_tracker_use_case import AgentTaskTrackerUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase + from src.domain.use_cases.checkpoints_use_case import CheckpointsUseCase from src.domain.use_cases.deployment_history_use_case import ( DeploymentHistoryUseCase, ) @@ -377,7 +378,6 @@ async def isolated_integration_app( from src.domain.use_cases.messages_use_case import MessagesUseCase from src.domain.use_cases.spans_use_case import SpanUseCase from src.domain.use_cases.states_use_case import StatesUseCase - from src.domain.use_cases.checkpoints_use_case import CheckpointsUseCase from src.domain.use_cases.tasks_use_case import TasksUseCase # Create use case factory functions with isolated repositories @@ -436,7 +436,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], @@ -467,12 +467,12 @@ def create_messages_use_case(): DDatabaseAsyncReadWriteSessionMaker, DMongoDBDatabase, ) - from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.agent_api_key_repository import AgentAPIKeyRepository from src.domain.repositories.agent_repository import AgentRepository from src.domain.repositories.agent_task_tracker_repository import ( AgentTaskTrackerRepository, ) + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.deployment_history_repository import ( DeploymentHistoryRepository, ) diff --git a/agentex/tests/integration/test_task_stream.py b/agentex/tests/integration/test_task_stream.py index 14c85807..a91982e8 100644 --- a/agentex/tests/integration/test_task_stream.py +++ b/agentex/tests/integration/test_task_stream.py @@ -71,7 +71,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], @@ -98,7 +98,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], diff --git a/agentex/tests/unit/services/test_agent_acp_service.py b/agentex/tests/unit/services/test_agent_acp_service.py index 8074036a..13e8ecc0 100644 --- a/agentex/tests/unit/services/test_agent_acp_service.py +++ b/agentex/tests/unit/services/test_agent_acp_service.py @@ -150,7 +150,7 @@ async def test_create_task_success( result = await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", params={"test_param": "value"}, ) @@ -208,7 +208,7 @@ async def test_send_message_success( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -264,7 +264,7 @@ async def test_send_message_success_data( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -320,7 +320,7 @@ async def test_send_message_success_tool_request( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -382,7 +382,7 @@ async def test_send_message_success_tool_response( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -426,7 +426,7 @@ async def test_cancel_task_success( result = await agent_acp_service.cancel_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -468,7 +468,7 @@ async def test_send_event_success( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -522,7 +522,7 @@ async def test_send_event_with_request_headers( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", request_headers=request_headers, ) @@ -581,7 +581,7 @@ async def test_send_event_without_request_headers( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -629,7 +629,7 @@ async def test_jsonrpc_error_handling( await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) assert "RPC error" in str(exc_info.value) @@ -655,7 +655,7 @@ async def test_http_gateway_error_handling( await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) assert "Connection timeout" in str(exc_info.value) diff --git a/agentex/tests/unit/services/test_task_service.py b/agentex/tests/unit/services/test_task_service.py index eb096eb1..a8968923 100644 --- a/agentex/tests/unit/services/test_task_service.py +++ b/agentex/tests/unit/services/test_task_service.py @@ -34,8 +34,8 @@ async def create_or_get_agent(agent_repository, agent): @pytest.fixture -def mock_acp_client(): - """Mock ACP client for testing service interactions (external dependency)""" +def mock_protocol_gateway(): + """Mock protocol gateway for testing service interactions (external dependency)""" mock = AsyncMock() mock.create_task = AsyncMock() mock.send_message = AsyncMock() @@ -71,15 +71,15 @@ def event_repository(postgres_session_maker): @pytest.fixture def task_service( - mock_acp_client, + mock_protocol_gateway, task_repository, task_state_repository, event_repository, redis_stream_repository, ): - """Create TaskService instance with real repositories and mocked ACP client""" + """Create TaskService instance with real repositories and mocked protocol gateway""" return AgentTaskService( - acp_client=mock_acp_client, + protocol_gateway=mock_protocol_gateway, task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, @@ -154,7 +154,7 @@ async def test_create_task_success( assert result.id is not None assert result.name == "integration-test" assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_without_name( self, task_service, task_repository, agent_repository, sample_agent @@ -171,7 +171,7 @@ async def test_create_task_without_name( assert result.id is not None assert result.name is None assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_with_params( self, task_service, agent_repository, sample_agent @@ -197,7 +197,7 @@ async def test_create_task_with_params( assert result.name == "task-with-params" assert result.params == task_params assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_with_params_retrieval( self, task_service, agent_repository, sample_agent @@ -248,10 +248,10 @@ async def test_create_task_with_null_params( retrieved_task = await task_service.get_task(id=result.id) assert retrieved_task.params is None - async def test_create_task_and_forward_to_acp_success( + async def test_create_task_and_forward_success( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -260,10 +260,10 @@ async def test_create_task_and_forward_to_acp_success( # Given - Persist the agent first to satisfy foreign key constraints await create_or_get_agent(agent_repository, sample_agent) task_params = {"param1": "value1", "param2": "value2"} - mock_acp_client.create_task.return_value = None + mock_protocol_gateway.create_task.return_value = None # When - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sample_agent, task_name="forwarded-task", task_params=task_params ) @@ -273,20 +273,20 @@ async def test_create_task_and_forward_to_acp_success( assert result.name == "forwarded-task" assert result.params == task_params # Verify params are stored in the task assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" - # Verify ACP client was called with correct parameters - mock_acp_client.create_task.assert_called_once_with( + # Verify protocol gateway was called with correct parameters + mock_protocol_gateway.create_task.assert_called_once_with( agent=sample_agent, task=result, # Use the actual created task - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, params=task_params, ) async def test_create_task_and_forward_sync_agent_skips_acp( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_sync_agent, @@ -296,7 +296,7 @@ async def test_create_task_and_forward_sync_agent_skips_acp( await create_or_get_agent(agent_repository, sample_sync_agent) # When - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sample_sync_agent, task_name="sync-task" ) @@ -305,15 +305,15 @@ async def test_create_task_and_forward_sync_agent_skips_acp( assert result.id is not None assert result.name == "sync-task" assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" # Verify ACP client was NOT called for sync agents - mock_acp_client.create_task.assert_not_called() + mock_protocol_gateway.create_task.assert_not_called() async def test_create_task_and_forward_acp_error_handling( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -322,16 +322,16 @@ async def test_create_task_and_forward_acp_error_handling( # Given - Persist the agent first to satisfy foreign key constraints await create_or_get_agent(agent_repository, sample_agent) acp_error = Exception("ACP server unavailable") - mock_acp_client.create_task.side_effect = acp_error + mock_protocol_gateway.create_task.side_effect = acp_error # When / Then with pytest.raises(Exception) as exc_info: - await task_service.create_task_and_forward_to_acp(agent=sample_agent) + await task_service.create_task_and_forward(agent=sample_agent) assert str(exc_info.value) == "ACP server unavailable" # Verify task was created but then marked as failed due to ACP error - mock_acp_client.create_task.assert_called_once() + mock_protocol_gateway.create_task.assert_called_once() async def test_fail_task( self, task_service, task_repository, agent_repository, sample_agent @@ -519,7 +519,7 @@ async def test_list_tasks( async def test_send_message( self, task_service, - mock_acp_client, + mock_protocol_gateway, sample_agent, sample_task, sample_message_content, @@ -527,30 +527,30 @@ async def test_send_message( """Test sending message to task""" # Given acp_url = "http://test-acp.example.com" - mock_acp_client.send_message.return_value = sample_message_content + mock_protocol_gateway.send_message.return_value = sample_message_content # When result = await task_service.send_message( agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ) # Then assert result == sample_message_content - mock_acp_client.send_message.assert_called_once_with( + mock_protocol_gateway.send_message.assert_called_once_with( agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ) # async def test_send_message_stream( self, task_service, - mock_acp_client, + mock_protocol_gateway, sample_agent, sample_task, sample_message_content, @@ -574,7 +574,7 @@ async def mock_stream_method(*args, **kwargs): yield update # Replace the entire method with our async generator - mock_acp_client.send_message_stream = mock_stream_method + mock_protocol_gateway.send_message_stream = mock_stream_method # When updates = [] @@ -582,7 +582,7 @@ async def mock_stream_method(*args, **kwargs): agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ): updates.append(update) @@ -601,7 +601,7 @@ async def mock_stream_method(*args, **kwargs): async def test_cancel_task( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -616,7 +616,7 @@ async def test_cancel_task( # When result = await task_service.cancel_task( - agent=sample_agent, task=created_task, acp_url=acp_url + agent=sample_agent, task=created_task, service_url=acp_url ) # Then @@ -624,8 +624,8 @@ async def test_cancel_task( assert result.status_reason == "Task canceled by user" # Verify ACP client was called to cancel the task - mock_acp_client.cancel_task.assert_called_once_with( - agent=sample_agent, task=created_task, acp_url=acp_url + mock_protocol_gateway.cancel_task.assert_called_once_with( + agent=sample_agent, task=created_task, service_url=acp_url ) # Verify task status is updated in the database @@ -633,10 +633,10 @@ async def test_cancel_task( assert updated_task.status == TaskStatus.CANCELED assert updated_task.status_reason == "Task canceled by user" - async def test_create_event_and_forward_to_acp( + async def test_create_event_and_forward( self, task_service, - mock_acp_client, + mock_protocol_gateway, event_repository, agent_repository, sample_agent, @@ -651,10 +651,10 @@ async def test_create_event_and_forward_to_acp( acp_url = "http://test-acp.example.com" # When - result = await task_service.create_event_and_forward_to_acp( + result = await task_service.create_event_and_forward( agent=sample_agent, task=created_task, - acp_url=acp_url, + service_url=acp_url, content=sample_message_content, ) @@ -669,18 +669,18 @@ async def test_create_event_and_forward_to_acp( assert result.content.content == sample_message_content.content # Verify ACP client was called to send the event - mock_acp_client.send_event.assert_called_once_with( + mock_protocol_gateway.send_event.assert_called_once_with( agent=sample_agent, event=result, # Use the actual created event task=created_task, - acp_url=acp_url, + service_url=acp_url, request_headers=None, ) - async def test_create_event_and_forward_to_acp_with_headers( + async def test_create_event_and_forward_with_headers( self, task_service, - mock_acp_client, + mock_protocol_gateway, event_repository, agent_repository, sample_agent, @@ -699,10 +699,10 @@ async def test_create_event_and_forward_to_acp_with_headers( } # When - result = await task_service.create_event_and_forward_to_acp( + result = await task_service.create_event_and_forward( agent=sample_agent, task=created_task, - acp_url=acp_url, + service_url=acp_url, content=sample_message_content, request_headers=request_headers, ) @@ -715,11 +715,11 @@ async def test_create_event_and_forward_to_acp_with_headers( assert result.content is not None # Verify ACP client was called with request_headers - mock_acp_client.send_event.assert_called_once_with( + mock_protocol_gateway.send_event.assert_called_once_with( agent=sample_agent, event=result, task=created_task, - acp_url=acp_url, + service_url=acp_url, request_headers=request_headers, ) diff --git a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py index 60914adc..4e52343b 100644 --- a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py +++ b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py @@ -21,7 +21,6 @@ from src.domain.repositories.task_state_repository import TaskStateRepository from src.domain.services.agent_acp_service import AgentACPService from src.domain.services.task_service import AgentTaskService -from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase @@ -55,29 +54,25 @@ async def test_agentic_agent_in_allowed_methods_dictionary(self): assert ACPType.ASYNC in ACP_TYPE_TO_ALLOWED_RPC_METHODS assert ACPType.SYNC in ACP_TYPE_TO_ALLOWED_RPC_METHODS + @pytest.mark.asyncio @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_task_create(self): - """Verify AGENTIC agents can use task/create method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.TASK_CREATE - ) + """Verify AGENTIC agents can use task/create method via gateway""" + # Use the static ACP_TYPE_TO_ALLOWED_RPC_METHODS for backward compat validation + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.TASK_CREATE in allowed @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_event_send(self): - """Verify AGENTIC agents can use event/send method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.EVENT_SEND - ) + """Verify AGENTIC agents can use event/send method via gateway""" + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.EVENT_SEND in allowed @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_task_cancel(self): - """Verify AGENTIC agents can use task/cancel method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.TASK_CANCEL - ) + """Verify AGENTIC agents can use task/cancel method via gateway""" + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.TASK_CANCEL in allowed @pytest.mark.asyncio async def test_agentic_agent_forwards_task_to_acp(self): @@ -90,7 +85,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -117,7 +112,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): acp_client.create_task.return_value = None # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=agentic_agent, task_name="test-task", task_params={"test": "params"}, @@ -127,7 +122,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): acp_client.create_task.assert_called_once_with( agent=agentic_agent, task=task, - acp_url=agentic_agent.acp_url, + service_url=agentic_agent.acp_url, params={"test": "params"}, ) assert result == task @@ -143,7 +138,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -169,7 +164,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): task_repo.create.return_value = task # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sync_agent, task_name="test-task", task_params={"test": "params"}, @@ -190,7 +185,7 @@ async def test_async_agent_forwards_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -217,7 +212,7 @@ async def test_async_agent_forwards_task_to_acp(self): acp_client.create_task.return_value = None # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=async_agent, task_name="test-task", task_params={"test": "params"}, @@ -227,7 +222,7 @@ async def test_async_agent_forwards_task_to_acp(self): acp_client.create_task.assert_called_once_with( agent=async_agent, task=task, - acp_url=async_agent.acp_url, + service_url=async_agent.acp_url, params={"test": "params"}, ) assert result == task diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index 08e31bfc..25703cb3 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -125,7 +125,7 @@ def task_service( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, stream_repository=redis_stream_repository, ) @@ -165,7 +165,7 @@ def agents_acp_use_case( return AgentsACPUseCase( agent_repository=agent_repository, deployment_repository=deployment_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, task_service=task_service, task_message_service=task_message_service, authorization_service=authorization_service, From 13d9d54b9261e825c8be435ecf8a3ced77dd2bc3 Mon Sep 17 00:00:00 2001 From: Prassanna Ravishankar Date: Tue, 7 Apr 2026 18:49:08 +0100 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20rename=20acp=5Furl=E2=86=92service?= =?UTF-8?q?=5Furl=20in=20use=20case=20test=20kwargs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../use_cases/test_agents_acp_use_case.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index 25703cb3..1bf48c86 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -290,7 +290,7 @@ def create_mock_stream(*args, **kwargs): await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) async def test_handle_message_send_sync_success( @@ -346,7 +346,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert isinstance(result, list) @@ -413,7 +413,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert isinstance(result, list) @@ -513,7 +513,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -639,7 +639,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -795,7 +795,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -880,7 +880,7 @@ def create_mock_stream_error(*args, **kwargs): await agents_acp_use_case._handle_task_create( agent=sample_agent, params=create_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert "ACP server connection failed" in str(exc_info.value) @@ -919,7 +919,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_create( agent=sample_agent, params=create_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -964,7 +964,7 @@ def create_mock_stream_error(*args, **kwargs): await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert "ACP server connection failed" in str(exc_info.value) @@ -1034,7 +1034,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -1319,7 +1319,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -1433,7 +1433,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_cancel( agent=sample_agent, params=cancel_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1479,7 +1479,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_cancel( agent=sample_agent, params=cancel_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1528,7 +1528,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1603,7 +1603,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1701,7 +1701,7 @@ async def mock_async_call(*args, **kwargs): agent=sample_agent, params=event_request, request_headers=request_headers, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1761,7 +1761,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1852,7 +1852,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1940,7 +1940,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -2043,7 +2043,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -2140,7 +2140,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -2250,7 +2250,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then - params should be updated @@ -2318,7 +2318,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then - task should not be updated