diff --git a/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py b/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py new file mode 100644 index 00000000..76b1ca7c --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_04_07_1538_add_agent_protocol_field_a064de6df78e.py @@ -0,0 +1,28 @@ +"""add_agent_protocol_field + +Revision ID: a064de6df78e +Revises: 4a9b7787ccd7 +Create Date: 2026-04-07 15:38:00.000000 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a064de6df78e" +down_revision: Union[str, None] = "4a9b7787ccd7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "agents", + sa.Column("protocol", sa.String(), nullable=False, server_default="acp"), + ) + + +def downgrade() -> None: + op.drop_column("agents", "protocol") diff --git a/agentex/src/adapters/orm.py b/agentex/src/adapters/orm.py index 8fbd6d08..d362476f 100644 --- a/agentex/src/adapters/orm.py +++ b/agentex/src/adapters/orm.py @@ -39,6 +39,7 @@ class AgentORM(BaseORM): acp_url = Column(String, nullable=True) # URL of the agent's ACP server # TODO: make this a SQLAlchemyEnum rather than a string acp_type = Column(String, nullable=False, server_default="async") + protocol = Column(String, nullable=False, server_default="acp") created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() diff --git a/agentex/src/api/routes/agents.py b/agentex/src/api/routes/agents.py index ef33bcea..1449ccdc 100644 --- a/agentex/src/api/routes/agents.py +++ b/agentex/src/api/routes/agents.py @@ -189,6 +189,7 @@ async def register_agent( acp_type=request.acp_type, registration_metadata=request.registration_metadata, agent_input_type=request.agent_input_type, + protocol=request.protocol, ) await authorization_service.grant( AgentexResource.agent(agent_entity.id), diff --git a/agentex/src/api/routes/checkpoints.py b/agentex/src/api/routes/checkpoints.py index 81f3a831..48428e9c 100644 --- a/agentex/src/api/routes/checkpoints.py +++ b/agentex/src/api/routes/checkpoints.py @@ -2,6 +2,10 @@ from fastapi import APIRouter, Response +from src.api.schemas.authorization_types import ( + AgentexResourceType, + AuthorizedOperationType, +) from src.api.schemas.checkpoints import ( BlobResponse, CheckpointListItem, @@ -14,10 +18,6 @@ PutWritesRequest, WriteResponse, ) -from src.api.schemas.authorization_types import ( - AgentexResourceType, - AuthorizedOperationType, -) from src.domain.use_cases.checkpoints_use_case import DCheckpointsUseCase from src.utils.authorization_shortcuts import DAuthorizedBodyId from src.utils.logging import make_logger @@ -95,7 +95,9 @@ async def put_checkpoint( request: PutCheckpointRequest, checkpoints_use_case: DCheckpointsUseCase, _authorized_task_id: DAuthorizedBodyId( - AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + AgentexResourceType.task, + AuthorizedOperationType.execute, + field_name="thread_id", ), ) -> PutCheckpointResponse: blobs = [ @@ -133,7 +135,9 @@ async def put_writes( request: PutWritesRequest, checkpoints_use_case: DCheckpointsUseCase, _authorized_task_id: DAuthorizedBodyId( - AgentexResourceType.task, AuthorizedOperationType.execute, field_name="thread_id" + AgentexResourceType.task, + AuthorizedOperationType.execute, + field_name="thread_id", ), ) -> Response: writes = [ diff --git a/agentex/src/api/routes/deployments.py b/agentex/src/api/routes/deployments.py index 0718794e..7f83ad7c 100644 --- a/agentex/src/api/routes/deployments.py +++ b/agentex/src/api/routes/deployments.py @@ -194,7 +194,7 @@ async def _handle_deployment_sync_rpc( method=request.method, params=request.params, request_headers=request_headers, - acp_url_override=acp_url, + service_url_override=acp_url, ) if isinstance(result_entity, AsyncIterator): @@ -231,7 +231,7 @@ async def rpc_response_generator(): method=request.method, params=request.params, request_headers=request_headers, - acp_url_override=acp_url, + service_url_override=acp_url, ) if not isinstance(result_entity_async_iterator, AsyncIterator): diff --git a/agentex/src/api/schemas/agents.py b/agentex/src/api/schemas/agents.py index 79724171..04f36fc5 100644 --- a/agentex/src/api/schemas/agents.py +++ b/agentex/src/api/schemas/agents.py @@ -34,6 +34,10 @@ class AgentInputType(str, Enum): JSON = "json" +class AgentProtocol(str, Enum): + ACP = "acp" + + class Agent(BaseModel): id: str = Field(..., description="The unique identifier of the agent.") name: str = Field(..., description="The unique name of the agent.") @@ -66,6 +70,10 @@ class Agent(BaseModel): agent_input_type: AgentInputType | None = Field( default=None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) production_deployment_id: str | None = Field( default=None, description="ID of the current production deployment." ) @@ -95,6 +103,10 @@ class RegisterAgentRequest(BaseModel): agent_input_type: AgentInputType | None = Field( default=None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) class RegisterAgentResponse(Agent): diff --git a/agentex/src/domain/entities/agents.py b/agentex/src/domain/entities/agents.py index 0d721de7..d465ef9d 100644 --- a/agentex/src/domain/entities/agents.py +++ b/agentex/src/domain/entities/agents.py @@ -27,6 +27,10 @@ class AgentInputType(str, Enum): JSON = "json" +class AgentProtocol(str, Enum): + ACP = "acp" + + class AgentEntity(BaseModel): id: str = Field(..., description="The unique identifier of the agent.") docker_image: str | None = Field( @@ -65,6 +69,10 @@ class AgentEntity(BaseModel): agent_input_type: AgentInputType | None = Field( None, description="The type of input the agent expects." ) + protocol: AgentProtocol = Field( + AgentProtocol.ACP, + description="The communication protocol used by this agent.", + ) production_deployment_id: str | None = Field( None, description="ID of the current production deployment." ) diff --git a/agentex/src/domain/entities/agents_rpc.py b/agentex/src/domain/entities/agents_rpc.py index 571ba104..a2ab69c6 100644 --- a/agentex/src/domain/entities/agents_rpc.py +++ b/agentex/src/domain/entities/agents_rpc.py @@ -85,6 +85,8 @@ class CancelTaskParams(BaseModel): task: TaskEntity = Field(..., description="The task that was cancelled") +# Deprecated: canonical source is AgentACPService.get_allowed_methods(). +# Kept for backward compatibility with existing tests. ACP_TYPE_TO_ALLOWED_RPC_METHODS = { ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE], ACPType.AGENTIC: [ diff --git a/agentex/src/domain/repositories/checkpoint_repository.py b/agentex/src/domain/repositories/checkpoint_repository.py index b320517d..567fd15d 100644 --- a/agentex/src/domain/repositories/checkpoint_repository.py +++ b/agentex/src/domain/repositories/checkpoint_repository.py @@ -248,9 +248,7 @@ async def list_checkpoints( self.async_ro_session_maker() as session, async_sql_exception_handler(), ): - query = select(CheckpointORM).where( - CheckpointORM.thread_id == thread_id - ) + query = select(CheckpointORM).where(CheckpointORM.thread_id == thread_id) if checkpoint_ns is not None: query = query.where(CheckpointORM.checkpoint_ns == checkpoint_ns) diff --git a/agentex/src/domain/services/agent_acp_service.py b/agentex/src/domain/services/agent_acp_service.py index ce214b4b..860e4b62 100644 --- a/agentex/src/domain/services/agent_acp_service.py +++ b/agentex/src/domain/services/agent_acp_service.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from src.adapters.http.adapter_httpx import DHttpxGateway -from src.domain.entities.agents import AgentEntity +from src.domain.entities.agents import ACPType, AgentEntity from src.domain.entities.agents_rpc import ( AgentRPCMethod, CancelTaskParams, @@ -39,6 +39,7 @@ from src.domain.mixins.task_messages.task_message_mixin import TaskMessageMixin from src.domain.repositories.agent_api_key_repository import DAgentAPIKeyRepository from src.domain.repositories.agent_repository import DAgentRepository +from src.domain.services.agent_protocol_gateway import AgentProtocolGateway from src.utils.logging import ctx_var_request_id, make_logger logger = make_logger(__name__) @@ -107,7 +108,7 @@ def filter_request_headers(headers: dict[str, str] | None) -> dict[str, str]: } -class AgentACPService(TaskMessageMixin): +class AgentACPService(AgentProtocolGateway, TaskMessageMixin): """ Client service for communicating with downstream ACP servers. Handles JSON-RPC 2.0 communication with agent ACP servers. @@ -279,7 +280,7 @@ async def create_task( self, agent: AgentEntity, task: TaskEntity, - acp_url: str, + service_url: str, params: dict[str, Any] | None = None, ) -> dict[str, Any]: """Create a new task""" @@ -290,7 +291,7 @@ async def create_task( ) headers = await self.get_headers(agent) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.TASK_CREATE, params=params, request_id=f"{AgentRPCMethod.TASK_CREATE}-{task.id}", # Use create-specific request ID @@ -302,7 +303,7 @@ async def send_message( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> TaskMessageContentEntity: """Send a message to a running task""" params = SendMessageParams( @@ -319,7 +320,7 @@ async def send_message( f"Agent {agent.id} already processing message send for task {task.id}" ) result = await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", # Use message-specific request ID @@ -333,7 +334,7 @@ async def send_message_stream( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> AsyncIterator[TaskMessageUpdateEntity]: """Send a message to a running task and stream the response""" params = SendMessageParams( @@ -357,7 +358,7 @@ async def send_message_stream( f"Agent {agent.id} already processing message send for task {task.id}" ) async for chunk in self._call_jsonrpc_stream( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", @@ -366,7 +367,7 @@ async def send_message_stream( yield self._parse_task_message_update(chunk) else: async for chunk in self._call_jsonrpc_stream( - url=acp_url, + url=service_url, method=AgentRPCMethod.MESSAGE_SEND, params=params, request_id=f"{AgentRPCMethod.MESSAGE_SEND}-{task.id}", @@ -375,13 +376,13 @@ async def send_message_stream( yield self._parse_task_message_update(chunk) async def cancel_task( - self, agent: AgentEntity, task: TaskEntity, acp_url: str + self, agent: AgentEntity, task: TaskEntity, service_url: str ) -> dict[str, Any]: """Cancel a running task""" params = CancelTaskParams(agent=agent, task=task) headers = await self.get_headers(agent) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.TASK_CANCEL, params=params, request_id=f"{AgentRPCMethod.TASK_CANCEL}-{task.id}", # Use cancel-specific request ID @@ -393,7 +394,7 @@ async def send_event( agent: AgentEntity, event: EventEntity, task: TaskEntity, - acp_url: str, + service_url: str, request_headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Send an event to a running task""" @@ -418,12 +419,60 @@ async def send_event( headers.update(auth_headers) return await self._call_jsonrpc( - url=acp_url, + url=service_url, method=AgentRPCMethod.EVENT_SEND, params=params, request_id=f"{AgentRPCMethod.EVENT_SEND}-{task.id}", # Use event-specific request ID default_headers=headers, ) + async def check_health( + self, + agent_id: str, + service_url: str, + ) -> bool: + """Check if the agent server is healthy via its /healthz endpoint.""" + try: + response = await self._http_gateway.async_call( + method="GET", + url=f"{service_url}/healthz", + timeout=5, + ) + if response.get("status") != "healthy": + logger.error( + f"Agent {agent_id} returned non-healthy status: {response.get('status')}" + ) + return False + response_agent_id = response.get("agent_id") + if response_agent_id and response_agent_id != agent_id: + logger.error( + f"Agent {agent_id} returned unexpected agent ID: {response_agent_id}" + ) + return False + return True + except Exception as e: + logger.error(f"Failed to check health of agent {agent_id}: {e}") + return False + + # ACP-specific: maps agent type to allowed RPC methods + ACP_ALLOWED_METHODS: dict[ACPType, list[AgentRPCMethod]] = { + ACPType.SYNC: [AgentRPCMethod.MESSAGE_SEND, AgentRPCMethod.TASK_CREATE], + ACPType.AGENTIC: [ + AgentRPCMethod.TASK_CREATE, + AgentRPCMethod.TASK_CANCEL, + AgentRPCMethod.EVENT_SEND, + ], + ACPType.ASYNC: [ + AgentRPCMethod.TASK_CREATE, + AgentRPCMethod.TASK_CANCEL, + AgentRPCMethod.EVENT_SEND, + ], + } + + def get_allowed_methods(self, acp_type: ACPType) -> list[AgentRPCMethod]: + """Return the list of RPC methods allowed for the given ACP type.""" + return self.ACP_ALLOWED_METHODS.get(acp_type, []) + DAgentACPService = Annotated[AgentACPService, Depends(AgentACPService)] +DAgentProtocolGateway = Annotated[AgentProtocolGateway, Depends(AgentACPService)] diff --git a/agentex/src/domain/services/agent_protocol_gateway.py b/agentex/src/domain/services/agent_protocol_gateway.py new file mode 100644 index 00000000..65751aa0 --- /dev/null +++ b/agentex/src/domain/services/agent_protocol_gateway.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any + +from src.domain.entities.agents import AgentEntity +from src.domain.entities.events import EventEntity +from src.domain.entities.task_message_updates import TaskMessageUpdateEntity +from src.domain.entities.task_messages import TaskMessageContentEntity +from src.domain.entities.tasks import TaskEntity + + +class AgentProtocolGateway(ABC): + """Protocol-neutral interface for communicating with downstream agent servers.""" + + @abstractmethod + async def create_task( + self, + agent: AgentEntity, + task: TaskEntity, + service_url: str, + params: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a new task on the agent server.""" + ... + + @abstractmethod + async def send_message( + self, + agent: AgentEntity, + task: TaskEntity, + content: TaskMessageContentEntity, + service_url: str, + ) -> TaskMessageContentEntity: + """Send a message to a running task.""" + ... + + @abstractmethod + async def send_message_stream( + self, + agent: AgentEntity, + task: TaskEntity, + content: TaskMessageContentEntity, + service_url: str, + ) -> AsyncIterator[TaskMessageUpdateEntity]: + """Send a message to a running task and stream the response.""" + ... + + @abstractmethod + async def cancel_task( + self, + agent: AgentEntity, + task: TaskEntity, + service_url: str, + ) -> dict[str, Any]: + """Cancel a running task.""" + ... + + async def send_event( + self, + agent: AgentEntity, + event: EventEntity, + task: TaskEntity, + service_url: str, + request_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + """Send an event to a running task. Not all protocols support events.""" + raise NotImplementedError("This protocol does not support events") + + async def check_health( + self, + agent_id: str, + service_url: str, + ) -> bool: + """Check if the agent server is healthy.""" + raise NotImplementedError("This protocol does not support health checks") diff --git a/agentex/src/domain/services/schedule_service.py b/agentex/src/domain/services/schedule_service.py index 07f08d5b..295c7208 100644 --- a/agentex/src/domain/services/schedule_service.py +++ b/agentex/src/domain/services/schedule_service.py @@ -276,7 +276,10 @@ def _description_to_response( # Decode bytes to string if possible try: import json - workflow_params.append(json.loads(arg.data.decode("utf-8"))) + + workflow_params.append( + json.loads(arg.data.decode("utf-8")) + ) except (json.JSONDecodeError, UnicodeDecodeError): workflow_params.append(str(arg.data)) else: @@ -314,7 +317,9 @@ def _description_to_response( if hasattr(info, "recent_actions") and info.recent_actions: # ScheduleActionResult has started_at (when action started) and scheduled_at (when it was scheduled) last_action = info.recent_actions[-1] - last_action_time = getattr(last_action, "started_at", None) or getattr(last_action, "scheduled_at", None) + last_action_time = getattr(last_action, "started_at", None) or getattr( + last_action, "scheduled_at", None + ) created_at: datetime | None = ( cast(datetime, info.create_time) if hasattr(info, "create_time") and info.create_time diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 9afa3e4b..8102af8b 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -13,7 +13,7 @@ from src.domain.repositories.event_repository import DEventRepository from src.domain.repositories.task_repository import DTaskRepository from src.domain.repositories.task_state_repository import DTaskStateRepository -from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.agent_acp_service import DAgentProtocolGateway from src.utils.ids import orm_id from src.utils.logging import make_logger from src.utils.stream_topics import get_task_event_stream_topic @@ -23,18 +23,18 @@ class AgentTaskService: """ - Service for managing agent tasks and forwarding operations to ACP servers. + Service for managing agent tasks and forwarding operations to agent servers. """ def __init__( self, - acp_client: DAgentACPService, + protocol_gateway: DAgentProtocolGateway, task_state_repository: DTaskStateRepository, task_repository: DTaskRepository, event_repository: DEventRepository, stream_repository: DRedisStreamRepository, ): - self.acp_client = acp_client + self.protocol_gateway = protocol_gateway self.task_state_repository = task_state_repository self.task_repository = task_repository self.event_repository = event_repository @@ -63,13 +63,13 @@ async def create_task( id=orm_id(), name=task_name, status=TaskStatus.RUNNING, - status_reason="Task created, forwarding to ACP server", + status_reason="Task created, forwarding to agent server", params=task_params, ), ) return task_entity - async def create_task_and_forward_to_acp( + async def create_task_and_forward( self, agent: AgentEntity, task_name: str | None = None, @@ -77,11 +77,11 @@ async def create_task_and_forward_to_acp( ) -> TaskEntity: """ Create a new task record in the repository with single agent (maintains existing interface). - Then, forward the task to the ACP server. + Then, forward the task to the agent server. Args: agent: The agent to create the task for - task_params: The parameters for the task to be sent to the ACP server + task_params: The parameters for the task to be sent to the agent server Returns: Task containing the created task info @@ -92,38 +92,38 @@ async def create_task_and_forward_to_acp( if agent.acp_type == ACPType.SYNC: logger.info( - "For sync agents, there are no initialization handlers, skipping ACP call" + "For sync agents, there are no initialization handlers, skipping forwarding" ) return task_entity try: - await self.acp_client.create_task( + await self.protocol_gateway.create_task( agent=agent, task=task_entity, - acp_url=agent.acp_url, + service_url=agent.acp_url, params=task_params, ) return task_entity except Exception as e: - logger.error(f"Error creating task in ACP: {e}") + logger.error(f"Error creating task: {e}") await self.fail_task(task_entity, str(e)) raise e from e - async def forward_task_to_acp( + async def forward_task( self, agent: AgentEntity, task: TaskEntity, task_params: dict[str, Any] | None = None, - acp_url: str | None = None, + service_url: str | None = None, ) -> None: try: - await self.acp_client.create_task( + await self.protocol_gateway.create_task( agent=agent, task=task, - acp_url=acp_url or agent.acp_url, + service_url=service_url or agent.acp_url, params=task_params, ) except Exception as e: - logger.error(f"Error creating task in ACP: {e}") + logger.error(f"Error creating task: {e}") await self.fail_task(task, str(e)) raise e from e @@ -244,14 +244,14 @@ async def send_message( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> TaskMessageContentEntity: """Send a message to a running task""" - return await self.acp_client.send_message( + return await self.protocol_gateway.send_message( agent=agent, task=task, content=content, - acp_url=acp_url, + service_url=service_url, ) async def send_message_stream( @@ -259,49 +259,51 @@ async def send_message_stream( agent: AgentEntity, task: TaskEntity, content: TaskMessageContentEntity, - acp_url: str, + service_url: str, ) -> AsyncIterator[TaskMessageUpdateEntity]: """Send a message to a running task and stream the response""" logger.info(f"TaskService: Sending message stream for task {task.id}") - async for chunk in self.acp_client.send_message_stream( + async for chunk in self.protocol_gateway.send_message_stream( agent=agent, task=task, content=content, - acp_url=acp_url, + service_url=service_url, ): yield chunk async def cancel_task( - self, agent: AgentEntity, task: TaskEntity, acp_url: str + self, agent: AgentEntity, task: TaskEntity, service_url: str ) -> TaskEntity: """Cancel a running task""" - await self.acp_client.cancel_task(agent=agent, task=task, acp_url=acp_url) + await self.protocol_gateway.cancel_task( + agent=agent, task=task, service_url=service_url + ) task = await self.task_repository.get(id=task.id) task.status = TaskStatus.CANCELED task.status_reason = "Task canceled by user" return await self.task_repository.update(task) - async def create_event_and_forward_to_acp( + async def create_event_and_forward( self, agent: AgentEntity, task: TaskEntity, - acp_url: str, + service_url: str, content: TaskMessageContentEntity | None = None, request_headers: dict[str, str] | None = None, ) -> EventEntity: - """Create an event and forward it to the ACP server""" + """Create an event and forward it to the agent server""" event = await self.event_repository.create( id=orm_id(), task_id=task.id, agent_id=agent.id, content=content, ) - await self.acp_client.send_event( + await self.protocol_gateway.send_event( agent=agent, event=event, task=task, - acp_url=acp_url, + service_url=service_url, request_headers=request_headers, ) return event diff --git a/agentex/src/domain/use_cases/agents_acp_use_case.py b/agentex/src/domain/use_cases/agents_acp_use_case.py index 62abb5ad..f038b001 100644 --- a/agentex/src/domain/use_cases/agents_acp_use_case.py +++ b/agentex/src/domain/use_cases/agents_acp_use_case.py @@ -48,7 +48,7 @@ from src.domain.mixins.task_messages.task_message_mixin import TaskMessageMixin from src.domain.repositories.agent_repository import DAgentRepository from src.domain.repositories.deployment_repository import DDeploymentRepository -from src.domain.services.agent_acp_service import DAgentACPService +from src.domain.services.agent_acp_service import DAgentProtocolGateway from src.domain.services.authorization_service import DAuthorizationService from src.domain.services.task_message_service import DTaskMessageService from src.domain.services.task_service import DAgentTaskService @@ -189,14 +189,14 @@ def __init__( self, agent_repository: DAgentRepository, deployment_repository: DDeploymentRepository, - acp_client: DAgentACPService, + protocol_gateway: DAgentProtocolGateway, task_service: DAgentTaskService, task_message_service: DTaskMessageService, authorization_service: DAuthorizationService, ): self.agent_repository = agent_repository self.deployment_repo = deployment_repository - self.acp_client = acp_client + self.protocol_gateway = protocol_gateway self.task_service = task_service self.task_message_service = task_message_service self.authorization_service = authorization_service @@ -309,14 +309,14 @@ async def _get_or_create_task( await self.grant_with_retry(task) return task - async def _resolve_acp_url( + async def _resolve_service_url( self, agent: AgentEntity, - acp_url_override: str | None = None, + service_url_override: str | None = None, ) -> str: - """Resolve the ACP URL for an agent, optionally overriding with a specific URL.""" - if acp_url_override: - return acp_url_override + """Resolve the service URL for an agent, optionally overriding with a specific URL.""" + if service_url_override: + return service_url_override # Resolve through production deployment if available if agent.production_deployment_id: @@ -330,7 +330,7 @@ async def _resolve_acp_url( if agent.acp_url: return agent.acp_url - raise ClientError(f"Agent {agent.id} does not have an ACP URL configured") + raise ClientError(f"Agent {agent.id} does not have a service URL configured") async def handle_rpc_request( self, @@ -342,7 +342,7 @@ async def handle_rpc_request( agent_id: str | None = None, agent_name: str | None = None, request_headers: dict[str, str] | None = None, - acp_url_override: str | None = None, + service_url_override: str | None = None, ) -> ( list[TaskMessageEntity] | AsyncIterator[TaskMessageUpdateEntity] @@ -358,7 +358,7 @@ async def handle_rpc_request( method: JSON-RPC method name params: JSON-RPC parameters request_headers: HTTP headers from the incoming request - acp_url_override: Override ACP URL (for preview deployment routing) + service_url_override: Override service URL (for preview deployment routing) Returns: - list[TaskMessageEntity] for synchronous MESSAGE_SEND @@ -373,7 +373,7 @@ async def handle_rpc_request( if agent.status == AgentStatus.DELETED: raise ClientError(f"Agent {agent_id} is deleted") - acp_url = await self._resolve_acp_url(agent, acp_url_override) + service_url = await self._resolve_service_url(agent, service_url_override) logger.info( f"[handle_rpc_request] Validating RPC method for ACP type: {agent.acp_type} - {method}" @@ -382,20 +382,20 @@ async def handle_rpc_request( # Handle different methods if method == AgentRPCMethod.MESSAGE_SEND: - return await self._handle_message_send(agent, params, acp_url) + return await self._handle_message_send(agent, params, service_url) elif method == AgentRPCMethod.TASK_CREATE: - return await self._handle_task_create(agent, params, acp_url) + return await self._handle_task_create(agent, params, service_url) elif method == AgentRPCMethod.TASK_CANCEL: - return await self._handle_task_cancel(agent, params, acp_url) + return await self._handle_task_cancel(agent, params, service_url) elif method == AgentRPCMethod.EVENT_SEND: return await self._handle_event_send( - agent, params, request_headers, acp_url + agent, params, request_headers, service_url ) else: raise ValueError(f"Unsupported method: {method}") async def _handle_task_create( - self, agent: AgentEntity, params: CreateTaskRequestEntity, acp_url: str + self, agent: AgentEntity, params: CreateTaskRequestEntity, service_url: str ) -> TaskEntity: """ Handle task/create method. @@ -403,27 +403,27 @@ async def _handle_task_create( Args: agent: The agent to create the task for params: Parameters containing task and initial message - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: Task containing the created task info """ - # This creates the task record then forwards the message to the ACP server + # This creates the task record then forwards the message to the agent server task = await self._get_or_create_task( agent=agent, task_name=params.name, task_params=params.params ) if agent.acp_type in [ACPType.AGENTIC, ACPType.ASYNC]: - await self.task_service.forward_task_to_acp( + await self.task_service.forward_task( agent=agent, task=task, task_params=params.params, - acp_url=acp_url, + service_url=service_url, ) return task async def _handle_message_send( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> list[TaskMessageEntity] | AsyncIterator[TaskMessageUpdateEntity]: """ Handle message/send method. @@ -431,18 +431,18 @@ async def _handle_message_send( Args: agent: The agent to send the message to params: Parameters containing task_id and message - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: TaskMessageEntry for synchronous requests or AsyncIterator[TaskMessage] for streaming """ if params.stream: - return self._handle_message_send_stream(agent, params, acp_url) + return self._handle_message_send_stream(agent, params, service_url) else: - return await self._handle_message_send_sync(agent, params, acp_url) + return await self._handle_message_send_sync(agent, params, service_url) async def _handle_message_send_sync( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> list[TaskMessageEntity]: task = await self._get_or_create_task( agent=agent, @@ -494,7 +494,7 @@ async def flush_aggregated_deltas( agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, ): logger.debug( f"[message_send_stream] Received message chunk: {task_message_update}" @@ -569,7 +569,7 @@ async def flush_aggregated_deltas( return new_task_message_entities async def _handle_message_send_stream( - self, agent: AgentEntity, params: SendMessageRequestEntity, acp_url: str + self, agent: AgentEntity, params: SendMessageRequestEntity, service_url: str ) -> AsyncIterator[TaskMessageUpdateEntity]: """Handle streaming message send - yields raw TaskMessage objects""" @@ -647,7 +647,7 @@ async def flush_aggregated_deltas(task_message_index: int) -> TaskMessageEntity: agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, ): logger.debug( f"[message_send_stream] Received message chunk type: {type(task_message_update).__name__}" @@ -769,7 +769,7 @@ async def flush_aggregated_deltas(task_message_index: int) -> TaskMessageEntity: return async def _handle_task_cancel( - self, agent: AgentEntity, params: CancelTaskRequestEntity, acp_url: str + self, agent: AgentEntity, params: CancelTaskRequestEntity, service_url: str ) -> TaskEntity: """ Handle task/cancel method. @@ -777,7 +777,7 @@ async def _handle_task_cancel( Args: agent: The agent to cancel the task for params: Parameters containing task_id - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: Dict containing the cancellation result @@ -790,7 +790,7 @@ async def _handle_task_cancel( return await self.task_service.cancel_task( agent=agent, task=task, - acp_url=acp_url, + service_url=service_url, ) async def _handle_event_send( @@ -798,7 +798,7 @@ async def _handle_event_send( agent: AgentEntity, params: SendEventRequestEntity, request_headers: dict[str, str] | None = None, - acp_url: str = "", + service_url: str = "", ) -> EventEntity: """ Handle event/send method @@ -807,7 +807,7 @@ async def _handle_event_send( agent: The agent to send the event to params: Parameters containing task_id and event data request_headers: HTTP headers from the incoming request - acp_url: Resolved ACP URL to route to + service_url: Resolved service URL to route to Returns: EventEntity for the created and forwarded event @@ -820,11 +820,11 @@ async def _handle_event_send( id=params.task_id, name=params.task_name ) # Create the event in the DB - event_entity = await self.task_service.create_event_and_forward_to_acp( + event_entity = await self.task_service.create_event_and_forward( agent=agent, task=task, content=params.content, - acp_url=acp_url, + service_url=service_url, request_headers=request_headers, ) return event_entity diff --git a/agentex/src/domain/use_cases/agents_use_case.py b/agentex/src/domain/use_cases/agents_use_case.py index fcb1bc9e..a0a535a7 100644 --- a/agentex/src/domain/use_cases/agents_use_case.py +++ b/agentex/src/domain/use_cases/agents_use_case.py @@ -9,7 +9,13 @@ TemporalWorkflowAlreadyExistsError, ) from src.config.environment_variables import EnvironmentVariables -from src.domain.entities.agents import ACPType, AgentEntity, AgentInputType, AgentStatus +from src.domain.entities.agents import ( + ACPType, + AgentEntity, + AgentInputType, + AgentProtocol, + AgentStatus, +) from src.domain.entities.deployments import DeploymentEntity, DeploymentStatus from src.domain.repositories.agent_repository import DAgentRepository from src.domain.repositories.deployment_history_repository import ( @@ -45,6 +51,7 @@ async def register_agent( acp_type: ACPType = ACPType.ASYNC, registration_metadata: dict[str, Any] | None = None, agent_input_type: AgentInputType | None = None, + protocol: AgentProtocol = AgentProtocol.ACP, ) -> AgentEntity: deployment_id = (registration_metadata or {}).get("deployment_id") @@ -67,6 +74,7 @@ async def register_agent( agent.status = AgentStatus.READY agent.status_reason = "Agent registered successfully." agent.acp_type = acp_type + agent.protocol = protocol if agent_input_type: agent.agent_input_type = agent_input_type if registration_metadata: @@ -97,6 +105,7 @@ async def register_agent( agent.status = AgentStatus.READY agent.status_reason = "Agent registered successfully." agent.acp_type = acp_type + agent.protocol = protocol if registration_metadata: existing_metadata = agent.registration_metadata or {} existing_metadata.update(registration_metadata) @@ -132,6 +141,7 @@ async def register_agent( status_reason="Agent registered successfully.", acp_url=acp_url, acp_type=acp_type, + protocol=protocol, registration_metadata=registration_metadata, registered_at=datetime.now(UTC), agent_input_type=agent_input_type, diff --git a/agentex/src/temporal/activities/healthcheck_activities.py b/agentex/src/temporal/activities/healthcheck_activities.py index 4c2a22c2..16f836d6 100644 --- a/agentex/src/temporal/activities/healthcheck_activities.py +++ b/agentex/src/temporal/activities/healthcheck_activities.py @@ -6,11 +6,9 @@ the status checks and the database updates. """ -import json - -import httpx from src.domain.entities.agents import AgentStatus from src.domain.repositories.agent_repository import AgentRepository +from src.domain.services.agent_protocol_gateway import AgentProtocolGateway from src.utils.logging import make_logger from temporalio import activity @@ -37,10 +35,12 @@ class HealthCheckActivities: - Updating agent status in the database """ - def __init__(self, agent_repo: AgentRepository, http_client: httpx.AsyncClient): - """Initialize with session maker and http client.""" + def __init__( + self, agent_repo: AgentRepository, protocol_gateway: AgentProtocolGateway + ): + """Initialize with agent repository and protocol gateway.""" self.agent_repo = agent_repo - self.http_client = http_client + self.protocol_gateway = protocol_gateway @activity.defn(name=CHECK_STATUS_ACTIVITY) async def check_status_activity(self, agent_id: str, acp_url: str) -> bool: @@ -55,36 +55,9 @@ async def check_status_activity(self, agent_id: str, acp_url: str) -> bool: bool: True if the agent is healthy, False otherwise """ logger.info(f"Checking status of agent {agent_id} via {acp_url}") - try: - response = await self.http_client.get(f"{acp_url}/healthz", timeout=5) - if response.status_code != 200: - logger.error( - f"Agent {agent_id} returned non-200 status: {response.status_code}" - ) - return False - try: - parsed_response = response.json() - status = parsed_response.get("status") - if status != "healthy": - logger.error( - f"Agent {agent_id} returned non-healthy status: {status}" - ) - return False - response_agent_id = parsed_response.get("agent_id") - if response_agent_id and response_agent_id != agent_id: - logger.error( - f"Agent {agent_id} returned unexpected agent ID: {response_agent_id}" - ) - return False - except json.JSONDecodeError: - logger.error( - f"Agent {agent_id} returned non-JSON response: {response.text}" - ) - return False - return True - except Exception as e: - logger.error(f"Failed to check status of agent {agent_id}: {e}") - return False + return await self.protocol_gateway.check_health( + agent_id=agent_id, service_url=acp_url + ) @activity.defn(name=UPDATE_AGENT_STATUS_ACTIVITY) async def update_agent_status_activity(self, agent_id: str, status: str) -> None: diff --git a/agentex/src/temporal/run_worker.py b/agentex/src/temporal/run_worker.py index d3665094..9020d4ef 100644 --- a/agentex/src/temporal/run_worker.py +++ b/agentex/src/temporal/run_worker.py @@ -9,19 +9,20 @@ import uuid from concurrent.futures import ThreadPoolExecutor -import httpx from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from src.adapters.http.adapter_httpx import HttpxGateway from src.adapters.temporal.client_factory import TemporalClientFactory from src.config.dependencies import ( database_async_read_only_session_maker, database_async_read_write_engine, database_async_read_write_session_maker, - httpx_client, startup_global_dependencies, ) from src.config.environment_variables import EnvironmentVariables +from src.domain.repositories.agent_api_key_repository import AgentAPIKeyRepository from src.domain.repositories.agent_repository import AgentRepository +from src.domain.services.agent_acp_service import AgentACPService from src.temporal.activities.healthcheck_activities import HealthCheckActivities from src.temporal.workflows.healthcheck_workflow import HealthCheckWorkflow from src.utils.logging import make_logger @@ -122,7 +123,8 @@ async def run_worker( def create_health_check_worker( - agent_repo: AgentRepository, http_client: httpx.AsyncClient + agent_repo: AgentRepository, + agent_api_key_repo: AgentAPIKeyRepository, ) -> asyncio.Task: """ Create a Health Check worker. @@ -133,10 +135,19 @@ def create_health_check_worker( logger.info("Starting Temporal Health Check Worker") logger.info(f"Task queue: {task_queue}") + # Construct protocol gateway + environment_variables = EnvironmentVariables.refresh() + http_gateway = HttpxGateway(environment_variables=environment_variables) + protocol_gateway = AgentACPService( + agent_repository=agent_repo, + agent_api_key_repository=agent_api_key_repo, + http_gateway=http_gateway, + ) + # Create activities instance with dependencies health_check_activities = HealthCheckActivities( agent_repo=agent_repo, - http_client=httpx_client(), + protocol_gateway=protocol_gateway, ) # Extract activity methods @@ -169,9 +180,12 @@ async def main() -> None: session_maker = database_async_read_write_session_maker(engine) read_only_session_maker = database_async_read_only_session_maker(engine) agent_repo = AgentRepository(session_maker, read_only_session_maker) + agent_api_key_repo = AgentAPIKeyRepository( + session_maker, read_only_session_maker + ) health_check_worker_task = create_health_check_worker( agent_repo=agent_repo, - http_client=httpx_client(), + agent_api_key_repo=agent_api_key_repo, ) # Wait for the worker to complete await health_check_worker_task diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index 30b16ca2..9f583c74 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -44,7 +44,7 @@ def create_task_service( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, stream_repository=redis_stream_repository, ) diff --git a/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py index 062e8c86..4a4fdddd 100644 --- a/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py +++ b/agentex/tests/integration/api/checkpoints/test_checkpoint_repository.py @@ -640,7 +640,12 @@ async def test_null_blob_stored_correctly(self, isolated_repositories): repo = isolated_repositories["checkpoint_repository"] blobs = [ - {"channel": "empty_channel", "version": "v1", "type": "empty", "blob": None}, + { + "channel": "empty_channel", + "version": "v1", + "type": "empty", + "blob": None, + }, ] checkpoint = { "id": "cp-1", diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index ade4bdc4..7b338a4d 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -259,6 +259,7 @@ async def __aenter__(self): from src.domain.repositories.agent_task_tracker_repository import ( AgentTaskTrackerRepository, ) + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.deployment_history_repository import ( DeploymentHistoryRepository, ) @@ -266,7 +267,6 @@ async def __aenter__(self): from src.domain.repositories.span_repository import SpanRepository from src.domain.repositories.task_message_repository import TaskMessageRepository from src.domain.repositories.task_repository import TaskRepository - from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.task_state_repository import TaskStateRepository # Create Redis repository with mock environment variables @@ -370,6 +370,7 @@ async def isolated_integration_app( from src.domain.use_cases.agent_api_keys_use_case import AgentAPIKeysUseCase from src.domain.use_cases.agent_task_tracker_use_case import AgentTaskTrackerUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase + from src.domain.use_cases.checkpoints_use_case import CheckpointsUseCase from src.domain.use_cases.deployment_history_use_case import ( DeploymentHistoryUseCase, ) @@ -377,7 +378,6 @@ async def isolated_integration_app( from src.domain.use_cases.messages_use_case import MessagesUseCase from src.domain.use_cases.spans_use_case import SpanUseCase from src.domain.use_cases.states_use_case import StatesUseCase - from src.domain.use_cases.checkpoints_use_case import CheckpointsUseCase from src.domain.use_cases.tasks_use_case import TasksUseCase # Create use case factory functions with isolated repositories @@ -436,7 +436,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], @@ -467,12 +467,12 @@ def create_messages_use_case(): DDatabaseAsyncReadWriteSessionMaker, DMongoDBDatabase, ) - from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.agent_api_key_repository import AgentAPIKeyRepository from src.domain.repositories.agent_repository import AgentRepository from src.domain.repositories.agent_task_tracker_repository import ( AgentTaskTrackerRepository, ) + from src.domain.repositories.checkpoint_repository import CheckpointRepository from src.domain.repositories.deployment_history_repository import ( DeploymentHistoryRepository, ) diff --git a/agentex/tests/integration/test_task_stream.py b/agentex/tests/integration/test_task_stream.py index 14c85807..a91982e8 100644 --- a/agentex/tests/integration/test_task_stream.py +++ b/agentex/tests/integration/test_task_stream.py @@ -71,7 +71,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], @@ -98,7 +98,7 @@ async def send_message(self, *args, **kwargs): pass task_service = AgentTaskService( - acp_client=MockAgentACPService(), + protocol_gateway=MockAgentACPService(), task_state_repository=isolated_repositories["task_state_repository"], task_repository=isolated_repositories["task_repository"], event_repository=isolated_repositories["event_repository"], diff --git a/agentex/tests/unit/services/test_agent_acp_service.py b/agentex/tests/unit/services/test_agent_acp_service.py index 8074036a..13e8ecc0 100644 --- a/agentex/tests/unit/services/test_agent_acp_service.py +++ b/agentex/tests/unit/services/test_agent_acp_service.py @@ -150,7 +150,7 @@ async def test_create_task_success( result = await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", params={"test_param": "value"}, ) @@ -208,7 +208,7 @@ async def test_send_message_success( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -264,7 +264,7 @@ async def test_send_message_success_data( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -320,7 +320,7 @@ async def test_send_message_success_tool_request( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -382,7 +382,7 @@ async def test_send_message_success_tool_response( agent=sample_agent, task=sample_task, content=sample_text_content, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -426,7 +426,7 @@ async def test_cancel_task_success( result = await agent_acp_service.cancel_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -468,7 +468,7 @@ async def test_send_event_success( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -522,7 +522,7 @@ async def test_send_event_with_request_headers( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", request_headers=request_headers, ) @@ -581,7 +581,7 @@ async def test_send_event_without_request_headers( agent=sample_agent, event=sample_event, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) # Then @@ -629,7 +629,7 @@ async def test_jsonrpc_error_handling( await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) assert "RPC error" in str(exc_info.value) @@ -655,7 +655,7 @@ async def test_http_gateway_error_handling( await agent_acp_service.create_task( agent=sample_agent, task=sample_task, - acp_url="http://test-acp.example.com", + service_url="http://test-acp.example.com", ) assert "Connection timeout" in str(exc_info.value) diff --git a/agentex/tests/unit/services/test_task_service.py b/agentex/tests/unit/services/test_task_service.py index eb096eb1..a8968923 100644 --- a/agentex/tests/unit/services/test_task_service.py +++ b/agentex/tests/unit/services/test_task_service.py @@ -34,8 +34,8 @@ async def create_or_get_agent(agent_repository, agent): @pytest.fixture -def mock_acp_client(): - """Mock ACP client for testing service interactions (external dependency)""" +def mock_protocol_gateway(): + """Mock protocol gateway for testing service interactions (external dependency)""" mock = AsyncMock() mock.create_task = AsyncMock() mock.send_message = AsyncMock() @@ -71,15 +71,15 @@ def event_repository(postgres_session_maker): @pytest.fixture def task_service( - mock_acp_client, + mock_protocol_gateway, task_repository, task_state_repository, event_repository, redis_stream_repository, ): - """Create TaskService instance with real repositories and mocked ACP client""" + """Create TaskService instance with real repositories and mocked protocol gateway""" return AgentTaskService( - acp_client=mock_acp_client, + protocol_gateway=mock_protocol_gateway, task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, @@ -154,7 +154,7 @@ async def test_create_task_success( assert result.id is not None assert result.name == "integration-test" assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_without_name( self, task_service, task_repository, agent_repository, sample_agent @@ -171,7 +171,7 @@ async def test_create_task_without_name( assert result.id is not None assert result.name is None assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_with_params( self, task_service, agent_repository, sample_agent @@ -197,7 +197,7 @@ async def test_create_task_with_params( assert result.name == "task-with-params" assert result.params == task_params assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" async def test_create_task_with_params_retrieval( self, task_service, agent_repository, sample_agent @@ -248,10 +248,10 @@ async def test_create_task_with_null_params( retrieved_task = await task_service.get_task(id=result.id) assert retrieved_task.params is None - async def test_create_task_and_forward_to_acp_success( + async def test_create_task_and_forward_success( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -260,10 +260,10 @@ async def test_create_task_and_forward_to_acp_success( # Given - Persist the agent first to satisfy foreign key constraints await create_or_get_agent(agent_repository, sample_agent) task_params = {"param1": "value1", "param2": "value2"} - mock_acp_client.create_task.return_value = None + mock_protocol_gateway.create_task.return_value = None # When - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sample_agent, task_name="forwarded-task", task_params=task_params ) @@ -273,20 +273,20 @@ async def test_create_task_and_forward_to_acp_success( assert result.name == "forwarded-task" assert result.params == task_params # Verify params are stored in the task assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" - # Verify ACP client was called with correct parameters - mock_acp_client.create_task.assert_called_once_with( + # Verify protocol gateway was called with correct parameters + mock_protocol_gateway.create_task.assert_called_once_with( agent=sample_agent, task=result, # Use the actual created task - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, params=task_params, ) async def test_create_task_and_forward_sync_agent_skips_acp( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_sync_agent, @@ -296,7 +296,7 @@ async def test_create_task_and_forward_sync_agent_skips_acp( await create_or_get_agent(agent_repository, sample_sync_agent) # When - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sample_sync_agent, task_name="sync-task" ) @@ -305,15 +305,15 @@ async def test_create_task_and_forward_sync_agent_skips_acp( assert result.id is not None assert result.name == "sync-task" assert result.status == TaskStatus.RUNNING - assert result.status_reason == "Task created, forwarding to ACP server" + assert result.status_reason == "Task created, forwarding to agent server" # Verify ACP client was NOT called for sync agents - mock_acp_client.create_task.assert_not_called() + mock_protocol_gateway.create_task.assert_not_called() async def test_create_task_and_forward_acp_error_handling( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -322,16 +322,16 @@ async def test_create_task_and_forward_acp_error_handling( # Given - Persist the agent first to satisfy foreign key constraints await create_or_get_agent(agent_repository, sample_agent) acp_error = Exception("ACP server unavailable") - mock_acp_client.create_task.side_effect = acp_error + mock_protocol_gateway.create_task.side_effect = acp_error # When / Then with pytest.raises(Exception) as exc_info: - await task_service.create_task_and_forward_to_acp(agent=sample_agent) + await task_service.create_task_and_forward(agent=sample_agent) assert str(exc_info.value) == "ACP server unavailable" # Verify task was created but then marked as failed due to ACP error - mock_acp_client.create_task.assert_called_once() + mock_protocol_gateway.create_task.assert_called_once() async def test_fail_task( self, task_service, task_repository, agent_repository, sample_agent @@ -519,7 +519,7 @@ async def test_list_tasks( async def test_send_message( self, task_service, - mock_acp_client, + mock_protocol_gateway, sample_agent, sample_task, sample_message_content, @@ -527,30 +527,30 @@ async def test_send_message( """Test sending message to task""" # Given acp_url = "http://test-acp.example.com" - mock_acp_client.send_message.return_value = sample_message_content + mock_protocol_gateway.send_message.return_value = sample_message_content # When result = await task_service.send_message( agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ) # Then assert result == sample_message_content - mock_acp_client.send_message.assert_called_once_with( + mock_protocol_gateway.send_message.assert_called_once_with( agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ) # async def test_send_message_stream( self, task_service, - mock_acp_client, + mock_protocol_gateway, sample_agent, sample_task, sample_message_content, @@ -574,7 +574,7 @@ async def mock_stream_method(*args, **kwargs): yield update # Replace the entire method with our async generator - mock_acp_client.send_message_stream = mock_stream_method + mock_protocol_gateway.send_message_stream = mock_stream_method # When updates = [] @@ -582,7 +582,7 @@ async def mock_stream_method(*args, **kwargs): agent=sample_agent, task=sample_task, content=sample_message_content, - acp_url=acp_url, + service_url=acp_url, ): updates.append(update) @@ -601,7 +601,7 @@ async def mock_stream_method(*args, **kwargs): async def test_cancel_task( self, task_service, - mock_acp_client, + mock_protocol_gateway, task_repository, agent_repository, sample_agent, @@ -616,7 +616,7 @@ async def test_cancel_task( # When result = await task_service.cancel_task( - agent=sample_agent, task=created_task, acp_url=acp_url + agent=sample_agent, task=created_task, service_url=acp_url ) # Then @@ -624,8 +624,8 @@ async def test_cancel_task( assert result.status_reason == "Task canceled by user" # Verify ACP client was called to cancel the task - mock_acp_client.cancel_task.assert_called_once_with( - agent=sample_agent, task=created_task, acp_url=acp_url + mock_protocol_gateway.cancel_task.assert_called_once_with( + agent=sample_agent, task=created_task, service_url=acp_url ) # Verify task status is updated in the database @@ -633,10 +633,10 @@ async def test_cancel_task( assert updated_task.status == TaskStatus.CANCELED assert updated_task.status_reason == "Task canceled by user" - async def test_create_event_and_forward_to_acp( + async def test_create_event_and_forward( self, task_service, - mock_acp_client, + mock_protocol_gateway, event_repository, agent_repository, sample_agent, @@ -651,10 +651,10 @@ async def test_create_event_and_forward_to_acp( acp_url = "http://test-acp.example.com" # When - result = await task_service.create_event_and_forward_to_acp( + result = await task_service.create_event_and_forward( agent=sample_agent, task=created_task, - acp_url=acp_url, + service_url=acp_url, content=sample_message_content, ) @@ -669,18 +669,18 @@ async def test_create_event_and_forward_to_acp( assert result.content.content == sample_message_content.content # Verify ACP client was called to send the event - mock_acp_client.send_event.assert_called_once_with( + mock_protocol_gateway.send_event.assert_called_once_with( agent=sample_agent, event=result, # Use the actual created event task=created_task, - acp_url=acp_url, + service_url=acp_url, request_headers=None, ) - async def test_create_event_and_forward_to_acp_with_headers( + async def test_create_event_and_forward_with_headers( self, task_service, - mock_acp_client, + mock_protocol_gateway, event_repository, agent_repository, sample_agent, @@ -699,10 +699,10 @@ async def test_create_event_and_forward_to_acp_with_headers( } # When - result = await task_service.create_event_and_forward_to_acp( + result = await task_service.create_event_and_forward( agent=sample_agent, task=created_task, - acp_url=acp_url, + service_url=acp_url, content=sample_message_content, request_headers=request_headers, ) @@ -715,11 +715,11 @@ async def test_create_event_and_forward_to_acp_with_headers( assert result.content is not None # Verify ACP client was called with request_headers - mock_acp_client.send_event.assert_called_once_with( + mock_protocol_gateway.send_event.assert_called_once_with( agent=sample_agent, event=result, task=created_task, - acp_url=acp_url, + service_url=acp_url, request_headers=request_headers, ) diff --git a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py index 60914adc..4e52343b 100644 --- a/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py +++ b/agentex/tests/unit/use_cases/test_acp_type_backwards_compatibility_use_case.py @@ -21,7 +21,6 @@ from src.domain.repositories.task_state_repository import TaskStateRepository from src.domain.services.agent_acp_service import AgentACPService from src.domain.services.task_service import AgentTaskService -from src.domain.use_cases.agents_acp_use_case import AgentsACPUseCase from src.domain.use_cases.agents_use_case import AgentsUseCase @@ -55,29 +54,25 @@ async def test_agentic_agent_in_allowed_methods_dictionary(self): assert ACPType.ASYNC in ACP_TYPE_TO_ALLOWED_RPC_METHODS assert ACPType.SYNC in ACP_TYPE_TO_ALLOWED_RPC_METHODS + @pytest.mark.asyncio @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_task_create(self): - """Verify AGENTIC agents can use task/create method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.TASK_CREATE - ) + """Verify AGENTIC agents can use task/create method via gateway""" + # Use the static ACP_TYPE_TO_ALLOWED_RPC_METHODS for backward compat validation + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.TASK_CREATE in allowed @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_event_send(self): - """Verify AGENTIC agents can use event/send method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.EVENT_SEND - ) + """Verify AGENTIC agents can use event/send method via gateway""" + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.EVENT_SEND in allowed @pytest.mark.asyncio async def test_validate_rpc_method_accepts_agentic_for_task_cancel(self): - """Verify AGENTIC agents can use task/cancel method""" - # Should not raise an error - AgentsACPUseCase._validate_rpc_method_for_acp_type( - ACPType.AGENTIC, AgentRPCMethod.TASK_CANCEL - ) + """Verify AGENTIC agents can use task/cancel method via gateway""" + allowed = ACP_TYPE_TO_ALLOWED_RPC_METHODS[ACPType.AGENTIC] + assert AgentRPCMethod.TASK_CANCEL in allowed @pytest.mark.asyncio async def test_agentic_agent_forwards_task_to_acp(self): @@ -90,7 +85,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -117,7 +112,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): acp_client.create_task.return_value = None # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=agentic_agent, task_name="test-task", task_params={"test": "params"}, @@ -127,7 +122,7 @@ async def test_agentic_agent_forwards_task_to_acp(self): acp_client.create_task.assert_called_once_with( agent=agentic_agent, task=task, - acp_url=agentic_agent.acp_url, + service_url=agentic_agent.acp_url, params={"test": "params"}, ) assert result == task @@ -143,7 +138,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -169,7 +164,7 @@ async def test_sync_agent_does_not_forward_task_to_acp(self): task_repo.create.return_value = task # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=sync_agent, task_name="test-task", task_params={"test": "params"}, @@ -190,7 +185,7 @@ async def test_async_agent_forwards_task_to_acp(self): stream_repo = AsyncMock() task_service = AgentTaskService( - acp_client=acp_client, + protocol_gateway=acp_client, task_state_repository=task_state_repo, task_repository=task_repo, event_repository=event_repo, @@ -217,7 +212,7 @@ async def test_async_agent_forwards_task_to_acp(self): acp_client.create_task.return_value = None # Execute - result = await task_service.create_task_and_forward_to_acp( + result = await task_service.create_task_and_forward( agent=async_agent, task_name="test-task", task_params={"test": "params"}, @@ -227,7 +222,7 @@ async def test_async_agent_forwards_task_to_acp(self): acp_client.create_task.assert_called_once_with( agent=async_agent, task=task, - acp_url=async_agent.acp_url, + service_url=async_agent.acp_url, params={"test": "params"}, ) assert result == task diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index 08e31bfc..1bf48c86 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -125,7 +125,7 @@ def task_service( task_repository=task_repository, task_state_repository=task_state_repository, event_repository=event_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, stream_repository=redis_stream_repository, ) @@ -165,7 +165,7 @@ def agents_acp_use_case( return AgentsACPUseCase( agent_repository=agent_repository, deployment_repository=deployment_repository, - acp_client=agent_acp_service, + protocol_gateway=agent_acp_service, task_service=task_service, task_message_service=task_message_service, authorization_service=authorization_service, @@ -290,7 +290,7 @@ def create_mock_stream(*args, **kwargs): await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) async def test_handle_message_send_sync_success( @@ -346,7 +346,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert isinstance(result, list) @@ -413,7 +413,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert isinstance(result, list) @@ -513,7 +513,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -639,7 +639,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -795,7 +795,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -880,7 +880,7 @@ def create_mock_stream_error(*args, **kwargs): await agents_acp_use_case._handle_task_create( agent=sample_agent, params=create_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert "ACP server connection failed" in str(exc_info.value) @@ -919,7 +919,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_create( agent=sample_agent, params=create_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -964,7 +964,7 @@ def create_mock_stream_error(*args, **kwargs): await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) assert "ACP server connection failed" in str(exc_info.value) @@ -1034,7 +1034,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -1319,7 +1319,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -1433,7 +1433,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_cancel( agent=sample_agent, params=cancel_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1479,7 +1479,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_task_cancel( agent=sample_agent, params=cancel_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1528,7 +1528,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1603,7 +1603,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1701,7 +1701,7 @@ async def mock_async_call(*args, **kwargs): agent=sample_agent, params=event_request, request_headers=request_headers, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1761,7 +1761,7 @@ async def mock_async_call(*args, **kwargs): result = await agents_acp_use_case._handle_event_send( agent=sample_agent, params=event_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1852,7 +1852,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -1940,7 +1940,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -2043,7 +2043,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then @@ -2140,7 +2140,7 @@ def create_mock_stream(*args, **kwargs): async for update in agents_acp_use_case._handle_message_send_stream( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ): updates.append(update) @@ -2250,7 +2250,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then - params should be updated @@ -2318,7 +2318,7 @@ def create_mock_stream(*args, **kwargs): result = await agents_acp_use_case._handle_message_send_sync( agent=sample_agent, params=send_request, - acp_url=sample_agent.acp_url, + service_url=sample_agent.acp_url, ) # Then - task should not be updated