Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/infrastructure/deepagent/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import AsyncIterator

from langgraph.types import Command
from pydantic import BaseModel

from src.domain.entities.message import Message, MessageRole, MessageStatus
from src.domain.entities.stream_event import StreamEvent, StreamEventType
Expand All @@ -16,9 +17,10 @@


class DeepAgentRunner(AgentRunner):
def __init__(self, graph, tracing_provider: TracingProvider | None = None):
def __init__(self, graph, tracing_provider: TracingProvider | None = None, response_format_model: type[BaseModel] | None = None):
self._graph = graph
self._tracing_provider = tracing_provider
self._response_format_model = response_format_model

@staticmethod
def _try_parse_json(content: str) -> dict | None:
Expand All @@ -40,6 +42,31 @@ def _try_parse_json(content: str) -> dict | None:
pass
return None

def _validate_structured_response(self, data: dict) -> dict:
"""Validate structured_response against the response_format model.

Strips any extra fields not defined in the schema and logs warnings.
"""
try:
validated = self._response_format_model.model_validate(data)
cleaned = validated.model_dump()
self._log_extra_fields(data, cleaned)
return cleaned
except Exception:
logger.warning("Failed to validate structured_response against schema, returning raw data")
return data

@staticmethod
def _log_extra_fields(original: dict, cleaned: dict) -> None:
"""Log any top-level or nested fields that were stripped."""
for key in original:
if key not in cleaned:
logger.warning("Stripped extra field from structured_response: '%s'", key)
elif isinstance(original[key], dict) and isinstance(cleaned[key], dict):
for sub_key in original[key]:
if sub_key not in cleaned[key]:
logger.warning("Stripped extra nested field: '%s.%s'", key, sub_key)

@staticmethod
def _is_nonblank_str(val: object) -> bool:
return isinstance(val, str) and val.strip() != ""
Expand Down Expand Up @@ -114,6 +141,10 @@ def _build_response(self, result: dict, config: dict, thinking: str | None) -> M
if structured_response is None:
structured_response = self._try_parse_json(last_message.content)

# 4. Validate against response_format schema (strip extra fields)
if structured_response is not None and self._response_format_model is not None:
structured_response = self._validate_structured_response(structured_response)

