diff --git a/agentex/database/migrations/alembic/versions/2026_05_04_1111_add_tasks_metadata_gin_index_e9c4ff9e6542.py b/agentex/database/migrations/alembic/versions/2026_05_04_1111_add_tasks_metadata_gin_index_e9c4ff9e6542.py new file mode 100644 index 00000000..54a74aa0 --- /dev/null +++ b/agentex/database/migrations/alembic/versions/2026_05_04_1111_add_tasks_metadata_gin_index_e9c4ff9e6542.py @@ -0,0 +1,27 @@ +"""add_tasks_metadata_gin_index + +Revision ID: e9c4ff9e6542 +Revises: 9ff3ee32c81b +Create Date: 2026-05-04 11:11:35.017451 + +""" +from typing import Sequence, Union + +from alembic import op + + +revision: str = 'e9c4ff9e6542' +down_revision: Union[str, None] = '9ff3ee32c81b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute( + "CREATE INDEX IF NOT EXISTS ix_tasks_metadata_gin " + "ON tasks USING GIN (task_metadata jsonb_path_ops)" + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS ix_tasks_metadata_gin") diff --git a/agentex/database/migrations/migration_history.txt b/agentex/database/migrations/migration_history.txt index 921ae998..25e97ddb 100644 --- a/agentex/database/migrations/migration_history.txt +++ b/agentex/database/migrations/migration_history.txt @@ -1,4 +1,5 @@ -57c5ed4f59ae -> 9ff3ee32c81b (head), uppercase deployment status +9ff3ee32c81b -> e9c4ff9e6542 (head), add_tasks_metadata_gin_index +57c5ed4f59ae -> 9ff3ee32c81b, uppercase deployment status enum labels 4a9b7787ccd7 -> 57c5ed4f59ae, add_task_id_to_spans d1a6cde41b3f -> 4a9b7787ccd7, deployments d024851e790c -> d1a6cde41b3f, add_langgraph_checkpoint_tables diff --git a/agentex/src/api/routes/tasks.py b/agentex/src/api/routes/tasks.py index 076402ba..dc557ed3 100644 --- a/agentex/src/api/routes/tasks.py +++ b/agentex/src/api/routes/tasks.py @@ -1,6 +1,7 @@ +import json from typing import Annotated, Any -from fastapi import APIRouter, Query +from fastapi import APIRouter, HTTPException, Query from fastapi.responses import StreamingResponse from src.adapters.temporal.adapter_temporal import DTemporalAdapter @@ -14,9 +15,11 @@ Task, TaskRelationships, TaskResponse, + TaskStatus, TaskStatusReasonRequest, UpdateTaskRequest, ) +from src.domain.entities.tasks import TaskStatus as DomainTaskStatus from src.domain.services.authorization_service import DAuthorizationService from src.domain.use_cases.streams_use_case import DStreamsUseCase from src.domain.use_cases.tasks_use_case import DTaskUseCase @@ -79,6 +82,19 @@ async def list_tasks( authorized_ids: DAuthorizedResourceIds(AgentexResourceType.task), agent_id: str | None = None, agent_name: str | None = None, + status: Annotated[ + TaskStatus | None, + Query(description="Filter tasks by status (e.g. RUNNING, COMPLETED)."), + ] = None, + task_metadata: Annotated[ + str | None, + Query( + description=( + "JSON-encoded object used to filter tasks via JSONB containment. " + 'Example: {"created_by_user_id": "abc-123"}.' + ) + ), + ] = None, limit: int = 50, page_number: int = 1, order_by: str | None = None, @@ -86,11 +102,41 @@ async def list_tasks( relationships: Annotated[list[TaskRelationships], Query()] = None, ): """List all tasks.""" + parsed_metadata: dict | None = None + if task_metadata is not None: + try: + parsed_metadata = json.loads(task_metadata) + except json.JSONDecodeError as exc: + raise HTTPException( + status_code=400, + detail=f"Invalid JSON in task_metadata query parameter: {exc.msg}", + ) from exc + if not isinstance(parsed_metadata, dict): + raise HTTPException( + status_code=400, + detail="task_metadata must decode to a JSON object.", + ) + if not parsed_metadata: + raise HTTPException( + status_code=400, + detail="task_metadata cannot be empty; omit the parameter to skip filtering.", + ) + + if status == TaskStatus.DELETED: + # list_tasks always excludes DELETED rows at the repository layer, so + # filtering on it would silently return an empty list. Reject explicitly. + raise HTTPException( + status_code=400, + detail="Cannot filter by DELETED status; deleted tasks are not returned by list_tasks.", + ) + domain_status = DomainTaskStatus(status.value) if status is not None else None task_entities = await task_use_case.list_tasks( id=authorized_ids, agent_id=agent_id, agent_name=agent_name, + status=domain_status, + task_metadata=parsed_metadata, limit=limit, page_number=page_number, order_by=order_by, diff --git a/agentex/src/api/schemas/agents_rpc.py b/agentex/src/api/schemas/agents_rpc.py index e3571a7b..884d55d9 100644 --- a/agentex/src/api/schemas/agents_rpc.py +++ b/agentex/src/api/schemas/agents_rpc.py @@ -26,6 +26,14 @@ class CreateTaskRequest(BaseModel): params: dict[str, Any] | None = Field( None, description="The parameters for the task" ) + task_metadata: dict[str, Any] | None = Field( + None, + description=( + "Caller-provided metadata to persist on the task row. Only applied at " + "task creation; ignored if a task with this name already exists. " + "Forwarded to the agent inside the ACP payload for backward compatibility." + ), + ) class CancelTaskRequest(BaseModel): diff --git a/agentex/src/domain/entities/agents_rpc.py b/agentex/src/domain/entities/agents_rpc.py index 571ba104..864c9469 100644 --- a/agentex/src/domain/entities/agents_rpc.py +++ b/agentex/src/domain/entities/agents_rpc.py @@ -105,6 +105,14 @@ class CreateTaskRequestEntity(BaseModel): params: dict[str, Any] | None = Field( None, description="The parameters for the task" ) + task_metadata: dict[str, Any] | None = Field( + None, + description=( + "Caller-provided metadata to persist on the task row. Only applied at " + "task creation; ignored if a task with this name already exists. " + "Forwarded to the agent inside the ACP payload for backward compatibility." + ), + ) class CancelTaskRequestEntity(BaseModel): @@ -184,6 +192,7 @@ def from_api_request(cls, request: AgentRPCRequest) -> Self: params = CreateTaskRequestEntity( name=request.params.root.name, params=request.params.root.params, + task_metadata=request.params.root.task_metadata, ) elif request.method == AgentRPCMethod.TASK_CANCEL and isinstance( request.params.root, CancelTaskRequest diff --git a/agentex/src/domain/repositories/task_repository.py b/agentex/src/domain/repositories/task_repository.py index 884f4be1..fd67ea97 100644 --- a/agentex/src/domain/repositories/task_repository.py +++ b/agentex/src/domain/repositories/task_repository.py @@ -49,6 +49,7 @@ async def list_with_join( | None = None, agent_id: str | None = None, agent_name: str | None = None, + task_metadata: dict | None = None, order_by: str | None = None, order_direction: Literal["asc", "desc"] = "desc", limit: int | None = None, @@ -62,6 +63,8 @@ async def list_with_join( - task_filters: Filters on the task table itself - agent_id: Filter tasks by agent ID using the join table - agent_name: Filter tasks by agent name + - task_metadata: JSONB containment filter on `task_metadata`. Returns + tasks whose metadata is a JSON superset of the provided dict. - order_by: Column to order by - order_direction: Direction to order by - limit: Maximum number of results to return @@ -78,6 +81,8 @@ async def list_with_join( ).where(AgentORM.name == agent_name) if agent_id: query = query.where(TaskAgentORM.agent_id == agent_id) + if task_metadata is not None: + query = query.where(TaskORM.task_metadata.contains(task_metadata)) query = query.where(TaskORM.status != TaskStatus.DELETED) return await self.list( filters=task_filters, diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index 9afa3e4b..013c6903 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -45,6 +45,7 @@ async def create_task( agent: AgentEntity, task_name: str | None = None, task_params: dict[str, Any] | None = None, + task_metadata: dict[str, Any] | None = None, ) -> TaskEntity: """ Create a new task record in the repository with single agent (maintains existing interface). @@ -53,6 +54,8 @@ async def create_task( agent: The agent to create the task for task_name: The name of the task to be created task_params: The parameters for the task + task_metadata: Caller-provided metadata to persist on the task row. + Not forwarded to the agent. Returns: Task containing the created task info """ @@ -65,6 +68,7 @@ async def create_task( status=TaskStatus.RUNNING, status_reason="Task created, forwarding to ACP server", params=task_params, + task_metadata=task_metadata, ), ) return task_entity @@ -220,6 +224,8 @@ async def list_tasks( id: str | list[str] | None = None, agent_id: str | None = None, agent_name: str | None = None, + status: TaskStatus | list[TaskStatus] | None = None, + task_metadata: dict | None = None, order_by: str | None = None, order_direction: str = "desc", relationships: list[TaskRelationships] | None = None, @@ -227,11 +233,17 @@ async def list_tasks( """ List all tasks from the repository. """ + task_filters: dict = {} + if id is not None: + task_filters["id"] = id + if status is not None: + task_filters["status"] = status return await self.task_repository.list_with_join( - task_filters={"id": id} if id is not None else None, + task_filters=task_filters or None, agent_id=agent_id, agent_name=agent_name, + task_metadata=task_metadata, order_by=order_by, order_direction=order_direction, limit=limit, diff --git a/agentex/src/domain/use_cases/agents_acp_use_case.py b/agentex/src/domain/use_cases/agents_acp_use_case.py index 62abb5ad..fc727d88 100644 --- a/agentex/src/domain/use_cases/agents_acp_use_case.py +++ b/agentex/src/domain/use_cases/agents_acp_use_case.py @@ -267,6 +267,7 @@ async def _get_or_create_task( task_id: str | None = None, task_name: str | None = None, task_params: dict[str, Any] | None = None, + task_metadata: dict[str, Any] | None = None, ) -> TaskEntity: """Return the existing task if *task_id* is provided, otherwise create a new one. @@ -303,7 +304,10 @@ async def _get_or_create_task( # Create a new task if it doesn't exist task = await self.task_service.create_task( - agent=agent, task_name=task_name, task_params=task_params + agent=agent, + task_name=task_name, + task_params=task_params, + task_metadata=task_metadata, ) logger.info(f"[agent_id={agent.id}] Created task {task.id}") await self.grant_with_retry(task) @@ -408,9 +412,13 @@ async def _handle_task_create( Returns: Task containing the created task info """ - # This creates the task record then forwards the message to the ACP server + # This creates the task record then forwards the message to the ACP server. + # task_metadata is persisted on the task row but never forwarded to the agent. task = await self._get_or_create_task( - agent=agent, task_name=params.name, task_params=params.params + agent=agent, + task_name=params.name, + task_params=params.params, + task_metadata=params.task_metadata, ) if agent.acp_type in [ACPType.AGENTIC, ACPType.ASYNC]: diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index f358cf1b..4ad1a61a 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -69,6 +69,8 @@ async def list_tasks( id: str | list[str] | None = None, agent_id: str | None = None, agent_name: str | None = None, + status: TaskStatus | list[TaskStatus] | None = None, + task_metadata: dict | None = None, order_by: str | None = None, order_direction: str = "desc", relationships: list[TaskRelationships] | None = None, @@ -78,6 +80,8 @@ async def list_tasks( id=id, agent_id=agent_id, agent_name=agent_name, + status=status, + task_metadata=task_metadata, limit=limit, page_number=page_number, order_by=order_by, diff --git a/agentex/tests/integration/api/tasks/test_tasks_api.py b/agentex/tests/integration/api/tasks/test_tasks_api.py index 934e079f..83a5b9f8 100644 --- a/agentex/tests/integration/api/tasks/test_tasks_api.py +++ b/agentex/tests/integration/api/tasks/test_tasks_api.py @@ -340,6 +340,109 @@ async def test_list_tasks_with_both_agent_id_and_agent_name_filter( assert len(tasks) == 1 assert tasks[0]["id"] == target_task.id + async def test_list_tasks_with_task_metadata_filter( + self, isolated_client, isolated_repositories + ): + """list_tasks?task_metadata={...} should return only matching tasks.""" + agent_repo = isolated_repositories["agent_repository"] + agent = AgentEntity( + id=orm_id(), + name="metadata-filter-agent", + description="agent for metadata filter test", + acp_url="http://test-acp:8000", + acp_type=ACPType.SYNC, + ) + await agent_repo.create(agent) + + task_repo = isolated_repositories["task_repository"] + matching = TaskEntity( + id=orm_id(), + name="matching-task", + status=TaskStatus.RUNNING, + task_metadata={"created_by_user_id": "user-a"}, + ) + other = TaskEntity( + id=orm_id(), + name="other-task", + status=TaskStatus.RUNNING, + task_metadata={"created_by_user_id": "user-b"}, + ) + await task_repo.create(agent_id=agent.id, task=matching) + await task_repo.create(agent_id=agent.id, task=other) + + response = await isolated_client.get( + "/tasks", + params={"task_metadata": '{"created_by_user_id": "user-a"}'}, + ) + assert response.status_code == 200 + ids = {t["id"] for t in response.json()} + assert matching.id in ids + assert other.id not in ids + + async def test_list_tasks_rejects_malformed_task_metadata(self, isolated_client): + """Malformed JSON in task_metadata should yield a 400.""" + response = await isolated_client.get( + "/tasks", params={"task_metadata": "not-json"} + ) + assert response.status_code == 400 + + async def test_list_tasks_rejects_empty_task_metadata(self, isolated_client): + """Empty JSON object in task_metadata should yield a 400.""" + response = await isolated_client.get("/tasks", params={"task_metadata": "{}"}) + assert response.status_code == 400 + + async def test_list_tasks_rejects_non_object_task_metadata(self, isolated_client): + """Non-object JSON in task_metadata should yield a 400.""" + response = await isolated_client.get( + "/tasks", params={"task_metadata": '"some-string"'} + ) + assert response.status_code == 400 + + async def test_list_tasks_with_status_filter( + self, isolated_client, isolated_repositories + ): + """list_tasks?status=RUNNING should return only RUNNING tasks.""" + agent_repo = isolated_repositories["agent_repository"] + agent = AgentEntity( + id=orm_id(), + name="status-filter-agent", + description="agent for status filter test", + acp_url="http://test-acp:8000", + acp_type=ACPType.SYNC, + ) + await agent_repo.create(agent) + + task_repo = isolated_repositories["task_repository"] + running = TaskEntity( + id=orm_id(), + name="status-filter-running", + status=TaskStatus.RUNNING, + ) + completed = TaskEntity( + id=orm_id(), + name="status-filter-completed", + status=TaskStatus.COMPLETED, + ) + await task_repo.create(agent_id=agent.id, task=running) + await task_repo.create(agent_id=agent.id, task=completed) + + response = await isolated_client.get("/tasks", params={"status": "RUNNING"}) + assert response.status_code == 200 + ids = {t["id"] for t in response.json()} + assert running.id in ids + assert completed.id not in ids + + async def test_list_tasks_rejects_invalid_status(self, isolated_client): + """Invalid status enum value should yield a 422.""" + response = await isolated_client.get("/tasks", params={"status": "BOGUS"}) + assert response.status_code == 422 + + async def test_list_tasks_rejects_deleted_status(self, isolated_client): + """status=DELETED is contradictory with the always-on DELETED exclusion; + rejecting at the route avoids silently returning an empty list.""" + response = await isolated_client.get("/tasks", params={"status": "DELETED"}) + assert response.status_code == 400 + # async def test_get_task_by_id_returns_correct_task( self, isolated_client, test_task diff --git a/agentex/tests/unit/repositories/test_task_repository.py b/agentex/tests/unit/repositories/test_task_repository.py index eab1b0f1..18834b34 100644 --- a/agentex/tests/unit/repositories/test_task_repository.py +++ b/agentex/tests/unit/repositories/test_task_repository.py @@ -1129,3 +1129,83 @@ async def test_list_with_join(postgres_url): order_direction="asc", ) assert len(all_tasks_result) == 3 # all 3 tasks should be returned + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_list_with_join_filters_by_task_metadata(postgres_url): + """list_with_join should filter rows by JSONB containment on task_metadata.""" + + sqlalchemy_asyncpg_url = postgres_url.replace( + "postgresql+psycopg2://", "postgresql+asyncpg://" + ) + + for attempt in range(10): + try: + engine = create_async_engine(sqlalchemy_asyncpg_url, echo=True) + async with engine.begin() as conn: + await conn.run_sync(BaseORM.metadata.create_all) + await conn.execute(text("SELECT 1")) + break + except Exception as e: + if attempt < 9: + print( + f"Database not ready (attempt {attempt + 1}), retrying... Error: {e}" + ) + await asyncio.sleep(2) + continue + raise + + async_session_maker = async_sessionmaker(engine, expire_on_commit=False) + + task_repo = TaskRepository(async_session_maker, async_session_maker) + agent_repo = AgentRepository(async_session_maker, async_session_maker) + + unique_suffix = orm_id()[:8] + agent = AgentEntity( + id=orm_id(), + name=f"metadata-filter-agent-{unique_suffix}", + description="agent for metadata containment filter test", + docker_image="test/agent:latest", + status=AgentStatus.READY, + acp_url="http://localhost:8000/acp", + acp_type=ACPType.ASYNC, + ) + await agent_repo.create(agent) + + user_a_task = await task_repo.create( + agent.id, + TaskEntity( + id=orm_id(), + name=f"user-a-task-{unique_suffix}", + status=TaskStatus.RUNNING, + task_metadata={"created_by_user_id": "user-a", "other": "field"}, + ), + ) + user_b_task = await task_repo.create( + agent.id, + TaskEntity( + id=orm_id(), + name=f"user-b-task-{unique_suffix}", + status=TaskStatus.RUNNING, + task_metadata={"created_by_user_id": "user-b"}, + ), + ) + no_meta_task = await task_repo.create( + agent.id, + TaskEntity( + id=orm_id(), + name=f"no-meta-task-{unique_suffix}", + status=TaskStatus.RUNNING, + task_metadata=None, + ), + ) + + results = await task_repo.list_with_join( + task_metadata={"created_by_user_id": "user-a"}, + ) + + result_ids = {t.id for t in results} + assert user_a_task.id in result_ids + assert user_b_task.id not in result_ids + assert no_meta_task.id not in result_ids diff --git a/agentex/tests/unit/services/test_agent_acp_service.py b/agentex/tests/unit/services/test_agent_acp_service.py index 8074036a..c354b1b0 100644 --- a/agentex/tests/unit/services/test_agent_acp_service.py +++ b/agentex/tests/unit/services/test_agent_acp_service.py @@ -688,3 +688,197 @@ async def test_parse_task_message_update_invalid_type(self, agent_acp_service): agent_acp_service._parse_task_message_update(invalid_result) assert "Unknown update type" in str(exc_info.value) + + +class _AsyncStreamMock: + """Minimal async iterator for mocking HttpxGateway.stream_call.""" + + def __init__(self, responses): + self.responses = list(responses) + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.responses): + raise StopAsyncIteration + response = self.responses[self.index] + self.index += 1 + return response + + +@pytest.fixture +def task_with_metadata(): + """Task carrying caller-side metadata that is forwarded as-is to the agent.""" + return TaskEntity( + id=str(uuid4()), + name="task-with-meta", + status=TaskStatus.RUNNING, + status_reason="Test", + task_metadata={"created_by_user_id": "user-value"}, + ) + + +@pytest.mark.asyncio +@pytest.mark.unit +class TestACPPayloadForwardsTaskMetadata: + """task_metadata is forwarded to the agent unchanged. + + Pre-existing agents may rely on reading task_metadata that callers set via + PUT /tasks/{id}, so we keep the pass-through behaviour for backward + compatibility. + """ + + async def test_create_task_payload_forwards_metadata( + self, + agent_acp_service, + mock_http_gateway, + agent_repository, + sample_agent, + task_with_metadata, + ): + await create_or_get_agent(agent_repository, sample_agent) + mock_http_gateway.async_call.return_value = { + "jsonrpc": "2.0", + "result": {"status": "created", "task_id": task_with_metadata.id}, + "id": f"AgentRPCMethod.TASK_CREATE-{task_with_metadata.id}", + } + + await agent_acp_service.create_task( + agent=sample_agent, + task=task_with_metadata, + acp_url="http://test-acp.example.com", + ) + + payload = mock_http_gateway.async_call.call_args[1]["payload"] + assert payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-value" + } + + async def test_send_message_payload_forwards_metadata( + self, + agent_acp_service, + mock_http_gateway, + agent_repository, + sample_agent, + task_with_metadata, + sample_text_content, + ): + await create_or_get_agent(agent_repository, sample_agent) + mock_http_gateway.async_call.return_value = { + "jsonrpc": "2.0", + "result": { + "type": "text", + "author": "agent", + "style": "static", + "format": "plain", + "content": "ok", + "attachments": None, + }, + "id": f"AgentRPCMethod.MESSAGE_SEND-{task_with_metadata.id}", + } + + await agent_acp_service.send_message( + agent=sample_agent, + task=task_with_metadata, + content=sample_text_content, + acp_url="http://test-acp.example.com", + ) + + payload = mock_http_gateway.async_call.call_args[1]["payload"] + assert payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-value" + } + + async def test_send_message_stream_payload_forwards_metadata( + self, + agent_acp_service, + mock_http_gateway, + agent_repository, + sample_agent, + task_with_metadata, + sample_text_content, + ): + await create_or_get_agent(agent_repository, sample_agent) + + captured = {} + + def fake_stream_call(*args, **kwargs): + captured["payload"] = kwargs.get("payload") + return _AsyncStreamMock([]) + + mock_http_gateway.stream_call = fake_stream_call + + async for _ in agent_acp_service.send_message_stream( + agent=sample_agent, + task=task_with_metadata, + content=sample_text_content, + acp_url="http://test-acp.example.com", + ): + pass + + payload = captured["payload"] + assert payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-value" + } + + async def test_cancel_task_payload_forwards_metadata( + self, + agent_acp_service, + mock_http_gateway, + agent_repository, + sample_agent, + task_with_metadata, + ): + await create_or_get_agent(agent_repository, sample_agent) + mock_http_gateway.async_call.return_value = { + "jsonrpc": "2.0", + "result": {"status": "cancelled", "task_id": task_with_metadata.id}, + "id": f"AgentRPCMethod.TASK_CANCEL-{task_with_metadata.id}", + } + + await agent_acp_service.cancel_task( + agent=sample_agent, + task=task_with_metadata, + acp_url="http://test-acp.example.com", + ) + + payload = mock_http_gateway.async_call.call_args[1]["payload"] + assert payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-value" + } + + async def test_send_event_payload_forwards_metadata( + self, + agent_acp_service, + mock_http_gateway, + agent_repository, + sample_agent, + task_with_metadata, + ): + await create_or_get_agent(agent_repository, sample_agent) + event = EventEntity( + id=str(uuid4()), + task_id=task_with_metadata.id, + agent_id=sample_agent.id, + sequence_id=1, + content=TextContent(content="evt", author=MessageAuthor.AGENT), + ) + mock_http_gateway.async_call.return_value = { + "jsonrpc": "2.0", + "result": {"status": "event_sent", "event_id": event.id}, + "id": f"AgentRPCMethod.EVENT_SEND-{task_with_metadata.id}", + } + + await agent_acp_service.send_event( + agent=sample_agent, + event=event, + task=task_with_metadata, + acp_url="http://test-acp.example.com", + ) + + payload = mock_http_gateway.async_call.call_args[1]["payload"] + assert payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-value" + } diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index 08e31bfc..edde7ebd 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -933,6 +933,100 @@ async def mock_async_call(*args, **kwargs): # Verify HTTP call was made mock_http_gateway.async_call.assert_called_once() + async def test_handle_task_create_persists_task_metadata( + self, agents_acp_use_case, mock_http_gateway, agent_repository, sample_agent + ): + """task_metadata on CreateTaskRequest is persisted on the row and forwarded to ACP.""" + await create_or_get_agent(agent_repository, sample_agent) + + from src.api.schemas.agents_rpc import CreateTaskRequest + + async def mock_async_call(*args, **kwargs): + payload = kwargs.get("payload", {}) + return { + "jsonrpc": "2.0", + "result": {"status": "created", "task_id": "new-task-id"}, + "id": payload.get("id", ""), + } + + mock_http_gateway.async_call.side_effect = mock_async_call + + import uuid + + unique_task_name = f"test-task-meta-{uuid.uuid4().hex[:8]}" + create_request = CreateTaskRequest( + name=unique_task_name, + params={"param1": "value1"}, + task_metadata={"created_by_user_id": "user-a"}, + ) + + result = await agents_acp_use_case._handle_task_create( + agent=sample_agent, + params=create_request, + acp_url=sample_agent.acp_url, + ) + + assert result.task_metadata == {"created_by_user_id": "user-a"} + + # task_metadata is forwarded to the agent (kept for backward compat with + # agents that already read metadata set via PUT /tasks/{id}). + mock_http_gateway.async_call.assert_called_once() + sent_payload = mock_http_gateway.async_call.call_args.kwargs["payload"] + assert sent_payload["params"]["task"]["task_metadata"] == { + "created_by_user_id": "user-a" + } + + async def test_handle_task_create_ignores_task_metadata_for_existing_task( + self, agents_acp_use_case, mock_http_gateway, agent_repository, sample_agent + ): + """task_metadata is only stamped at creation; a re-issued task/create with the + same name must not overwrite the existing row's metadata. Update-via-PUT is the + supported way to mutate metadata after creation.""" + await create_or_get_agent(agent_repository, sample_agent) + + from src.api.schemas.agents_rpc import CreateTaskRequest + + async def mock_async_call(*args, **kwargs): + payload = kwargs.get("payload", {}) + return { + "jsonrpc": "2.0", + "result": {"status": "created", "task_id": "x"}, + "id": payload.get("id", ""), + } + + mock_http_gateway.async_call.side_effect = mock_async_call + + import uuid + + unique_task_name = f"test-task-existing-{uuid.uuid4().hex[:8]}" + + # First create — metadata is stamped. + first_request = CreateTaskRequest( + name=unique_task_name, + params={"param1": "v1"}, + task_metadata={"created_by_user_id": "user-a"}, + ) + first = await agents_acp_use_case._handle_task_create( + agent=sample_agent, + params=first_request, + acp_url=sample_agent.acp_url, + ) + assert first.task_metadata == {"created_by_user_id": "user-a"} + + # Second create with same name — metadata in the request must be ignored. + second_request = CreateTaskRequest( + name=unique_task_name, + params={"param1": "v1"}, + task_metadata={"created_by_user_id": "user-b-attacker"}, + ) + second = await agents_acp_use_case._handle_task_create( + agent=sample_agent, + params=second_request, + acp_url=sample_agent.acp_url, + ) + assert second.id == first.id + assert second.task_metadata == {"created_by_user_id": "user-a"} + # async def test_handle_message_send_sync_error_handling( self, diff --git a/agentex/tests/unit/use_cases/test_tasks_use_case.py b/agentex/tests/unit/use_cases/test_tasks_use_case.py index 3de88eea..5703de32 100644 --- a/agentex/tests/unit/use_cases/test_tasks_use_case.py +++ b/agentex/tests/unit/use_cases/test_tasks_use_case.py @@ -455,3 +455,40 @@ async def test_update_metadata_requires_id_or_name(self, tasks_use_case): await tasks_use_case.update_mutable_fields_on_task( task_metadata={"key": "value"} ) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTasksUseCaseListTasks: + """Test suite for list_tasks filtering""" + + async def test_list_tasks_forwards_task_metadata_filter( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """list_tasks should forward task_metadata filter to the service/repository.""" + await create_or_get_agent(agent_repository, sample_agent) + + suffix = uuid4().hex[:8] + matching = await task_service.create_task( + agent=sample_agent, task_name=f"match-{suffix}" + ) + other = await task_service.create_task( + agent=sample_agent, task_name=f"other-{suffix}" + ) + + await tasks_use_case.update_mutable_fields_on_task( + id=matching.id, task_metadata={"created_by_user_id": "user-a"} + ) + await tasks_use_case.update_mutable_fields_on_task( + id=other.id, task_metadata={"created_by_user_id": "user-b"} + ) + + results = await tasks_use_case.list_tasks( + limit=100, + page_number=1, + task_metadata={"created_by_user_id": "user-a"}, + ) + + result_ids = {t.id for t in results} + assert matching.id in result_ids + assert other.id not in result_ids