return Message(
role=MessageRole.AI,
content=last_message.content,
Expand Down
14 changes: 6 additions & 8 deletions src/infrastructure/deepagent/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from src.domain.entities.agent_config import AgentConfig, BackendType
from src.domain.ports.mcp_tool_loader import McpToolLoader
from src.domain.ports.prompt_manager import PromptManager
from src.infrastructure.deepagent.schema_utils import make_validation_model

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -193,7 +194,7 @@ async def create_agent_from_config(
mcp_tool_loader: Optional MCP tool loader for loading remote tools.

Returns:
The compiled agent ready for execution.
Tuple of (compiled agent graph, response_format_model or None).
"""
logger.info("Creating agent '%s' (model=%s)", config.name, config.model)
checkpointer = MemorySaver()
Expand Down Expand Up @@ -238,17 +239,14 @@ async def create_agent_from_config(
kwargs["skills"] = config.skills

if config.response_format:
# Use tool-based structured output instead of ProviderStrategy to bypass
# Bedrock schema limitations (max 16 anyOf, max 24 optionals, max grammar size).
# The structured_response tool is injected into the agent's tool list so the LLM
# sees it, but we do NOT pass response_format to avoid create_agent forcing
# tool_choice="any", which suppresses intermediate streaming messages.
# The system prompt is augmented with an instruction to use the tool.
response_format_model = make_validation_model(config.response_format)
response_tool = _create_response_tool(config.response_format)
all_tools = (all_tools or []) + [response_tool]
kwargs["tools"] = all_tools
current_prompt = kwargs.get("system_prompt", "")
kwargs["system_prompt"] = (current_prompt or "") + STRUCTURED_OUTPUT_INSTRUCTION
else:
response_format_model = None

subagents = await _resolve_subagents(config, mcp_tool_loader, prompt_manager)
if subagents:
Expand All @@ -260,7 +258,7 @@ async def create_agent_from_config(
logger.error(f"Error creating agent '{config.name}': {e}")
raise
logger.info("Agent '%s' created successfully", config.name)
return graph
return graph, response_format_model


# helper to get system_prompt from Phoenix
Expand Down
101 changes: 101 additions & 0 deletions src/infrastructure/deepagent/schema_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import hashlib
import json
from typing import Any

from pydantic import BaseModel, ConfigDict, Field, create_model


def schema_to_pydantic_model(schema: dict[str, Any], model_name: str = "DynamicModel") -> type[BaseModel]:
"""Convert a JSON Schema dict to a Pydantic BaseModel with extra='ignore'.

Recursively builds nested models for object properties, stripping any
fields not present in the schema when model_validate is called.
"""
return _build_model(schema, model_name)


def _build_model(schema: dict[str, Any], name: str) -> type[BaseModel]:
schema_type = schema.get("type", "object")

if schema_type == "object":
return _build_object_model(schema, name)
if schema_type == "array":
return _build_array_model(schema, name)

primitives = {"string": str, "number": float, "integer": int, "boolean": bool}
return primitives.get(schema_type, str)


def _build_object_model(schema: dict[str, Any], name: str) -> type[BaseModel]:
properties = schema.get("properties", {})
required = set(schema.get("required", []))

if not properties:
return create_model(
name,
__config__=ConfigDict(extra="ignore"),
**{_sanitize(k): (dict | None, Field(default=None)) for k in properties},
)

field_defs: dict[str, Any] = {}
for prop_name, prop_schema in properties.items():
python_name = _sanitize(prop_name)
prop_type = prop_schema.get("type", "string")

if prop_type == "object":
nested_name = f"{name}_{_sanitize(prop_name).capitalize()}"
nested_model = _build_object_model(prop_schema, nested_name)
if prop_name in required:
field_defs[python_name] = (nested_model, Field(description=prop_schema.get("description", "")))
else:
field_defs[python_name] = (nested_model | None, Field(default=None, description=prop_schema.get("description", "")))
elif prop_type == "array":
items_schema = prop_schema.get("items", {})
items_type = items_schema.get("type", "string")
if items_type == "object":
nested_name = f"{name}_{_sanitize(prop_name).capitalize()}Item"
nested_model = _build_object_model(items_schema, nested_name)
list_type = list[nested_model]
else:
primitives = {"string": str, "number": float, "integer": int, "boolean": bool}
list_type = list[primitives.get(items_type, str)]
if prop_name in required:
field_defs[python_name] = (list_type, Field(default_factory=list, description=prop_schema.get("description", "")))
else:
field_defs[python_name] = (list_type | None, Field(default=None, description=prop_schema.get("description", "")))
else:
primitives = {"string": str, "number": float, "integer": int, "boolean": bool}
python_type = primitives.get(prop_type, str)
if prop_name in required:
field_defs[python_name] = (python_type, Field(description=prop_schema.get("description", "")))
else:
field_defs[python_name] = (python_type | None, Field(default=None, description=prop_schema.get("description", "")))

return create_model(name, __config__=ConfigDict(extra="ignore"), **field_defs)


def _build_array_model(schema: dict[str, Any], name: str) -> type:
items_schema = schema.get("items", {})
items_type = items_schema.get("type", "string")
if items_type == "object":
nested_model = _build_object_model(items_schema, f"{name}Item")
return list[nested_model]
primitives = {"string": str, "number": float, "integer": int, "boolean": bool}
return list[primitives.get(items_type, str)]


def _sanitize(name: str) -> str:
sanitized = name.replace("-", "_").replace(" ", "_")
if sanitized[0].isdigit():
sanitized = f"field_{sanitized}"
return sanitized


def make_validation_model(response_format: dict[str, Any]) -> type[BaseModel]:
"""Create a Pydantic validation model from an agent's response_format schema.

The resulting model uses extra='ignore' so any fields not in the schema
are silently stripped on model_validate.
"""
schema_hash = hashlib.sha256(json.dumps(response_format, sort_keys=True).encode()).hexdigest()[:8]
return schema_to_pydantic_model(response_format, f"ResponseFormat_{schema_hash}")
4 changes: 2 additions & 2 deletions src/infrastructure/persistent_registry/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ async def get_runner(self, agent_name: str) -> AgentRunner:
logger.info("Building agent '%s' from persistent store", agent_name)
yaml_content = await self._config_store.get(agent_name)
config = self._config_loader.load_from_string(yaml_content)
graph = await create_agent_from_config(config, self._mcp_tool_loader, self._prompt_manager)
runner = DeepAgentRunner(graph, tracing_provider=self._tracing_provider)
graph, response_format_model = await create_agent_from_config(config, self._mcp_tool_loader, self._prompt_manager)
runner = DeepAgentRunner(graph, tracing_provider=self._tracing_provider, response_format_model=response_format_model)
self._runners[agent_name] = runner
logger.info("Agent '%s' ready and cached", agent_name)
return runner
Expand Down
120 changes: 120 additions & 0 deletions tests/unit/test_deep_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,123 @@ async def test_build_response_no_structured_response(self):
result = await runner.invoke("thread-1", "hi")

assert result.structured_response is None

# --- Post-validation tests ---

async def test_validate_structured_response_strips_extra_top_level_fields(self):
"""Extra top-level fields invented by the LLM are stripped."""
from src.infrastructure.deepagent.schema_utils import make_validation_model

schema = {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "integer"}},
"required": ["name"],
}
model = make_validation_model(schema)
mock_msg = MagicMock()
mock_msg.content = "Result"
mock_msg.tool_calls = None
graph = _make_graph([mock_msg])
graph.ainvoke.return_value = {
"messages": [mock_msg],
"structured_response": {"name": "Alice", "age": 30, "terraceArea": 50, "parkingSpaces": 2},
}

runner = DeepAgentRunner(graph, response_format_model=model)
result = await runner.invoke("thread-1", "analyze")

assert result.structured_response == {"name": "Alice", "age": 30}
assert "terraceArea" not in result.structured_response
assert "parkingSpaces" not in result.structured_response

async def test_validate_structured_response_strips_nested_extra_fields(self):
"""Extra nested fields invented by the LLM are stripped."""
from src.infrastructure.deepagent.schema_utils import make_validation_model

schema = {
"type": "object",
"properties": {
"building": {
"type": "object",
"properties": {"floors": {"type": "integer"}},
"required": ["floors"],
}
},
"required": ["building"],
}
model = make_validation_model(schema)
mock_msg = MagicMock()
mock_msg.content = "Result"
mock_msg.tool_calls = None
graph = _make_graph([mock_msg])
graph.ainvoke.return_value = {
"messages": [mock_msg],
"structured_response": {"building": {"floors": 3, "rooftop": True}},
}

runner = DeepAgentRunner(graph, response_format_model=model)
result = await runner.invoke("thread-1", "analyze")

assert result.structured_response == {"building": {"floors": 3}}
assert "rooftop" not in result.structured_response["building"]

async def test_validate_structured_response_no_model_returns_raw(self):
"""When no response_format_model is set, data passes through unmodified."""
mock_msg = MagicMock()
mock_msg.content = "Result"
mock_msg.tool_calls = None
graph = _make_graph([mock_msg])
graph.ainvoke.return_value = {
"messages": [mock_msg],
"structured_response": {"name": "test", "extra": True},
}

runner = DeepAgentRunner(graph, response_format_model=None)
result = await runner.invoke("thread-1", "analyze")

assert result.structured_response == {"name": "test", "extra": True}

def test_log_extra_fields_logs_top_level(self, caplog):
"""_log_extra_fields logs warnings for stripped top-level keys."""
import logging

with caplog.at_level(logging.WARNING):
DeepAgentRunner._log_extra_fields(
{"name": "a", "invented": 1},
{"name": "a"},
)
assert "invented" in caplog.text

def test_log_extra_fields_logs_nested(self, caplog):
"""_log_extra_fields logs warnings for stripped nested keys."""
import logging

with caplog.at_level(logging.WARNING):
DeepAgentRunner._log_extra_fields(
{"building": {"floors": 3, "bogus": 1}},
{"building": {"floors": 3}},
)
assert "building.bogus" in caplog.text

async def test_validate_structured_response_from_tool_call(self):
"""structured_response extracted from tool_calls is also validated."""
from src.infrastructure.deepagent.schema_utils import make_validation_model

schema = {
"type": "object",
"properties": {"summary": {"type": "string"}},
"required": ["summary"],
}
model = make_validation_model(schema)
ai_msg = MagicMock()
ai_msg.content = "Done"
ai_msg.tool_calls = [
{"name": "structured_response", "args": {"summary": "ok", "hallucinated": 99}, "id": "tc-1"}
]
graph = _make_graph([ai_msg])

runner = DeepAgentRunner(graph, response_format_model=model)
result = await runner.invoke("thread-1", "summarize")

assert result.structured_response == {"summary": "ok"}
assert "hallucinated" not in result.structured_response
8 changes: 4 additions & 4 deletions tests/unit/test_persistent_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def test_get_runner_loads_from_store(self, mock_runner_cls, mock_create, r
"""get_runner should fetch YAML from MinIO, parse it, create agent, cache runner."""
mock_store.get.return_value = VALID_YAML
mock_graph = MagicMock()
mock_create.return_value = mock_graph
mock_create.return_value = (mock_graph, None)
mock_runner_instance = MagicMock()
mock_runner_cls.return_value = mock_runner_instance

Expand All @@ -85,7 +85,7 @@ async def test_get_runner_loads_from_store(self, mock_runner_cls, mock_create, r
async def test_get_runner_cache_hit(self, mock_runner_cls, mock_create, registry, mock_store):
"""Second call should return cached runner without fetching from store again."""
mock_store.get.return_value = VALID_YAML
mock_create.return_value = MagicMock()
mock_create.return_value = (MagicMock(), None)
mock_runner_instance = MagicMock()
mock_runner_cls.return_value = mock_runner_instance

Expand Down Expand Up @@ -124,7 +124,7 @@ async def test_list_agents_queries_repository(self, registry, mock_repository):
async def test_invalidate_clears_cache(self, mock_runner_cls, mock_create, registry, mock_store):
"""After invalidate, next get_runner should re-fetch from store."""
mock_store.get.return_value = VALID_YAML
mock_create.return_value = MagicMock()
mock_create.return_value = (MagicMock(), None)
runner_a = MagicMock()
runner_b = MagicMock()
mock_runner_cls.side_effect = [runner_a, runner_b]
Expand All @@ -148,7 +148,7 @@ async def test_invalidate_clears_cache(self, mock_runner_cls, mock_create, regis
async def test_close_clears_all_runners(self, mock_runner_cls, mock_create, registry, mock_store):
"""close should empty the runners cache."""
mock_store.get.return_value = VALID_YAML
mock_create.return_value = MagicMock()
mock_create.return_value = (MagicMock(), None)
mock_runner_cls.return_value = MagicMock()

await registry.get_runner("test-agent")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _wire_dependencies(
patch(
"src.infrastructure.persistent_registry.adapter.create_agent_from_config",
new_callable=AsyncMock,
return_value=MagicMock(),
return_value=(MagicMock(), None),
),
patch(
"src.infrastructure.persistent_registry.adapter.DeepAgentRunner",
Expand Down
Loading
Loading