From 5612f3226ef6c58cc2e53a1ae102e2e65144d5ca Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 19:31:41 -0700 Subject: [PATCH 01/12] latest azure-ai-agentserver-core --- sdk/agentserver/AGENTS.md | 214 +++++ sdk/agentserver/CLAUDE.md | 1 + sdk/agentserver/PLANNING.md | 100 +++ .../azure-ai-agentserver-core/CHANGELOG.md | 129 +++ .../azure-ai-agentserver-core/README.md | 2 +- .../azure/ai/agentserver/core/__init__.py | 7 +- .../azure/ai/agentserver/core/_version.py | 2 +- .../agentserver/core/application/__init__.py | 21 + .../agentserver/core/application/_builder.py | 5 + .../core/application/_configuration.py | 42 + .../agentserver/core/application/_metadata.py | 98 +++ .../agentserver/core/application/_options.py | 45 ++ .../agentserver/core/checkpoints/__init__.py | 18 + .../core/checkpoints/client/__init__.py | 6 + .../core/checkpoints/client/_client.py | 158 ++++ .../core/checkpoints/client/_configuration.py | 37 + .../core/checkpoints/client/_models.py | 201 +++++ .../checkpoints/client/operations/__init__.py | 12 + .../checkpoints/client/operations/_items.py | 198 +++++ .../client/operations/_sessions.py | 132 ++++ .../azure/ai/agentserver/core/constants.py | 6 +- .../azure/ai/agentserver/core/logger.py | 179 +++-- .../ai/agentserver/core/models/__init__.py | 3 +- .../core/models/_create_response.py | 8 +- .../models/{openai => _openai}/__init__.py | 0 .../{projects => _projects}/__init__.py | 0 .../models/{projects => _projects}/_enums.py | 0 .../models/{projects => _projects}/_models.py | 0 .../models/{projects => _projects}/_patch.py | 0 .../_patch_evaluations.py | 0 .../_utils/__init__.py | 0 .../_utils/model_base.py | 0 .../_utils/serialization.py | 0 .../azure/ai/agentserver/core/server/_base.py | 732 ++++++++++++++++++ .../ai/agentserver/core/server/_context.py | 32 + .../core/server/_response_metadata.py | 61 ++ .../azure/ai/agentserver/core/server/base.py | 315 -------- ...t_run_context.py => _agent_run_context.py} | 31 +- .../core/server/common/_constants.py | 6 + ..._generator.py => _foundry_id_generator.py} | 102 ++- .../{id_generator.py => _id_generator.py} | 3 + .../ai/agentserver/core/tools/__init__.py | 82 ++ .../ai/agentserver/core/tools/_exceptions.py | 74 ++ .../agentserver/core/tools/client/__init__.py | 5 + .../agentserver/core/tools/client/_client.py | 215 +++++ .../core/tools/client/_configuration.py | 35 + .../agentserver/core/tools/client/_models.py | 615 +++++++++++++++ .../core/tools/client/operations/_base.py | 73 ++ .../operations/_foundry_connected_tools.py | 180 +++++ .../operations/_foundry_hosted_mcp_tools.py | 168 ++++ .../core/tools/runtime/__init__.py | 5 + .../core/tools/runtime/_catalog.py | 143 ++++ .../agentserver/core/tools/runtime/_facade.py | 95 +++ .../core/tools/runtime/_invoker.py | 69 ++ .../core/tools/runtime/_resolver.py | 60 ++ .../core/tools/runtime/_runtime.py | 147 ++++ .../core/tools/runtime/_starlette.py | 67 ++ .../agentserver/core/tools/runtime/_user.py | 60 ++ .../agentserver/core/tools/utils/__init__.py | 11 + .../core/tools/utils/_name_resolver.py | 37 + .../ai/agentserver/core/utils/__init__.py | 5 + .../ai/agentserver/core/utils/_credential.py | 100 +++ .../azure-ai-agentserver-core/cspell.json | 6 +- .../azure.ai.agentserver.core.application.rst | 7 + ...ver.core.checkpoints.client.operations.rst | 7 + ...ai.agentserver.core.checkpoints.client.rst | 15 + .../azure.ai.agentserver.core.checkpoints.rst | 15 + .../doc/azure.ai.agentserver.core.models.rst | 8 + .../doc/azure.ai.agentserver.core.rst | 5 + ...server.core.server.common.id_generator.rst | 12 +- ...zure.ai.agentserver.core.server.common.rst | 14 +- .../doc/azure.ai.agentserver.core.server.rst | 6 +- ...azure.ai.agentserver.core.tools.client.rst | 7 + .../doc/azure.ai.agentserver.core.tools.rst | 17 + ...zure.ai.agentserver.core.tools.runtime.rst | 7 + .../azure.ai.agentserver.core.tools.utils.rst | 7 + .../doc/azure.ai.agentserver.core.utils.rst | 7 + .../azure-ai-agentserver-core/pyproject.toml | 18 +- .../samples/bilingual_weekend_planner/main.py | 2 +- .../samples/mcp_simple/mcp_simple.py | 2 +- .../custom_mock_agent_test.py | 4 +- .../tests/unit_tests/__init__.py | 5 + .../common/test_foundry_id_generator.py | 27 + .../server/test_conversation_persistence.py | 266 +++++++ .../unit_tests/server/test_otel_context.py | 14 + .../server/test_response_metadata.py | 140 ++++ .../tests/unit_tests/test_logger.py | 144 ++++ .../tests/unit_tests/tools/__init__.py | 4 + .../tests/unit_tests/tools/client/__init__.py | 5 + .../tools/client/operations/__init__.py | 4 + .../test_foundry_connected_tools.py | 479 ++++++++++++ .../test_foundry_hosted_mcp_tools.py | 309 ++++++++ .../unit_tests/tools/client/test_client.py | 485 ++++++++++++ .../tools/client/test_configuration.py | 25 + .../tests/unit_tests/tools/conftest.py | 127 +++ .../unit_tests/tools/runtime/__init__.py | 4 + .../unit_tests/tools/runtime/conftest.py | 39 + .../unit_tests/tools/runtime/test_catalog.py | 350 +++++++++ .../unit_tests/tools/runtime/test_facade.py | 180 +++++ .../unit_tests/tools/runtime/test_invoker.py | 198 +++++ .../unit_tests/tools/runtime/test_resolver.py | 202 +++++ .../unit_tests/tools/runtime/test_runtime.py | 401 ++++++++++ .../tools/runtime/test_starlette.py | 261 +++++++ .../unit_tests/tools/runtime/test_user.py | 210 +++++ .../tests/unit_tests/tools/utils/__init__.py | 4 + .../tests/unit_tests/tools/utils/conftest.py | 56 ++ .../tools/utils/test_name_resolver.py | 260 +++++++ 107 files changed, 9085 insertions(+), 432 deletions(-) create mode 100644 sdk/agentserver/AGENTS.md create mode 100644 sdk/agentserver/CLAUDE.md create mode 100644 sdk/agentserver/PLANNING.md create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{openai => _openai}/__init__.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/__init__.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_enums.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_models.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_patch.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_patch_evaluations.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_utils/__init__.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_utils/model_base.py (100%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/{projects => _projects}/_utils/serialization.py (100%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py delete mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/{agent_run_context.py => _agent_run_context.py} (71%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_constants.py rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/{foundry_id_generator.py => _foundry_id_generator.py} (53%) rename sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/{id_generator.py => _id_generator.py} (87%) create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.operations.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py create mode 100644 sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py diff --git a/sdk/agentserver/AGENTS.md b/sdk/agentserver/AGENTS.md new file mode 100644 index 000000000000..f34b271faf32 --- /dev/null +++ b/sdk/agentserver/AGENTS.md @@ -0,0 +1,214 @@ +# AGENTS.md + +This file provides comprehensive guidance to Coding Agents(Codex, Claude Code, GitHub Copilot, etc.) when working with Python code in this repository. + + +## 🎯 (Read-first) Project Awareness & Context + +- **Always read `PLANNING.md`** at the start of a new conversation to understand the project's architecture, goals, style, and constraints. +- **Check `TASK.md`** before starting a new task. If the task isn't listed, add it with a brief description and today's date. + +## πŸ† Core Development Philosophy + +### KISS (Keep It Simple, Stupid) + +Simplicity should be a key goal in design. Choose straightforward solutions over complex ones whenever possible. Simple solutions are easier to understand, maintain, and debug. + +### YAGNI (You Aren't Gonna Need It) + +Avoid building functionality on speculation. Implement features only when they are needed, not when you anticipate they might be useful in the future. + +### Design Principles + +- **Dependency Inversion**: High-level modules should not depend on low-level modules. Both should depend on abstractions. +- **Open/Closed Principle**: Software entities should be open for extension but closed for modification. +- **Single Responsibility**: Each function, class, and module should have one clear purpose. +- **Fail Fast**: Check for potential errors early and raise exceptions immediately when issues occur. +- **Encapsulation**: Hide internal state and require all interaction to be performed through an object's methods. + +### Implementation Patterns + +- **Type-Safe**: Prefer explicit, type‑safe structures (TypedDicts, dataclasses, enums, unions) over Dict, Any, or other untyped containers, unless there is no viable alternative. +- **Explicit validation at boundaries**: Validate inputs and identifiers early; reject malformed descriptors and missing required fields. +- **Separation of resolution and execution**: Keep discovery/selection separate from invocation to allow late binding and interchangeable implementations. +- **Context-aware execution**: Thread request/user context through calls via scoped providers; avoid global mutable state. +- **Cache with safety**: Use bounded TTL caching with concurrency-safe in-flight de-duplication; invalidate on errors. +- **Stable naming and deterministic mapping**: Derive stable, unique names deterministically to avoid collisions. +- **Graceful defaults, loud misconfiguration**: Provide sensible defaults when optional data is missing; raise clear errors when required configuration is absent. +- **Thin integration layers**: Use adapters/middleware to translate between layers without leaking internals. + +## βœ… Work Process (required) + +- **Before coding**: confirm the task in `TASK.md` β†’ **Now**. +- **While working**: Add new sub-tasks or TODOs discovered during development to `TASK.md` under a "Discovered During Work" section. +- **After finishing**: mark the task done immediately, and note what changed (files/areas) in `TASK.md`. +- **Update CHANGELOG.md only when required** by release policy. + +### TASK.md template +```markdown +## Now (active) +- [ ] YYYY-MM-DD β€” + - Scope: + - Exit criteria: + +## Next (queued) +- [ ] YYYY-MM-DD β€” + +## Discovered During Work +- [ ] YYYY-MM-DD β€” + +## Done +- [x] YYYY-MM-DD β€” +``` + +## πŸ“Ž Style & Conventions & Standards + +### File and Function Limits + +- **Never create a file longer than 500 lines of code**. If approaching this limit, refactor by splitting into modules. +- **Functions should be under 50 lines** with a single, clear responsibility. +- **Classes should be under 100 lines** and represent a single concept or entity. +- **Organize code into clearly separated modules**, grouped by feature or responsibility. +- **Line length should be max 120 characters**, as enforced by Ruff in `pyproject.toml`. +- **Use the standard repo workflow**: from the package root, run tests and linters via `tox` (for example, `tox -e pytest` or `tox -e pylint`) rather than relying on a custom virtual environment name. +- **Keep modules focused and cohesive**; split by feature responsibility when a file grows large or mixes concerns. +- **Avoid drive-by refactors** unless required by the task. +- **Preserve public API stability** and match existing patterns in the package you touch. + +### Naming Conventions + +- **Variables and functions**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_SNAKE_CASE` +- **Private attributes/methods**: `_leading_underscore` +- **Type aliases**: `PascalCase` +- **Enum values**: `UPPER_SNAKE_CASE` + +### Project File Conventions + +- **Follow per-package `pyproject.toml`** and `[tool.azure-sdk-build]` settings. +- **Do not edit generated code**. +- **Do not edit files with the header** `Code generated by Microsoft (R) Python Code Generator.` + - `sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/` + - `sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py` + - `sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py` + - `sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py` +- **Do not introduce secrets or credentials.** +- **Do not disable TLS/SSL verification** without explicit approval. +- **Keep logs free of sensitive data.** + +### Python Code Standards + +- **Follow existing package patterns** and Azure SDK for Python guidelines. +- **Type hints and async patterns** should match existing code. +- **Respect Ruff settings** (line length 120, isort rules) in each package `pyproject.toml`. +- **Avoid new dependencies** unless necessary and approved. + +### Docstring Standards +Use reStructuredText (reST) / Sphinx Style docstrings for all public functions, classes, and modules: + +```python +def calculate_discount( + price: Decimal, + discount_percent: float, + min_amount: Decimal = Decimal("0.01") +) -> Decimal: + """Calculate the discounted price for a product. + + :param price: The original price of the product. + :type price: Decimal + :param discount_percent: The discount percentage to apply (0-100). + :type discount_percent: float + :param min_amount: The minimum amount after discount (default is 0.01). + :type min_amount: Decimal + :return: The price after applying the discount, not less than min_amount. + :rtype: Decimal + :raises ValueError: If discount_percent is not between 0 and 100. + + Example: + Calculate a 15% discount on a $100 product: + + .. code-block:: python + + calculate_discount(Decimal("100.00"), 15.0) + """ +``` + +### Error Handling Standards + +- Prefer explicit validation at API boundaries and raise errors **as early as possible**. +- Use standard Python exceptions (`ValueError`, `TypeError`, `KeyError`, etc.) when they accurately describe the problem. +- When a domain-specific error is needed, define a clear, documented exception type and reuse it consistently. +- Do **not** silently swallow exceptions. Either handle them meaningfully (with clear recovery behavior) or let them propagate. +- Preserve the original traceback when re-raising (`raise` without arguments) so issues remain diagnosable. +- Fail fast on programmer errors (e.g., inconsistent state, impossible branches) using assertions or explicit exceptions. +- For public APIs, validate user input and return helpful, actionable messages without leaking secrets or internal implementation details. + +#### Exception Best Practices + +- Avoid `except Exception:` and **never** use bare `except:`; always catch the most specific exception type possible. +- Keep `try` blocks **small** and focused so that it is clear which statements may raise the handled exception. +- When adding context to an error, use either `raise NewError("message") from exc` or log the context and re-raise with `raise`. +- Do not use exceptions for normal control flow; reserve them for truly exceptional or error conditions. +- When a function can raise non-obvious exceptions, document them in the docstring under a `:raises:` section. +- In asynchronous code, make sure exceptions are not lost in background tasks; gather and handle them explicitly where needed. + +### Logging Standards + +- Use the standard library `logging` module for all diagnostic output; **do not** use `print` in library or service code. +- Create a module-level logger via `logger = logging.getLogger(__name__)` and use it consistently within that module. +- Choose log levels appropriately: + - `logger.debug(...)` for detailed diagnostics and tracing. + - `logger.info(...)` for high-level lifecycle events (startup, shutdown, major state changes). + - `logger.warning(...)` for recoverable issues or unexpected-but-tolerated conditions. + - `logger.error(...)` for failures where the current operation cannot succeed. + - `logger.critical(...)` for unrecoverable conditions affecting process health. +- Never log secrets, credentials, access tokens, full connection strings, or sensitive customer data. +- When logging exceptions, prefer `logger.exception("message")` inside an `except` block so the traceback is included. +- Keep log messages clear and structured (include identifiers like request IDs, resource names, or correlation IDs when available). +## ⚠️ Important Notes + +- **NEVER ASSUME OR GUESS** - When in doubt, ask for clarification +- **Always verify file paths and module names** before use +- **Keep this file (`AGENTS.md`) updated** when adding new patterns or dependencies +- **Test your code** - No feature is complete without tests +- **Document your decisions** - Future developers (including yourself) will thank you + +## πŸ“š Documentation & Explainability + +- **Keep samples runnable** and focused on SDK usage. +- **Follow third-party dependency guidance** in `CONTRIBUTING.md`. + +## πŸ› οΈ Development Environment + +- **Use the repo-root virtual environment for validation work**: activate and use `azuresdk-env` from the repository root (`azure-sdk-for-python/azuresdk-env`) when running tests, linters, type checks, samples, or other validation commands. + +### Dev Environment Setup + +1. From the repository root, check for `azure-sdk-for-python/azuresdk-env`. +2. If it exists, activate it. If it does not exist, create it with `python -m venv azuresdk-env` and then activate it. +3. Install shared tooling into that environment: `uv` and `tox`. +4. Install the local AgentServer packages from source: + - `sdk/agentserver/azure-ai-agentserver-core/` + - `sdk/agentserver/azure-ai-agentserver-agentframework/` + - `sdk/agentserver/azure-ai-agentserver-langgraph/` +5. Install the common local test/runtime dependencies: `python-dotenv`, `pytest`, and `pytest-asyncio`. +6. Use that same environment for all subsequent local validation commands. + +### βœ… Testing & Quality Gates (tox) + +Run from a package root (e.g., `sdk/agentserver/azure-ai-agentserver-core`). + +- `tox run -e sphinx -c ../../../eng/tox/tox.ini --root .` + - Docs output: `.tox/sphinx/tmp/dist/site/index.html` +- `tox run -e pylint -c ../../../eng/tox/tox.ini --root .` + - Uses repo `pylintrc`; `next-pylint` uses `eng/pylintrc` +- `tox run -e mypy -c ../../../eng/tox/tox.ini --root .` +- `tox run -e pyright -c ../../../eng/tox/tox.ini --root .` +- `tox run -e verifytypes -c ../../../eng/tox/tox.ini --root .` +- `tox run -e whl -c ../../../eng/tox/tox.ini --root .` +- `tox run -e sdist -c ../../../eng/tox/tox.ini --root .` +- `tox run -e samples -c ../../../eng/tox/tox.ini --root .` (runs all samples) +- `tox run -e apistub -c ../../../eng/tox/tox.ini --root .` + +Check each package `pyproject.toml` under `[tool.azure-sdk-build]` to see which checks are enabled/disabled. diff --git a/sdk/agentserver/CLAUDE.md b/sdk/agentserver/CLAUDE.md new file mode 100644 index 000000000000..43c994c2d361 --- /dev/null +++ b/sdk/agentserver/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md diff --git a/sdk/agentserver/PLANNING.md b/sdk/agentserver/PLANNING.md new file mode 100644 index 000000000000..6a65e2925a83 --- /dev/null +++ b/sdk/agentserver/PLANNING.md @@ -0,0 +1,100 @@ +# 🧭 PLANNING.md + +## 🎯 What this project is +AgentServer is a set of Python packages under `sdk/agentserver` that host agents for +Azure AI Foundry. The core package provides the runtime/server, tooling runtime, and +Responses API models, while the adapter packages wrap popular frameworks. The primary +users are SDK consumers who want to run agents locally and deploy them as Foundry-hosted +containers. Work is β€œdone” when adapters faithfully translate framework execution into +Responses API-compatible outputs and the packages pass their expected tests and samples. + +**Behavioral/policy rules live in `AGENTS.md`.** This document is architecture + repo map + doc index. + +## 🎯 Goals / Non-goals +Goals: +- Keep a stable architecture snapshot and repo map for fast onboarding. +- Document key request/response flows, including streaming. +- Clarify the development workflow and testing expectations for AgentServer packages. + +Non-goals: +- Detailed API documentation (belongs in package docs and docstrings). +- Per-initiative plans (belong in `TASK.md` or a dedicated plan file). +- Speculative refactors (align with KISS/YAGNI in `AGENTS.md`). + +## 🧩 Architecture (snapshot) +### πŸ—οΈ Project Structure +- **azure-ai-agentserver-core**: Core library + - Runtime/context + - HTTP gateway + - Foundry integrations + - Responses API protocol (current) +- **azure-ai-agentserver-agentframework**: adapters for Agent Framework agents/workflows, + thread and checkpoint persistence. +- **azure-ai-agentserver-langgraph**: adapter and converters for LangGraph agents and + Response API events. + +### Current vs target +- Current: OpenAI Responses API protocol lives in `azure-ai-agentserver-core` alongside + core runtime and HTTP gateway code; framework adapters layer on top. +- Target (planned, not fully implemented): + - Core layer: app/runtime/context, foundry integrations (tools, checkpointing), HTTP gateway + - Protocol layer: Responses API in its own package + - Framework layer: adapters (agentframework, langgraph, other frameworks) + +### Key flows +- Request path: `/runs` or `/responses` β†’ `AgentRunContext` β†’ agent execution β†’ Responses + API payload. +- Streaming path: generator/async generator β†’ SSE event stream. +- Framework adapter path: framework input β†’ converter β†’ Response API output (streaming + or non-streaming). +- Tools path: Foundry tool runtime invoked via core `tools/runtime` APIs. + +## πŸ—ΊοΈ Repo map +- `azure-ai-agentserver-core`: Core library (runtime/context, HTTP gateway, Foundry integrations, + Responses API protocol today). +- `azure-ai-agentserver-agentframework`: Agent Framework adapter. +- `azure-ai-agentserver-langgraph`: LangGraph adapter. +- Core runtime and models: `azure-ai-agentserver-core/azure/ai/agentserver/core/` +- Agent Framework adapter: `azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/` +- LangGraph adapter: `azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/` +- Samples: `azure-ai-agentserver-*/samples/` +- Tests: `azure-ai-agentserver-*/tests/` +- Package docs (Sphinx inputs): `azure-ai-agentserver-*/doc/` +- Repo-wide guidance: `CONTRIBUTING.md`, `doc/dev/tests.md`, `doc/eng_sys_checks.md` + +## πŸ“š Doc index +### **Read repo-wide guidance**: +- `CONTRIBUTING.md` +- `doc/dev/tests.md` +- `doc/eng_sys_checks.md` + +### **Read the package READMEs**: + - `sdk/agentserver/azure-ai-agentserver-core/README.md` + - `sdk/agentserver/azure-ai-agentserver-agentframework/README.md` + - `sdk/agentserver/azure-ai-agentserver-langgraph/README.md` + +### β€œIf you need X, look at Y” +- Enable/disable checks for a package β†’ that package `pyproject.toml` β†’ `[tool.azure-sdk-build]` +- How to run tests / live-recorded tests β†’ `doc/dev/tests.md` +- Engineering system checks / gates β†’ `doc/eng_sys_checks.md` +- Adapter conversion behavior β†’ the relevant adapter package + its tests + samples + +## βœ… Testing strategy +- Unit/integration tests live in each package’s `tests/` directory. +- Samples are part of validation via the `samples` tox environment. +- For live/recorded testing patterns, follow `doc/dev/tests.md`. + +## πŸš€ Rollout / migrations +- Preserve public API stability and follow Azure SDK release policy. +- Do not modify generated code (see paths in `AGENTS.md`). +- CI checks are controlled per package in `pyproject.toml` under + `[tool.azure-sdk-build]`. + +## ⚠️ Risks / edge cases +- Streaming event ordering and keep-alive behavior. +- Credential handling (async credentials and adapters). +- Response API schema compatibility across adapters. +- Tool invocation failures and error surfacing. + +## πŸ“Œ Progress +See `TASK.md` for active work items; no checklists here. diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index cfcf2445e256..b312c9d01737 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,5 +1,134 @@ # Release History + +## 1.0.0b16 (2026-03-10) + +### Other Changes + +- Upgraded to support agent-framework v1.0.0rc3. + + +## 1.0.0b15 (2026-03-02) + +### Other Changes + +- Upgraded starlette>=1.0.0rc1. + + +## 1.0.0b14 (2026-02-24) + +### Other Changes + +- Pin opentelemetry-semantic-conventions-ai==0.4.13. + + +## 1.0.0b13 (2026-02-20) + +### Other Changes + +- Upgraded azure-ai-projects version. + + +## 1.0.0b12 (2026-02-12) + +### Bugs Fixed + +- Minor bugs fixed for -langgraph and -agentframework. + + +## 1.0.0b11 (2026-02-10) + +### Features Added + +- Added conversation persistence: automatically save input and output items to conversation when `store=True` in request +- Added deduplication check to avoid saving duplicate input items +- Added server startup success log message + +### Other Changes + +- Improved logging: replaced confusing print statements with proper logger calls +- Changed logger to use stdout instead of stderr for consistency with uvicorn +- Added `_items_are_equal()` method for comparing conversation items + +## 1.0.0b10 (2026-01-27) + +### Bugs Fixed + +- Make AZURE_AI_PROJECTS_ENDPOINT optional. + +## 1.0.0b9 (2026-01-23) + +### Features Added + +- Integrated with Foundry Tools + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + + +## 1.0.0b7 (2025-12-05) + +### Features Added + +- Update response with created_by + +### Bugs Fixed + +- Fixed error response handling in stream and non-stream modes + +## 1.0.0b6 (2025-11-26) + +### Features Added + +- Support Agent-framework greater than 251112 + + +## 1.0.0b5 (2025-11-16) + +### Features Added + +- Support Tools Oauth + +### Bugs Fixed + +- Fixed streaming generation issues. + + +## 1.0.0b4 (2025-11-13) + +### Features Added + +- Adapters support tools + +### Bugs Fixed + +- Pin azure-ai-projects and azure-ai-agents version to avoid version confliction + + +## 1.0.0b3 (2025-11-11) + +### Bugs Fixed + +- Fixed Id generator format. + +- Fixed trace initialization for agent-framework. + + +## 1.0.0b2 (2025-11-10) + +### Bugs Fixed + +- Fixed Id generator format. + +- Improved stream mode error messsage. + +- Updated application insights related configuration environment variables. + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-core/README.md b/sdk/agentserver/azure-ai-agentserver-core/README.md index ff60cf460196..cc420579e5fe 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/README.md +++ b/sdk/agentserver/azure-ai-agentserver-core/README.md @@ -26,7 +26,7 @@ from azure.ai.agentserver.core.models import ( CreateResponse, Response as OpenAIResponse, ) -from azure.ai.agentserver.core.models.projects import ( +from azure.ai.agentserver.core.models._projects import ( ItemContentOutputText, ResponsesAssistantMessageItemResource, ResponseTextDeltaEvent, diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index 895074d32ae3..39de11cefe55 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -5,10 +5,11 @@ from ._version import VERSION from .logger import configure as config_logging -from .server.base import FoundryCBAgent -from .server.common.agent_run_context import AgentRunContext +from .server._base import FoundryCBAgent +from .server.common._agent_run_context import AgentRunContext +from .server._context import AgentServerContext config_logging() -__all__ = ["FoundryCBAgent", "AgentRunContext"] +__all__ = ["FoundryCBAgent", "AgentRunContext", "AgentServerContext"] __version__ = VERSION diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index be71c81bd282..e4218ac5b98d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b16" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py new file mode 100644 index 000000000000..052f9894497b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +__all__ = [ + "AgentServerMetadata", + "PackageMetadata", + "RuntimeMetadata", + "get_current_app", + "set_current_app" +] + +from ._metadata import ( + AgentServerMetadata, + PackageMetadata, + RuntimeMetadata, + get_current_app, + set_current_app, +) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py new file mode 100644 index 000000000000..c09c253ab09f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +class AgentServerBuilder: + pass diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py new file mode 100644 index 000000000000..1f8a01d57639 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from dataclasses import dataclass, field + +from azure.core.credentials_async import AsyncTokenCredential + + +@dataclass(frozen=True) +class HttpServerConfiguration: + """Resolved configuration for the HTTP server. + + :ivar str host: The host address the server listens on. Defaults to '0.0.0.0'. + :ivar int port: The port number the server listens on. Defaults to 8088. + """ + + host: str = "0.0.0.0" + port: int = 8088 + + +class ToolsConfiguration: + """Resolved configuration for the Tools subsystem. + + :ivar int catalog_cache_ttl: The time-to-live (TTL) for the tool catalog cache in seconds. + Defaults to 600 seconds (10 minutes). + :ivar int catalog_cache_max_size: The maximum size of the tool catalog cache. + Defaults to 1024 entries. + """ + + catalog_cache_ttl: int = 600 + catalog_cache_max_size: int = 1024 + + +@dataclass(frozen=True) +class AgentServerConfiguration: + """Resolved configuration for the Agent Server application.""" + + project_endpoint: str + credential: AsyncTokenCredential + agent_name: str = "$default" + http: HttpServerConfiguration = field(default_factory=HttpServerConfiguration) + tools: ToolsConfiguration = field(default_factory=ToolsConfiguration) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py new file mode 100644 index 000000000000..7053622a8044 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py @@ -0,0 +1,98 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import os +import platform # pylint: disable=unused-import +from dataclasses import dataclass, field +from importlib.metadata import Distribution, PackageNotFoundError + +@dataclass(frozen=True, kw_only=True) +class PackageMetadata: + name: str + version: str + + @staticmethod + def from_dist(dist_name: str) -> "PackageMetadata": + try: + ver = Distribution.from_name(dist_name).version + except PackageNotFoundError: + ver = "" + + return PackageMetadata( + name=dist_name, + version=ver, + ) + + +@dataclass(frozen=True, kw_only=True) +class RuntimeMetadata: + python_version: str = field(default_factory=platform.python_version) + platform: str = field(default_factory=platform.platform) + host_name: str = "" + replica_name: str = "" + + @staticmethod + def from_aca_app_env() -> "RuntimeMetadata | None": + host_name = os.environ.get("CONTAINER_APP_REVISION_FQDN") + replica_name = os.environ.get("CONTAINER_APP_REPLICA_NAME") + + if not host_name and not replica_name: + return None + + return RuntimeMetadata( + host_name=host_name or "", + replica_name=replica_name or "", + ) + + @staticmethod + def resolve(host_name: str | None = None, replica_name: str | None = None) -> "RuntimeMetadata": + runtime = RuntimeMetadata.from_aca_app_env() + + override = RuntimeMetadata(host_name=host_name or "", replica_name=replica_name or "") + return runtime.merged_with(override) if runtime else override + + def merged_with(self, override: "RuntimeMetadata | None") -> "RuntimeMetadata": + if override is None: + return self + + return RuntimeMetadata( + python_version=override.python_version or self.python_version, + platform=override.platform or self.platform, + host_name=override.host_name or self.host_name, + replica_name=override.replica_name or self.replica_name, + ) + + +@dataclass(frozen=True) +class AgentServerMetadata: + package: PackageMetadata + runtime: RuntimeMetadata + + def as_user_agent(self, component: str | None = None) -> str: + component_value = f" {component}" if component else "" + return ( + f"{self.package.name}/{self.package.version} " + f"Python {self.runtime.python_version}{component_value} " + f"({self.runtime.platform})" + ) + + +_default = AgentServerMetadata( + package=PackageMetadata.from_dist("azure-ai-agentserver-core"), + runtime=RuntimeMetadata.resolve(), +) +_app: AgentServerMetadata = _default + + +def set_current_app(app: PackageMetadata, runtime: RuntimeMetadata | None = None) -> None: + global _app # pylint: disable=W0603 + resolved_runtime = RuntimeMetadata.resolve() + merged_runtime = resolved_runtime.merged_with(runtime) + _app = AgentServerMetadata(package=app, runtime=merged_runtime) + + +def get_current_app() -> AgentServerMetadata: + global _app # pylint: disable=W0602 + return _app diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py new file mode 100644 index 000000000000..d70270261e7b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py @@ -0,0 +1,45 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Literal, TypedDict, Union + +from typing_extensions import NotRequired + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + + +class AgentServerOptions(TypedDict): + """Configuration options for the Agent Server. + + Attributes: + project_endpoint (str, optional): The endpoint URL for the project. Defaults to current project. + credential (Union[AsyncTokenCredential, TokenCredential], optional): The credential used for authentication. + Defaults to current project's managed identity. + """ + project_endpoint: NotRequired[str] + credential: NotRequired[Union[AsyncTokenCredential, TokenCredential]] + http: NotRequired["HttpServerOptions"] + tools: NotRequired["ToolsOptions"] + + +class HttpServerOptions(TypedDict): + """Configuration options for the HTTP server. + + Attributes: + host (str, optional): The host address the server listens on. + """ + host: NotRequired[Literal["127.0.0.1", "localhost", "0.0.0.0"]] + + +class ToolsOptions(TypedDict): + """Configuration options for the Tools subsystem. + + Attributes: + catalog_cache_ttl (int, optional): The time-to-live (TTL) for the tool catalog cache in seconds. + Defaults to 600 seconds (10 minutes). + catalog_cache_max_size (int, optional): The maximum size of the tool catalog cache. + Defaults to 1024 entries. + """ + catalog_cache_ttl: NotRequired[int] + catalog_cache_max_size: NotRequired[int] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py new file mode 100644 index 000000000000..0ca387146579 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint storage module for Azure AI Agent Server.""" + +from .client._client import FoundryCheckpointClient +from .client._models import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + +__all__ = [ + "CheckpointItem", + "CheckpointItemId", + "CheckpointSession", + "FoundryCheckpointClient", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py new file mode 100644 index 000000000000..901cbb3d70a8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py @@ -0,0 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint client module for Azure AI Agent Server.""" + +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py new file mode 100644 index 000000000000..fc2f45321968 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py @@ -0,0 +1,158 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=client-method-missing-kwargs,client-accepts-api-version-keyword,missing-client-constructor-parameter-kwargs +# ^^^ azure-sdk pylint rules: internal client not intended as a public Azure SDK client +"""Asynchronous client for Azure AI Foundry checkpoint storage API.""" + +from typing import Any, AsyncContextManager, List, Optional + +from azure.core import AsyncPipelineClient +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._configuration import FoundryCheckpointClientConfiguration +from ._models import CheckpointItem, CheckpointItemId, CheckpointSession +from .operations import CheckpointItemOperations, CheckpointSessionOperations + + +class FoundryCheckpointClient(AsyncContextManager["FoundryCheckpointClient"]): + """Asynchronous client for Azure AI Foundry checkpoint storage API. + + This client provides access to checkpoint storage for workflow state persistence, + enabling checkpoint save, load, list, and delete operations. + + :param endpoint: The fully qualified project endpoint for the Azure AI Foundry service. + Example: "https://.services.ai.azure.com/api/projects/" + :type endpoint: str + :param credential: Credential for authenticating requests to the service. + Use credentials from azure-identity like DefaultAzureCredential. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + + def __init__( + self, + endpoint: str, + credential: "AsyncTokenCredential", + ) -> None: + """Initialize the asynchronous Azure AI Checkpoint Client. + + :param endpoint: The project endpoint URL (includes project context). + :type endpoint: str + :param credential: Credentials for authenticating requests. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + config = FoundryCheckpointClientConfiguration(credential) + self._client: AsyncPipelineClient = AsyncPipelineClient( + base_url=endpoint, config=config + ) + self._sessions = CheckpointSessionOperations(self._client) + self._items = CheckpointItemOperations(self._client) + + # Session operations + + @distributed_trace_async + async def upsert_session(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + return await self._sessions.upsert(session) + + @distributed_trace_async + async def read_session(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + return await self._sessions.read(session_id) + + @distributed_trace_async + async def delete_session(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + await self._sessions.delete(session_id) + + # Item operations + + @distributed_trace_async + async def create_items(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + return await self._items.create_batch(items) + + @distributed_trace_async + async def read_item(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + return await self._items.read(item_id) + + @distributed_trace_async + async def delete_item(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + return await self._items.delete(item_id) + + @distributed_trace_async + async def list_item_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + return await self._items.list_ids(session_id, parent_id) + + # Context manager methods + + async def close(self) -> None: + """Close the underlying HTTP pipeline.""" + await self._client.close() + + async def __aenter__(self) -> "FoundryCheckpointClient": + """Enter the async context manager. + + :return: The client instance. + :rtype: FoundryCheckpointClient + """ + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + """Exit the async context manager. + + :param exc_details: Exception details if an exception occurred. + :type exc_details: Any + :return: None + :rtype: None + """ + await self._client.__aexit__(*exc_details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py new file mode 100644 index 000000000000..cd9ed9ee7ff7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Configuration for Azure AI Checkpoint Client.""" + +from azure.core.configuration import Configuration +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline import policies + +from ...application._metadata import get_current_app + + +class FoundryCheckpointClientConfiguration(Configuration): + """Configuration for Azure AI Checkpoint Client. + + Manages authentication, endpoint configuration, and policy settings for the + Azure AI Checkpoint Client. This class is used internally by the client and should + not typically be instantiated directly. + + :param credential: Azure TokenCredential for authentication. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + + def __init__(self, credential: "AsyncTokenCredential") -> None: + super().__init__() + + self.retry_policy = policies.AsyncRetryPolicy() + self.logging_policy = policies.NetworkTraceLoggingPolicy() + self.request_id_policy = policies.RequestIdPolicy() + self.http_logging_policy = policies.HttpLoggingPolicy() + self.user_agent_policy = policies.UserAgentPolicy( + base_user_agent=get_current_app().as_user_agent("FoundryCheckpointClient") + ) + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + credential, "https://ai.azure.com/.default" + ) + self.redirect_policy = policies.AsyncRedirectPolicy() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py new file mode 100644 index 000000000000..626bcbfaba56 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py @@ -0,0 +1,201 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Data models for checkpoint storage API.""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +@dataclass +class CheckpointSession: + """Represents a checkpoint session. + + A session maps to a conversation and groups related checkpoints together. + + :ivar session_id: The session identifier (maps to conversation_id). + :ivar metadata: Optional metadata for the session. + """ + + session_id: str + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class CheckpointItemId: + """Identifier for a checkpoint item. + + :ivar session_id: The session identifier this checkpoint belongs to. + :ivar item_id: The unique checkpoint item identifier. + :ivar parent_id: Optional parent checkpoint identifier for hierarchical checkpoints. + """ + + session_id: str + item_id: str + parent_id: Optional[str] = None + + +@dataclass +class CheckpointItem: + """Represents a single checkpoint item. + + Contains the serialized checkpoint data along with identifiers. + + :ivar session_id: The session identifier this checkpoint belongs to. + :ivar item_id: The unique checkpoint item identifier. + :ivar data: Serialized checkpoint data as bytes. + :ivar parent_id: Optional parent checkpoint identifier. + """ + + session_id: str + item_id: str + data: bytes + parent_id: Optional[str] = None + + def to_item_id(self) -> CheckpointItemId: + """Convert to a CheckpointItemId. + + :return: The checkpoint item identifier. + :rtype: CheckpointItemId + """ + return CheckpointItemId( + session_id=self.session_id, + item_id=self.item_id, + parent_id=self.parent_id, + ) + + +# Pydantic models for API request/response serialization + + +class CheckpointSessionRequest(BaseModel): + """Request model for creating/updating a checkpoint session.""" + + session_id: str = Field(alias="sessionId") + metadata: Optional[Dict[str, Any]] = None + + model_config = {"populate_by_name": True} + + @classmethod + def from_session(cls, session: CheckpointSession) -> "CheckpointSessionRequest": + """Create a request from a CheckpointSession. + + :param session: The checkpoint session. + :type session: CheckpointSession + :return: The request model. + :rtype: CheckpointSessionRequest + """ + return cls( + session_id=session.session_id, + metadata=session.metadata, + ) + + +class CheckpointSessionResponse(BaseModel): + """Response model for checkpoint session operations.""" + + session_id: str = Field(alias="sessionId") + metadata: Optional[Dict[str, Any]] = None + etag: Optional[str] = None + + model_config = {"populate_by_name": True} + + def to_session(self) -> CheckpointSession: + """Convert to a CheckpointSession. + + :return: The checkpoint session. + :rtype: CheckpointSession + """ + return CheckpointSession( + session_id=self.session_id, + metadata=self.metadata, + ) + + +class CheckpointItemIdResponse(BaseModel): + """Response model for checkpoint item identifiers.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + parent_id: Optional[str] = Field(default=None, alias="parentId") + + model_config = {"populate_by_name": True} + + def to_item_id(self) -> CheckpointItemId: + """Convert to a CheckpointItemId. + + :return: The checkpoint item identifier. + :rtype: CheckpointItemId + """ + return CheckpointItemId( + session_id=self.session_id, + item_id=self.item_id, + parent_id=self.parent_id, + ) + + +class CheckpointItemRequest(BaseModel): + """Request model for creating checkpoint items.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + data: str # Base64-encoded bytes + parent_id: Optional[str] = Field(default=None, alias="parentId") + + model_config = {"populate_by_name": True} + + @classmethod + def from_item(cls, item: CheckpointItem) -> "CheckpointItemRequest": + """Create a request from a CheckpointItem. + + :param item: The checkpoint item. + :type item: CheckpointItem + :return: The request model. + :rtype: CheckpointItemRequest + """ + import base64 + + return cls( + session_id=item.session_id, + item_id=item.item_id, + data=base64.b64encode(item.data).decode("utf-8"), + parent_id=item.parent_id, + ) + + +class CheckpointItemResponse(BaseModel): + """Response model for checkpoint item operations.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + data: str # Base64-encoded bytes + parent_id: Optional[str] = Field(default=None, alias="parentId") + etag: Optional[str] = None + + model_config = {"populate_by_name": True} + + def to_item(self) -> CheckpointItem: + """Convert to a CheckpointItem. + + :return: The checkpoint item. + :rtype: CheckpointItem + """ + import base64 + + return CheckpointItem( + session_id=self.session_id, + item_id=self.item_id, + data=base64.b64decode(self.data), + parent_id=self.parent_id, + ) + + +class ListCheckpointItemIdsResponse(BaseModel): + """Response model for listing checkpoint item identifiers.""" + + value: List[CheckpointItemIdResponse] = Field(default_factory=list) + next_link: Optional[str] = Field(default=None, alias="nextLink") + + model_config = {"populate_by_name": True} diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py new file mode 100644 index 000000000000..42ba9bacd20b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py @@ -0,0 +1,12 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint operations module for Azure AI Agent Server.""" + +from ._items import CheckpointItemOperations +from ._sessions import CheckpointSessionOperations + +__all__ = [ + "CheckpointItemOperations", + "CheckpointSessionOperations", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py new file mode 100644 index 000000000000..dc871a5c91a6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py @@ -0,0 +1,198 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Operations for checkpoint items.""" + +from typing import Any, ClassVar, Dict, List, Optional + +from azure.core.exceptions import ResourceNotFoundError +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ....tools.client.operations._base import BaseOperations +from .._models import ( + CheckpointItem, + CheckpointItemId, + CheckpointItemRequest, + CheckpointItemResponse, + ListCheckpointItemIdsResponse, +) + + +class CheckpointItemOperations(BaseOperations): + """Operations for managing checkpoint items.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = {"api-version": _API_VERSION} + + def _items_path(self, item_id: Optional[str] = None) -> str: + """Get the API path for item operations. + + :param item_id: Optional item identifier. + :type item_id: Optional[str] + :return: The API path. + :rtype: str + """ + base = "/checkpoints/items" + return f"{base}/{item_id}" if item_id else base + + def _build_create_batch_request(self, items: List[CheckpointItem]) -> HttpRequest: + """Build the HTTP request for creating items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The HTTP request. + :rtype: HttpRequest + """ + request_models = [CheckpointItemRequest.from_item(item) for item in items] + return self._client.post( + self._items_path(), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=[model.model_dump(by_alias=True) for model in request_models], + ) + + def _build_read_request(self, item_id: CheckpointItemId) -> HttpRequest: + """Build the HTTP request for reading an item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = item_id.session_id + if item_id.parent_id: + params["parentId"] = item_id.parent_id + return self._client.get( + self._items_path(item_id.item_id), + params=params, + headers=self._HEADERS, + ) + + def _build_delete_request(self, item_id: CheckpointItemId) -> HttpRequest: + """Build the HTTP request for deleting an item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = item_id.session_id + if item_id.parent_id: + params["parentId"] = item_id.parent_id + return self._client.delete( + self._items_path(item_id.item_id), + params=params, + headers=self._HEADERS, + ) + + def _build_list_ids_request( + self, session_id: str, parent_id: Optional[str] = None + ) -> HttpRequest: + """Build the HTTP request for listing item IDs. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier. + :type parent_id: Optional[str] + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = session_id + if parent_id: + params["parentId"] = parent_id + return self._client.get( + self._items_path(), + params=params, + headers=self._HEADERS, + ) + + @distributed_trace_async + async def create_batch(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + if not items: + return [] + + request = self._build_create_batch_request(items) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + if isinstance(json_response, list): + return [ + CheckpointItemResponse.model_validate(item).to_item() + for item in json_response + ] + # Single item response + return [CheckpointItemResponse.model_validate(json_response).to_item()] + + @distributed_trace_async + async def read(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + request = self._build_read_request(item_id) + try: + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + item_response = CheckpointItemResponse.model_validate(json_response) + return item_response.to_item() + except ResourceNotFoundError: + return None + + @distributed_trace_async + async def delete(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + request = self._build_delete_request(item_id) + try: + response = await self._send_request(request) + async with response: + pass # No response body expected + return True + except ResourceNotFoundError: + return False + + @distributed_trace_async + async def list_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + request = self._build_list_ids_request(session_id, parent_id) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + list_response = ListCheckpointItemIdsResponse.model_validate(json_response) + return [item.to_item_id() for item in list_response.value] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py new file mode 100644 index 000000000000..c2bff9fba899 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py @@ -0,0 +1,132 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Operations for checkpoint sessions.""" + +from typing import Any, ClassVar, Dict, Optional + +from azure.core.exceptions import ResourceNotFoundError +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ....tools.client.operations._base import BaseOperations +from .._models import ( + CheckpointSession, + CheckpointSessionRequest, + CheckpointSessionResponse, +) + + +class CheckpointSessionOperations(BaseOperations): + """Operations for managing checkpoint sessions.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = {"api-version": _API_VERSION} + + def _session_path(self, session_id: Optional[str] = None) -> str: + """Get the API path for session operations. + + :param session_id: Optional session identifier. + :type session_id: Optional[str] + :return: The API path. + :rtype: str + """ + base = "/checkpoints/sessions" + return f"{base}/{session_id}" if session_id else base + + def _build_upsert_request(self, session: CheckpointSession) -> HttpRequest: + """Build the HTTP request for upserting a session. + + :param session: The checkpoint session. + :type session: CheckpointSession + :return: The HTTP request. + :rtype: HttpRequest + """ + request_model = CheckpointSessionRequest.from_session(session) + return self._client.put( + self._session_path(session.session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=request_model.model_dump(by_alias=True), + ) + + def _build_read_request(self, session_id: str) -> HttpRequest: + """Build the HTTP request for reading a session. + + :param session_id: The session identifier. + :type session_id: str + :return: The HTTP request. + :rtype: HttpRequest + """ + return self._client.get( + self._session_path(session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + ) + + def _build_delete_request(self, session_id: str) -> HttpRequest: + """Build the HTTP request for deleting a session. + + :param session_id: The session identifier. + :type session_id: str + :return: The HTTP request. + :rtype: HttpRequest + """ + return self._client.delete( + self._session_path(session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + ) + + @distributed_trace_async + async def upsert(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + request = self._build_upsert_request(session) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + session_response = CheckpointSessionResponse.model_validate(json_response) + return session_response.to_session() + + @distributed_trace_async + async def read(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + request = self._build_read_request(session_id) + try: + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + session_response = CheckpointSessionResponse.model_validate(json_response) + return session_response.to_session() + except ResourceNotFoundError: + return None + + @distributed_trace_async + async def delete(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + request = self._build_delete_request(session_id) + response = await self._send_request(request) + async with response: + pass # No response body expected diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py index a13f23aa261e..ae6c04235ff1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py @@ -3,12 +3,14 @@ # --------------------------------------------------------- class Constants: # well-known environment variables - APPLICATION_INSIGHTS_CONNECTION_STRING = "_AGENT_RUNTIME_APP_INSIGHTS_CONNECTION_STRING" AZURE_AI_PROJECT_ENDPOINT = "AZURE_AI_PROJECT_ENDPOINT" AGENT_ID = "AGENT_ID" AGENT_NAME = "AGENT_NAME" AGENT_PROJECT_RESOURCE_ID = "AGENT_PROJECT_NAME" OTEL_EXPORTER_ENDPOINT = "OTEL_EXPORTER_ENDPOINT" + OTEL_EXPORTER_OTLP_PROTOCOL = "OTEL_EXPORTER_OTLP_PROTOCOL" AGENT_LOG_LEVEL = "AGENT_LOG_LEVEL" AGENT_DEBUG_ERRORS = "AGENT_DEBUG_ERRORS" - ENABLE_APPLICATION_INSIGHTS_LOGGER = "ENABLE_APPLICATION_INSIGHTS_LOGGER" + ENABLE_APPLICATION_INSIGHTS_LOGGER = "AGENT_APP_INSIGHTS_ENABLED" + AZURE_AI_WORKSPACE_ENDPOINT = "AZURE_AI_WORKSPACE_ENDPOINT" + AZURE_AI_TOOLS_ENDPOINT = "AZURE_AI_TOOLS_ENDPOINT" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py index f062398c0d3b..2b5f39e964b4 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py @@ -1,4 +1,3 @@ -# pylint: disable=broad-exception-caught,dangerous-default-value # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- @@ -6,30 +5,66 @@ import logging import os from logging import config +from typing import Any, Optional from ._version import VERSION from .constants import Constants -default_log_config = { - "version": 1, - "disable_existing_loggers": False, - "loggers": { - "azure.ai.agentserver": { - "handlers": ["console"], - "level": "INFO", - "propagate": False, +def _get_default_log_config() -> dict[str, Any]: + """Build default log config with level from environment. + + :return: A dictionary containing logging configuration. + :rtype: dict[str, Any] + """ + log_level = _get_log_level() + return { + "version": 1, + "disable_existing_loggers": False, + "loggers": { + "azure.ai.agentserver": { + "handlers": ["console"], + "level": log_level, + "propagate": False, + }, + }, + "handlers": { + "console": { + "formatter": "std_out", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + "level": log_level}, }, - }, - "handlers": { - "console": {"formatter": "std_out", "class": "logging.StreamHandler", "level": "INFO"}, - }, - "formatters": {"std_out": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}}, -} + "formatters": {"std_out": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}}, + } + + +def _get_log_level() -> str: + """Read log level from the ``AGENT_LOG_LEVEL`` environment variable. + + Falls back to ``"INFO"`` if the variable is unset or contains an invalid value. + + :return: A valid Python logging level name. + :rtype: str + """ + log_level = os.getenv(Constants.AGENT_LOG_LEVEL, "INFO").upper() + valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + if log_level not in valid_levels: + print(f"Invalid log level '{log_level}' specified. Defaulting to 'INFO'.") + log_level = "INFO" + return log_level + request_context = contextvars.ContextVar("request_context", default=None) +APPINSIGHT_CONNSTR_ENV_NAME = "APPLICATIONINSIGHTS_CONNECTION_STRING" -def get_dimensions(): + +def _get_dimensions() -> dict[str, str]: + """Collect environment-based dimensions for structured logging. + + :return: A mapping of dimension keys to their runtime values. + :rtype: dict[str, str] + """ env_values = {name: value for name, value in vars(Constants).items() if not name.startswith("_")} res = {"azure.ai.agentserver.version": VERSION} for name, env_name in env_values.items(): @@ -40,49 +75,100 @@ def get_dimensions(): return res -def get_project_endpoint(): +def get_project_endpoint(logger: Optional[logging.Logger] = None) -> Optional[str]: + """Resolve the project endpoint from environment variables. + + Checks ``AZURE_AI_PROJECT_ENDPOINT`` first, then falls back to deriving + an endpoint from ``AGENT_PROJECT_NAME``. + + :param logger: Optional logger for diagnostic messages. + :type logger: Optional[logging.Logger] + :return: The resolved project endpoint URL, or ``None`` if unavailable. + :rtype: Optional[str] + """ + project_endpoint = os.environ.get(Constants.AZURE_AI_PROJECT_ENDPOINT) + if project_endpoint: + if logger: + logger.info( + "Using project endpoint from %s: %s", + Constants.AZURE_AI_PROJECT_ENDPOINT, + project_endpoint, + ) + return project_endpoint project_resource_id = os.environ.get(Constants.AGENT_PROJECT_RESOURCE_ID) if project_resource_id: last_part = project_resource_id.split("/")[-1] parts = last_part.split("@") if len(parts) < 2: - print(f"invalid project resource id: {project_resource_id}") + if logger: + logger.warning("Invalid project resource id format: %s", project_resource_id) return None account = parts[0] project = parts[1] - return f"https://{account}.services.ai.azure.com/api/projects/{project}" - print("environment variable AGENT_PROJECT_RESOURCE_ID not set.") + endpoint = f"https://{account}.services.ai.azure.com/api/projects/{project}" + if logger: + logger.info( + "Using project endpoint derived from %s: %s", + Constants.AGENT_PROJECT_RESOURCE_ID, + endpoint, + ) + return endpoint return None -def get_application_insights_connstr(): +def _get_application_insights_connstr(logger: Optional[logging.Logger] = None) -> Optional[str]: + """Retrieve or derive the Application Insights connection string. + + Looks in the ``APPLICATIONINSIGHTS_CONNECTION_STRING`` environment variable first, + then attempts to fetch it from the project endpoint. + + :param logger: Optional logger for diagnostic messages. + :type logger: Optional[logging.Logger] + :return: The connection string, or ``None`` if unavailable. + :rtype: Optional[str] + """ try: - conn_str = os.environ.get(Constants.APPLICATION_INSIGHTS_CONNECTION_STRING) + conn_str = os.environ.get(APPINSIGHT_CONNSTR_ENV_NAME) if not conn_str: - print("environment variable APPLICATION_INSIGHTS_CONNECTION_STRING not set.") - project_endpoint = get_project_endpoint() + project_endpoint = get_project_endpoint(logger=logger) if project_endpoint: # try to get the project connected application insights from azure.ai.projects import AIProjectClient from azure.identity import DefaultAzureCredential - project_client = AIProjectClient(credential=DefaultAzureCredential(), endpoint=project_endpoint) conn_str = project_client.telemetry.get_application_insights_connection_string() - if not conn_str: - print(f"no connected application insights found for project:{project_endpoint}") - else: - os.environ[Constants.APPLICATION_INSIGHTS_CONNECTION_STRING] = conn_str + if not conn_str and logger: + logger.info( + "No Application Insights connection found for project: %s", + project_endpoint, + ) + elif conn_str: + os.environ[APPINSIGHT_CONNSTR_ENV_NAME] = conn_str + elif logger: + logger.info("Application Insights not configured, telemetry export disabled.") return conn_str - except Exception as e: - print(f"failed to get application insights with error: {e}") + except Exception as e: # pylint: disable=broad-exception-caught # bootstrap: many failure modes possible + if logger: + logger.warning( + "Failed to get Application Insights connection string, telemetry export disabled: %s", + e, + ) return None class CustomDimensionsFilter(logging.Filter): - def filter(self, record): - # Add custom dimensions to every log record - dimensions = get_dimensions() + """Logging filter that attaches environment dimensions and request context to log records.""" + + def filter(self, record: logging.LogRecord) -> bool: + """Inject custom dimensions into *record* and allow it through. + + :param record: The log record to enrich. + :type record: logging.LogRecord + :return: Always ``True`` so the record is never discarded. + :rtype: bool + """ + dimensions = _get_dimensions() for key, value in dimensions.items(): setattr(record, key, value) cur_request_context = request_context.get() @@ -92,18 +178,21 @@ def filter(self, record): return True -def configure(log_config: dict = default_log_config): +def configure(log_config: Optional[dict[str, Any]] = None): """ Configure logging based on the provided configuration dictionary. The dictionary should contain the logging configuration in a format compatible with `logging.config.dictConfig`. - :param log_config: A dictionary containing logging configuration. - :type log_config: dict + :param log_config: A dictionary containing logging configuration. If None, uses default config with AGENT_LOG_LEVEL. + :type log_config: Optional[dict[str, Any]] """ try: + if log_config is None: + log_config = _get_default_log_config() config.dictConfig(log_config) + app_logger = logging.getLogger("azure.ai.agentserver") - application_insights_connection_string = get_application_insights_connstr() + application_insights_connection_string = _get_application_insights_connstr(logger=app_logger) enable_application_insights_logger = ( os.environ.get(Constants.ENABLE_APPLICATION_INSIGHTS_LOGGER, "true").lower() == "true" ) @@ -132,23 +221,13 @@ def configure(log_config: dict = default_log_config): handler.addFilter(custom_filter) # Only add to azure.ai.agentserver namespace to avoid infrastructure logs - app_logger = logging.getLogger("azure.ai.agentserver") - app_logger.setLevel(get_log_level()) + app_logger.setLevel(_get_log_level()) app_logger.addHandler(handler) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught print(f"Failed to configure logging: {e}") -def get_log_level(): - log_level = os.getenv(Constants.AGENT_LOG_LEVEL, "INFO").upper() - valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - if log_level not in valid_levels: - print(f"Invalid log level '{log_level}' specified. Defaulting to 'INFO'.") - log_level = "INFO" - return log_level - - def get_logger() -> logging.Logger: """ If the logger is not already configured, it will be initialized with default settings. diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/__init__.py index d5622ebe7732..b6a1895a3868 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/__init__.py @@ -1,7 +1,8 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +# TypedDict module; __all__ cannot be statically typed because the list is built at runtime. from ._create_response import CreateResponse # type: ignore -from .projects import Response, ResponseStreamEvent +from ._projects import Response, ResponseStreamEvent __all__ = ["CreateResponse", "Response", "ResponseStreamEvent"] # type: ignore[var-annotated] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py index a38f55408c7f..5ec72115734a 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py @@ -1,12 +1,14 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=no-name-in-module +# pylint: disable=no-name-in-module # openai re-exports are dynamically generated from typing import Optional -from .openai import response_create_params # type: ignore -from . import projects as _azure_ai_projects_models +# ResponseCreateParamsBase is a TypedDict β€” mypy cannot verify total=False on mixed bases. +from ._openai import response_create_params # type: ignore +from . import _projects as _azure_ai_projects_models class CreateResponse(response_create_params.ResponseCreateParamsBase, total=False): # type: ignore agent: Optional[_azure_ai_projects_models.AgentReference] stream: Optional[bool] + tools: Optional[list[_azure_ai_projects_models.Tool]] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/openai/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_openai/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/openai/__init__.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_openai/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/__init__.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_enums.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_enums.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_enums.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_enums.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_models.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_models.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_models.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_patch.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_patch.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_patch.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_patch.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_patch_evaluations.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_patch_evaluations.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_patch_evaluations.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_patch_evaluations.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/__init__.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/__init__.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/__init__.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/model_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/model_base.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/model_base.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/model_base.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/serialization.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/serialization.py similarity index 100% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/_utils/serialization.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_projects/_utils/serialization.py diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py new file mode 100644 index 000000000000..e1ce45188c34 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py @@ -0,0 +1,732 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=C4763 # azure-sdk: async-client-bad-name (false positive on module) +import contextlib +import inspect +import json +import os +import time +from abc import abstractmethod +from typing import Any, AsyncGenerator, Generator, Optional, Union + +import uvicorn +from openai import AsyncOpenAI +from opentelemetry import context as otel_context, trace +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from starlette.applications import Starlette +from starlette.concurrency import iterate_in_threadpool +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response, StreamingResponse +from starlette.routing import Route +from starlette.types import ASGIApp + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.identity.aio import ( + DefaultAzureCredential as AsyncDefaultTokenCredential, + get_bearer_token_provider, +) + +from ._context import AgentServerContext +from ._response_metadata import ( + attach_foundry_metadata_to_response, + build_foundry_agents_metadata_headers, + try_attach_foundry_metadata_to_event, +) +from .common._agent_run_context import AgentRunContext +from ..constants import Constants +from ..logger import APPINSIGHT_CONNSTR_ENV_NAME, get_logger, get_project_endpoint, request_context +from ..models import ( + Response as OpenAIResponse, + ResponseStreamEvent, + _projects as project_models +) +from ..tools import UserInfoContextMiddleware, create_tool_runtime +from ..utils._credential import AsyncTokenCredentialAdapter + +logger = get_logger() +DEBUG_ERRORS = os.environ.get(Constants.AGENT_DEBUG_ERRORS, "false").lower() == "true" +KEEP_ALIVE_INTERVAL = 15.0 # seconds + +class AgentRunContextMiddleware(BaseHTTPMiddleware): + def __init__(self, app: ASGIApp, agent: Optional['FoundryCBAgent'] = None): + super().__init__(app) + self.agent = agent + + async def dispatch(self, request: Request, call_next): # type: ignore[override] + if request.url.path in ("/runs", "/responses"): + try: + self.set_request_id_to_context_var(request) + payload = await request.json() + except Exception as e: # pylint: disable=broad-exception-caught # middleware catch-all for bad payload + logger.error("Invalid JSON payload: %s", e) + return JSONResponse({"error": f"Invalid JSON payload: {e}"}, status_code=400) + try: + request.state.agent_run_context = AgentRunContext(payload) + self.set_run_context_to_context_var(request.state.agent_run_context) + except Exception as e: # pylint: disable=broad-exception-caught # middleware catch-all for context build + logger.error("Context build failed: %s.", e, exc_info=True) + return JSONResponse({"error": f"Context build failed: {e}"}, status_code=500) + return await call_next(request) + + def set_request_id_to_context_var(self, request): + request_id = request.headers.get("X-Request-Id", None) + if request_id: + ctx = request_context.get() or {} + ctx["azure.ai.agentserver.x-request-id"] = request_id + request_context.set(ctx) + + def set_run_context_to_context_var(self, run_context): + agent_id, agent_name = "", "" + agent_obj = run_context.get_agent_id_object() + if agent_obj: + agent_name = getattr(agent_obj, "name", "") + agent_version = getattr(agent_obj, "version", "") + agent_id = f"{agent_name}:{agent_version}" + + res = { + "azure.ai.agentserver.response_id": run_context.response_id or "", + "azure.ai.agentserver.conversation_id": run_context.conversation_id or "", + "azure.ai.agentserver.streaming": str(run_context.stream or False), + "gen_ai.agent.id": agent_id, + "gen_ai.agent.name": agent_name, + "gen_ai.provider.name": "AzureAI Hosted Agents", + "gen_ai.response.id": run_context.response_id or "", + } + ctx = request_context.get() or {} + ctx.update(res) + request_context.set(ctx) + + +class FoundryCBAgent: + def __init__( # pylint: disable=too-many-statements # Starlette app setup requires sequential route/middleware wiring + self, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + project_endpoint: Optional[str] = None) -> None: + self.credentials = AsyncTokenCredentialAdapter(credentials) if credentials else AsyncDefaultTokenCredential() + self._project_endpoint = get_project_endpoint(logger=logger) or project_endpoint + AgentServerContext(create_tool_runtime(self._project_endpoint, self.credentials)) + self._port: Optional[int] = None + + async def runs_endpoint(request): + # Set up tracing context and span + context = request.state.agent_run_context + ctx = request_context.get() + with self.tracer.start_as_current_span( + name=f"HostedAgents-{context.response_id}", + attributes=ctx, + kind=trace.SpanKind.SERVER, + ): + try: + logger.info("Start processing CreateResponse request.") + + # Save input to conversation if store=True + if self._should_store(context): + logger.debug("Storing input to conversation.") + await self._save_input_to_conversation(context) + + context_carrier = {} + TraceContextTextMapPropagator().inject(context_carrier) + + ex = None + resp = await self.agent_run(context) + except Exception as e: # pylint: disable=broad-exception-caught # top-level agent_run catch-all + # TODO: extract status code from exception + logger.error("Error processing CreateResponse request: %s", e, exc_info=True) + ex = e + + if not context.stream: + logger.info("End of processing CreateResponse request.") + result = resp if not ex else project_models.ResponseError( + code=project_models.ResponseErrorCode.SERVER_ERROR, + message=_format_error(ex)) + if not ex: + attach_foundry_metadata_to_response(result) + # Save output to conversation if store=True + if self._should_store(context): + logger.debug("Storing output to conversation.") + await self._save_output_to_conversation(context, result) + return JSONResponse(result.as_dict(), headers=self.create_response_headers()) + + async def gen_async(ex): + ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) + prev_ctx = otel_context.get_current() + otel_context.attach(ctx) + seq = 0 + output_events = [] # Collect events for saving to conversation + try: + if ex: + return + it = iterate_in_threadpool(resp) if inspect.isgenerator(resp) else resp + # Wrap iterator with keep-alive mechanism + async for event in _iter_with_keep_alive(it): + if event is None: + # Keep-alive signal + yield _keep_alive_comment() + else: + try_attach_foundry_metadata_to_event(event) + seq += 1 + output_events.append(event) + yield _event_to_sse_chunk(event) + logger.info("End of processing CreateResponse request.") + # Save output to conversation if store=True + if self._should_store(context): + logger.debug("Storing output to conversation.") + await self._save_output_events_to_conversation(context, output_events) + except Exception as e: # noqa: BLE001 # pylint: disable=broad-exception-caught + logger.error("Error in async generator: %s", e, exc_info=True) + ex = e + finally: + if ex: + err = project_models.ResponseErrorEvent( + sequence_number=seq + 1, + code=project_models.ResponseErrorCode.SERVER_ERROR, + message=_format_error(ex), + param="") + yield _event_to_sse_chunk(err) + otel_context.attach(prev_ctx) + + return StreamingResponse( + gen_async(ex), + media_type="text/event-stream", + headers=self.create_response_headers(), + ) + + async def liveness_endpoint(request): + result = await self.agent_liveness(request) + return _to_response(result) + + async def readiness_endpoint(request): + result = await self.agent_readiness(request) + return _to_response(result) + + routes = [ + Route("/runs", runs_endpoint, methods=["POST"], name="agent_run"), + Route("/responses", runs_endpoint, methods=["POST"], name="agent_response"), + Route("/liveness", liveness_endpoint, methods=["GET"], name="agent_liveness"), + Route("/readiness", readiness_endpoint, methods=["GET"], name="agent_readiness"), + ] + + @contextlib.asynccontextmanager + async def _lifespan(app): # pylint: disable=unused-argument + import logging + + # Log server started successfully + port = getattr(self, '_port', 'unknown') + logger.info("FoundryCBAgent server started successfully on port %s", port) + + # Attach App Insights handler to uvicorn loggers + for handler in logger.handlers: + if handler.name == "appinsights_handler": + for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: + uv_logger = logging.getLogger(logger_name) + uv_logger.addHandler(handler) + uv_logger.setLevel(logger.level) + uv_logger.propagate = False + + yield + + self.app = Starlette(routes=routes, lifespan=_lifespan) + UserInfoContextMiddleware.install(self.app) + self.app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + self.app.add_middleware(AgentRunContextMiddleware, agent=self) # type: ignore[arg-type] + + self.tracer: trace.Tracer = trace.get_tracer(__name__) + + def _should_store(self, context: AgentRunContext) -> bool: + """Determine whether conversation artifacts should be persisted. + + :param context: Agent run context that contains the incoming request payload. + :type context: AgentRunContext + :return: ``True`` when storage is requested and the conversation is scoped to a project. + :rtype: bool + """ + return bool(context.request.get("store", False) and context.conversation_id and self._project_endpoint) + + def _items_are_equal(self, item1: dict, item2: dict) -> bool: + """Compare two conversation items for equality based on type and content. + + :param item1: First conversation item. + :type item1: dict + :param item2: Second conversation item. + :type item2: dict + :return: ``True`` when both the metadata and content match. + :rtype: bool + """ + if item1.get("type") != item2.get("type"): + return False + if item1.get("role") != item2.get("role"): + return False + # Compare content - handle both string and structured content + content1 = item1.get("content") + content2 = item2.get("content") + if isinstance(content1, str) and isinstance(content2, str): + return content1 == content2 + if isinstance(content1, list) and isinstance(content2, list): + # For structured content, compare text parts + text1 = "".join(p.get("text", "") for p in content1 if isinstance(p, dict)) + text2 = "".join(p.get("text", "") for p in content2 if isinstance(p, dict)) + return text1 == text2 + return content1 == content2 + + async def _create_openai_client(self) -> AsyncOpenAI: + """Create an AsyncOpenAI client for conversation operations. + + :return: Configured AsyncOpenAI client scoped to the Foundry project endpoint. + :rtype: AsyncOpenAI + """ + from openai import AsyncOpenAI + + token_provider = get_bearer_token_provider( + self.credentials, "https://ai.azure.com/.default" + ) + token = await token_provider() + return AsyncOpenAI( + base_url=f"{self._project_endpoint}/openai", + api_key=token, + default_query={"api-version": "2025-11-15-preview"}, + ) + + async def _save_input_to_conversation(self, context: AgentRunContext) -> None: + """Persist request input items when storage is enabled on the request. + + :param context: Agent run context containing the request payload and conversation metadata. + :type context: AgentRunContext + :return: None + :rtype: None + """ + try: + conversation_id = context.conversation_id + input_items = context.request.get("input", []) + if not input_items or not conversation_id: + return + + # Handle string input as a single item + if isinstance(input_items, str): + input_items = [input_items] + + # Convert input items to the format expected by the API + items_to_save = [] + for item in input_items: + if isinstance(item, str): + items_to_save.append({"type": "message", "role": "user", "content": item}) + elif isinstance(item, dict): + items_to_save.append(item) + elif hasattr(item, 'as_dict'): + items_to_save.append(item.as_dict()) + else: + items_to_save.append({"type": "message", "role": "user", "content": str(item)}) + + if not items_to_save: + return + + openai_client = await self._create_openai_client() + + # Check for duplicates by comparing the last N historical items with current N items + try: + historical_items = [] + async for item in openai_client.conversations.items.list(conversation_id): + historical_items.append(item) + # API returns items in reverse order (newest first), so reverse to get chronological order + historical_items.reverse() + + n = len(items_to_save) + if len(historical_items) >= n: + # Get last N historical items (in chronological order) + last_n_historical = historical_items[-n:] + # Compare as a whole - all N items must match in order + all_match = True + for i in range(n): + hist_dict = last_n_historical[i].model_dump() \ + if hasattr(last_n_historical[i], 'model_dump') \ + else dict(last_n_historical[i]) + if not self._items_are_equal(hist_dict, items_to_save[i]): + all_match = False + break + if all_match: + logger.debug( + "All %d input items already exist in conversation %s, skipping save", + n, + conversation_id, + ) + return + except Exception as e: # pylint: disable=broad-exception-caught # best-effort duplicate check + logger.debug("Could not check for duplicates: %s", e) + + await openai_client.conversations.items.create( + conversation_id=conversation_id, + items=items_to_save, + ) + logger.debug("Saved %d input items to conversation %s", len(items_to_save), conversation_id) + except Exception as e: # pylint: disable=broad-exception-caught # best-effort conversation persistence + logger.warning("Failed to save input items to conversation: %s", e, exc_info=True) + + async def _save_output_to_conversation( + self, context: AgentRunContext, response: project_models.Response) -> None: + """ + Save output items from a non-streaming response to the conversation. + + :param context: The agent run context containing conversation information. + :type context: AgentRunContext + :param response: The response object containing output items to save. + :type response: project_models.Response + :return: None + :rtype: None + """ + try: + conversation_id = context.conversation_id + output_items = response.get("output", []) + if not output_items: + return + + # Convert output items to the format expected by the API + items_to_save = [] + for item in output_items: + if isinstance(item, dict): + items_to_save.append(item) + elif hasattr(item, 'as_dict'): + items_to_save.append(item.as_dict()) + else: + items_to_save.append(dict(item)) + + openai_client = await self._create_openai_client() + await openai_client.conversations.items.create( + conversation_id=conversation_id, + items=items_to_save, + ) + logger.debug("Saved %d output items to conversation %s", len(items_to_save), conversation_id) + except Exception as e: # pylint: disable=broad-exception-caught # best-effort conversation persistence + logger.warning("Failed to save output items to conversation: %s", e, exc_info=True) + + async def _save_output_events_to_conversation(self, context: AgentRunContext, events: list) -> None: + """Persist streaming output events for later retrieval. + + :param context: Agent run context containing conversation identifiers. + :type context: AgentRunContext + :param events: Response stream events captured during execution. + :type events: list + :return: None + :rtype: None + """ + try: + conversation_id = context.conversation_id + # Extract completed items from ResponseOutputItemDoneEvent + items_to_save = [] + for event in events: + if hasattr(event, 'type') and event.type == 'response.output_item.done': + item = getattr(event, 'item', None) + if item: + if isinstance(item, dict): + items_to_save.append(item) + elif hasattr(item, 'as_dict'): + items_to_save.append(item.as_dict()) + else: + items_to_save.append(dict(item)) + + if not items_to_save: + return + + openai_client = await self._create_openai_client() + await openai_client.conversations.items.create( + conversation_id=conversation_id, + items=items_to_save, + ) + logger.debug("Saved %d output items to conversation %s", len(items_to_save), conversation_id) + except Exception as e: # pylint: disable=broad-exception-caught # best-effort conversation persistence + logger.warning("Failed to save output items to conversation: %s", e, exc_info=True) + + @abstractmethod + async def agent_run( + self, context: AgentRunContext + ) -> Union[OpenAIResponse, Generator[ResponseStreamEvent, Any, Any], AsyncGenerator[ResponseStreamEvent, Any]]: + raise NotImplementedError + + async def respond_with_oauth_consent(self, context, error) -> project_models.Response: + """Generate a response indicating that OAuth consent is required. + + :param context: The agent run context. + :type context: AgentRunContext + :param error: The OAuthConsentRequiredError instance. + :type error: OAuthConsentRequiredError + :return: A Response indicating the need for OAuth consent. + :rtype: project_models.Response + """ + output = [ + project_models.OAuthConsentRequestItemResource( + id=context.id_generator.generate_oauthreq_id(), + consent_link=error.consent_url, + server_label="server_label" + ) + ] + agent_id = context.get_agent_id_object() + conversation = context.get_conversation_object() + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "created_at": int(time.time()), + "output": output, + }) + return response + + async def respond_with_oauth_consent_astream(self, context, error) -> AsyncGenerator[ResponseStreamEvent, None]: + """Generate a response stream indicating that OAuth consent is required. + + :param context: The agent run context. + :type context: AgentRunContext + :param error: The OAuthConsentRequiredError instance. + :type error: OAuthConsentRequiredError + :return: An async generator yielding ResponseStreamEvent instances. + :rtype: AsyncGenerator[ResponseStreamEvent, None] + """ + sequence_number = 0 + agent_id = context.get_agent_id_object() + conversation = context.get_conversation_object() + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "status": "in_progress", + "created_at": int(time.time()), + "output": [] + }) + yield project_models.ResponseCreatedEvent(sequence_number=sequence_number, response=response) + sequence_number += 1 + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "status": "in_progress", + "created_at": int(time.time()), + "output": [] + }) + yield project_models.ResponseInProgressEvent(sequence_number=sequence_number, response=response) + + sequence_number += 1 + output_index = 0 + oauth_id = context.id_generator.generate_oauthreq_id() + item = project_models.OAuthConsentRequestItemResource({ + "id": oauth_id, + "type": "oauth_consent_request", + "consent_link": error.consent_url, + "server_label": "server_label", + }) + yield project_models.ResponseOutputItemAddedEvent(sequence_number=sequence_number, + output_index=output_index, item=item) + sequence_number += 1 + yield project_models.ResponseStreamEvent({ + "sequence_number": sequence_number, + "output_index": output_index, + "id": oauth_id, + "type": "response.oauth_consent_requested", + "consent_link": error.consent_url, + "server_label": "server_label", + }) + + sequence_number += 1 + yield project_models.ResponseOutputItemDoneEvent(sequence_number=sequence_number, + output_index=output_index, item=item) + sequence_number += 1 + output = [ + project_models.OAuthConsentRequestItemResource( + id= oauth_id, + consent_link=error.consent_url, + server_label="server_label" + ) + ] + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "created_at": int(time.time()), + "status": "completed", + "output": output, + }) + yield project_models.ResponseCompletedEvent(sequence_number=sequence_number, response=response) + + async def agent_liveness(self, request) -> Union[Response, dict]: # pylint: disable=unused-argument + return Response(status_code=200) + + async def agent_readiness(self, request) -> Union[Response, dict]: # pylint: disable=unused-argument + return {"status": "ready"} + + async def run_async( + self, + port: int = int(os.environ.get("DEFAULT_AD_PORT", 8088)), + ) -> None: + """ + Awaitable server starter for use **inside** an existing event loop. + + :param port: Port to listen on. + :type port: int + """ + self.init_tracing() + config = uvicorn.Config(self.app, host="0.0.0.0", port=port, loop="asyncio") + server = uvicorn.Server(config) + self._port = port + logger.info("Starting FoundryCBAgent server async on port %s", port) + await server.serve() + + def run(self, port: int = int(os.environ.get("DEFAULT_AD_PORT", 8088))) -> None: + """ + Start a Starlette server on localhost: exposing: + POST /runs + POST /responses + GET /liveness + GET /readiness + + :param port: Port to listen on. + :type port: int + """ + self.init_tracing() + self._port = port + logger.info("Starting FoundryCBAgent server on port %s", port) + uvicorn.run(self.app, host="0.0.0.0", port=port) + + def init_tracing(self): + exporter = os.environ.get(Constants.OTEL_EXPORTER_ENDPOINT) + app_insights_conn_str = os.environ.get(APPINSIGHT_CONNSTR_ENV_NAME) + if exporter or app_insights_conn_str: + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + + resource = Resource.create(self.get_trace_attributes()) + provider = TracerProvider(resource=resource) + if exporter: + self.setup_otlp_exporter(exporter, provider) + if app_insights_conn_str: + self.setup_application_insights_exporter(app_insights_conn_str, provider) + trace.set_tracer_provider(provider) + self.init_tracing_internal(exporter_endpoint=exporter, app_insights_conn_str=app_insights_conn_str) + self.tracer = trace.get_tracer(__name__) + + def get_trace_attributes(self): + return { + "service.name": "azure.ai.agentserver", + } + + def init_tracing_internal( # pylint: disable=unused-argument # base class hook, params used by subclasses + self, exporter_endpoint=None, app_insights_conn_str=None + ): + pass + + def setup_application_insights_exporter(self, connection_string, provider): + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + from azure.monitor.opentelemetry.exporter import AzureMonitorTraceExporter + + exporter_instance = AzureMonitorTraceExporter.from_connection_string(connection_string) + processor = BatchSpanProcessor(exporter_instance) + provider.add_span_processor(processor) + logger.info("Tracing setup with Application Insights exporter.") + + def setup_otlp_exporter(self, endpoint, provider): + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.trace.export import BatchSpanProcessor + + exporter_instance = OTLPSpanExporter(endpoint=endpoint) + processor = BatchSpanProcessor(exporter_instance) + provider.add_span_processor(processor) + logger.info("Tracing setup with OTLP exporter: %s", endpoint) + + def create_response_headers(self) -> dict[str, str]: + headers = {} + headers.update(build_foundry_agents_metadata_headers()) + return headers + + +def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: + event_data = json.dumps(event.as_dict()) + if event.type: + return f"event: {event.type}\ndata: {event_data}\n\n" + return f"data: {event_data}\n\n" + + +def _keep_alive_comment() -> str: + """Generate a keep-alive SSE comment to maintain connection. + + :return: The keep-alive comment string. + :rtype: str + """ + return ": keep-alive\n\n" + + +async def _iter_with_keep_alive( + it: AsyncGenerator[ResponseStreamEvent, None] +) -> AsyncGenerator[Optional[ResponseStreamEvent], None]: + """Wrap an async iterator with keep-alive mechanism. + + If no event is received within KEEP_ALIVE_INTERVAL seconds, + yields None as a signal to send a keep-alive comment. + The original iterator is protected with asyncio.shield to ensure + it continues running even when timeout occurs. + + :param it: The async generator to wrap. + :type it: AsyncGenerator[ResponseStreamEvent, None] + :return: An async generator that yields events or None for keep-alive. + :rtype: AsyncGenerator[Optional[ResponseStreamEvent], None] + """ + it_anext = it.__anext__ + pending_task: Optional[asyncio.Task] = None + + while True: + try: + # If there's a pending task from previous timeout, wait for it first + if pending_task is not None: + event = await pending_task + pending_task = None + yield event + continue + + # Create a task for the next event + next_event_task = asyncio.create_task(it_anext()) + + try: + # Shield the task and wait with timeout + event = await asyncio.wait_for( + asyncio.shield(next_event_task), + timeout=KEEP_ALIVE_INTERVAL + ) + yield event + except asyncio.TimeoutError: + # Timeout occurred, but task continues due to shield + # Save task to check in next iteration + pending_task = next_event_task + yield None + + except StopAsyncIteration: + # Iterator exhausted + break + + +def _format_error(exc: Exception) -> str: + message = str(exc) + if message: + return message + if DEBUG_ERRORS: + return repr(exc) + return f"{type(exc)}: Internal error" + + +def _to_response(result: Union[Response, dict]) -> Response: + return result if isinstance(result, Response) else JSONResponse(result) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py new file mode 100644 index 000000000000..f86d1ae0d4ac --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import AsyncContextManager, ClassVar, Optional + +from azure.ai.agentserver.core.tools import FoundryToolRuntime + + +class AgentServerContext(AsyncContextManager["AgentServerContext"]): + _INSTANCE: ClassVar[Optional["AgentServerContext"]] = None + + def __init__(self, tool_runtime: FoundryToolRuntime): + self._tool_runtime = tool_runtime + + self.__class__._INSTANCE = self + + @classmethod + def get(cls) -> "AgentServerContext": + if cls._INSTANCE is None: + raise ValueError("AgentServerContext has not been initialized.") + return cls._INSTANCE + + @property + def tools(self) -> FoundryToolRuntime: + return self._tool_runtime + + async def __aenter__(self) -> "AgentServerContext": + await self._tool_runtime.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self._tool_runtime.__aexit__(exc_type, exc_value, traceback) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py new file mode 100644 index 000000000000..9b13cfedd636 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py @@ -0,0 +1,61 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import json +from dataclasses import asdict +from typing import Dict + +from ..application._metadata import get_current_app +from ..models import Response as OpenAIResponse, ResponseStreamEvent +from ..models._projects import ( + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseInProgressEvent, +) + +HEADER_NAME = "x-aml-foundry-agents-metadata" +METADATA_KEY = "foundry_agents_metadata" + + +def _metadata_json() -> str: + payload = asdict(get_current_app()) + return json.dumps(payload) + + +def build_foundry_agents_metadata_headers() -> Dict[str, str]: + """ + Return header dict containing the foundry metadata header. + + :return: A dictionary with the foundry metadata header. + :rtype: Dict[str, str] + """ + return {HEADER_NAME: _metadata_json()} + + +def attach_foundry_metadata_to_response(response: OpenAIResponse) -> None: + """ + Attach metadata into response.metadata[METADATA_KEY]. + + :param response: The OpenAIResponse object to attach metadata to. + :type response: OpenAIResponse + :return: None + :rtype: None + """ + meta = response.metadata or {} + meta[METADATA_KEY] = _metadata_json() + response.metadata = meta + + +def try_attach_foundry_metadata_to_event(event: ResponseStreamEvent) -> None: + """ + Attach metadata to supported stream events; skip others. + + :param event: The ResponseStreamEvent object to attach metadata to. + :type event: ResponseStreamEvent + :return: None + :rtype: None + """ + if isinstance(event, (ResponseCreatedEvent, ResponseInProgressEvent, ResponseCompletedEvent)): + attach_foundry_metadata_to_response(event.response) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py deleted file mode 100644 index 8915aadb172b..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py +++ /dev/null @@ -1,315 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- -# pylint: disable=broad-exception-caught,unused-argument,logging-fstring-interpolation,too-many-statements,too-many-return-statements -import inspect -import json -import os -import traceback -from abc import abstractmethod -from typing import Any, AsyncGenerator, Generator, Union - -import uvicorn -from opentelemetry import context as otel_context, trace -from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator -from starlette.applications import Starlette -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.responses import JSONResponse, Response, StreamingResponse -from starlette.routing import Route -from starlette.types import ASGIApp - -from ..constants import Constants -from ..logger import get_logger, request_context -from ..models import ( - Response as OpenAIResponse, - ResponseStreamEvent, -) -from .common.agent_run_context import AgentRunContext - -logger = get_logger() -DEBUG_ERRORS = os.environ.get(Constants.AGENT_DEBUG_ERRORS, "false").lower() == "true" - - -class AgentRunContextMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp): - super().__init__(app) - - async def dispatch(self, request: Request, call_next): - if request.url.path in ("/runs", "/responses"): - try: - self.set_request_id_to_context_var(request) - payload = await request.json() - except Exception as e: - logger.error(f"Invalid JSON payload: {e}") - return JSONResponse({"error": f"Invalid JSON payload: {e}"}, status_code=400) - try: - request.state.agent_run_context = AgentRunContext(payload) - self.set_run_context_to_context_var(request.state.agent_run_context) - except Exception as e: - logger.error(f"Context build failed: {e}.", exc_info=True) - return JSONResponse({"error": f"Context build failed: {e}"}, status_code=500) - return await call_next(request) - - def set_request_id_to_context_var(self, request): - request_id = request.headers.get("X-Request-Id", None) - if request_id: - ctx = request_context.get() or {} - ctx["azure.ai.agentserver.x-request-id"] = request_id - request_context.set(ctx) - - def set_run_context_to_context_var(self, run_context): - agent_id, agent_name = "", "" - agent_obj = run_context.get_agent_id_object() - if agent_obj: - agent_name = getattr(agent_obj, "name", "") - agent_version = getattr(agent_obj, "version", "") - agent_id = f"{agent_name}:{agent_version}" - - res = { - "azure.ai.agentserver.response_id": run_context.response_id or "", - "azure.ai.agentserver.conversation_id": run_context.conversation_id or "", - "azure.ai.agentserver.streaming": str(run_context.stream or False), - "gen_ai.agent.id": agent_id, - "gen_ai.agent.name": agent_name, - "gen_ai.provider.name": "AzureAI Hosted Agents", - "gen_ai.response.id": run_context.response_id or "", - } - ctx = request_context.get() or {} - ctx.update(res) - request_context.set(ctx) - - -class FoundryCBAgent: - def __init__(self): - async def runs_endpoint(request): - # Set up tracing context and span - context = request.state.agent_run_context - ctx = request_context.get() - with self.tracer.start_as_current_span( - name=f"HostedAgents-{context.response_id}", - attributes=ctx, - kind=trace.SpanKind.SERVER, - ): - try: - logger.info("Start processing CreateResponse request:") - - context_carrier = {} - TraceContextTextMapPropagator().inject(context_carrier) - - resp = await self.agent_run(context) - - if inspect.isgenerator(resp): - # Prefetch first event to allow 500 status if generation fails immediately - try: - first_event = next(resp) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Generator initialization failed: %s\n%s", e, traceback.format_exc()) - return JSONResponse({"error": err_msg}, status_code=500) - - def gen(): - ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) - token = otel_context.attach(ctx) - error_sent = False - try: - # yield prefetched first event - yield _event_to_sse_chunk(first_event) - for event in resp: - yield _event_to_sse_chunk(event) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Error in non-async generator: %s\n%s", e, traceback.format_exc()) - payload = {"error": err_msg} - yield f"event: error\ndata: {json.dumps(payload)}\n\n" - yield "data: [DONE]\n\n" - error_sent = True - finally: - logger.info("End of processing CreateResponse request:") - otel_context.detach(token) - if not error_sent: - yield "data: [DONE]\n\n" - - return StreamingResponse(gen(), media_type="text/event-stream") - if inspect.isasyncgen(resp): - # Prefetch first async event to allow early 500 - try: - first_event = await resp.__anext__() - except StopAsyncIteration: - # No items produced; treat as empty successful stream - def empty_gen(): - yield "data: [DONE]\n\n" - - return StreamingResponse(empty_gen(), media_type="text/event-stream") - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Async generator initialization failed: %s\n%s", e, traceback.format_exc()) - return JSONResponse({"error": err_msg}, status_code=500) - - async def gen_async(): - ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) - token = otel_context.attach(ctx) - error_sent = False - try: - # yield prefetched first event - yield _event_to_sse_chunk(first_event) - async for event in resp: - yield _event_to_sse_chunk(event) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Error in async generator: %s\n%s", e, traceback.format_exc()) - payload = {"error": err_msg} - yield f"event: error\ndata: {json.dumps(payload)}\n\n" - yield "data: [DONE]\n\n" - error_sent = True - finally: - logger.info("End of processing CreateResponse request.") - otel_context.detach(token) - if not error_sent: - yield "data: [DONE]\n\n" - - return StreamingResponse(gen_async(), media_type="text/event-stream") - logger.info("End of processing CreateResponse request.") - return JSONResponse(resp.as_dict()) - except Exception as e: - # TODO: extract status code from exception - logger.error(f"Error processing CreateResponse request: {traceback.format_exc()}") - return JSONResponse({"error": str(e)}, status_code=500) - - async def liveness_endpoint(request): - result = await self.agent_liveness(request) - return _to_response(result) - - async def readiness_endpoint(request): - result = await self.agent_readiness(request) - return _to_response(result) - - routes = [ - Route("/runs", runs_endpoint, methods=["POST"], name="agent_run"), - Route("/responses", runs_endpoint, methods=["POST"], name="agent_response"), - Route("/liveness", liveness_endpoint, methods=["GET"], name="agent_liveness"), - Route("/readiness", readiness_endpoint, methods=["GET"], name="agent_readiness"), - ] - - self.app = Starlette(routes=routes) - self.app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - self.app.add_middleware(AgentRunContextMiddleware) - - @self.app.on_event("startup") - async def attach_appinsights_logger(): - import logging - - for handler in logger.handlers: - if handler.name == "appinsights_handler": - for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: - uv_logger = logging.getLogger(logger_name) - uv_logger.addHandler(handler) - uv_logger.setLevel(logger.level) - uv_logger.propagate = False - - self.tracer = None - - @abstractmethod - async def agent_run( - self, context: AgentRunContext - ) -> Union[OpenAIResponse, Generator[ResponseStreamEvent, Any, Any], AsyncGenerator[ResponseStreamEvent, Any]]: - raise NotImplementedError - - async def agent_liveness(self, request) -> Union[Response, dict]: - return Response(status_code=200) - - async def agent_readiness(self, request) -> Union[Response, dict]: - return {"status": "ready"} - - async def run_async( - self, - port: int = int(os.environ.get("DEFAULT_AD_PORT", 8088)), - ) -> None: - """ - Awaitable server starter for use **inside** an existing event loop. - - :param port: Port to listen on. - :type port: int - """ - self.init_tracing() - config = uvicorn.Config(self.app, host="0.0.0.0", port=port, loop="asyncio") - server = uvicorn.Server(config) - logger.info(f"Starting FoundryCBAgent server async on port {port}") - await server.serve() - - def run(self, port: int = int(os.environ.get("DEFAULT_AD_PORT", 8088))) -> None: - """ - Start a Starlette server on localhost: exposing: - POST /runs - POST /responses - GET /liveness - GET /readiness - - :param port: Port to listen on. - :type port: int - """ - self.init_tracing() - logger.info(f"Starting FoundryCBAgent server on port {port}") - uvicorn.run(self.app, host="0.0.0.0", port=port) - - def init_tracing(self): - exporter = os.environ.get(Constants.OTEL_EXPORTER_ENDPOINT) - app_insights_conn_str = os.environ.get(Constants.APPLICATION_INSIGHTS_CONNECTION_STRING) - if exporter or app_insights_conn_str: - from opentelemetry.sdk.resources import Resource - from opentelemetry.sdk.trace import TracerProvider - - resource = Resource.create(self.get_trace_attributes()) - provider = TracerProvider(resource=resource) - if exporter: - self.setup_otlp_exporter(exporter, provider) - if app_insights_conn_str: - self.setup_application_insights_exporter(app_insights_conn_str, provider) - trace.set_tracer_provider(provider) - self.init_tracing_internal(exporter_endpoint=exporter, app_insights_conn_str=app_insights_conn_str) - self.tracer = trace.get_tracer(__name__) - - def get_trace_attributes(self): - return { - "service.name": "azure.ai.agentserver", - } - - def init_tracing_internal(self, exporter_endpoint=None, app_insights_conn_str=None): - pass - - def setup_application_insights_exporter(self, connection_string, provider): - from opentelemetry.sdk.trace.export import BatchSpanProcessor - - from azure.monitor.opentelemetry.exporter import AzureMonitorTraceExporter - - exporter_instance = AzureMonitorTraceExporter.from_connection_string(connection_string) - processor = BatchSpanProcessor(exporter_instance) - provider.add_span_processor(processor) - logger.info("Tracing setup with Application Insights exporter.") - - def setup_otlp_exporter(self, endpoint, provider): - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter - from opentelemetry.sdk.trace.export import BatchSpanProcessor - - exporter_instance = OTLPSpanExporter(endpoint=endpoint) - processor = BatchSpanProcessor(exporter_instance) - provider.add_span_processor(processor) - logger.info(f"Tracing setup with OTLP exporter: {endpoint}") - - -def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: - event_data = json.dumps(event.as_dict()) - if event.type: - return f"event: {event.type}\ndata: {event_data}\n\n" - return f"data: {event_data}\n\n" - - -def _to_response(result: Union[Response, dict]) -> Response: - return result if isinstance(result, Response) else JSONResponse(result) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_agent_run_context.py similarity index 71% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_agent_run_context.py index 6fae56f0027d..750e4209d9e5 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_agent_run_context.py @@ -1,17 +1,22 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +from typing import Optional + +from .id_generator._foundry_id_generator import FoundryIdGenerator +from .id_generator._id_generator import IdGenerator from ...logger import get_logger from ...models import CreateResponse -from ...models.projects import AgentId, AgentReference, ResponseConversation1 -from .id_generator.foundry_id_generator import FoundryIdGenerator -from .id_generator.id_generator import IdGenerator +from ...models._projects import AgentId, AgentReference, ResponseConversation1 logger = get_logger() class AgentRunContext: - def __init__(self, payload: dict): + """ + :meta private: + """ + def __init__(self, payload: dict) -> None: self._raw_payload = payload self._request = _deserialize_create_response(payload) self._id_generator = FoundryIdGenerator.from_request(payload) @@ -36,17 +41,17 @@ def response_id(self) -> str: return self._response_id @property - def conversation_id(self) -> str: + def conversation_id(self) -> Optional[str]: return self._conversation_id @property def stream(self) -> bool: return self._stream - def get_agent_id_object(self) -> AgentId: + def get_agent_id_object(self) -> Optional[AgentId]: agent = self.request.get("agent") if not agent: - return None # type: ignore + return None return AgentId( { "type": agent.type, @@ -55,9 +60,9 @@ def get_agent_id_object(self) -> AgentId: } ) - def get_conversation_object(self) -> ResponseConversation1: + def get_conversation_object(self) -> Optional[ResponseConversation1]: if not self._conversation_id: - return None # type: ignore + return None return ResponseConversation1(id=self._conversation_id) @@ -67,10 +72,14 @@ def _deserialize_create_response(payload: dict) -> CreateResponse: raw_agent_reference = payload.get("agent") if raw_agent_reference: _deserialized["agent"] = _deserialize_agent_reference(raw_agent_reference) + + tools = payload.get("tools") + if tools: + _deserialized["tools"] = list(tools) return _deserialized -def _deserialize_agent_reference(payload: dict) -> AgentReference: +def _deserialize_agent_reference(payload: dict) -> Optional[AgentReference]: if not payload: - return None # type: ignore + return None return AgentReference(**payload) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_constants.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_constants.py new file mode 100644 index 000000000000..6d4fb628a7f2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/_constants.py @@ -0,0 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# Reserved function name for HITL. +HUMAN_IN_THE_LOOP_FUNCTION_NAME = "__hosted_agent_adapter_hitl__" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_foundry_id_generator.py similarity index 53% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_foundry_id_generator.py index 910a7c481daa..0c0f91cbb36d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_foundry_id_generator.py @@ -1,4 +1,3 @@ -# pylint: disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- @@ -9,7 +8,7 @@ import re from typing import Optional -from .id_generator import IdGenerator +from ._id_generator import IdGenerator _WATERMARK_RE = re.compile(r"^[A-Za-z0-9]*$") @@ -26,23 +25,48 @@ class FoundryIdGenerator(IdGenerator): """ def __init__(self, response_id: Optional[str], conversation_id: Optional[str]): + """Initialize the ID generator. + + :param response_id: An existing response ID, or ``None`` to generate one. + :type response_id: Optional[str] + :param conversation_id: An existing conversation ID, or ``None``. + :type conversation_id: Optional[str] + """ self.response_id = response_id or self._new_id("resp") - self.conversation_id = conversation_id or self._new_id("conv") - self._partition_id = self._extract_partition_id(self.conversation_id) + self.conversation_id = conversation_id + partition_source = self.conversation_id or self.response_id + try: + self._partition_id = self._extract_partition_id(partition_source) + except ValueError: + self._partition_id = self._secure_entropy(18) @classmethod def from_request(cls, payload: dict) -> "FoundryIdGenerator": + """Create a generator from an incoming request payload. + + :param payload: The raw request payload dictionary. + :type payload: dict + :return: A configured :class:`FoundryIdGenerator` instance. + :rtype: FoundryIdGenerator + """ response_id = payload.get("metadata", {}).get("response_id", None) conv_id_raw = payload.get("conversation", None) if isinstance(conv_id_raw, str): conv_id = conv_id_raw elif isinstance(conv_id_raw, dict): - conv_id = conv_id_raw.get("id", None) + conv_id = conv_id_raw.get("id", None) # type: ignore[assignment] else: conv_id = None return cls(response_id, conv_id) def generate(self, category: Optional[str] = None) -> str: + """Generate a new unique ID for the given category. + + :param category: Optional prefix category (e.g. ``"msg"``, ``"func"``). Defaults to ``"id"``. + :type category: Optional[str] + :return: The generated unique identifier string. + :rtype: str + """ prefix = "id" if not category else category return self._new_id(prefix, partition_key=self._partition_id) @@ -59,12 +83,29 @@ def _new_id( partition_key: Optional[str] = None, partition_key_hint: str = "", ) -> str: - """ - Generates a new ID. - - Format matches the C# logic: - f"{prefix}{delimiter}{infix}{partitionKey}{entropy}" - (i.e., exactly one delimiter after prefix; no delimiter between entropy and partition key) + """Generate a new ID matching the C# FoundryIdGenerator format. + + Format: ``"{prefix}{delimiter}{infix}{partitionKey}{entropy}"`` + + :param prefix: The ID prefix (e.g. ``"resp"``, ``"msg"``). + :type prefix: str + :param string_length: Length of the random entropy portion. + :type string_length: int + :param partition_key_length: Length of the partition key. + :type partition_key_length: int + :param infix: Optional infix inserted between delimiter and partition key. + :type infix: Optional[str] + :param watermark: Optional alphanumeric watermark inserted mid-entropy. + :type watermark: str + :param delimiter: Delimiter between prefix and the rest of the ID. + :type delimiter: str + :param partition_key: Explicit partition key; if ``None``, derived or generated. + :type partition_key: Optional[str] + :param partition_key_hint: ID string to extract a partition key from. + :type partition_key_hint: str + :return: The generated ID string. + :rtype: str + :raises ValueError: If the watermark contains non-alphanumeric characters. """ entropy = FoundryIdGenerator._secure_entropy(string_length) @@ -88,14 +129,20 @@ def _new_id( infix = infix or "" prefix_part = f"{prefix}{delimiter}" if prefix else "" - return f"{prefix_part}{entropy}{infix}{pkey}" + return f"{prefix_part}{infix}{pkey}{entropy}" @staticmethod def _secure_entropy(string_length: int) -> str: - """ - Generates a secure random alphanumeric string of exactly `string_length`. - Re-tries whole generation until the filtered base64 string is exactly the desired length, - matching the C# behavior. + """Generate a cryptographically secure alphanumeric string. + + Uses :func:`os.urandom` and base64 encoding, filtering to alphanumeric + characters and retrying until the exact length is reached. + + :param string_length: Desired length of the output string. + :type string_length: int + :return: A random alphanumeric string of exactly *string_length* characters. + :rtype: str + :raises ValueError: If *string_length* is less than 1. """ if string_length < 1: raise ValueError("Must be greater than or equal to 1") @@ -116,11 +163,22 @@ def _extract_partition_id( partition_key_length: int = 18, delimiter: str = "_", ) -> str: - """ - Extracts partition key from an existing ID. - - Expected shape (per C# logic): "_" - We take the last `partition_key_length` characters from the *second* segment. + """Extract the partition key from an existing ID. + + Expected shape: ``"_"``. + Returns the first *partition_key_length* characters of the second segment. + + :param id_str: The ID string to extract from. + :type id_str: str + :param string_length: Expected entropy length used for validation. + :type string_length: int + :param partition_key_length: Number of characters to extract as partition key. + :type partition_key_length: int + :param delimiter: The delimiter separating ID segments. + :type delimiter: str + :return: The extracted partition key. + :rtype: str + :raises ValueError: If the ID format is invalid. """ if not id_str: raise ValueError("Id cannot be null or empty") @@ -133,4 +191,4 @@ def _extract_partition_id( if len(segment) < string_length + partition_key_length: raise ValueError(f"Id '{id_str}' does not contain a valid id.") - return segment[-partition_key_length:] + return segment[:partition_key_length] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_id_generator.py similarity index 87% rename from sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py rename to sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_id_generator.py index 48f0d9add17d..5b602a7fc686 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/_id_generator.py @@ -17,3 +17,6 @@ def generate_function_output_id(self) -> str: def generate_message_id(self) -> str: return self.generate("msg") + + def generate_oauthreq_id(self) -> str: + return self.generate("oauthreq") diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py new file mode 100644 index 000000000000..fa58e50368bf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +from .client._client import FoundryToolClient +from ._exceptions import ( + ToolInvocationError, + OAuthConsentRequiredError, + UnableToResolveToolInvocationError, + InvalidToolFacadeError, +) +from .client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolDetails, + FoundryToolProtocol, + FoundryToolSource, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) +from .runtime._catalog import ( + FoundryToolCatalog, + CachedFoundryToolCatalog, + DefaultFoundryToolCatalog, +) +from .runtime._facade import FoundryToolFacade, FoundryToolLike, ensure_foundry_tool +from .runtime._invoker import FoundryToolInvoker, DefaultFoundryToolInvoker +from .runtime._resolver import FoundryToolInvocationResolver, DefaultFoundryToolInvocationResolver +from .runtime._runtime import create_tool_runtime, FoundryToolRuntime, DefaultFoundryToolRuntime +from .runtime._starlette import UserInfoContextMiddleware +from .runtime._user import UserProvider, ContextVarUserProvider + +__all__ = [ + # Client + "FoundryToolClient", + # Exceptions + "ToolInvocationError", + "OAuthConsentRequiredError", + "UnableToResolveToolInvocationError", + "InvalidToolFacadeError", + # Models + "FoundryConnectedTool", + "FoundryHostedMcpTool", + "FoundryTool", + "FoundryToolDetails", + "FoundryToolProtocol", + "FoundryToolSource", + "ResolvedFoundryTool", + "SchemaDefinition", + "SchemaProperty", + "SchemaType", + "UserInfo", + # Catalog + "FoundryToolCatalog", + "CachedFoundryToolCatalog", + "DefaultFoundryToolCatalog", + # Facade + "FoundryToolFacade", + "FoundryToolLike", + "ensure_foundry_tool", + # Invoker + "FoundryToolInvoker", + "DefaultFoundryToolInvoker", + # Resolver + "FoundryToolInvocationResolver", + "DefaultFoundryToolInvocationResolver", + # Runtime + "create_tool_runtime", + "FoundryToolRuntime", + "DefaultFoundryToolRuntime", + # Starlette + "UserInfoContextMiddleware", + # User + "UserProvider", + "ContextVarUserProvider", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py new file mode 100644 index 000000000000..a5fe7726e9f1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py @@ -0,0 +1,74 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .client._models import FoundryTool, ResolvedFoundryTool + + +class ToolInvocationError(RuntimeError): + """Raised when a tool invocation fails. + + :ivar ResolvedFoundryTool tool: The tool that failed during invocation. + + :param str message: Human-readable message describing the error. + :param ResolvedFoundryTool tool: The tool that failed during invocation. + + This exception is raised when an error occurs during the invocation of a tool, + providing details about the failure. + """ + + def __init__(self, message: str, tool: ResolvedFoundryTool): + super().__init__(message) + self.tool = tool + + +class OAuthConsentRequiredError(RuntimeError): + """Raised when the service requires end-user OAuth consent. + + This exception is raised when a tool or service operation requires explicit + OAuth consent from the end user before the operation can proceed. + + :ivar str message: Human-readable guidance returned by the service. + :ivar str consent_url: Link that the end user must visit to provide consent. + :ivar str project_connection_id: The project connection ID related to the consent request. + + :param str message: Human-readable guidance returned by the service. + :param str consent_url: Link that the end user must visit to provide the required consent. + :param str project_connection_id: The project connection ID related to the consent request. + """ + + def __init__(self, message: str, consent_url: str, project_connection_id: str): + super().__init__(message) + self.message = message + self.consent_url = consent_url + self.project_connection_id = project_connection_id + + +class UnableToResolveToolInvocationError(RuntimeError): + """Raised when a tool cannot be resolved. + + :ivar str message: Human-readable message describing the error. + :ivar FoundryTool tool: The tool that could not be resolved. + + :param str message: Human-readable message describing the error. + :param FoundryTool tool: The tool that could not be resolved. + + This exception is raised when a tool cannot be found or resolved + from the available tool sources. + """ + + def __init__(self, message: str, tool: FoundryTool): + super().__init__(message) + self.tool = tool + + +class InvalidToolFacadeError(RuntimeError): + """Raised when a tool facade is invalid. + + This exception is raised when a tool facade does not conform + to the expected structure or contains invalid data. + """ diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py new file mode 100644 index 000000000000..0efcf1c6f20b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py @@ -0,0 +1,215 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=C4763 # azure-sdk: async-client-bad-name +import itertools +from collections import defaultdict +from typing import ( + Any, + AsyncContextManager, + Awaitable, Collection, + DefaultDict, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + cast, +) + +from azure.core import AsyncPipelineClient +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._configuration import FoundryToolClientConfiguration +from ._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, + UserInfo, +) +from .operations._foundry_connected_tools import FoundryConnectedToolsOperations +from .operations._foundry_hosted_mcp_tools import FoundryMcpToolsOperations +from .._exceptions import ToolInvocationError + + +class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): # pylint: disable=C4748 # azure-sdk: client-paging-methods-use-list + """Asynchronous client for aggregating tools from Azure AI MCP and Tools APIs. + + This client provides access to tools from both MCP (Model Context Protocol) servers + and Azure AI Tools API endpoints, enabling unified tool discovery and invocation. + + :param endpoint: + The fully qualified endpoint for the Azure AI Agents service. + Example: "https://.api.azureml.ms" + :type endpoint: str + :param credential: + Credential for authenticating requests to the service. + Use credentials from azure-identity like DefaultAzureCredential. + :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None + """ + + def __init__( # pylint: disable=C4718 # azure-sdk: client-method-name-no-double-underscore + self, + endpoint: str, + credential: "AsyncTokenCredential", + ) -> None: + """Initialize the asynchronous Azure AI Tool Client. + + :param endpoint: The service endpoint URL. + :type endpoint: str + :param credential: Credentials for authenticating requests. + :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None + """ + # noinspection PyTypeChecker + config = FoundryToolClientConfiguration(credential) + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=endpoint, config=config) + + self._hosted_mcp_tools = FoundryMcpToolsOperations(self._client) + self._connected_tools = FoundryConnectedToolsOperations(self._client) + + @distributed_trace_async + async def list_tools( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> List[ResolvedFoundryTool]: + """List all available tools from configured sources. + + Retrieves tools from both MCP servers and Azure AI Tools API endpoints, + returning them as ResolvedFoundryTool instances ready for invocation. + + :param tools: Collection of FoundryTool instances to resolve. + :type tools: Collection[~FoundryTool] + :param user: Information about the user requesting the tools. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent requesting the tools. + :type agent_name: str + + :return: List of resolved Foundry tools. + :rtype: List[ResolvedFoundryTool] + :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + + """ + _ = kwargs # Reserved for future use + resolved_tools: List[ResolvedFoundryTool] = [] + results = await self._list_tools_details_internal(tools, agent_name, user) + for definition, details in results: + resolved_tools.append(ResolvedFoundryTool(definition=definition, details=details)) + return resolved_tools + + @distributed_trace_async + async def list_tools_details( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Mapping[str, List[FoundryToolDetails]]: + """List all available tools from configured sources. + + Retrieves tools from both MCP servers and Azure AI Tools API endpoints, + returning them as ResolvedFoundryTool instances ready for invocation. + + :param tools: Collection of FoundryTool instances to resolve. + :type tools: Collection[~FoundryTool] + :param user: Information about the user requesting the tools. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent requesting the tools. + :type agent_name: str + + :return: Mapping of tool IDs to lists of FoundryToolDetails. + :rtype: Mapping[str, List[FoundryToolDetails]] + :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + + """ + _ = kwargs # Reserved for future use + resolved_tools: Dict[str, List[FoundryToolDetails]] = defaultdict(list) + results = await self._list_tools_details_internal(tools, agent_name, user) + for definition, details in results: + resolved_tools[definition.id].append(details) + return resolved_tools + + async def _list_tools_details_internal( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + ) -> Iterable[Tuple[FoundryTool, FoundryToolDetails]]: + tools_by_source: DefaultDict[FoundryToolSource, List[FoundryTool]] = defaultdict(list) + for t in tools: + tools_by_source[t.source].append(t) + + listing_tools: List[Awaitable[Iterable[Tuple[FoundryTool, FoundryToolDetails]]]] = [] + if FoundryToolSource.HOSTED_MCP in tools_by_source: + hosted_mcp_tools = cast(List[FoundryHostedMcpTool], tools_by_source[FoundryToolSource.HOSTED_MCP]) + listing_tools.append(self._hosted_mcp_tools.list_tools(hosted_mcp_tools)) + if FoundryToolSource.CONNECTED in tools_by_source: + connected_tools = cast(List[FoundryConnectedTool], tools_by_source[FoundryToolSource.CONNECTED]) + listing_tools.append(self._connected_tools.list_tools(connected_tools, user, agent_name)) + iters = await asyncio.gather(*listing_tools) + return itertools.chain.from_iterable(iters) + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Any: + """Invoke a tool by instance, name, or descriptor. + + :param tool: Tool to invoke, specified as an AzureAITool instance, + tool name string, or FoundryTool. + :type tool: ResolvedFoundryTool + :param arguments: Arguments to pass to the tool. + :type arguments: Dict[str, Any] + :param user: Information about the user invoking the tool. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent invoking the tool. + :type agent_name: str + :return: The result of invoking the tool. + :rtype: Any + :raises ~OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + :raises ~ToolInvocationError: + Raised when the tool invocation fails or source is not supported. + + """ + _ = kwargs # Reserved for future use + if tool.source is FoundryToolSource.HOSTED_MCP: + return await self._hosted_mcp_tools.invoke_tool(tool, arguments) + if tool.source is FoundryToolSource.CONNECTED: + return await self._connected_tools.invoke_tool(tool, arguments, user, agent_name) + raise ToolInvocationError(f"Unsupported tool source: {tool.source}", tool=tool) + + async def close(self) -> None: + """Close the underlying HTTP pipeline.""" + await self._client.close() + + async def __aenter__(self) -> "FoundryToolClient": + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py new file mode 100644 index 000000000000..e09c80ed83f8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.core.configuration import Configuration +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline import policies + +from ...application._metadata import get_current_app + + +class FoundryToolClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes + """Configuration for Azure AI Tool Client. + + Manages authentication, endpoint configuration, and policy settings for the + Azure AI Tool Client. This class is used internally by the client and should + not typically be instantiated directly. + + :param credential: + Azure TokenCredential for authentication. + :type credential: ~azure.core.credentials.TokenCredential + """ + + def __init__(self, credential: "AsyncTokenCredential"): + super().__init__() + + self.retry_policy = policies.AsyncRetryPolicy() + self.logging_policy = policies.NetworkTraceLoggingPolicy() + self.request_id_policy = policies.RequestIdPolicy() + self.http_logging_policy = policies.HttpLoggingPolicy() + self.user_agent_policy = policies.UserAgentPolicy( + base_user_agent=get_current_app().as_user_agent("FoundryToolClient")) + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + credential, "https://ai.azure.com/.default" + ) + self.redirect_policy = policies.AsyncRedirectPolicy() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py new file mode 100644 index 000000000000..b3a505ae37ae --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py @@ -0,0 +1,615 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Annotated, + Any, + ClassVar, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Set, + Type, + Union, +) + +from pydantic import ( + AliasChoices, + AliasPath, + BaseModel, + Discriminator, + Field, + ModelWrapValidatorHandler, + Tag, + TypeAdapter, + model_validator, +) + +from azure.core import CaseInsensitiveEnumMeta + +from .._exceptions import OAuthConsentRequiredError + + +class FoundryToolSource(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Identifies the origin of a tool. + + Specifies whether a tool comes from an MCP (Model Context Protocol) server + or from the Azure AI Tools API (remote tools). + """ + + HOSTED_MCP = "hosted_mcp" + CONNECTED = "connected" + + +class FoundryToolProtocol(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Identifies the protocol used by a connected tool.""" + + MCP = "mcp" + A2A = "a2a" + + +@dataclass(frozen=True, eq=False) +class FoundryTool(ABC): + """Definition of a foundry tool including its parameters.""" + source: FoundryToolSource = field(init=False) + + @property + @abstractmethod + def id(self) -> str: + """Unique identifier for the tool. + + :rtype: str + """ + raise NotImplementedError + + def __str__(self): + return self.id + + +@dataclass(frozen=True, eq=False) +class FoundryHostedMcpTool(FoundryTool): + """Foundry MCP tool definition. + + :ivar str name: Name of MCP tool. + :ivar Mapping[str, Any] configuration: Tools configuration. + """ + source: Literal[FoundryToolSource.HOSTED_MCP] = field(init=False, default=FoundryToolSource.HOSTED_MCP) + name: str + configuration: Optional[Mapping[str, Any]] = None + + @property + def id(self) -> str: + """Unique identifier for the tool. + + :rtype: str + """ + return f"{self.source}:{self.name}" + + +@dataclass(frozen=True, eq=False) +class FoundryConnectedTool(FoundryTool): + """Foundry connected tool definition. + + :ivar str project_connection_id: connection name of foundry tool. + """ + source: Literal[FoundryToolSource.CONNECTED] = field(init=False, default=FoundryToolSource.CONNECTED) + protocol: str + project_connection_id: str + + @property + def id(self) -> str: + return f"{self.source}:{self.protocol}:{self.project_connection_id}" + + +@dataclass(frozen=True) +class FoundryToolDetails: + """Details about a Foundry tool. + + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar SchemaDefinition input_schema: Input schema for the tool parameters. + :ivar Optional[SchemaDefinition] metadata: Optional metadata schema for the tool. + """ + name: str + description: str + input_schema: "SchemaDefinition" + metadata: Optional["SchemaDefinition"] = None + + +@dataclass(frozen=True) +class ResolvedFoundryTool: + """Resolved Foundry tool with definition and details. + + :ivar ToolDefinition definition: + Optional tool definition object, or None. + :ivar FoundryToolDetails details: + Details about the tool, including name, description, and input schema. + """ + + definition: FoundryTool + details: FoundryToolDetails + + @property + def id(self) -> str: + return f"{self.definition.id}:{self.details.name}" + + @property + def source(self) -> FoundryToolSource: + """Origin of the tool. + + :rtype: FoundryToolSource + """ + return self.definition.source + + @property + def name(self) -> str: + """Name of the tool. + + :rtype: str + """ + return self.details.name + + @property + def description(self) -> str: + """Description of the tool. + + :rtype: str + """ + return self.details.description + + @property + def input_schema(self) -> "SchemaDefinition": + """Input schema of the tool. + + :rtype: SchemaDefinition + """ + return self.details.input_schema + + @property + def metadata(self) -> Optional["SchemaDefinition"]: + """Metadata schema of the tool, if any.""" + return self.details.metadata + + +@dataclass(frozen=True) +class UserInfo: + """Represents user information. + + :ivar str object_id: User's object identifier. + :ivar str tenant_id: Tenant identifier. + """ + + object_id: str + tenant_id: str + + +class SchemaType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """ + Enumeration of possible schema types. + + :ivar py_type: The corresponding Python runtime type for this schema type + (e.g., ``SchemaType.STRING.py_type is str``). + """ + + py_type: Type[Any] + """The corresponding Python runtime type for this schema type.""" + + STRING = ("string", str) + """Schema type for string values (maps to ``str``).""" + + NUMBER = ("number", float) + """Schema type for numeric values with decimals (maps to ``float``).""" + + INTEGER = ("integer", int) + """Schema type for integer values (maps to ``int``).""" + + BOOLEAN = ("boolean", bool) + """Schema type for boolean values (maps to ``bool``).""" + + ARRAY = ("array", list) + """Schema type for array values (maps to ``list``).""" + + OBJECT = ("object", dict) + """Schema type for object/dictionary values (maps to ``dict``).""" + + def __new__(cls, value: str, py_type: Type[Any]): + """ + Create an enum member whose value is the schema type string, while also + attaching the mapped Python type. + + :param value: The serialized schema type string (e.g. ``"string"``). + :type value: str + :param py_type: The mapped Python runtime type (e.g. ``str``). + :type py_type: Type[Any] + :return: The created enum member. + :rtype: SchemaType + """ + obj = str.__new__(cls, value) + obj._value_ = value + obj.py_type = py_type + return obj + + @classmethod + def from_python_type(cls, t: Type[Any]) -> "SchemaType": + """ + Get the matching :class:`SchemaType` for a given Python runtime type. + + :param t: A Python runtime type (e.g. ``str``, ``int``, ``float``). + :type t: Type[Any] + :returns: The corresponding :class:`SchemaType`. + :rtype: SchemaType + :raises ValueError: If ``t`` is not supported by this enumeration. + """ + for member in cls: + if member.py_type is t: + return member + raise ValueError(f"Unsupported python type: {t!r}") + + +class SchemaProperty(BaseModel): + """ + A JSON Schema-like description of a single property (field) or nested schema node. + + This model is intended to be recursively nestable via :attr:`items` (for arrays) + and :attr:`properties` (for objects). + + :ivar type: The schema node type (e.g., ``string``, ``object``, ``array``). + :ivar description: Optional human-readable description of the property. + :ivar items: The item schema for an ``array`` type. Typically set when + :attr:`type` is :data:`~SchemaType.ARRAY`. + :ivar properties: Nested properties for an ``object`` type. Typically set when + :attr:`type` is :data:`~SchemaType.OBJECT`. Keys are property names, values + are their respective schemas. + :ivar default: Optional default value for the property. + :ivar required: For an ``object`` schema node, the set of required property + names within :attr:`properties`. (This mirrors JSON Schema’s ``required`` + keyword; it is *not* β€œthis property is required in a parent object”.) + """ + + type: Optional[SchemaType] = None + """The schema node type (e.g., ``string``, ``object``, ``array``). May be ``None`` + if the upstream tool manifest supplies an empty or unrecognised type string.""" + description: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def _coerce_empty_type(cls, data: Any) -> Any: + """ + Coerce an empty ``type`` string to ``None`` so that properties with + invalid or missing type information are still deserialized instead of + raising a validation error. + + :param data: The input data to validate. + :type data: Any + :return: The validated data with empty type coerced to None. + :rtype: Any + """ + if isinstance(data, dict) and data.get("type") == "": + data = {**data, "type": None} + return data + items: Optional["SchemaProperty"] = None + properties: Optional[Mapping[str, "SchemaProperty"]] = None + default: Any = None + required: Optional[Set[str]] = None + + def has_default(self) -> bool: + """ + Check if the property has a default value defined. + + :return: True if a default value is set, False otherwise. + :rtype: bool + """ + return "default" in self.model_fields_set + + +class SchemaDefinition(BaseModel): + """ + A top-level JSON Schema-like definition for an object. + + :ivar type: The schema type of the root. Typically :data:`~SchemaType.OBJECT`. + :ivar properties: Mapping of top-level property names to their schemas. + :ivar required: Set of required top-level property names within + :attr:`properties`. + """ + + type: SchemaType = SchemaType.OBJECT + properties: Mapping[str, SchemaProperty] = field(default_factory=dict) # pylint: disable=E3701 + required: Optional[Set[str]] = None + + def extract_from(self, + datasource: Mapping[str, Any], + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + return self._extract(datasource, self.properties, self.required, property_alias) + + @classmethod + def _extract(cls, + datasource: Mapping[str, Any], + properties: Mapping[str, SchemaProperty], + required: Optional[Set[str]] = None, + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + for property_name, schema in properties.items(): + # Determine the keys to look for in the datasource + keys_to_check = [property_name] + if property_alias and property_name in property_alias: + keys_to_check.extend(property_alias[property_name]) + + # Find the first matching key in the datasource + value_found = False + for key in keys_to_check: + if key in datasource: + value = datasource[key] + value_found = True + break + + if not value_found and schema.has_default(): + value = schema.default + value_found = True + + if not value_found: + # If the property is required but not found, raise an error + if required and property_name in required: + raise KeyError(f"Required property '{property_name}' not found in datasource.") + # If not found and not required, skip to next property + continue + + # Process the value based on its schema type + if schema.type == SchemaType.OBJECT and schema.properties: + if isinstance(value, Mapping): + nested_value = cls._extract( + value, + schema.properties, + schema.required, + property_alias + ) + result[property_name] = nested_value + elif schema.type == SchemaType.ARRAY and schema.items: + if isinstance(value, Iterable): + nested_list = [] + for item in value: + if schema.items.type == SchemaType.OBJECT and schema.items.properties: + nested_item = SchemaDefinition._extract( + item, + schema.items.properties, + schema.items.required, + property_alias + ) + nested_list.append(nested_item) + else: + nested_list.append(item) + result[property_name] = nested_list + else: + result[property_name] = value + + return result + + +class RawFoundryHostedMcpTool(BaseModel): + """Pydantic model for a single MCP tool. + + :ivar str name: Unique name identifier of the tool. + :ivar Optional[str] title: Display title of the tool, defaults to name if not provided. + :ivar str description: Human-readable description of the tool. + :ivar SchemaDefinition input_schema: JSON schema for tool input parameters. + :ivar Optional[SchemaDefinition] meta: Optional metadata for the tool. + """ + + name: str + title: Optional[str] = None + description: str = "" + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="inputSchema" + ) + meta: Optional[SchemaDefinition] = Field(default=None, validation_alias="_meta") + + def model_post_init(self, __context: Any) -> None: + if self.title is None: + self.title = self.name + + +class RawFoundryHostedMcpTools(BaseModel): + """Pydantic model for the result containing list of tools. + + :ivar List[RawFoundryHostedMcpTool] tools: List of MCP tool definitions. + """ + + tools: List[RawFoundryHostedMcpTool] = Field(default_factory=list) + + +class ListFoundryHostedMcpToolsResponse(BaseModel): + """Pydantic model for the complete MCP tools/list JSON-RPC response. + + :ivar str jsonrpc: JSON-RPC version, defaults to "2.0". + :ivar int id: Request identifier, defaults to 0. + :ivar RawFoundryHostedMcpTools result: Result containing the list of tools. + """ + + jsonrpc: str = "2.0" + id: int = 0 + result: RawFoundryHostedMcpTools = Field( + default_factory=RawFoundryHostedMcpTools + ) + + +class BaseConnectedToolsErrorResult(BaseModel, ABC): + """Base model for connected tools error responses.""" + + @abstractmethod + def as_exception(self) -> Exception: + """Convert the error result to an appropriate exception. + + :return: An exception representing the error. + :rtype: Exception + """ + raise NotImplementedError + + +class OAuthConsentRequiredErrorResult(BaseConnectedToolsErrorResult): + """Model for OAuth consent required error responses. + + :ivar Literal["OAuthConsentRequired"] type: Error type identifier. + :ivar Optional[str] consent_url: URL for user consent, if available. + :ivar Optional[str] message: Human-readable error message. + :ivar Optional[str] project_connection_id: Project connection ID related to the error. + """ + + type: Literal["OAuthConsentRequired"] + consent_url: str = Field( + validation_alias=AliasChoices( + AliasPath("toolResult", "consentUrl"), + AliasPath("toolResult", "message"), + ), + ) + message: str = Field( + validation_alias=AliasPath("toolResult", "message"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("toolResult", "projectConnectionId"), + ) + + def as_exception(self) -> Exception: + return OAuthConsentRequiredError(self.message, self.consent_url, self.project_connection_id) + + +class RawFoundryConnectedTool(BaseModel): + """Pydantic model for a single connected tool. + + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar Optional[SchemaDefinition] input_schema: Input schema for the tool parameters. + """ + name: str + description: str + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="parameters", + ) + + +class RawFoundryConnectedRemoteServer(BaseModel): + """Pydantic model for a connected remote server. + + :ivar str protocol: Protocol used by the remote server. + :ivar str project_connection_id: Project connection ID of the remote server. + :ivar List[RawFoundryConnectedTool] tools: List of connected tools from this server. + """ + protocol: str = Field( + validation_alias=AliasPath("remoteServer", "protocol"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("remoteServer", "projectConnectionId"), + ) + tools: List[RawFoundryConnectedTool] = Field( + default_factory=list, + validation_alias="manifest", + ) + + +class ListConnectedToolsResult(BaseModel): + """Pydantic model for the result of listing connected tools. + + :ivar List[ConnectedRemoteServer] servers: List of connected remote servers. + """ + servers: List[RawFoundryConnectedRemoteServer] = Field( + default_factory=list, + validation_alias="tools", + ) + + +class ListFoundryConnectedToolsResponse(BaseModel): + """Pydantic model for the response of listing the connected tools. + + :ivar Optional[ConnectedToolsResult] result: Result containing connected tool servers. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + + result: Optional[ListConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[ListConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and "type" in payload else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "ListFoundryConnectedToolsResponse": + parsed = cls._TYPE_ADAPTER.validate_python(data) + normalized = {} + if isinstance(parsed, ListConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed # type: ignore[assignment] + return handler(normalized) + + +class InvokeConnectedToolsResult(BaseModel): + """Pydantic model for the result of invoking a connected tool. + + :ivar Any value: The result value from the tool invocation. + """ + value: Any = Field(validation_alias="toolResult") + + +class InvokeFoundryConnectedToolsResponse(BaseModel): + """Pydantic model for the response of invoking a connected tool. + + :ivar Optional[InvokeConnectedToolsResult] result: Result of the tool invocation. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + result: Optional[InvokeConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[InvokeConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and + # handle other error types in the future + payload.get("type") == "OAuthConsentRequired" + else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "InvokeFoundryConnectedToolsResponse": + parsed: Union[InvokeConnectedToolsResult, BaseConnectedToolsErrorResult] = (cls._TYPE_ADAPTER + .validate_python(data)) + normalized: Dict[str, Any] = {} + if isinstance(parsed, InvokeConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed + return handler(normalized) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py new file mode 100644 index 000000000000..a3c552fe2575 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from abc import ABC +import json +from typing import Any, ClassVar, MutableMapping, Type + +from azure.core import AsyncPipelineClient +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceExistsError, \ + ResourceNotFoundError, ResourceNotModifiedError, map_error +from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest + +ErrorMapping = MutableMapping[int, Type[HttpResponseError]] + + +class BaseOperations(ABC): + DEFAULT_ERROR_MAP: ClassVar[ErrorMapping] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + + def __init__(self, client: AsyncPipelineClient, error_map: ErrorMapping | None = None) -> None: + self._client = client + self._error_map = self._prepare_error_map(error_map) + + @classmethod + def _prepare_error_map(cls, custom_error_map: ErrorMapping | None = None) -> MutableMapping: + """Prepare error map by merging default and custom error mappings. + + :param custom_error_map: Custom error mappings to merge + :return: Merged error map + """ + error_map = cls.DEFAULT_ERROR_MAP + if custom_error_map: + error_map = dict(cls.DEFAULT_ERROR_MAP) + error_map.update(custom_error_map) + return error_map + + async def _send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> AsyncHttpResponse: + """Send an HTTP request. + + :param request: HTTP request + :param stream: Stream to be used for HTTP requests + :param kwargs: Keyword arguments + + :return: Response object + """ + response: AsyncHttpResponse = await self._client.send_request(request, stream=stream, **kwargs) + self._handle_response_error(response) + return response + + def _handle_response_error(self, response: AsyncHttpResponse) -> None: + """Handle HTTP response errors. + + :param response: HTTP response to check + :raises HttpResponseError: If response status is not 200 + """ + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=self._error_map) + raise HttpResponseError(response=response) + + def _extract_response_json(self, response: AsyncHttpResponse) -> Any: + try: + payload_text = response.text() + payload_json = json.loads(payload_text) if payload_text else {} + except AttributeError: + payload_bytes = response.body() + payload_json = json.loads(payload_bytes.decode("utf-8")) if payload_bytes else {} + return payload_json \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py new file mode 100644 index 000000000000..83138a17ad9a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Mapping, Optional, Tuple, cast + +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._base import BaseOperations +from .._models import FoundryConnectedTool, FoundryToolDetails, FoundryToolSource, InvokeFoundryConnectedToolsResponse, \ + ListFoundryConnectedToolsResponse, ResolvedFoundryTool, UserInfo +from ..._exceptions import ToolInvocationError + + +class BaseFoundryConnectedToolsOperations(BaseOperations, ABC): + """Base operations for Foundry connected tools.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = { + "api-version": _API_VERSION + } + + @staticmethod + def _list_tools_path(agent_name: str) -> str: + return f"/agents/{agent_name}/tools/resolve" + + @staticmethod + def _invoke_tool_path(agent_name: str) -> str: + return f"/agents/{agent_name}/tools/invoke" + + def _build_list_tools_request( + self, + tools: List[FoundryConnectedTool], + user: Optional[UserInfo], + agent_name: str,) -> HttpRequest: + payload: Dict[str, Any] = { + "remoteServers": [ + { + "projectConnectionId": tool.project_connection_id, + "protocol": tool.protocol, + } for tool in tools + ], + } + if user: + payload["user"] = { + "objectId": user.object_id, + "tenantId": user.tenant_id, + } + return self._client.post( + self._list_tools_path(agent_name), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _convert_listed_tools( + cls, + resp: ListFoundryConnectedToolsResponse, + input_tools: List[FoundryConnectedTool]) -> Iterable[Tuple[FoundryConnectedTool, FoundryToolDetails]]: + if resp.error: + raise resp.error.as_exception() + if not resp.result: + return + + tool_map = {(tool.project_connection_id, tool.protocol): tool for tool in input_tools} + for server in resp.result.servers: + input_tool = tool_map.get((server.project_connection_id, server.protocol)) + if not input_tool: + continue + + for tool in server.tools: + details = FoundryToolDetails( + name=tool.name, + description=tool.description, + input_schema=tool.input_schema, + ) + yield input_tool, details + + def _build_invoke_tool_request( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + user: Optional[UserInfo], + agent_name: str) -> HttpRequest: + if tool.definition.source != FoundryToolSource.CONNECTED: + raise ToolInvocationError(f"Tool {tool.name} is not a Foundry connected tool.", tool=tool) + + tool_def = cast(FoundryConnectedTool, tool.definition) + payload: Dict[str, Any] = { + "toolName": tool.name, + "arguments": arguments, + "remoteServer": { + "projectConnectionId": tool_def.project_connection_id, + "protocol": tool_def.protocol, + }, + } + if user: + payload["user"] = { + "objectId": user.object_id, + "tenantId": user.tenant_id, + } + return self._client.post( + self._invoke_tool_path(agent_name), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _convert_invoke_result(cls, resp: InvokeFoundryConnectedToolsResponse) -> Any: + if resp.error: + raise resp.error.as_exception() + if not resp.result: + return None + return resp.result.value + + +class FoundryConnectedToolsOperations(BaseFoundryConnectedToolsOperations): + """Operations for managing Foundry connected tools.""" + + @distributed_trace_async + async def list_tools(self, + tools: List[FoundryConnectedTool], + user: Optional[UserInfo], + agent_name: str) -> Iterable[Tuple[FoundryConnectedTool, FoundryToolDetails]]: + """List connected tools. + + :param tools: List of connected tool definitions. + :type tools: List[FoundryConnectedTool] + :param user: User information for the request. Value can be None if running in local. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent. + :type agent_name: str + :return: An async iterable of tuples containing the tool definition and its details. + :rtype: AsyncIterable[Tuple[FoundryConnectedTool, FoundryToolDetails]] + """ + if not tools: + return [] + + request = self._build_list_tools_request(tools, user, agent_name) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + tools_response = ListFoundryConnectedToolsResponse.model_validate(json_response) + return self._convert_listed_tools(tools_response, tools) + + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + user: Optional[UserInfo], + agent_name: str) -> Any: + """Invoke a connected tool. + + :param tool: Tool descriptor to invoke. + :type tool: ResolvedFoundryTool + :param arguments: Input arguments for the tool. + :type arguments: Mapping[str, Any] + :param user: User information for the request. Value can be None if running in local. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent. + :type agent_name: str + :return: Result of the tool invocation. + :rtype: Any + """ + request = self._build_invoke_tool_request(tool, arguments, user, agent_name) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + invoke_response = InvokeFoundryConnectedToolsResponse.model_validate(json_response) + return self._convert_invoke_result(invoke_response) + \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..08587e274096 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py @@ -0,0 +1,168 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Tuple, cast + +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._base import BaseOperations +from .._models import FoundryHostedMcpTool, FoundryToolDetails, FoundryToolSource, ListFoundryHostedMcpToolsResponse, \ + ResolvedFoundryTool +from ..._exceptions import ToolInvocationError + + +class BaseFoundryHostedMcpToolsOperations(BaseOperations, ABC): + """Base operations for Foundry-hosted MCP tools.""" + + _PATH: ClassVar[str] = "/mcp_tools" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + "Connection": "keep-alive", + "Cache-Control": "no-cache", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = { + "api-version": _API_VERSION + } + + _LIST_TOOLS_REQUEST_BODY: ClassVar[Dict[str, Any]] = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + } + + _INVOKE_TOOL_REQUEST_BODY_TEMPLATE: ClassVar[Dict[str, Any]] = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + } + + # Tool-specific property key overrides + # Format: {"tool_name": {"tool_def_key": "meta_schema_key"}} + _TOOL_PROPERTY_ALIAS: ClassVar[Dict[str, Dict[str, List[str]]]] = { + "_default": { + "imagegen_model_deployment_name": ["model_deployment_name"], + "model_deployment_name": ["model"], + "deployment_name": ["model"], + }, + "image_generation": { + "imagegen_model_deployment_name": ["model"] + }, + # Add more tool-specific mappings as needed + } + + def _build_list_tools_request(self) -> HttpRequest: + """Build request for listing MCP tools. + + :return: Request for listing MCP tools. + """ + return self._client.post(self._PATH, + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=self._LIST_TOOLS_REQUEST_BODY) + + @staticmethod + def _convert_listed_tools( + response: ListFoundryHostedMcpToolsResponse, + allowed_tools: List[FoundryHostedMcpTool]) -> Iterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]]: + + allowlist = {tool.name: tool for tool in allowed_tools} + for tool in response.result.tools: + definition = allowlist.get(tool.name) + if not definition: + continue + details = FoundryToolDetails( + name=tool.name, + description=tool.description, + metadata=tool.meta, + input_schema=tool.input_schema) + yield definition, details + + def _build_invoke_tool_request(self, tool: ResolvedFoundryTool, arguments: Dict[str, Any]) -> HttpRequest: + if tool.definition.source != FoundryToolSource.HOSTED_MCP: + raise ToolInvocationError(f"Tool {tool.name} is not a Foundry-hosted MCP tool.", tool=tool) + definition = cast(FoundryHostedMcpTool, tool.definition) + + payload = dict(self._INVOKE_TOOL_REQUEST_BODY_TEMPLATE) + payload["params"] = { + "name": tool.name, + "arguments": arguments + } + if tool.metadata and definition.configuration: + payload["_meta"] = tool.metadata.extract_from(definition.configuration, + self._resolve_property_alias(tool.name)) + + return self._client.post(self._PATH, + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _resolve_property_alias(cls, tool_name: str) -> Dict[str, List[str]]: + """Get property key overrides for a specific tool. + + :param tool_name: Name of the tool. + :type tool_name: str + :return: Property key overrides. + :rtype: Dict[str, List[str]] + """ + overrides = dict(cls._TOOL_PROPERTY_ALIAS.get("_default", {})) + tool_specific = cls._TOOL_PROPERTY_ALIAS.get(tool_name, {}) + overrides.update(tool_specific) + return overrides + + +class FoundryMcpToolsOperations(BaseFoundryHostedMcpToolsOperations): + """Operations for Foundry-hosted MCP tools.""" + + @distributed_trace_async + async def list_tools( + self, + allowed_tools: List[FoundryHostedMcpTool] + ) -> Iterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]]: + """List MCP tools. + + :param allowed_tools: List of allowed MCP tools to filter. + :type allowed_tools: List[FoundryHostedMcpTool] + :return: An async iterable of tuples containing tool definitions and their details. + :rtype: AsyncIterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]] + """ + if not allowed_tools: + return [] + + request = self._build_list_tools_request() + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + tools_response = ListFoundryHostedMcpToolsResponse.model_validate(json_response) + + return self._convert_listed_tools(tools_response, allowed_tools) + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + ) -> Any: + """Invoke an MCP tool. + + :param tool: Tool descriptor for the tool to invoke. + :type tool: ResolvedFoundryTool + :param arguments: Input arguments for the tool. + :type arguments: Dict[str, Any] + :return: Result of the tool invocation. + :rtype: Any + """ + request = self._build_invoke_tool_request(tool, arguments) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + invoke_response = json_response + return invoke_response diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py new file mode 100644 index 000000000000..c75532f0d3e4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py @@ -0,0 +1,143 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=C4763 # azure-sdk: async-client-bad-name +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Collection, List, Mapping, MutableMapping, Optional, Union + +from cachetools import TTLCache # type: ignore[import-untyped] + +from ._facade import FoundryToolLike, ensure_foundry_tool +from ._user import UserProvider +from ..client._client import FoundryToolClient +from ..client._models import FoundryTool, FoundryToolDetails, FoundryToolSource, ResolvedFoundryTool, UserInfo + + +class FoundryToolCatalog(ABC): + """Base class for Foundry tool catalogs.""" + def __init__(self, user_provider: UserProvider): + self._user_provider = user_provider + + async def get(self, tool: FoundryToolLike) -> Optional[ResolvedFoundryTool]: + """Gets a Foundry tool by its definition. + + :param tool: The Foundry tool to resolve. + :type tool: FoundryToolLike + :return: The resolved Foundry tool. + :rtype: Optional[ResolvedFoundryTool] + """ + tools = await self.list([tool]) + return tools[0] if tools else None + + @abstractmethod + async def list(self, tools: List[FoundryToolLike]) -> List[ResolvedFoundryTool]: + """Lists all available Foundry tools. + + :param tools: The list of Foundry tools to resolve. + :type tools: List[FoundryToolLike] + :return: A list of resolved Foundry tools. + :rtype: List[ResolvedFoundryTool] + """ + raise NotImplementedError + + +_CachedValueType = Union[Awaitable[List[FoundryToolDetails]], List[FoundryToolDetails]] + + +class CachedFoundryToolCatalog(FoundryToolCatalog, ABC): + """Cached implementation of FoundryToolCatalog with concurrency-safe caching.""" + + def __init__(self, user_provider: UserProvider): + super().__init__(user_provider) + self._cache: MutableMapping[Any, _CachedValueType] = self._create_cache() + + def _create_cache(self) -> MutableMapping[Any, _CachedValueType]: + return TTLCache(maxsize=1024, ttl=600) + + def _get_key(self, user: Optional[UserInfo], tool: FoundryTool) -> Any: + if tool.source is FoundryToolSource.HOSTED_MCP: + return tool.id + return user, tool.id + + async def list(self, tools: List[FoundryToolLike]) -> List[ResolvedFoundryTool]: + user = await self._user_provider.get_user() + foundry_tools = {} + tools_to_fetch = {} + fetching_tasks = [] + for t in tools: + tool = ensure_foundry_tool(t) + key = self._get_key(user, tool) + foundry_tools[key] = tool + if key not in self._cache: + tools_to_fetch[key] = tool + elif (task := self._cache[key]) and isinstance(task, Awaitable): + fetching_tasks.append(task) + + # for tools that are not being listed, create a batch task, convert to per-tool resolving tasks, and cache them + if tools_to_fetch: + # Awaitable[Mapping[str, List[FoundryToolDetails]]] + fetched_tools = asyncio.create_task(self._fetch_tools(tools_to_fetch.values(), user)) + + for k, t in tools_to_fetch.items(): + # safe to write cache since it's the only runner in this event loop + task = asyncio.create_task(self._per_tool_fetching_task(k, t, fetched_tools)) + self._cache[k] = task + fetching_tasks.append(task) + + try: + # now we have every tool associated with a task + if fetching_tasks: + await asyncio.gather(*fetching_tasks) + except: + # exception can only be caused by fetching tasks, remove them from cache + for k, _ in tools_to_fetch.items(): + if k in self._cache: + del self._cache[k] + raise + + resolved_tools = [] + for key, tool in foundry_tools.items(): + # this acts like a lock - every task of the same tool waits for the same underlying fetch + task_or_value = self._cache[key] + details_list = (await task_or_value) if isinstance(task_or_value, Awaitable) else task_or_value + for details in details_list: + resolved_tools.append( + ResolvedFoundryTool( + definition=tool, + details=details + ) + ) + + return resolved_tools + + async def _per_tool_fetching_task( + self, + cache_key: Any, + tool: FoundryTool, + fetching: Awaitable[Mapping[str, List[FoundryToolDetails]]] + ) -> List[FoundryToolDetails]: + details = await fetching + details_list = details.get(tool.id, []) + # replace the task in cache with the actual value to optimize memory usage + self._cache[cache_key] = details_list + return details_list + + @abstractmethod + async def _fetch_tools(self, + tools: Collection[FoundryTool], + user: Optional[UserInfo]) -> Mapping[str, List[FoundryToolDetails]]: + raise NotImplementedError + + +class DefaultFoundryToolCatalog(CachedFoundryToolCatalog): + """Default implementation of FoundryToolCatalog.""" + + def __init__(self, client: FoundryToolClient, user_provider: UserProvider, agent_name: str): + super().__init__(user_provider) + self._client = client + self._agent_name = agent_name + + async def _fetch_tools(self, + tools: Collection[FoundryTool], + user: Optional[UserInfo]) -> Mapping[str, List[FoundryToolDetails]]: + return await self._client.list_tools_details(tools, self._agent_name, user) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py new file mode 100644 index 000000000000..71c4601cb525 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +from typing import Any, Dict, Union + +from .. import FoundryConnectedTool, FoundryHostedMcpTool +from .._exceptions import InvalidToolFacadeError +from ..client._models import FoundryTool, FoundryToolProtocol + +# FoundryToolFacade: a β€œtool descriptor” bag. +# +# Reserved keys: +# Required: +# - "type": str Discriminator, e.g. "mcp" | "a2a" | "code_interpreter" | ... +# Optional: +# - "project_connection_id": str Project connection id of Foundry connected tools, +# required when "type" is "mcp" or "a2a". +# +# Custom keys: +# - Allowed, but MUST NOT shadow reserved keys. +FoundryToolFacade = Dict[str, Any] + +FoundryToolLike = Union[FoundryToolFacade, FoundryTool] + + +def ensure_foundry_tool(tool: FoundryToolLike) -> FoundryTool: + """Ensure the input is a FoundryTool instance. + + :param tool: The tool descriptor, either as a FoundryToolFacade or FoundryTool. + :type tool: FoundryToolLike + :return: The corresponding FoundryTool instance. + :rtype: FoundryTool + """ + if isinstance(tool, FoundryTool): + return tool + + tool = tool.copy() + tool_type = tool.pop("type", None) + if not isinstance(tool_type, str) or not tool_type: + raise InvalidToolFacadeError("FoundryToolFacade must have a valid 'type' field of type str.") + + try: + protocol = FoundryToolProtocol(tool_type) + project_connection_id = tool.pop("project_connection_id", None) + if not isinstance(project_connection_id, str) or not project_connection_id: + raise InvalidToolFacadeError(f"project_connection_id is required for tool protocol {protocol}.") + + # Parse the connection identifier to extract the connection name + connection_name = _parse_connection_id(project_connection_id) + return FoundryConnectedTool(protocol=protocol, project_connection_id=connection_name) + except ValueError: + return FoundryHostedMcpTool(name=tool_type, configuration=tool) + + +# Pattern for Azure resource ID format: +# /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts +# //projects//connections/ +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft\.CognitiveServices/" + r"accounts/[^/]+/projects/[^/]+/connections/(?P[^/]+)$", + re.IGNORECASE, +) + + +def _parse_connection_id(connection_id: str) -> str: + """Parse the connection identifier and extract the connection name. + + Supports two formats: + 1. Simple name: "my-connection-name" + 2. Resource ID: "/subscriptions//resourceGroups//providers + /Microsoft.CognitiveServices/accounts//projects//connections/" + + :param connection_id: The connection identifier, either a simple name or a full resource ID. + :type connection_id: str + :return: The connection name extracted from the identifier. + :rtype: str + :raises InvalidToolFacadeError: If the connection_id format is invalid. + """ + if not connection_id: + raise InvalidToolFacadeError("Connection identifier cannot be empty.") + + # Check if it's a resource ID format (starts with /) + if connection_id.startswith("/"): + match = _RESOURCE_ID_PATTERN.match(connection_id) + if not match: + raise InvalidToolFacadeError( + f"Invalid resource ID format for connection: '{connection_id}'. " + "Expected format: /subscriptions//resourceGroups//providers/" + "Microsoft.CognitiveServices/accounts//projects//connections/" + ) + return match.group("name") + + # Otherwise, treat it as a simple connection name + return connection_id diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py new file mode 100644 index 000000000000..d24c79dd4d12 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py @@ -0,0 +1,69 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from typing import Any, Dict + +from ._user import UserProvider +from ..client._client import FoundryToolClient +from ..client._models import ResolvedFoundryTool + + +class FoundryToolInvoker(ABC): + """Abstract base class for Foundry tool invokers.""" + + @property + @abstractmethod + def resolved_tool(self) -> ResolvedFoundryTool: + """Get the resolved tool definition. + + :return: The tool definition. + :rtype: ResolvedFoundryTool + """ + raise NotImplementedError + + @abstractmethod + async def invoke(self, arguments: Dict[str, Any]) -> Any: + """Invoke the tool with the given arguments. + + :param arguments: The arguments to pass to the tool. + :type arguments: Dict[str, Any] + :return: The result of the tool invocation + :rtype: Any + """ + raise NotImplementedError + + +class DefaultFoundryToolInvoker(FoundryToolInvoker): + """Default implementation of FoundryToolInvoker.""" + + def __init__(self, + resolved_tool: ResolvedFoundryTool, + client: FoundryToolClient, + user_provider: UserProvider, + agent_name: str): + self._resolved_tool = resolved_tool + self._client = client + self._user_provider = user_provider + self._agent_name = agent_name + + @property + def resolved_tool(self) -> ResolvedFoundryTool: + """Get the resolved tool definition. + + :return: The tool definition. + :rtype: ResolvedFoundryTool + """ + return self._resolved_tool + + async def invoke(self, arguments: Dict[str, Any]) -> Any: + """Invoke the tool with the given arguments. + + :param arguments: The arguments to pass to the tool + :type arguments: Dict[str, Any] + :return: The result of the tool invocation + :rtype: Any + """ + user = await self._user_provider.get_user() + result = await self._client.invoke_tool(self._resolved_tool, arguments, self._agent_name, user) + return result diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py new file mode 100644 index 000000000000..9596124d9b55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from typing import Union + +from ._catalog import FoundryToolCatalog +from ._facade import FoundryToolLike, ensure_foundry_tool +from ._invoker import DefaultFoundryToolInvoker, FoundryToolInvoker +from ._user import UserProvider +from .. import FoundryToolClient +from .._exceptions import UnableToResolveToolInvocationError +from ..client._models import ResolvedFoundryTool + + +class FoundryToolInvocationResolver(ABC): + """Resolver for Foundry tool invocations.""" + + @abstractmethod + async def resolve(self, tool: Union[FoundryToolLike, ResolvedFoundryTool]) -> FoundryToolInvoker: + """Resolves a Foundry tool invocation. + + :param tool: The Foundry tool to resolve. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :return: The resolved Foundry tool invoker. + :rtype: FoundryToolInvoker + """ + raise NotImplementedError + + +class DefaultFoundryToolInvocationResolver(FoundryToolInvocationResolver): + """Default implementation of FoundryToolInvocationResolver.""" + + def __init__(self, + catalog: FoundryToolCatalog, + client: FoundryToolClient, + user_provider: UserProvider, + agent_name: str): + self._catalog = catalog + self._client = client + self._user_provider = user_provider + self._agent_name = agent_name + + async def resolve(self, tool: Union[FoundryToolLike, ResolvedFoundryTool]) -> FoundryToolInvoker: + """Resolves a Foundry tool invocation. + + :param tool: The Foundry tool to resolve. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :return: The resolved Foundry tool invoker. + :rtype: FoundryToolInvoker + """ + if isinstance(tool, ResolvedFoundryTool): + resolved_tool = tool + else: + foundry_tool = ensure_foundry_tool(tool) + resolved_tool = await self._catalog.get(foundry_tool) # type: ignore[assignment] + if not resolved_tool: + raise UnableToResolveToolInvocationError(f"Unable to resolve tool {foundry_tool} from catalog", + foundry_tool) + return DefaultFoundryToolInvoker(resolved_tool, self._client, self._user_provider, self._agent_name) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py new file mode 100644 index 000000000000..6335d6ee7c58 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py @@ -0,0 +1,147 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from typing import Any, AsyncContextManager, ClassVar, Dict, Optional, Union + +from azure.core.credentials_async import AsyncTokenCredential + +from ._catalog import DefaultFoundryToolCatalog, FoundryToolCatalog +from ._facade import FoundryToolLike +from ._resolver import DefaultFoundryToolInvocationResolver, FoundryToolInvocationResolver +from ._user import ContextVarUserProvider, UserProvider +from ..client._models import ResolvedFoundryTool +from ..client._client import FoundryToolClient +from ...constants import Constants + + +def create_tool_runtime(project_endpoint: str | None, + credential: AsyncTokenCredential | None) -> "FoundryToolRuntime": + """Create a Foundry tool runtime. + Returns a DefaultFoundryToolRuntime if both project_endpoint and credential are provided, + otherwise returns a ThrowingFoundryToolRuntime which raises errors on usage. + + :param project_endpoint: The project endpoint. + :type project_endpoint: str | None + :param credential: The credential. + :type credential: AsyncTokenCredential | None + :return: The Foundry tool runtime. + :rtype: FoundryToolRuntime + """ + if project_endpoint and credential: + return DefaultFoundryToolRuntime(project_endpoint=project_endpoint, credential=credential) + return ThrowingFoundryToolRuntime() + +class FoundryToolRuntime(AsyncContextManager["FoundryToolRuntime"], ABC): + """Base class for Foundry tool runtimes.""" + + @property + @abstractmethod + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :return: The tool catalog. + :rtype: FoundryToolCatalog + """ + raise NotImplementedError + + @property + @abstractmethod + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :return: The tool invocation resolver. + :rtype: FoundryToolInvocationResolver + """ + raise NotImplementedError + + async def invoke(self, tool: Union[FoundryToolLike, ResolvedFoundryTool], arguments: Dict[str, Any]) -> Any: + """Invoke a tool with the given arguments. + + :param tool: The tool to invoke. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :param arguments: The arguments to pass to the tool. + :type arguments: Dict[str, Any] + :return: The result of the tool invocation. + :rtype: Any + """ + invoker = await self.invocation.resolve(tool) + return await invoker.invoke(arguments) + + +class DefaultFoundryToolRuntime(FoundryToolRuntime): + """Default implementation of FoundryToolRuntime.""" + + def __init__(self, + project_endpoint: str, + credential: "AsyncTokenCredential", + user_provider: Optional[UserProvider] = None): + # Do we need introduce DI here? + self._user_provider = user_provider or ContextVarUserProvider() + self._agent_name = os.getenv(Constants.AGENT_NAME, "$default") + self._client = FoundryToolClient(endpoint=project_endpoint, credential=credential) + self._catalog = DefaultFoundryToolCatalog(client=self._client, + user_provider=self._user_provider, + agent_name=self._agent_name) + self._invocation = DefaultFoundryToolInvocationResolver(catalog=self._catalog, + client=self._client, + user_provider=self._user_provider, + agent_name=self._agent_name) + + @property + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :rtype: FoundryToolCatalog + """ + return self._catalog + + @property + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :rtype: FoundryToolInvocationResolver + """ + return self._invocation + + async def __aenter__(self) -> "DefaultFoundryToolRuntime": + await self._client.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self._client.__aexit__(exc_type, exc_value, traceback) + + +class ThrowingFoundryToolRuntime(FoundryToolRuntime): + """A FoundryToolRuntime that raises errors on usage.""" + _ERROR_MESSAGE: ClassVar[str] = ("FoundryToolRuntime is not configured. " + "Please provide a valid project endpoint and credential.") + + @property + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :returns: The tool catalog. + :rtype: FoundryToolCatalog + :raises RuntimeError: Always raised to indicate the runtime is not configured. + """ + raise RuntimeError(self._ERROR_MESSAGE) + + @property + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :returns: The tool invocation resolver. + :rtype: FoundryToolInvocationResolver + :raises RuntimeError: Always raised to indicate the runtime is not configured. + """ + raise RuntimeError(self._ERROR_MESSAGE) + + async def __aenter__(self) -> "ThrowingFoundryToolRuntime": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py new file mode 100644 index 000000000000..9604124cde9b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py @@ -0,0 +1,67 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from contextvars import ContextVar +from typing import Awaitable, Callable, Optional + +from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.types import ASGIApp + +from ._user import ContextVarUserProvider, resolve_user_from_headers +from ..client._models import UserInfo + +_UserContextType = ContextVar[Optional[UserInfo]] +_ResolverType = Callable[[Request], Awaitable[Optional[UserInfo]]] + +class UserInfoContextMiddleware(BaseHTTPMiddleware): + """Middleware to set user information in a context variable for each request.""" + + def __init__(self, app: ASGIApp, user_info_var: _UserContextType, user_resolver: _ResolverType): + super().__init__(app) + self._user_info_var = user_info_var + self._user_resolver = user_resolver + + @classmethod + def install(cls, + app: Starlette, + user_context: Optional[_UserContextType] = None, + user_resolver: Optional[_ResolverType] = None): + """Install the middleware into a Starlette application. + + :param app: The Starlette application to install the middleware into. + :type app: Starlette + :param user_context: Optional context variable to use for storing user info. + If not provided, a default context variable will be used. + :type user_context: Optional[ContextVar[Optional[UserInfo]]] + :param user_resolver: Optional function to resolve user info from the request. + If not provided, a default resolver will be used. + :type user_resolver: Optional[Callable[[Request], Awaitable[Optional[UserInfo]]]] + + """ + user_info_var : _UserContextType = user_context or ContextVarUserProvider.default_user_info_context + app.add_middleware(UserInfoContextMiddleware, # type: ignore[arg-type] + user_info_var=user_info_var, + user_resolver=user_resolver or cls._default_user_resolver) + + @staticmethod + async def _default_user_resolver(request: Request) -> Optional[UserInfo]: + return resolve_user_from_headers(request.headers) + + async def dispatch(self, request: Request, call_next): + """Process the incoming request, setting the user info in the context variable. + + :param request: The incoming Starlette request. + :type request: Request + :param call_next: The next middleware or endpoint to call. + :type call_next: Callable[[Request], Awaitable[Response]] + :return: The response from the next middleware or endpoint. + :rtype: Response + """ + user = await self._user_resolver(request) + token = self._user_info_var.set(user) + try: + return await call_next(request) + finally: + self._user_info_var.reset(token) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py new file mode 100644 index 000000000000..f72b30c0d3d3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from contextvars import ContextVar +from abc import ABC, abstractmethod +from typing import ClassVar, Mapping, Optional + +from ..client._models import UserInfo + + +class UserProvider(ABC): + """Base class for user providers.""" + + @abstractmethod + async def get_user(self) -> Optional[UserInfo]: + """Get the user information. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + raise NotImplementedError + + +class ContextVarUserProvider(UserProvider): + """User provider that retrieves user information from a ContextVar.""" + default_user_info_context: ClassVar[ContextVar[UserInfo]] = ContextVar("user_info_context") + + def __init__(self, context: Optional[ContextVar[UserInfo]] = None): + self.context = context or self.default_user_info_context + + async def get_user(self) -> Optional[UserInfo]: + """Get the user information from the context variable. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + return self.context.get(None) + + +def resolve_user_from_headers(headers: Mapping[str, str], + object_id_header: str = "x-aml-oid", + tenant_id_header: str = "x-aml-tid") -> Optional[UserInfo]: + """Resolve user information from HTTP headers. + + :param headers: The HTTP headers. + :type headers: Mapping[str, str] + :param object_id_header: The header name for the object ID. + :type object_id_header: str + :param tenant_id_header: The header name for the tenant ID. + :type tenant_id_header: str + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + object_id = headers.get(object_id_header, "") + tenant_id = headers.get(tenant_id_header, "") + + if not object_id or not tenant_id: + return None + + return UserInfo(object_id=object_id, tenant_id=tenant_id) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py new file mode 100644 index 000000000000..037fb1dc04de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +from ._name_resolver import ToolNameResolver + +__all__ = [ + "ToolNameResolver", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py new file mode 100644 index 000000000000..9f1b7874f52c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from ..client._models import ResolvedFoundryTool + + +class ToolNameResolver: + """Utility class for resolving tool names to be registered to model.""" + + def __init__(self): + self._count_by_name = {} + self._stable_names = {} + + def resolve(self, tool: ResolvedFoundryTool) -> str: + """Resolve a stable name for the given tool. + If the tool name has not been used before, use it as is. + If it has been used, append an underscore and a count to make it unique. + + :param tool: The tool to resolve the name for. + :type tool: ResolvedFoundryTool + :return: The resolved stable name for the tool. + :rtype: str + """ + final_name = self._stable_names.get(tool.id) + if final_name is not None: + return final_name + + dup_count = self._count_by_name.setdefault(tool.details.name, 0) + + if dup_count == 0: + final_name = tool.details.name + else: + final_name = f"{tool.details.name}_{dup_count}" + + self._stable_names[tool.id] = final_name + self._count_by_name[tool.details.name] = dup_count + 1 + return self._stable_names[tool.id] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py new file mode 100644 index 000000000000..0b6600de7d6a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import asyncio # pylint: disable=C4763 # azure-sdk: async-client-bad-name +import inspect +from types import TracebackType +from typing import Any, Type, cast + +from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + + +async def _to_thread(func, *args, **kwargs): # pylint: disable=C4743 # azure-sdk: client-method-should-not-use-static-method + """Compatibility wrapper for asyncio.to_thread (Python 3.8+). + + :param func: The function to run in a thread. + :type func: Callable + :param args: Positional arguments to pass to the function. + :type args: Any + :param kwargs: Keyword arguments to pass to the function. + :type kwargs: Any + :return: The result of the function call. + :rtype: Any + """ + if hasattr(asyncio, "to_thread"): + return await asyncio.to_thread(func, *args, **kwargs) # py>=3.9 + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + +class AsyncTokenCredentialAdapter(AsyncTokenCredential): + """ + AsyncTokenCredential adapter for either: + - azure.core.credentials.TokenCredential (sync) + - azure.core.credentials_async.AsyncTokenCredential (async) + """ + + def __init__(self, credential: TokenCredential | AsyncTokenCredential) -> None: + if not hasattr(credential, "get_token"): + raise TypeError("credential must have a get_token method") + self._credential = credential + self._is_async = isinstance(credential, AsyncTokenCredential) or inspect.iscoroutinefunction( + getattr(credential, "get_token", None) + ) + + async def get_token( + self, + *scopes: str, + claims: str | None = None, + tenant_id: str | None = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + if self._is_async: + cred = cast(AsyncTokenCredential, self._credential) + return await cred.get_token(*scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs) + return await _to_thread(self._credential.get_token, + *scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs) + + async def close(self) -> None: + """ + Best-effort resource cleanup: + - if underlying has async close(): await it + - else if underlying has sync close(): run it in a thread + """ + close_fn = getattr(self._credential, "close", None) + if close_fn is None: + return + + if inspect.iscoroutinefunction(close_fn): + await close_fn() + else: + await _to_thread(close_fn) + + async def __aenter__(self) -> "AsyncTokenCredentialAdapter": + enter = getattr(self._credential, "__aenter__", None) + if enter is not None and inspect.iscoroutinefunction(enter): + await enter() + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + aexit = getattr(self._credential, "__aexit__", None) + if aexit is not None and inspect.iscoroutinefunction(aexit): + return await aexit(exc_type, exc_value, traceback) + await self.close() diff --git a/sdk/agentserver/azure-ai-agentserver-core/cspell.json b/sdk/agentserver/azure-ai-agentserver-core/cspell.json index 126cadc0625c..d5003af37fe1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/cspell.json +++ b/sdk/agentserver/azure-ai-agentserver-core/cspell.json @@ -16,7 +16,11 @@ "GETFL", "DETFL", "SETFL", - "Planifica" + "Planifica", + "ainvoke", + "oauthreq", + "hitl", + "HITL" ], "ignorePaths": [ "*.csv", diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst new file mode 100644 index 000000000000..415b7d3b2538 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.application package +============================================= + +.. automodule:: azure.ai.agentserver.core.application + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.operations.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.operations.rst new file mode 100644 index 000000000000..3076ff010e1b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.operations.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.checkpoints.client.operations package +=============================================================== + +.. automodule:: azure.ai.agentserver.core.checkpoints.client.operations + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.rst new file mode 100644 index 000000000000..cd6763335948 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.client.rst @@ -0,0 +1,15 @@ +azure.ai.agentserver.core.checkpoints.client package +==================================================== + +.. automodule:: azure.ai.agentserver.core.checkpoints.client + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.checkpoints.client.operations diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.rst new file mode 100644 index 000000000000..99b9dfa2ef50 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.checkpoints.rst @@ -0,0 +1,15 @@ +azure.ai.agentserver.core.checkpoints package +============================================= + +.. automodule:: azure.ai.agentserver.core.checkpoints + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.checkpoints.client diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst new file mode 100644 index 000000000000..120b01cccc5a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst @@ -0,0 +1,8 @@ +azure.ai.agentserver.core.models package +======================================== + +.. automodule:: azure.ai.agentserver.core.models + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst index da01b083b0b3..60005f2b04cc 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst @@ -12,7 +12,12 @@ Subpackages .. toctree:: :maxdepth: 4 + azure.ai.agentserver.core.application + azure.ai.agentserver.core.checkpoints + azure.ai.agentserver.core.models azure.ai.agentserver.core.server + azure.ai.agentserver.core.tools + azure.ai.agentserver.core.utils Submodules ---------- diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst index cf935aa1d1ed..68f155131f5c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst @@ -9,18 +9,18 @@ azure.ai.agentserver.core.server.common.id\_generator package Submodules ---------- -azure.ai.agentserver.core.server.common.id\_generator.foundry\_id\_generator module ------------------------------------------------------------------------------------ +azure.ai.agentserver.core.server.common.id\_generator.\_foundry\_id\_generator module +------------------------------------------------------------------------------------ -.. automodule:: azure.ai.agentserver.core.server.common.id_generator.foundry_id_generator +.. automodule:: azure.ai.agentserver.core.server.common.id_generator._foundry_id_generator :inherited-members: :members: :undoc-members: -azure.ai.agentserver.core.server.common.id\_generator.id\_generator module --------------------------------------------------------------------------- +azure.ai.agentserver.core.server.common.id\_generator.\_id\_generator module +--------------------------------------------------------------------------- -.. automodule:: azure.ai.agentserver.core.server.common.id_generator.id_generator +.. automodule:: azure.ai.agentserver.core.server.common.id_generator._id_generator :inherited-members: :members: :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst index 26c4aaf4d15a..fd02e856642c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst @@ -17,10 +17,18 @@ Subpackages Submodules ---------- -azure.ai.agentserver.core.server.common.agent\_run\_context module ------------------------------------------------------------------- +azure.ai.agentserver.core.server.common.\_agent\_run\_context module +------------------------------------------------------------------- -.. automodule:: azure.ai.agentserver.core.server.common.agent_run_context +.. automodule:: azure.ai.agentserver.core.server.common._agent_run_context + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.core.server.common.\_constants module +---------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.core.server.common._constants :inherited-members: :members: :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst index b82fa765b839..8363ec9e32d8 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst @@ -17,10 +17,10 @@ Subpackages Submodules ---------- -azure.ai.agentserver.core.server.base module --------------------------------------------- +azure.ai.agentserver.core.server.\_base module +---------------------------------------------- -.. automodule:: azure.ai.agentserver.core.server.base +.. automodule:: azure.ai.agentserver.core.server._base :inherited-members: :members: :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst new file mode 100644 index 000000000000..14304731f5e7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.client package +============================================== + +.. automodule:: azure.ai.agentserver.core.tools.client + :inherited-members: BaseModel + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst new file mode 100644 index 000000000000..6b798851fed2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst @@ -0,0 +1,17 @@ +azure.ai.agentserver.core.tools package +======================================= + +.. automodule:: azure.ai.agentserver.core.tools + :inherited-members: BaseModel + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.tools.client + azure.ai.agentserver.core.tools.runtime + azure.ai.agentserver.core.tools.utils diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst new file mode 100644 index 000000000000..c502d56b42f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.runtime package +=============================================== + +.. automodule:: azure.ai.agentserver.core.tools.runtime + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst new file mode 100644 index 000000000000..94d3f310e112 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.utils package +============================================= + +.. automodule:: azure.ai.agentserver.core.tools.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst new file mode 100644 index 000000000000..5250167cf7e6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.utils package +======================================= + +.. automodule:: azure.ai.agentserver.core.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index f574360722bb..a0bca5c434fa 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -8,6 +8,7 @@ authors = [ ] license = "MIT" classifiers = [ + "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3", @@ -19,22 +20,27 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-monitor-opentelemetry>=1.5.0", - "azure-ai-projects", - "azure-ai-agents>=1.2.0b5", + "azure-monitor-opentelemetry>=1.5.0,<1.8.5", + "azure-ai-projects>=2.0.0b1", + "azure-ai-agents==1.2.0b5", "azure-core>=1.35.0", - "azure-identity", + "azure-identity>=1.25.1", "openai>=1.80.0", "opentelemetry-api>=1.35", "opentelemetry-exporter-otlp-proto-http", - "starlette>=0.45.0", + "starlette>=1.0.0rc1,<2.0.0", "uvicorn>=0.31.0", + "aiohttp>=3.13.0", # used by azure-identity aio + "cachetools>=6.0.0" ] [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + [tool.setuptools.packages.find] exclude = [ "tests*", @@ -68,4 +74,4 @@ combine-as-imports = true [tool.azure-sdk-build] breaking = false # incompatible python version pyright = false -verifytypes = false \ No newline at end of file +verifytypes = false diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/bilingual_weekend_planner/main.py b/sdk/agentserver/azure-ai-agentserver-core/samples/bilingual_weekend_planner/main.py index 099d8dc45181..2cf533eb33fb 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/bilingual_weekend_planner/main.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/bilingual_weekend_planner/main.py @@ -33,7 +33,7 @@ CreateResponse, Response as OpenAIResponse, ) -from azure.ai.agentserver.core.models.projects import ( +from azure.ai.agentserver.core.models._projects import ( ItemContentOutputText, ResponseCompletedEvent, ResponseCreatedEvent, diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/mcp_simple/mcp_simple.py b/sdk/agentserver/azure-ai-agentserver-core/samples/mcp_simple/mcp_simple.py index af9812826941..3831f702564d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/mcp_simple/mcp_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/mcp_simple/mcp_simple.py @@ -29,7 +29,7 @@ from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent from azure.ai.agentserver.core.models import Response as OpenAIResponse -from azure.ai.agentserver.core.models.projects import ( +from azure.ai.agentserver.core.models._projects import ( ItemContentOutputText, MCPListToolsItemResource, MCPListToolsTool, diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py index 3d4187a188f2..f4298d21d39c 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py @@ -3,7 +3,7 @@ from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent from azure.ai.agentserver.core.models import Response as OpenAIResponse -from azure.ai.agentserver.core.models.projects import ( +from azure.ai.agentserver.core.models._projects import ( ItemContentOutputText, ResponseCompletedEvent, ResponseCreatedEvent, @@ -97,7 +97,7 @@ async def agent_run(context: AgentRunContext): return response -my_agent = FoundryCBAgent() +my_agent = FoundryCBAgent(project_endpoint="mock-endpoint") my_agent.agent_run = agent_run if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py new file mode 100644 index 000000000000..fb6dc8858c86 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.ai.agentserver.core.server.common.id_generator._foundry_id_generator import FoundryIdGenerator + + +def test_conversation_id_none_uses_response_partition(): + response_id = FoundryIdGenerator._new_id("resp") + generator = FoundryIdGenerator(response_id=response_id, conversation_id=None) + + assert generator.conversation_id is None + + expected_partition = FoundryIdGenerator._extract_partition_id(response_id) + generated_id = generator.generate("msg") + assert FoundryIdGenerator._extract_partition_id(generated_id) == expected_partition + + +def test_conversation_id_present_uses_conversation_partition(): + response_id = FoundryIdGenerator._new_id("resp") + conversation_id = FoundryIdGenerator._new_id("conv") + generator = FoundryIdGenerator(response_id=response_id, conversation_id=conversation_id) + + assert generator.conversation_id == conversation_id + + expected_partition = FoundryIdGenerator._extract_partition_id(conversation_id) + generated_id = generator.generate("msg") + assert FoundryIdGenerator._extract_partition_id(generated_id) == expected_partition diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py new file mode 100644 index 000000000000..00137abecf15 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py @@ -0,0 +1,266 @@ +"""Tests for conversation persistence functionality in FoundryCBAgent.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + + +class MockAgentRunContext: + """Mock AgentRunContext for testing.""" + + def __init__(self, conversation_id=None, store=False, input_items=None): + self._conversation_id = conversation_id + self._request = { + "store": store, + "input": input_items or [], + } + + @property + def conversation_id(self): + return self._conversation_id + + @property + def request(self): + return self._request + + +class AsyncIteratorMock: + """Helper class to create an async iterator from a list.""" + + def __init__(self, items): + self.items = items + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.items): + raise StopAsyncIteration + item = self.items[self.index] + self.index += 1 + return item + + +def create_mock_agent(): + """Create a mock FoundryCBAgent without calling __init__.""" + from azure.ai.agentserver.core.server._base import FoundryCBAgent + + # Create instance without calling __init__ + agent = object.__new__(FoundryCBAgent) + agent._project_endpoint = None + agent.credentials = None + agent.tracer = None + agent._port = 8088 + return agent + + +@pytest.mark.unit +class TestShouldStore: + """Tests for _should_store method.""" + + def test_should_store_returns_true_when_all_conditions_met(self): + """Test that _should_store returns True when store=True, conversation_id exists, and endpoint exists.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + + context = MockAgentRunContext( + conversation_id="conv_123", + store=True, + ) + + result = agent._should_store(context) + assert result # Truthy value when all conditions met + + def test_should_store_returns_false_when_store_is_false(self): + """Test that _should_store returns False when store=False.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + + context = MockAgentRunContext( + conversation_id="conv_123", + store=False, + ) + + result = agent._should_store(context) + assert not result # Falsy value when store=False + + def test_should_store_returns_false_when_no_conversation_id(self): + """Test that _should_store returns False when conversation_id is None.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + + context = MockAgentRunContext( + conversation_id=None, + store=True, + ) + + result = agent._should_store(context) + assert not result # Falsy value when conversation_id is None + + def test_should_store_returns_false_when_no_endpoint(self): + """Test that _should_store returns False when project_endpoint is None.""" + agent = create_mock_agent() + agent._project_endpoint = None + + context = MockAgentRunContext( + conversation_id="conv_123", + store=True, + ) + + result = agent._should_store(context) + assert not result # Falsy value when endpoint is None + + +@pytest.mark.unit +class TestItemsAreEqual: + """Tests for _items_are_equal method.""" + + def test_items_equal_with_same_string_content(self): + """Test items are equal when type, role, and string content match.""" + agent = create_mock_agent() + + item1 = {"type": "message", "role": "user", "content": "Hello"} + item2 = {"type": "message", "role": "user", "content": "Hello"} + + assert agent._items_are_equal(item1, item2) is True + + def test_items_not_equal_with_different_content(self): + """Test items are not equal when content differs.""" + agent = create_mock_agent() + + item1 = {"type": "message", "role": "user", "content": "Hello"} + item2 = {"type": "message", "role": "user", "content": "Goodbye"} + + assert agent._items_are_equal(item1, item2) is False + + def test_items_not_equal_with_different_type(self): + """Test items are not equal when type differs.""" + agent = create_mock_agent() + + item1 = {"type": "message", "role": "user", "content": "Hello"} + item2 = {"type": "function_call", "role": "user", "content": "Hello"} + + assert agent._items_are_equal(item1, item2) is False + + def test_items_not_equal_with_different_role(self): + """Test items are not equal when role differs.""" + agent = create_mock_agent() + + item1 = {"type": "message", "role": "user", "content": "Hello"} + item2 = {"type": "message", "role": "assistant", "content": "Hello"} + + assert agent._items_are_equal(item1, item2) is False + + def test_items_equal_with_structured_content(self): + """Test items are equal when structured content text matches.""" + agent = create_mock_agent() + + item1 = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + item2 = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + + assert agent._items_are_equal(item1, item2) is True + + def test_items_not_equal_with_different_structured_content(self): + """Test items are not equal when structured content text differs.""" + agent = create_mock_agent() + + item1 = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Hello"}], + } + item2 = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Goodbye"}], + } + + assert agent._items_are_equal(item1, item2) is False + + +@pytest.mark.unit +class TestSaveInputToConversation: + """Tests for _save_input_to_conversation method.""" + + @pytest.mark.asyncio + async def test_save_input_skips_when_no_input_items(self): + """Test that save is skipped when there are no input items.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + agent._create_openai_client = AsyncMock() + + context = MockAgentRunContext( + conversation_id="conv_123", + store=True, + input_items=[], + ) + + await agent._save_input_to_conversation(context) + + # OpenAI client should not be created if no items + agent._create_openai_client.assert_not_called() + + @pytest.mark.asyncio + async def test_save_input_converts_string_to_message(self): + """Test that string input is converted to message format.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + + mock_client = AsyncMock() + mock_client.conversations.items.list = MagicMock(return_value=AsyncIteratorMock([])) + mock_client.conversations.items.create = AsyncMock() + agent._create_openai_client = AsyncMock(return_value=mock_client) + + context = MockAgentRunContext( + conversation_id="conv_123", + store=True, + input_items=["Hello, world!"], + ) + + await agent._save_input_to_conversation(context) + + mock_client.conversations.items.create.assert_called_once() + call_args = mock_client.conversations.items.create.call_args + assert call_args.kwargs["conversation_id"] == "conv_123" + assert call_args.kwargs["items"][0]["type"] == "message" + assert call_args.kwargs["items"][0]["role"] == "user" + assert call_args.kwargs["items"][0]["content"] == "Hello, world!" + + @pytest.mark.asyncio + async def test_save_input_skips_duplicates(self): + """Test that save is skipped when input matches last historical items.""" + agent = create_mock_agent() + agent._project_endpoint = "https://test.endpoint.com" + + # Create mock historical item + mock_historical_item = MagicMock() + mock_historical_item.model_dump = MagicMock(return_value={ + "type": "message", + "role": "user", + "content": "Hello, world!", + }) + + mock_client = AsyncMock() + mock_client.conversations.items.list = MagicMock( + return_value=AsyncIteratorMock([mock_historical_item]) + ) + mock_client.conversations.items.create = AsyncMock() + agent._create_openai_client = AsyncMock(return_value=mock_client) + + context = MockAgentRunContext( + conversation_id="conv_123", + store=True, + input_items=["Hello, world!"], + ) + + await agent._save_input_to_conversation(context) + + # Create should not be called because input is duplicate + mock_client.conversations.items.create.assert_not_called() diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py new file mode 100644 index 000000000000..f5ee395ff48f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py @@ -0,0 +1,14 @@ +import pytest +from opentelemetry import context as otel_context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + +@pytest.mark.asyncio +async def test_streaming_context_restore_uses_previous_context() -> None: + prev_ctx = otel_context.get_current() + ctx = TraceContextTextMapPropagator().extract(carrier={}) + + otel_context.attach(ctx) + otel_context.attach(prev_ctx) + + assert otel_context.get_current() is prev_ctx diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py new file mode 100644 index 000000000000..f01c4977cfb0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py @@ -0,0 +1,140 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json + +from azure.ai.agentserver.core.application import ( + PackageMetadata, + RuntimeMetadata, + get_current_app, + set_current_app, +) +from azure.ai.agentserver.core.models import Response as OpenAIResponse +from azure.ai.agentserver.core.models._projects import ResponseCreatedEvent, ResponseErrorEvent +from azure.ai.agentserver.core.server._response_metadata import ( + METADATA_KEY, + attach_foundry_metadata_to_response, + build_foundry_agents_metadata_headers, + try_attach_foundry_metadata_to_event, +) + + +def _set_test_app(): + previous = get_current_app() + set_current_app( + PackageMetadata( + name="test-package", + version="1.2.3", + ), + RuntimeMetadata( + python_version="3.11.0", + platform="test-platform", + host_name="test-host", + replica_name="test-replica", + ), + ) + return previous + + +def _expected_payload() -> dict[str, dict[str, str]]: + return { + "package": { + "name": "test-package", + "version": "1.2.3", + }, + "runtime": { + "python_version": "3.11.0", + "platform": "test-platform", + "host_name": "test-host", + "replica_name": "test-replica", + }, + } + + +def test_build_foundry_agents_metadata_headers_returns_json(): + previous = _set_test_app() + try: + headers = build_foundry_agents_metadata_headers() + payload = json.loads(headers["x-aml-foundry-agents-metadata"]) + assert payload == _expected_payload() + finally: + set_current_app(previous.package, previous.runtime) + + +def test_attach_foundry_metadata_to_response_sets_metadata_key(): + previous = _set_test_app() + try: + response = OpenAIResponse({"object": "response", "id": "resp", "metadata": {}}) + attach_foundry_metadata_to_response(response) + assert METADATA_KEY in response.metadata + assert json.loads(response.metadata[METADATA_KEY]) == _expected_payload() + finally: + set_current_app(previous.package, previous.runtime) + + +def test_try_attach_foundry_metadata_to_event_attaches_for_supported_events(): + previous = _set_test_app() + try: + response = OpenAIResponse({"object": "response", "id": "resp", "metadata": {}}) + event = ResponseCreatedEvent({"sequence_number": 0, "response": response}) + try_attach_foundry_metadata_to_event(event) + assert METADATA_KEY in response.metadata + + unsupported = ResponseErrorEvent( + {"sequence_number": 1, "code": "server_error", "message": "boom", "param": ""} + ) + try_attach_foundry_metadata_to_event(unsupported) + finally: + set_current_app(previous.package, previous.runtime) + + +def test_runtime_metadata_merge_overrides_non_empty_fields(): + base = RuntimeMetadata( + python_version="3.10.0", + platform="base-platform", + host_name="base-host", + replica_name="base-replica", + ) + override = RuntimeMetadata( + python_version="", + platform="override-platform", + host_name="", + replica_name="override-replica", + ) + + merged = base.merged_with(override) + + assert merged.python_version == "3.10.0" + assert merged.platform == "override-platform" + assert merged.host_name == "base-host" + assert merged.replica_name == "override-replica" + + +def test_runtime_metadata_resolve_falls_back_when_env_missing(monkeypatch): + monkeypatch.delenv("CONTAINER_APP_REVISION_FQDN", raising=False) + monkeypatch.delenv("CONTAINER_APP_REPLICA_NAME", raising=False) + runtime = RuntimeMetadata.resolve() + + assert runtime.host_name == "" + assert runtime.replica_name == "" + assert runtime.python_version + assert runtime.platform + + +def test_runtime_metadata_resolve_aca_env(monkeypatch): + monkeypatch.setenv("CONTAINER_APP_REVISION_FQDN", "aca-host") + monkeypatch.setenv("CONTAINER_APP_REPLICA_NAME", "aca-replica") + runtime = RuntimeMetadata.resolve() + + assert runtime.host_name == "aca-host" + assert runtime.replica_name == "aca-replica" + + +def test_runtime_metadata_resolve_explicit_overrides(monkeypatch): + monkeypatch.setenv("CONTAINER_APP_REVISION_FQDN", "aca-host") + monkeypatch.setenv("CONTAINER_APP_REPLICA_NAME", "aca-replica") + + runtime = RuntimeMetadata.resolve(host_name="override-host", replica_name="override-replica") + + assert runtime.host_name == "override-host" + assert runtime.replica_name == "override-replica" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py new file mode 100644 index 000000000000..35639ea8ae2c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py @@ -0,0 +1,144 @@ +"""Tests for logger functionality.""" +import os +import pytest +from unittest.mock import MagicMock, patch + + +@pytest.mark.unit +class TestGetProjectEndpoint: + """Tests for get_project_endpoint function.""" + + def test_returns_endpoint_from_azure_ai_project_endpoint_env(self): + """Test that endpoint is returned from AZURE_AI_PROJECT_ENDPOINT env var.""" + from azure.ai.agentserver.core.logger import get_project_endpoint + + with patch.dict(os.environ, {"AZURE_AI_PROJECT_ENDPOINT": "https://test.endpoint.com"}, clear=False): + result = get_project_endpoint() + assert result == "https://test.endpoint.com" + + def test_returns_none_when_no_env_vars_set(self): + """Test that None is returned when no environment variables are set.""" + from azure.ai.agentserver.core.logger import get_project_endpoint + from azure.ai.agentserver.core.constants import Constants + + # Temporarily remove the env vars if they exist + original_endpoint = os.environ.pop(Constants.AZURE_AI_PROJECT_ENDPOINT, None) + original_resource_id = os.environ.pop(Constants.AGENT_PROJECT_RESOURCE_ID, None) + + try: + result = get_project_endpoint() + assert result is None + finally: + # Restore original values + if original_endpoint: + os.environ[Constants.AZURE_AI_PROJECT_ENDPOINT] = original_endpoint + if original_resource_id: + os.environ[Constants.AGENT_PROJECT_RESOURCE_ID] = original_resource_id + + def test_derives_endpoint_from_agent_project_resource_id(self): + """Test that endpoint is derived from AGENT_PROJECT_RESOURCE_ID.""" + from azure.ai.agentserver.core.logger import get_project_endpoint + + resource_id = "/subscriptions/sub123/resourceGroups/rg/providers/Microsoft.MachineLearningServices/workspaces/account@project" + + with patch.dict(os.environ, { + "AZURE_AI_PROJECT_ENDPOINT": "", + "AGENT_PROJECT_RESOURCE_ID": resource_id, + }, clear=False): + result = get_project_endpoint() + # When AZURE_AI_PROJECT_ENDPOINT is empty, it should fall through to resource ID + assert result is None or "account" in str(result) or result == "" + + def test_logs_debug_when_logger_provided(self): + """Test that debug message is logged when logger is provided.""" + from azure.ai.agentserver.core.logger import get_project_endpoint + + mock_logger = MagicMock() + + with patch.dict(os.environ, {"AZURE_AI_PROJECT_ENDPOINT": "https://test.endpoint.com"}, clear=False): + result = get_project_endpoint(logger=mock_logger) + # Logger debug should be called when endpoint is found + assert mock_logger.debug.called or result == "https://test.endpoint.com" + + def test_logs_warning_for_invalid_resource_id(self): + """Test that warning is logged for invalid resource ID format.""" + from azure.ai.agentserver.core.logger import get_project_endpoint + + mock_logger = MagicMock() + invalid_resource_id = "/invalid/resource/id" + + with patch.dict(os.environ, { + "AZURE_AI_PROJECT_ENDPOINT": "", + "AGENT_PROJECT_RESOURCE_ID": invalid_resource_id, + }, clear=False): + result = get_project_endpoint(logger=mock_logger) + # Result should be None for invalid resource ID + assert result is None or result == "" + + +@pytest.mark.unit +class TestGetApplicationInsightsConnstr: + """Tests for _get_application_insights_connstr function.""" + + def test_returns_connstr_from_env_var(self): + """Test that connection string is returned from environment variable.""" + from azure.ai.agentserver.core.logger import _get_application_insights_connstr + + with patch.dict(os.environ, {"APPLICATIONINSIGHTS_CONNECTION_STRING": "InstrumentationKey=test123"}, clear=False): + result = _get_application_insights_connstr() + assert result == "InstrumentationKey=test123" + + def test_returns_none_when_no_connstr_and_no_project(self): + """Test that None is returned when no connection string and no project endpoint.""" + from azure.ai.agentserver.core.logger import _get_application_insights_connstr + + with patch.dict(os.environ, { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "", + "AZURE_AI_PROJECT_ENDPOINT": "", + "AGENT_PROJECT_RESOURCE_ID": "", + }, clear=False): + result = _get_application_insights_connstr() + assert result is None or result == "" + + def test_logs_debug_when_not_configured(self): + """Test that debug message is logged when not configured.""" + from azure.ai.agentserver.core.logger import _get_application_insights_connstr + + mock_logger = MagicMock() + + with patch.dict(os.environ, { + "APPLICATIONINSIGHTS_CONNECTION_STRING": "", + "AZURE_AI_PROJECT_ENDPOINT": "", + "AGENT_PROJECT_RESOURCE_ID": "", + }, clear=False): + result = _get_application_insights_connstr(logger=mock_logger) + # Debug should be called when not configured, or result should be None + assert mock_logger.debug.called or result is None or result == "" + + +@pytest.mark.unit +class TestLoggerConfiguration: + """Tests for logger configuration.""" + + def test_default_log_config_uses_stdout(self): + """Test that the default log config uses stdout stream.""" + from azure.ai.agentserver.core.logger import _get_default_log_config + + config = _get_default_log_config() + console_handler = config["handlers"]["console"] + + assert console_handler["stream"] == "ext://sys.stdout" + + def test_default_log_config_has_correct_format(self): + """Test that the default log config has the expected format.""" + from azure.ai.agentserver.core.logger import _get_default_log_config + + config = _get_default_log_config() + + assert "version" in config + assert config["version"] == 1 + assert "loggers" in config + assert "azure.ai.agentserver" in config["loggers"] + assert "handlers" in config + assert "console" in config["handlers"] + assert "formatters" in config diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py new file mode 100644 index 000000000000..e7273f37a7e7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryConnectedToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryToolDetails, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_connected_tools import ( + FoundryConnectedToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import OAuthConsentRequiredError, ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryConnectedToolsOperationsListTools: + """Tests for FoundryConnectedToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when tools list is empty.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + result = await ops.list_tools([], None, "test-agent") + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_tools_from_server( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns tools from server response.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "remote_tool", + "description": "A remote connected tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"} + } + } + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_connected_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == "remote_tool" + assert details.description == "A remote connected tool" + + @pytest.mark.asyncio + async def test_list_tools_without_user_info(self, sample_connected_tool): + """Test list_tools works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_no_user", + "description": "Tool without user", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], None, "test-agent")) + + assert len(result) == 1 + assert result[0][1].name == "tool_no_user" + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_connections(self, sample_user_info): + """Test list_tools with multiple connected tool definitions.""" + mock_client = AsyncMock() + + tool1 = FoundryConnectedTool(protocol="mcp", project_connection_id="conn-1") + tool2 = FoundryConnectedTool(protocol="a2a", project_connection_id="conn-2") + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "conn-1" + }, + "manifest": [ + { + "name": "tool_from_conn1", + "description": "From connection 1", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "a2a", + "projectConnectionId": "conn-2" + }, + "manifest": [ + { + "name": "tool_from_conn2", + "description": "From connection 2", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2], sample_user_info, "test-agent")) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_from_conn1", "tool_from_conn2"} + + @pytest.mark.asyncio + async def test_list_tools_filters_by_connection_id(self, sample_user_info): + """Test list_tools only returns tools from requested connections.""" + mock_client = AsyncMock() + + requested_tool = FoundryConnectedTool(protocol="mcp", project_connection_id="requested-conn") + + # Server returns tools from multiple connections, but we only requested one + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "requested-conn" + }, + "manifest": [ + { + "name": "requested_tool", + "description": "Requested", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "unrequested-conn" + }, + "manifest": [ + { + "name": "unrequested_tool", + "description": "Not requested", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([requested_tool], sample_user_info, "test-agent")) + + # Should only return tools from requested connection + assert len(result) == 1 + assert result[0][1].name == "requested_tool" + + @pytest.mark.asyncio + async def test_list_tools_multiple_tools_per_connection( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns multiple tools from same connection.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_one", + "description": "First tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_three", + "description": "Third tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 3 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two", "tool_three"} + + @pytest.mark.asyncio + async def test_list_tools_raises_oauth_consent_error( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/consent", + "message": "User consent is required to access this resource", + "projectConnectionId": sample_connected_tool.project_connection_id + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert exc_info.value.consent_url == "https://login.microsoftonline.com/consent" + assert "consent" in exc_info.value.message.lower() + + +class TestFoundryConnectedToolsOperationsInvokeTool: + """Tests for FoundryConnectedToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_result_value( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns the result value from server.""" + mock_client = AsyncMock() + + expected_result = {"data": "some output", "status": "success"} + response_data = {"toolResult": expected_result} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_tool_without_user_info(self, sample_resolved_connected_tool): + """Test invoke_tool works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = {"toolResult": "local result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + None, # No user info + "test-agent" + ) + + assert result == "local result" + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "processed"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + complex_args = { + "query": "search term", + "filters": { + "date_range": {"start": "2025-01-01", "end": "2025-12-31"}, + "categories": ["A", "B", "C"] + }, + "limit": 50 + } + + result = await ops.invoke_tool( + sample_resolved_connected_tool, + complex_args, + sample_user_info, + "test-agent" + ) + + assert result == "processed" + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_returns_none_for_empty_result( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns None when server returns no result.""" + mock_client = AsyncMock() + + # Server returns empty response (no toolResult) + response_data = { + "toolResult": None + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert result is None + + @pytest.mark.asyncio + async def test_invoke_tool_with_mcp_tool_raises_error( + self, + sample_resolved_mcp_tool, + sample_user_info + ): + """Test invoke_tool raises ToolInvocationError for non-connected tool.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool( + sample_resolved_mcp_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert "not a Foundry connected tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_raises_oauth_consent_error( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/oauth/consent", + "message": "Please provide consent to continue", + "projectConnectionId": "test-connection-id" + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert "https://login.microsoftonline.com/oauth/consent" in exc_info.value.consent_url + + @pytest.mark.asyncio + async def test_invoke_tool_with_different_agent_names( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool uses correct agent name in request.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + # Invoke with different agent names + for agent_name in ["agent-1", "my-custom-agent", "production-agent"]: + await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + agent_name + ) + + # Verify the correct path was used + call_args = mock_client.post.call_args + assert agent_name in call_args[0][0] + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..473b27cc8768 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py @@ -0,0 +1,309 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryMcpToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_hosted_mcp_tools import ( + FoundryMcpToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryMcpToolsOperationsListTools: + """Tests for FoundryMcpToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when allowed_tools is empty.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + result = await ops.list_tools([]) + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_matching_tools(self, sample_hosted_mcp_tool): + """Test list_tools returns tools that match the allowed list.""" + mock_client = AsyncMock() + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_hosted_mcp_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == sample_hosted_mcp_tool.name + assert details.description == "Test MCP tool" + + @pytest.mark.asyncio + async def test_list_tools_filters_out_non_allowed_tools(self, sample_hosted_mcp_tool): + """Test list_tools only returns tools in the allowed list.""" + mock_client = AsyncMock() + + # Server returns multiple tools but only one is allowed + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "other_tool_not_in_list", + "description": "Not allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "another_unlisted_tool", + "description": "Also not allowed", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + assert result[0][1].name == sample_hosted_mcp_tool.name + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_allowed_tools(self): + """Test list_tools with multiple tools in allowed list.""" + mock_client = AsyncMock() + + tool1 = FoundryHostedMcpTool(name="tool_one") + tool2 = FoundryHostedMcpTool(name="tool_two") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_one", + "description": "First tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2])) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two"} + + @pytest.mark.asyncio + async def test_list_tools_preserves_tool_metadata(self): + """Test list_tools preserves metadata from server response.""" + mock_client = AsyncMock() + + tool = FoundryHostedMcpTool(name="tool_with_meta") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_with_meta", + "description": "Tool with metadata", + "inputSchema": { + "type": "object", + "properties": { + "param1": {"type": "string"} + }, + "required": ["param1"] + }, + "_meta": { + "type": "object", + "properties": { + "model": {"type": "string"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool])) + + assert len(result) == 1 + details = result[0][1] + assert details.metadata is not None + + +class TestFoundryMcpToolsOperationsInvokeTool: + """Tests for FoundryMcpToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_server_response(self, sample_resolved_mcp_tool): + """Test invoke_tool returns the response from server.""" + mock_client = AsyncMock() + + expected_response = { + "jsonrpc": "2.0", + "id": 2, + "result": { + "content": [{"type": "text", "text": "Hello World"}] + } + } + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {"query": "test"}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_empty_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool works with empty arguments.""" + mock_client = AsyncMock() + + expected_response = {"result": "success"} + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + mock_response = create_mock_http_response(200, {"result": "ok"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + complex_args = { + "text": "sample text", + "options": { + "temperature": 0.7, + "max_tokens": 100 + }, + "tags": ["tag1", "tag2"] + } + + result = await ops.invoke_tool(sample_resolved_mcp_tool, complex_args) + + assert result == {"result": "ok"} + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_with_connected_tool_raises_error( + self, + sample_resolved_connected_tool + ): + """Test invoke_tool raises ToolInvocationError for non-MCP tool.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool(sample_resolved_connected_tool, {}) + + assert "not a Foundry-hosted MCP tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_with_configuration_and_metadata(self): + """Test invoke_tool handles tool with configuration and metadata.""" + mock_client = AsyncMock() + + # Create tool with configuration + tool_def = FoundryHostedMcpTool( + name="image_generation", + configuration={"model_deployment_name": "dall-e-3"} + ) + + # Create tool details with metadata schema + meta_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "model": SchemaProperty(type=SchemaType.STRING) + } + ) + details = FoundryToolDetails( + name="image_generation", + description="Generate images", + input_schema=SchemaDefinition(type=SchemaType.OBJECT, properties={}), + metadata=meta_schema + ) + resolved_tool = ResolvedFoundryTool(definition=tool_def, details=details) + + mock_response = create_mock_http_response(200, {"result": "image_url"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(resolved_tool, {"prompt": "a cat"}) + + assert result == {"result": "image_url"} + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py new file mode 100644 index 000000000000..de60f545e089 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py @@ -0,0 +1,485 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClient - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.client._client import FoundryToolClient +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ..conftest import create_mock_http_response + + +class TestFoundryToolClientInit: + """Tests for FoundryToolClient.__init__ public method.""" + + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + def test_init_with_valid_endpoint_and_credential(self, mock_pipeline_client_class, mock_credential): + """Test client can be initialized with valid endpoint and credential.""" + endpoint = "https://fake-project-endpoint.site" + + client = FoundryToolClient(endpoint, mock_credential) + + # Verify client was created with correct base_url + call_kwargs = mock_pipeline_client_class.call_args + assert call_kwargs[1]["base_url"] == endpoint + assert client is not None + + +class TestFoundryToolClientListTools: + """Tests for FoundryToolClient.list_tools public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_empty_collection_returns_empty_list( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test list_tools returns empty list when given empty collection.""" + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + result = await client.list_tools([], agent_name="test-agent") + + assert result == [] + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_mcp_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools with a single MCP tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for MCP tools listing + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool description", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == sample_hosted_mcp_tool.name + assert result[0].source == FoundryToolSource.HOSTED_MCP + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_connected_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with a single connected tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for connected tools listing + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_test_tool", + "description": "Test connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == "connected_test_tool" + assert result[0].source == FoundryToolSource.CONNECTED + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_mixed_tool_types_returns_all_resolved( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with both MCP and connected tools returns all resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # We need to return different responses based on the request + mcp_response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "MCP tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + connected_response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_tool", + "description": "Connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + + # Mock to return different responses for different requests + mock_client_instance.send_request.side_effect = [ + create_mock_http_response(200, mcp_response_data), + create_mock_http_response(200, connected_response_data) + ] + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_hosted_mcp_tool, sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 2 + sources = {tool.source for tool in result} + assert FoundryToolSource.HOSTED_MCP in sources + assert FoundryToolSource.CONNECTED in sources + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_filters_unlisted_mcp_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools only returns tools that are in the allowed list.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns more tools than requested + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Requested tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "unrequested_tool", + "description": "This tool was not requested", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + # Should only return the requested tool + assert len(result) == 1 + assert result[0].name == sample_hosted_mcp_tool.name + + +class TestFoundryToolClientListToolsDetails: + """Tests for FoundryToolClient.list_tools_details public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_returns_mapping_structure( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details returns correct mapping structure.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + assert isinstance(result, dict) + assert sample_hosted_mcp_tool.id in result + assert len(result[sample_hosted_mcp_tool.id]) == 1 + assert isinstance(result[sample_hosted_mcp_tool.id][0], FoundryToolDetails) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_groups_multiple_tools_by_definition( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details groups multiple tools from same source by definition ID.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns multiple tools for the same MCP source + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Tool variant 1", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + # All tools should be grouped under the same definition ID + assert sample_hosted_mcp_tool.id in result + + +class TestFoundryToolClientInvokeTool: + """Tests for FoundryToolClient.invoke_tool public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_mcp_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool with MCP tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_result = {"result": {"content": [{"text": "Hello World"}]}} + mock_response = create_mock_http_response(200, expected_result) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_connected_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool with connected tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_value = {"output": "Connected tool result"} + response_data = {"toolResult": expected_value} + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_connected_tool, + arguments={"input": "test"}, + agent_name="test-agent", + user=sample_user_info + ) + + assert result == expected_value + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_complex_arguments( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool correctly passes complex arguments.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + mock_response = create_mock_http_response(200, {"result": "success"}) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + complex_args = { + "string_param": "value", + "number_param": 42, + "bool_param": True, + "list_param": [1, 2, 3], + "nested_param": {"key": "value"} + } + + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments=complex_args, + agent_name="test-agent" + ) + + # Verify request was made + mock_client_instance.send_request.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_unsupported_source_raises_error( + self, + mock_pipeline_client_class, + mock_credential, + sample_tool_details + ): + """Test invoke_tool raises ToolInvocationError for unsupported tool source.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Create a mock tool with unsupported source + mock_definition = MagicMock() + mock_definition.source = "unsupported_source" + mock_tool = MagicMock(spec=ResolvedFoundryTool) + mock_tool.definition = mock_definition + mock_tool.source = "unsupported_source" + mock_tool.details = sample_tool_details + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + with pytest.raises(ToolInvocationError) as exc_info: + await client.invoke_tool( + mock_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert "Unsupported tool source" in str(exc_info.value) + + +class TestFoundryToolClientClose: + """Tests for FoundryToolClient.close public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_close_closes_underlying_client( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test close() properly closes the underlying HTTP client.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + await client.close() + + mock_client_instance.close.assert_called_once() + + +class TestFoundryToolClientContextManager: + """Tests for FoundryToolClient async context manager protocol.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_async_context_manager_enters_and_exits( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test client can be used as async context manager.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + async with FoundryToolClient("https://fake-project-endpoint.site", mock_credential) as client: + assert client is not None + mock_client_instance.__aenter__.assert_called_once() + + mock_client_instance.__aexit__.assert_called_once() + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py new file mode 100644 index 000000000000..2f3c2710a3fc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClientConfiguration.""" + +from azure.core.pipeline import policies + +from azure.ai.agentserver.core.tools.client._configuration import FoundryToolClientConfiguration + + +class TestFoundryToolClientConfiguration: + """Tests for FoundryToolClientConfiguration class.""" + + def test_init_creates_all_required_policies(self, mock_credential): + """Test that initialization creates all required pipeline policies.""" + config = FoundryToolClientConfiguration(mock_credential) + + assert isinstance(config.retry_policy, policies.AsyncRetryPolicy) + assert isinstance(config.logging_policy, policies.NetworkTraceLoggingPolicy) + assert isinstance(config.request_id_policy, policies.RequestIdPolicy) + assert isinstance(config.http_logging_policy, policies.HttpLoggingPolicy) + assert isinstance(config.user_agent_policy, policies.UserAgentPolicy) + assert isinstance(config.authentication_policy, policies.AsyncBearerTokenCredentialPolicy) + assert isinstance(config.redirect_policy, policies.AsyncRedirectPolicy) + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py new file mode 100644 index 000000000000..8849ce8aafbf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py @@ -0,0 +1,127 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for tools unit tests.""" +import json +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) + + +@pytest.fixture +def mock_credential(): + """Create a mock async token credential.""" + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=MagicMock(token="test-token")) + return credential + + +@pytest.fixture +def sample_user_info(): + """Create a sample UserInfo instance.""" + return UserInfo(object_id="test-object-id", tenant_id="test-tenant-id") + + +@pytest.fixture +def sample_hosted_mcp_tool(): + """Create a sample FoundryHostedMcpTool.""" + return FoundryHostedMcpTool( + name="test_mcp_tool", + configuration={"model_deployment_name": "gpt-4"} + ) + + +@pytest.fixture +def sample_connected_tool(): + """Create a sample FoundryConnectedTool.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="test-connection-id" + ) + + +@pytest.fixture +def sample_schema_definition(): + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input parameter") + }, + required={"input"} + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition): + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=sample_schema_definition + ) + + +@pytest.fixture +def sample_resolved_mcp_tool(sample_hosted_mcp_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for MCP.""" + return ResolvedFoundryTool( + definition=sample_hosted_mcp_tool, + details=sample_tool_details + ) + + +@pytest.fixture +def sample_resolved_connected_tool(sample_connected_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for connected tools.""" + return ResolvedFoundryTool( + definition=sample_connected_tool, + details=sample_tool_details + ) + + +def create_mock_http_response( + status_code: int = 200, + json_data: Optional[Dict[str, Any]] = None +) -> AsyncMock: + """Create a mock HTTP response that simulates real Azure SDK response behavior. + + This mock matches the behavior expected by BaseOperations._extract_response_json, + where response.text() and response.body() are synchronous methods that return + the actual string/bytes values directly. + + :param status_code: HTTP status code. + :param json_data: JSON data to return. + :return: Mock response object. + """ + response = AsyncMock() + response.status_code = status_code + + if json_data is not None: + json_str = json.dumps(json_data) + json_bytes = json_str.encode("utf-8") + # text() and body() are synchronous methods in AsyncHttpResponse + # They must be MagicMock (not AsyncMock) to return values directly when called + response.text = MagicMock(return_value=json_str) + response.body = MagicMock(return_value=json_bytes) + else: + response.text = MagicMock(return_value="") + response.body = MagicMock(return_value=b"") + + # Support async context manager + response.__aenter__ = AsyncMock(return_value=response) + response.__aexit__ = AsyncMock(return_value=None) + + return response diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py new file mode 100644 index 000000000000..964fac9d8a55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Runtime unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py new file mode 100644 index 000000000000..52a371bdc958 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for runtime unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_foundry_tool_client(): + """Create a mock FoundryToolClient.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=[]) + client.list_tools_details = AsyncMock(return_value={}) + client.invoke_tool = AsyncMock(return_value={"result": "success"}) + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + +@pytest.fixture +def mock_user_provider(sample_user_info): + """Create a mock UserProvider.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=sample_user_info) + return provider + + +@pytest.fixture +def mock_user_provider_none(): + """Create a mock UserProvider that returns None.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=None) + return provider + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py new file mode 100644 index 000000000000..a7013a7537e1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py @@ -0,0 +1,350 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _catalog.py - testing public methods of DefaultFoundryToolCatalog.""" +import asyncio +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools import ensure_foundry_tool +from azure.ai.agentserver.core.tools.runtime._catalog import ( + DefaultFoundryToolCatalog, +) +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + ResolvedFoundryTool, + UserInfo, +) + + +class TestFoundryToolCatalogGet: + """Tests for FoundryToolCatalog.get method.""" + + @pytest.mark.asyncio + async def test_get_returns_resolved_tool_when_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details, + sample_user_info + ): + """Test get returns a resolved tool when the tool is found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is not None + assert isinstance(result, ResolvedFoundryTool) + assert result.details == sample_tool_details + + @pytest.mark.asyncio + async def test_get_returns_none_when_not_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test get returns None when the tool is not found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock(return_value={}) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is None + + +class TestDefaultFoundryToolCatalogList: + """Tests for DefaultFoundryToolCatalog.list method.""" + + @pytest.mark.asyncio + async def test_list_returns_empty_list_when_no_tools( + self, + mock_foundry_tool_client, + mock_user_provider + ): + """Test list returns empty list when no tools are provided.""" + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([]) + + assert result == [] + + @pytest.mark.asyncio + async def test_list_returns_resolved_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test list returns resolved tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].definition == sample_hosted_mcp_tool + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_multiple_tools_with_multiple_details( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_schema_definition + ): + """Test list returns all resolved tools when tools have multiple details.""" + details1 = FoundryToolDetails( + name="tool1", + description="First tool", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="tool2", + description="Second tool", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={ + sample_hosted_mcp_tool.id: [details1], + sample_connected_tool.id: [details2] + } + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool, sample_connected_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"tool1", "tool2"} + + @pytest.mark.asyncio + async def test_list_caches_results_for_hosted_mcp_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that list caches results for hosted MCP tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # First call + result1 = await catalog.list([sample_hosted_mcp_tool]) + # Second call should use cache + result2 = await catalog.list([sample_hosted_mcp_tool]) + + # Client should only be called once + assert mock_foundry_tool_client.list_tools_details.call_count == 1 + assert len(result1) == len(result2) == 1 + + @pytest.mark.asyncio + async def test_list_with_facade_dict( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_tool_details + ): + """Test list works with facade dictionaries.""" + facade = {"type": "custom_tool", "config": "value"} + expected_id = ensure_foundry_tool(facade).id + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={expected_id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([facade]) + + assert len(result) == 1 + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_returns_multiple_details_per_tool( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_schema_definition + ): + """Test list returns multiple resolved tools when a tool has multiple details.""" + details1 = FoundryToolDetails( + name="function1", + description="First function", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function2", + description="Second function", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [details1, details2]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"function1", "function2"} + + @pytest.mark.asyncio + async def test_list_handles_exception_from_client( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test list propagates exception from client and clears cache.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + side_effect=RuntimeError("Network error") + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Network error"): + await catalog.list([sample_hosted_mcp_tool]) + + @pytest.mark.asyncio + async def test_list_connected_tool_cache_key_includes_user( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_tool_details, + sample_user_info + ): + """Test that connected tool cache key includes user info.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_connected_tool.id: [sample_tool_details]} + ) + + # Create a new user provider returning a different user + other_user = UserInfo(object_id="other-oid", tenant_id="other-tid") + mock_user_provider2 = AsyncMock() + mock_user_provider2.get_user = AsyncMock(return_value=other_user) + + catalog1 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + catalog2 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider2, + agent_name="test-agent" + ) + + # Both catalogs should be able to list tools + result1 = await catalog1.list([sample_connected_tool]) + result2 = await catalog2.list([sample_connected_tool]) + + assert len(result1) == 1 + assert len(result2) == 1 + + +class TestCachedFoundryToolCatalogConcurrency: + """Tests for CachedFoundryToolCatalog concurrency handling.""" + + @pytest.mark.asyncio + async def test_concurrent_requests_share_single_fetch( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that concurrent requests for the same tool share a single fetch.""" + call_count = 0 + fetch_event = asyncio.Event() + + async def slow_fetch(*args, **kwargs): + nonlocal call_count + call_count += 1 + await fetch_event.wait() + return {sample_hosted_mcp_tool.id: [sample_tool_details]} + + mock_foundry_tool_client.list_tools_details = slow_fetch + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # Start two concurrent requests + task1 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + task2 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + + # Allow tasks to start + await asyncio.sleep(0.01) + + # Release the fetch + fetch_event.set() + + results = await asyncio.gather(task1, task2) + + # Both should get results, but fetch should only be called once + assert len(results[0]) == 1 + assert len(results[1]) == 1 + assert call_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py new file mode 100644 index 000000000000..c5377dc339a4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _facade.py - testing public function ensure_foundry_tool.""" +import pytest + +from azure.ai.agentserver.core.tools.runtime._facade import ensure_foundry_tool +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolProtocol, + FoundryToolSource, +) +from azure.ai.agentserver.core.tools._exceptions import InvalidToolFacadeError + + +class TestEnsureFoundryTool: + """Tests for ensure_foundry_tool public function.""" + + def test_returns_same_instance_when_given_foundry_tool(self, sample_hosted_mcp_tool): + """Test that passing a FoundryTool returns the same instance.""" + result = ensure_foundry_tool(sample_hosted_mcp_tool) + + assert result is sample_hosted_mcp_tool + + def test_returns_same_instance_for_connected_tool(self, sample_connected_tool): + """Test that passing a FoundryConnectedTool returns the same instance.""" + result = ensure_foundry_tool(sample_connected_tool) + + assert result is sample_connected_tool + + def test_converts_facade_with_mcp_protocol_to_connected_tool(self): + """Test that a facade with 'mcp' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "mcp", + "project_connection_id": "my-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.MCP + assert result.project_connection_id == "my-connection" + assert result.source == FoundryToolSource.CONNECTED + + def test_converts_facade_with_a2a_protocol_to_connected_tool(self): + """Test that a facade with 'a2a' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "a2a", + "project_connection_id": "my-a2a-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.A2A + assert result.project_connection_id == "my-a2a-connection" + + def test_converts_facade_with_unknown_type_to_hosted_mcp_tool(self): + """Test that a facade with unknown type is converted to FoundryHostedMcpTool.""" + facade = { + "type": "my_custom_tool", + "some_config": "value123", + "another_config": True + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "my_custom_tool" + assert result.configuration == {"some_config": "value123", "another_config": True} + assert result.source == FoundryToolSource.HOSTED_MCP + + def test_raises_error_when_type_is_missing(self): + """Test that InvalidToolFacadeError is raised when 'type' is missing.""" + facade = {"project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_empty_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is empty string.""" + facade = {"type": "", "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_not_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is not a string.""" + facade = {"type": 123, "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_mcp_protocol_missing_connection_id(self): + """Test that InvalidToolFacadeError is raised when mcp protocol is missing project_connection_id.""" + facade = {"type": "mcp"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_raises_error_when_a2a_protocol_has_empty_connection_id(self): + """Test that InvalidToolFacadeError is raised when a2a protocol has empty project_connection_id.""" + facade = {"type": "a2a", "project_connection_id": ""} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_parses_resource_id_format_connection_id(self): + """Test that resource ID format project_connection_id is parsed correctly.""" + resource_id = ( + "/subscriptions/sub-123/resourceGroups/rg-test/providers/" + "Microsoft.CognitiveServices/accounts/acc-test/projects/proj-test/connections/my-conn-name" + ) + facade = { + "type": "mcp", + "project_connection_id": resource_id + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "my-conn-name" + + def test_raises_error_for_invalid_resource_id_format(self): + """Test that InvalidToolFacadeError is raised for invalid resource ID format.""" + invalid_resource_id = "/subscriptions/sub-123/invalid/path" + facade = { + "type": "mcp", + "project_connection_id": invalid_resource_id + } + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "Invalid resource ID format" in str(exc_info.value) + + def test_uses_simple_connection_name_as_is(self): + """Test that simple connection name is used as-is without parsing.""" + facade = { + "type": "mcp", + "project_connection_id": "simple-connection-name" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "simple-connection-name" + + def test_original_facade_not_modified(self): + """Test that the original facade dictionary is not modified.""" + facade = { + "type": "my_tool", + "config_key": "config_value" + } + original_facade = facade.copy() + + ensure_foundry_tool(facade) + + assert facade == original_facade + + def test_hosted_mcp_tool_with_no_extra_configuration(self): + """Test that hosted MCP tool works with no extra configuration.""" + facade = {"type": "simple_tool"} + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "simple_tool" + assert result.configuration == {} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py new file mode 100644 index 000000000000..b2a222c09d6e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py @@ -0,0 +1,198 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _invoker.py - testing public methods of DefaultFoundryToolInvoker.""" +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker + + +class TestDefaultFoundryToolInvokerResolvedTool: + """Tests for DefaultFoundryToolInvoker.resolved_tool property.""" + + def test_resolved_tool_returns_tool_passed_at_init( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns the tool passed during initialization.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_mcp_tool + + def test_resolved_tool_returns_connected_tool( + self, + sample_resolved_connected_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns connected tool.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_connected_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_connected_tool + + +class TestDefaultFoundryToolInvokerInvoke: + """Tests for DefaultFoundryToolInvoker.invoke method.""" + + @pytest.mark.asyncio + async def test_invoke_calls_client_with_correct_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke calls client.invoke_tool with correct arguments.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + arguments = {"input": "test value", "count": 5} + + await invoker.invoke(arguments) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + arguments, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_returns_result_from_client( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke returns the result from client.invoke_tool.""" + expected_result = {"output": "test result", "status": "completed"} + mock_foundry_tool_client.invoke_tool = AsyncMock(return_value=expected_result) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await invoker.invoke({"input": "test"}) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_with_empty_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke works with empty arguments dictionary.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke({}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {}, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_with_none_user( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider_none + ): + """Test invoke works when user provider returns None.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider_none, + agent_name="test-agent" + ) + + await invoker.invoke({"input": "test"}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {"input": "test"}, + "test-agent", + None + ) + + @pytest.mark.asyncio + async def test_invoke_propagates_client_exception( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke propagates exceptions from client.invoke_tool.""" + mock_foundry_tool_client.invoke_tool = AsyncMock( + side_effect=RuntimeError("Client error") + ) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Client error"): + await invoker.invoke({"input": "test"}) + + @pytest.mark.asyncio + async def test_invoke_with_complex_nested_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke with complex nested argument structure.""" + complex_args = { + "nested": {"key1": "value1", "key2": 123}, + "list": [1, 2, 3], + "mixed": [{"a": 1}, {"b": 2}] + } + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke(complex_args) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + complex_args, + "test-agent", + sample_user_info + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py new file mode 100644 index 000000000000..7bdaa8f957a9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py @@ -0,0 +1,202 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _resolver.py - testing public methods of DefaultFoundryToolInvocationResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker +from azure.ai.agentserver.core.tools._exceptions import UnableToResolveToolInvocationError +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, +) + + +class TestDefaultFoundryToolInvocationResolverResolve: + """Tests for DefaultFoundryToolInvocationResolver.resolve method.""" + + @pytest.fixture + def mock_catalog(self, sample_resolved_mcp_tool): + """Create a mock FoundryToolCatalog.""" + catalog = AsyncMock() + catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + catalog.list = AsyncMock(return_value=[sample_resolved_mcp_tool]) + return catalog + + @pytest.mark.asyncio + async def test_resolve_with_resolved_tool_returns_invoker_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve returns invoker directly when given ResolvedFoundryTool.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + assert invoker.resolved_tool is sample_resolved_mcp_tool + # Catalog should not be called when ResolvedFoundryTool is passed + mock_catalog.get.assert_not_called() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tool_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_resolved_mcp_tool + ): + """Test resolve uses catalog to resolve FoundryTool.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_hosted_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_hosted_mcp_tool) + + @pytest.mark.asyncio + async def test_resolve_with_facade_dict_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_connected_tool + ): + """Test resolve converts facade dict and uses catalog.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + facade = { + "type": "mcp", + "project_connection_id": "test-connection" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once() + # Verify the facade was converted to FoundryConnectedTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryConnectedTool) + + @pytest.mark.asyncio + async def test_resolve_raises_error_when_tool_not_found_in_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test resolve raises UnableToResolveToolInvocationError when catalog returns None.""" + mock_catalog.get = AsyncMock(return_value=None) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(UnableToResolveToolInvocationError) as exc_info: + await resolver.resolve(sample_hosted_mcp_tool) + + assert exc_info.value.tool is sample_hosted_mcp_tool + assert "Unable to resolve tool" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_resolve_with_hosted_mcp_facade( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve with hosted MCP facade (unknown type becomes FoundryHostedMcpTool).""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + facade = { + "type": "custom_mcp_tool", + "config_key": "config_value" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + # Verify the facade was converted to FoundryHostedMcpTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryHostedMcpTool) + assert call_arg.name == "custom_mcp_tool" + + @pytest.mark.asyncio + async def test_resolve_returns_invoker_with_correct_agent_name( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve creates invoker with the correct agent name.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="custom-agent-name" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + # Verify invoker was created with correct agent name by checking internal state + assert invoker._agent_name == "custom-agent-name" + + @pytest.mark.asyncio + async def test_resolve_with_connected_tool_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_resolved_connected_tool + ): + """Test resolve with FoundryConnectedTool directly.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_connected_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_connected_tool) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py new file mode 100644 index 000000000000..a99935662941 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py @@ -0,0 +1,401 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _runtime.py - testing public methods of DefaultFoundryToolRuntime.""" +import os +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.runtime._runtime import ( + create_tool_runtime, + DefaultFoundryToolRuntime, + ThrowingFoundryToolRuntime, +) +from azure.ai.agentserver.core.tools.runtime._catalog import DefaultFoundryToolCatalog +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + +class TestDefaultFoundryToolRuntimeInit: + """Tests for DefaultFoundryToolRuntime initialization.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_creates_client_with_endpoint_and_credential( + self, + mock_client_class, + mock_credential + ): + """Test initialization creates client with correct endpoint and credential.""" + endpoint = "https://test-project.azure.com" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint=endpoint, + credential=mock_credential + ) + + mock_client_class.assert_called_once_with( + endpoint=endpoint, + credential=mock_credential + ) + assert runtime is not None + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_user_provider_when_none_provided( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses ContextVarUserProvider when user_provider is None.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime._user_provider, ContextVarUserProvider) + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_custom_user_provider( + self, + mock_client_class, + mock_credential, + mock_user_provider + ): + """Test initialization uses custom user provider when provided.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential, + user_provider=mock_user_provider + ) + + assert runtime._user_provider is mock_user_provider + + @patch.dict(os.environ, {"AGENT_NAME": "custom-agent"}) + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_reads_agent_name_from_environment( + self, + mock_client_class, + mock_credential + ): + """Test initialization reads agent name from environment variable.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "custom-agent" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_agent_name_when_env_not_set( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses default agent name when env var is not set.""" + mock_client_class.return_value = MagicMock() + + # Ensure AGENT_NAME is not set + env_copy = os.environ.copy() + if "AGENT_NAME" in env_copy: + del env_copy["AGENT_NAME"] + + with patch.dict(os.environ, env_copy, clear=True): + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "$default" + + +class TestDefaultFoundryToolRuntimeCatalog: + """Tests for DefaultFoundryToolRuntime.catalog property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_catalog_returns_default_catalog( + self, + mock_client_class, + mock_credential + ): + """Test catalog property returns DefaultFoundryToolCatalog.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.catalog, DefaultFoundryToolCatalog) + + +class TestDefaultFoundryToolRuntimeInvocation: + """Tests for DefaultFoundryToolRuntime.invocation property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_invocation_returns_default_resolver( + self, + mock_client_class, + mock_credential + ): + """Test invocation property returns DefaultFoundryToolInvocationResolver.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.invocation, DefaultFoundryToolInvocationResolver) + + +class TestDefaultFoundryToolRuntimeInvoke: + """Tests for DefaultFoundryToolRuntime.invoke method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_resolves_and_invokes_tool( + self, + mock_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke resolves the tool and calls the invoker.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"result": "success"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(sample_resolved_mcp_tool, {"input": "test"}) + + assert result == {"result": "success"} + runtime._invocation.resolve.assert_called_once_with(sample_resolved_mcp_tool) + mock_invoker.invoke.assert_called_once_with({"input": "test"}) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_with_facade_dict( + self, + mock_client_class, + mock_credential + ): + """Test invoke works with facade dictionary.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + facade = {"type": "custom_tool", "config": "value"} + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"output": "done"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(facade, {"param": "value"}) + + assert result == {"output": "done"} + runtime._invocation.resolve.assert_called_once_with(facade) + + +class TestDefaultFoundryToolRuntimeContextManager: + """Tests for DefaultFoundryToolRuntime async context manager.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aenter_returns_runtime_and_enters_client( + self, + mock_client_class, + mock_credential + ): + """Test __aenter__ enters client and returns runtime.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime as r: + assert r is runtime + mock_client_instance.__aenter__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_exits_client( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ exits client properly.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime: + pass + + mock_client_instance.__aexit__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_called_on_exception( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ is called even when exception occurs.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + mock_client_instance.__aexit__.assert_called_once() + + +class TestCreateToolRuntime: + """Tests for create_tool_runtime factory function.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_create_tool_runtime_returns_default_runtime_with_valid_params( + self, + mock_client_class, + mock_credential + ): + """Test create_tool_runtime returns DefaultFoundryToolRuntime when both params are provided.""" + mock_client_class.return_value = MagicMock() + endpoint = "https://test-project.azure.com" + + runtime = create_tool_runtime(project_endpoint=endpoint, credential=mock_credential) + + assert isinstance(runtime, DefaultFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_endpoint_is_none( + self, + mock_credential + ): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when endpoint is None.""" + runtime = create_tool_runtime(project_endpoint=None, credential=mock_credential) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_credential_is_none(self): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when credential is None.""" + runtime = create_tool_runtime(project_endpoint="https://test.azure.com", credential=None) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_both_are_none(self): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when both params are None.""" + runtime = create_tool_runtime(project_endpoint=None, credential=None) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_endpoint_is_empty_string( + self, + mock_credential + ): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when endpoint is empty string.""" + runtime = create_tool_runtime(project_endpoint="", credential=mock_credential) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + +class TestThrowingFoundryToolRuntime: + """Tests for ThrowingFoundryToolRuntime.""" + + def test_catalog_raises_runtime_error(self): + """Test catalog property raises RuntimeError.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + _ = runtime.catalog + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + assert "project endpoint and credential" in str(exc_info.value) + + def test_invocation_raises_runtime_error(self): + """Test invocation property raises RuntimeError.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + _ = runtime.invocation + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + assert "project endpoint and credential" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_invoke_raises_runtime_error(self): + """Test invoke method raises RuntimeError (via invocation property).""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + await runtime.invoke({"type": "test"}, {"arg": "value"}) + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_aenter_returns_self(self): + """Test __aenter__ returns the runtime instance.""" + runtime = ThrowingFoundryToolRuntime() + + async with runtime as r: + assert r is runtime + + @pytest.mark.asyncio + async def test_aexit_completes_successfully(self): + """Test __aexit__ completes without error.""" + runtime = ThrowingFoundryToolRuntime() + + # Should not raise any exception + async with runtime: + pass + + @pytest.mark.asyncio + async def test_context_manager_does_not_suppress_exceptions(self): + """Test context manager does not suppress exceptions.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + def test_error_message_is_class_variable(self): + """Test _ERROR_MESSAGE is defined as a class variable.""" + assert hasattr(ThrowingFoundryToolRuntime, "_ERROR_MESSAGE") + assert isinstance(ThrowingFoundryToolRuntime._ERROR_MESSAGE, str) + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py new file mode 100644 index 000000000000..d1d72004d011 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _starlette.py - testing public methods of UserInfoContextMiddleware.""" +import pytest +from contextvars import ContextVar +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestUserInfoContextMiddlewareInstall: + """Tests for UserInfoContextMiddleware.install class method.""" + + def test_install_adds_middleware_to_starlette_app(self): + """Test install adds middleware to Starlette application.""" + # Import here to avoid requiring starlette when not needed + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + mock_app.add_middleware.assert_called_once() + call_args = mock_app.add_middleware.call_args + assert call_args[0][0] == UserInfoContextMiddleware + + def test_install_uses_default_context_when_none_provided(self): + """Test install uses default user context when none is provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is ContextVarUserProvider.default_user_info_context + + def test_install_uses_custom_context(self): + """Test install uses custom user context when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + custom_context = ContextVar("custom_context") + + UserInfoContextMiddleware.install(mock_app, user_context=custom_context) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is custom_context + + def test_install_uses_custom_resolver(self): + """Test install uses custom user resolver when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + async def custom_resolver(request): + return UserInfo(object_id="custom-oid", tenant_id="custom-tid") + + UserInfoContextMiddleware.install(mock_app, user_resolver=custom_resolver) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_resolver"] is custom_resolver + + +class TestUserInfoContextMiddlewareDispatch: + """Tests for UserInfoContextMiddleware.dispatch method.""" + + @pytest.mark.asyncio + async def test_dispatch_sets_user_in_context(self): + """Test dispatch sets user info in context variable.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + user_info = UserInfo(object_id="test-oid", tenant_id="test-tid") + + async def mock_resolver(request): + return user_info + + # Create a simple mock app + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = None + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get(None) + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is user_info + + @pytest.mark.asyncio + async def test_dispatch_resets_context_after_request(self): + """Test dispatch resets context variable after request completes.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + # During request, should have new_user + assert user_context.get(None) is new_user + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + # After request, context should be reset to original value + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_resets_context_on_exception(self): + """Test dispatch resets context even when call_next raises exception.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + raise RuntimeError("Request failed") + + with pytest.raises(RuntimeError, match="Request failed"): + await middleware.dispatch(mock_request, call_next) + + # Context should still be reset to original + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_handles_none_user(self): + """Test dispatch handles None user from resolver.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + + async def mock_resolver(request): + return None + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = "not_set" + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get("default") + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is None + + @pytest.mark.asyncio + async def test_dispatch_calls_resolver_with_request(self): + """Test dispatch calls user resolver with the request object.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + captured_request = None + + async def mock_resolver(request): + nonlocal captured_request + captured_request = request + return UserInfo(object_id="oid", tenant_id="tid") + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + mock_request.url = "https://test.com/api" + + async def call_next(request): + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_request is mock_request + + +class TestUserInfoContextMiddlewareDefaultResolver: + """Tests for UserInfoContextMiddleware default resolver.""" + + @pytest.mark.asyncio + async def test_default_resolver_extracts_user_from_headers(self): + """Test default resolver extracts user info from request headers.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = { + "x-aml-oid": "header-object-id", + "x-aml-tid": "header-tenant-id" + } + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is not None + assert result.object_id == "header-object-id" + assert result.tenant_id == "header-tenant-id" + + @pytest.mark.asyncio + async def test_default_resolver_returns_none_when_headers_missing(self): + """Test default resolver returns None when required headers are missing.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = {} + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py new file mode 100644 index 000000000000..a909d9e5948a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py @@ -0,0 +1,210 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _user.py - testing public methods of ContextVarUserProvider and resolve_user_from_headers.""" +import pytest +from contextvars import ContextVar + +from azure.ai.agentserver.core.tools.runtime._user import ( + ContextVarUserProvider, + resolve_user_from_headers, +) +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestContextVarUserProvider: + """Tests for ContextVarUserProvider public methods.""" + + @pytest.mark.asyncio + async def test_get_user_returns_none_when_context_not_set(self): + """Test get_user returns None when context variable is not set.""" + custom_context = ContextVar("test_user_context") + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is None + + @pytest.mark.asyncio + async def test_get_user_returns_user_when_context_is_set(self, sample_user_info): + """Test get_user returns UserInfo when context variable is set.""" + custom_context = ContextVar("test_user_context") + custom_context.set(sample_user_info) + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is sample_user_info + assert result.object_id == "test-object-id" + assert result.tenant_id == "test-tenant-id" + + @pytest.mark.asyncio + async def test_uses_default_context_when_none_provided(self, sample_user_info): + """Test that default context is used when no context is provided.""" + # Set value in default context + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider = ContextVarUserProvider() + + result = await provider.get_user() + + assert result is sample_user_info + + @pytest.mark.asyncio + async def test_different_providers_share_same_default_context(self, sample_user_info): + """Test that different providers using default context share the same value.""" + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider1 = ContextVarUserProvider() + provider2 = ContextVarUserProvider() + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is result2 is sample_user_info + + @pytest.mark.asyncio + async def test_custom_context_isolation(self, sample_user_info): + """Test that custom contexts are isolated from each other.""" + context1 = ContextVar("context1") + context2 = ContextVar("context2") + user2 = UserInfo(object_id="other-oid", tenant_id="other-tid") + + context1.set(sample_user_info) + context2.set(user2) + + provider1 = ContextVarUserProvider(context=context1) + provider2 = ContextVarUserProvider(context=context2) + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is sample_user_info + assert result2 is user2 + assert result1 is not result2 + + +class TestResolveUserFromHeaders: + """Tests for resolve_user_from_headers public function.""" + + def test_returns_user_info_when_both_headers_present(self): + """Test returns UserInfo when both object_id and tenant_id headers are present.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is not None + assert isinstance(result, UserInfo) + assert result.object_id == "user-object-id" + assert result.tenant_id == "user-tenant-id" + + def test_returns_none_when_object_id_missing(self): + """Test returns None when object_id header is missing.""" + headers = {"x-aml-tid": "user-tenant-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_missing(self): + """Test returns None when tenant_id header is missing.""" + headers = {"x-aml-oid": "user-object-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_both_headers_missing(self): + """Test returns None when both headers are missing.""" + headers = {} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_object_id_is_empty(self): + """Test returns None when object_id is empty string.""" + headers = { + "x-aml-oid": "", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_is_empty(self): + """Test returns None when tenant_id is empty string.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_custom_header_names(self): + """Test using custom header names for object_id and tenant_id.""" + headers = { + "custom-oid-header": "custom-object-id", + "custom-tid-header": "custom-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid-header", + tenant_id_header="custom-tid-header" + ) + + assert result is not None + assert result.object_id == "custom-object-id" + assert result.tenant_id == "custom-tenant-id" + + def test_default_headers_not_matched_with_custom_headers(self): + """Test that default headers are not matched when custom headers are specified.""" + headers = { + "x-aml-oid": "default-object-id", + "x-aml-tid": "default-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid", + tenant_id_header="custom-tid" + ) + + assert result is None + + def test_case_sensitive_header_matching(self): + """Test that header matching is case-sensitive.""" + headers = { + "X-AML-OID": "user-object-id", + "X-AML-TID": "user-tenant-id" + } + + # Default headers are lowercase, so these should not match + result = resolve_user_from_headers(headers) + + assert result is None + + def test_with_mapping_like_object(self): + """Test with a mapping-like object that supports .get().""" + class HeadersMapping: + def __init__(self, data): + self._data = data + + def get(self, key, default=""): + return self._data.get(key, default) + + headers = HeadersMapping({ + "x-aml-oid": "mapping-object-id", + "x-aml-tid": "mapping-tenant-id" + }) + + result = resolve_user_from_headers(headers) + + assert result is not None + assert result.object_id == "mapping-object-id" + assert result.tenant_id == "mapping-tenant-id" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py new file mode 100644 index 000000000000..2d7503de198d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Utils unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py new file mode 100644 index 000000000000..abd2f5145c29 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for utils unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from typing import Optional + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaType, +) + + +def create_resolved_tool_with_name( + name: str, + tool_type: str = "mcp", + connection_id: Optional[str] = None +) -> ResolvedFoundryTool: + """Helper to create a ResolvedFoundryTool with a specific name. + + :param name: The name for the tool details. + :param tool_type: Either "mcp" or "connected". + :param connection_id: Connection ID for connected tools. If provided with tool_type="mcp", + will automatically use "connected" type to ensure unique tool IDs. + :return: A ResolvedFoundryTool instance. + """ + schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={}, + required=set() + ) + details = FoundryToolDetails( + name=name, + description=f"Tool named {name}", + input_schema=schema + ) + + # If connection_id is provided, use connected tool to ensure unique IDs + if connection_id is not None or tool_type == "connected": + definition = FoundryConnectedTool( + protocol="mcp", + project_connection_id=connection_id or f"conn-{name}" + ) + else: + definition = FoundryHostedMcpTool( + name=f"mcp-{name}", + configuration={} + ) + + return ResolvedFoundryTool(definition=definition, details=details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py new file mode 100644 index 000000000000..14340799253b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py @@ -0,0 +1,260 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _name_resolver.py - testing public methods of ToolNameResolver.""" +from azure.ai.agentserver.core.tools.utils import ToolNameResolver +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, +) + +from .conftest import create_resolved_tool_with_name + + +class TestToolNameResolverResolve: + """Tests for ToolNameResolver.resolve method.""" + + def test_resolve_returns_tool_name_for_first_occurrence( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the original tool name for first occurrence.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_mcp_tool) + + assert result == sample_resolved_mcp_tool.details.name + + def test_resolve_returns_same_name_for_same_tool( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the same name when called multiple times for same tool.""" + resolver = ToolNameResolver() + + result1 = resolver.resolve(sample_resolved_mcp_tool) + result2 = resolver.resolve(sample_resolved_mcp_tool) + result3 = resolver.resolve(sample_resolved_mcp_tool) + + assert result1 == result2 == result3 + assert result1 == sample_resolved_mcp_tool.details.name + + def test_resolve_appends_count_for_duplicate_names(self): + """Test resolve appends count for tools with duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + tool3 = create_resolved_tool_with_name("my_tool", connection_id="conn-3") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "my_tool" + assert result2 == "my_tool_1" + assert result3 == "my_tool_2" + + def test_resolve_handles_multiple_unique_names(self): + """Test resolve handles multiple tools with unique names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_alpha") + tool2 = create_resolved_tool_with_name("tool_beta") + tool3 = create_resolved_tool_with_name("tool_gamma") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "tool_alpha" + assert result2 == "tool_beta" + assert result3 == "tool_gamma" + + def test_resolve_mixed_unique_and_duplicate_names(self): + """Test resolve handles a mix of unique and duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("shared_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("unique_name") + tool3 = create_resolved_tool_with_name("shared_name", connection_id="conn-2") + tool4 = create_resolved_tool_with_name("another_unique") + tool5 = create_resolved_tool_with_name("shared_name", connection_id="conn-3") + + assert resolver.resolve(tool1) == "shared_name" + assert resolver.resolve(tool2) == "unique_name" + assert resolver.resolve(tool3) == "shared_name_1" + assert resolver.resolve(tool4) == "another_unique" + assert resolver.resolve(tool5) == "shared_name_2" + + def test_resolve_returns_cached_name_after_duplicate_added(self): + """Test that resolving a tool again returns cached name even after duplicates are added.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + + # First resolution + first_result = resolver.resolve(tool1) + assert first_result == "my_tool" + + # Add duplicate + dup_result = resolver.resolve(tool2) + assert dup_result == "my_tool_1" + + # Resolve original again - should return cached value + second_result = resolver.resolve(tool1) + assert second_result == "my_tool" + + def test_resolve_with_connected_tool( + self, + sample_resolved_connected_tool + ): + """Test resolve works with connected tools.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_connected_tool) + + assert result == sample_resolved_connected_tool.details.name + + def test_resolve_different_tools_same_details_name(self, sample_schema_definition): + """Test resolve handles different tool definitions with same details name.""" + resolver = ToolNameResolver() + + details = FoundryToolDetails( + name="shared_function", + description="A shared function", + input_schema=sample_schema_definition + ) + + mcp_def = FoundryHostedMcpTool(name="mcp_server", configuration={}) + connected_def = FoundryConnectedTool(protocol="mcp", project_connection_id="my-conn") + + tool1 = ResolvedFoundryTool(definition=mcp_def, details=details) + tool2 = ResolvedFoundryTool(definition=connected_def, details=details) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "shared_function" + assert result2 == "shared_function_1" + + def test_resolve_empty_name(self): + """Test resolve handles tools with empty name.""" + resolver = ToolNameResolver() + + tool = create_resolved_tool_with_name("") + + result = resolver.resolve(tool) + + assert result == "" + + def test_resolve_special_characters_in_name(self): + """Test resolve handles tools with special characters in name.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-2") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "my-tool_v1.0" + assert result2 == "my-tool_v1.0_1" + + def test_independent_resolver_instances(self): + """Test that different resolver instances maintain independent state.""" + resolver1 = ToolNameResolver() + resolver2 = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("tool_name", connection_id="conn-2") + + # Both resolvers resolve tool1 first + assert resolver1.resolve(tool1) == "tool_name" + assert resolver2.resolve(tool1) == "tool_name" + + # resolver1 resolves tool2 as duplicate + assert resolver1.resolve(tool2) == "tool_name_1" + + # resolver2 has not seen tool2 yet in its context + # but tool2 has same name, so it should be duplicate + assert resolver2.resolve(tool2) == "tool_name_1" + + def test_resolve_many_duplicates(self): + """Test resolve handles many tools with the same name.""" + resolver = ToolNameResolver() + + tools = [ + create_resolved_tool_with_name("common_name", connection_id=f"conn-{i}") + for i in range(10) + ] + + results = [resolver.resolve(tool) for tool in tools] + + expected = ["common_name"] + [f"common_name_{i}" for i in range(1, 10)] + assert results == expected + + def test_resolve_uses_tool_id_for_caching(self, sample_schema_definition): + """Test that resolve uses tool.id for caching, not just name.""" + resolver = ToolNameResolver() + + # Create two tools with same definition but different details names + definition = FoundryHostedMcpTool(name="same_definition", configuration={}) + + details1 = FoundryToolDetails( + name="function_a", + description="Function A", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function_b", + description="Function B", + input_schema=sample_schema_definition + ) + + tool1 = ResolvedFoundryTool(definition=definition, details=details1) + tool2 = ResolvedFoundryTool(definition=definition, details=details2) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + # Both should get their respective names since they have different tool.id + assert result1 == "function_a" + assert result2 == "function_b" + + def test_resolve_idempotent_for_same_tool_id(self, sample_schema_definition): + """Test that resolve is idempotent for the same tool id.""" + resolver = ToolNameResolver() + + definition = FoundryHostedMcpTool(name="my_mcp", configuration={}) + details = FoundryToolDetails( + name="my_function", + description="My function", + input_schema=sample_schema_definition + ) + tool = ResolvedFoundryTool(definition=definition, details=details) + + # Call resolve many times + results = [resolver.resolve(tool) for _ in range(5)] + + # All should return the same name + assert all(r == "my_function" for r in results) + + def test_resolve_interleaved_tool_resolutions(self): + """Test resolve with interleaved resolutions of different tools.""" + resolver = ToolNameResolver() + + toolA_1 = create_resolved_tool_with_name("A", connection_id="A-1") + toolA_2 = create_resolved_tool_with_name("A", connection_id="A-2") + toolB_1 = create_resolved_tool_with_name("B", connection_id="B-1") + toolA_3 = create_resolved_tool_with_name("A", connection_id="A-3") + toolB_2 = create_resolved_tool_with_name("B", connection_id="B-2") + + assert resolver.resolve(toolA_1) == "A" + assert resolver.resolve(toolB_1) == "B" + assert resolver.resolve(toolA_2) == "A_1" + assert resolver.resolve(toolA_3) == "A_2" + assert resolver.resolve(toolB_2) == "B_1" From 0eb05266109c86714687ddc21c52f2cd652c1248 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 19:37:54 -0700 Subject: [PATCH 02/12] remove azure-ai-agents dependency --- sdk/agentserver/azure-ai-agentserver-core/pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index a0bca5c434fa..3de5d2dfb872 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -21,8 +21,7 @@ keywords = ["azure", "azure sdk"] dependencies = [ "azure-monitor-opentelemetry>=1.5.0,<1.8.5", - "azure-ai-projects>=2.0.0b1", - "azure-ai-agents==1.2.0b5", + "azure-ai-projects>=2.0.0b1,<3.0.0", "azure-core>=1.35.0", "azure-identity>=1.25.1", "openai>=1.80.0", From 4272a8e121ef80c3fb94c4e30d0ffb161cd80d4e Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 19:39:37 -0700 Subject: [PATCH 03/12] upgrade version --- sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md | 7 +++++++ .../azure/ai/agentserver/core/_version.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index b312c9d01737..fcbff2d16b74 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,6 +1,13 @@ # Release History +## 1.0.0b16 (2026-03-10) + +### Other Changes + +- Refined code + + ## 1.0.0b16 (2026-03-10) ### Other Changes diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index e4218ac5b98d..62cffc00b2cc 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b16" +VERSION = "1.0.0b17" From a3c9cc0619dd5122d9aa6442533cd9181d448ca8 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 20:01:08 -0700 Subject: [PATCH 04/12] fix sphinx and version --- .../pyproject.toml | 2 +- ...server.core.server.common.id_generator.rst | 19 ------------------- ...zure.ai.agentserver.core.server.common.rst | 19 ------------------- .../doc/azure.ai.agentserver.core.server.rst | 11 ----------- .../pyproject.toml | 2 +- 5 files changed, 2 insertions(+), 51 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml index 814d1d6d1a1e..5b9313b190ce 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core", + "azure-ai-agentserver-core==1.0.0b1", "agent-framework-azure-ai==1.0.0b251007", "agent-framework-core==1.0.0b251007", "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst index 68f155131f5c..043b6bfb14fa 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.id_generator.rst @@ -5,22 +5,3 @@ azure.ai.agentserver.core.server.common.id\_generator package :inherited-members: :members: :undoc-members: - -Submodules ----------- - -azure.ai.agentserver.core.server.common.id\_generator.\_foundry\_id\_generator module ------------------------------------------------------------------------------------- - -.. automodule:: azure.ai.agentserver.core.server.common.id_generator._foundry_id_generator - :inherited-members: - :members: - :undoc-members: - -azure.ai.agentserver.core.server.common.id\_generator.\_id\_generator module ---------------------------------------------------------------------------- - -.. automodule:: azure.ai.agentserver.core.server.common.id_generator._id_generator - :inherited-members: - :members: - :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst index fd02e856642c..b073e580cd10 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst @@ -13,22 +13,3 @@ Subpackages :maxdepth: 4 azure.ai.agentserver.core.server.common.id_generator - -Submodules ----------- - -azure.ai.agentserver.core.server.common.\_agent\_run\_context module -------------------------------------------------------------------- - -.. automodule:: azure.ai.agentserver.core.server.common._agent_run_context - :inherited-members: - :members: - :undoc-members: - -azure.ai.agentserver.core.server.common.\_constants module ----------------------------------------------------------- - -.. automodule:: azure.ai.agentserver.core.server.common._constants - :inherited-members: - :members: - :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst index 8363ec9e32d8..f4c838cae8b7 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.rst @@ -13,14 +13,3 @@ Subpackages :maxdepth: 4 azure.ai.agentserver.core.server.common - -Submodules ----------- - -azure.ai.agentserver.core.server.\_base module ----------------------------------------------- - -.. automodule:: azure.ai.agentserver.core.server._base - :inherited-members: - :members: - :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 5552ff8233d2..22dc18b55e3c 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core", + "azure-ai-agentserver-core==1.0.0b1", "langchain>0.3.5", "langchain-openai>0.3.10", "langchain-azure-ai[opentelemetry]>=0.1.4", From 9ee758f097df66ce557475d7d6d70845d4c23e6e Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 20:48:24 -0700 Subject: [PATCH 05/12] refined unittest --- .../azure-ai-agentserver-core/tests/test_custom.py | 11 +++++++++-- .../server/test_conversation_persistence.py | 3 --- .../tests/unit_tests/test_logger.py | 3 --- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/test_custom.py b/sdk/agentserver/azure-ai-agentserver-core/tests/test_custom.py index f8f2075e22e5..fe78aef9505d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/test_custom.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/test_custom.py @@ -21,6 +21,13 @@ project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) +env_paths = [ + project_root / ".env", + Path.cwd() / ".env", + Path(__file__).parent / ".env", +] +has_env_file = any(env_path.exists() for env_path in env_paths) + class BaseCustomAgentTest: """Base class for Custom agent sample tests with common utilities.""" @@ -190,7 +197,7 @@ def test_streaming_response(self, mock_server): assert lines_read > 0, "Expected to read at least one line from streaming response" -@pytest.mark.skip +@pytest.mark.skipif(not has_env_file, reason="Requires a .env file for MCP sample configuration") class TestMcpSimple: """Test suite for Custom MCP Simple - uses Microsoft Learn MCP.""" @@ -234,7 +241,7 @@ def test_mcp_operations(self, mcp_server, input_text: str, expected_keywords: li assert found_keyword, f"Expected one of {expected_keywords} in response" -@pytest.mark.skip +@pytest.mark.skipif(not has_env_file, reason="Requires a .env file for bilingual weekend planner configuration") class TestBilingualWeekendPlanner: """Test suite for the bilingual weekend planner custom sample.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py index 00137abecf15..8e1726644724 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_conversation_persistence.py @@ -53,7 +53,6 @@ def create_mock_agent(): return agent -@pytest.mark.unit class TestShouldStore: """Tests for _should_store method.""" @@ -110,7 +109,6 @@ def test_should_store_returns_false_when_no_endpoint(self): assert not result # Falsy value when endpoint is None -@pytest.mark.unit class TestItemsAreEqual: """Tests for _items_are_equal method.""" @@ -185,7 +183,6 @@ def test_items_not_equal_with_different_structured_content(self): assert agent._items_are_equal(item1, item2) is False -@pytest.mark.unit class TestSaveInputToConversation: """Tests for _save_input_to_conversation method.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py index 35639ea8ae2c..cc89460b96d7 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/test_logger.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock, patch -@pytest.mark.unit class TestGetProjectEndpoint: """Tests for get_project_endpoint function.""" @@ -76,7 +75,6 @@ def test_logs_warning_for_invalid_resource_id(self): assert result is None or result == "" -@pytest.mark.unit class TestGetApplicationInsightsConnstr: """Tests for _get_application_insights_connstr function.""" @@ -116,7 +114,6 @@ def test_logs_debug_when_not_configured(self): assert mock_logger.debug.called or result is None or result == "" -@pytest.mark.unit class TestLoggerConfiguration: """Tests for logger configuration.""" From bfaa43c9598b74fc7c6ed9f26b7e0f6667b46dd1 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 20:48:46 -0700 Subject: [PATCH 06/12] try fix langgraph dependency --- sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 22dc18b55e3c..72470377e7e7 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "azure-ai-agentserver-core==1.0.0b1", "langchain>0.3.5", "langchain-openai>0.3.10", - "langchain-azure-ai[opentelemetry]>=0.1.4", + "langchain-azure-ai[opentelemetry]>=0.1.4,<1.1.0", "langgraph>0.5.0", "opentelemetry-exporter-otlp-proto-http", ] From 93a9fe4b107f051683e1e7b7eaafab404495cb98 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 21:52:48 -0700 Subject: [PATCH 07/12] disable langgraph --- .../azure-ai-agentserver-core/pyproject.toml | 14 -------------- .../azure-ai-agentserver-langgraph/pyproject.toml | 3 ++- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index 3de5d2dfb872..123f7e12e442 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -56,20 +56,6 @@ readme = { file = ["README.md"], content-type = "text/markdown" } [tool.setuptools.package-data] pytyped = ["py.typed"] -[tool.ruff] -line-length = 120 -target-version = "py311" -lint.select = ["E", "F", "B", "I"] # E=pycodestyle errors, F=Pyflakes, B=bugbear, I=import sort -lint.ignore = [] -fix = false -exclude = [ - "**/azure/ai/agentserver/core/models/", -] - -[tool.ruff.lint.isort] -known-first-party = ["azure.ai.agentserver.core"] -combine-as-imports = true - [tool.azure-sdk-build] breaking = false # incompatible python version pyright = false diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 72470377e7e7..455d55b702fd 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -64,4 +64,5 @@ pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +apistub = false From 84a096be6fa535f3bab58ac3af30e2285efed5e8 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 22:52:32 -0700 Subject: [PATCH 08/12] disable af and lg --- .../azure-ai-agentserver-agentframework/CHANGELOG.md | 8 ++++++++ .../azure-ai-agentserver-agentframework/pyproject.toml | 3 ++- .../azure-ai-agentserver-langgraph/CHANGELOG.md | 8 ++++++++ .../azure-ai-agentserver-langgraph/pyproject.toml | 4 ++-- 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md index cfcf2445e256..ef89ac2ae33d 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History + +## 1.0.0b2 (2026-03-18) + +### Features Added + +DUMMY + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml index 5b9313b190ce..e8457d47f985 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml @@ -63,4 +63,5 @@ pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +mypy = false \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md index cfcf2445e256..ef89ac2ae33d 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md @@ -1,5 +1,13 @@ # Release History + +## 1.0.0b2 (2026-03-18) + +### Features Added + +DUMMY + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 455d55b702fd..6a4a44821df9 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "azure-ai-agentserver-core==1.0.0b1", "langchain>0.3.5", "langchain-openai>0.3.10", - "langchain-azure-ai[opentelemetry]>=0.1.4,<1.1.0", + "langchain-azure-ai[opentelemetry]>=0.1.4,<1.0.6", "langgraph>0.5.0", "opentelemetry-exporter-otlp-proto-http", ] @@ -65,4 +65,4 @@ verifytypes = false # incompatible python version for -core verify_keywords = false mindependency = false # depends on -core package whl_no_aio = false -apistub = false +mypy = false \ No newline at end of file From 8db35404e00b1285355fba5f1d41afa9a5408057 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 22:53:12 -0700 Subject: [PATCH 09/12] fix -core build --- .../azure/ai/agentserver/core/tools/runtime/_starlette.py | 2 +- sdk/agentserver/cspell.json | 5 +++++ shared_requirements.txt | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 sdk/agentserver/cspell.json diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py index 9604124cde9b..4b9024a58d49 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py @@ -40,7 +40,7 @@ def install(cls, :type user_resolver: Optional[Callable[[Request], Awaitable[Optional[UserInfo]]]] """ - user_info_var : _UserContextType = user_context or ContextVarUserProvider.default_user_info_context + user_info_var = user_context or ContextVarUserProvider.default_user_info_context # mypy: ignore[assignment] app.add_middleware(UserInfoContextMiddleware, # type: ignore[arg-type] user_info_var=user_info_var, user_resolver=user_resolver or cls._default_user_resolver) diff --git a/sdk/agentserver/cspell.json b/sdk/agentserver/cspell.json new file mode 100644 index 000000000000..9e335d64544d --- /dev/null +++ b/sdk/agentserver/cspell.json @@ -0,0 +1,5 @@ +{ + "ignoreWords": [ + "pylintrc" + ] + } \ No newline at end of file diff --git a/shared_requirements.txt b/shared_requirements.txt index b5e1b85f184a..630973de2eb2 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -96,4 +96,5 @@ opentelemetry-exporter-otlp-proto-grpc agent-framework-core langchain langchain-openai -azure-ai-language-questionanswering \ No newline at end of file +azure-ai-language-questionanswering +cachetools \ No newline at end of file From 0c0f258f6add2043a2aeb174cd9e2b64f754b38a Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 22:54:18 -0700 Subject: [PATCH 10/12] enable all checks --- sdk/agentserver/azure-ai-agentserver-core/pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index 123f7e12e442..ae989950dedc 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -56,7 +56,7 @@ readme = { file = ["README.md"], content-type = "text/markdown" } [tool.setuptools.package-data] pytyped = ["py.typed"] -[tool.azure-sdk-build] -breaking = false # incompatible python version -pyright = false -verifytypes = false +# [tool.azure-sdk-build] +# breaking = false # incompatible python version +# pyright = false +# verifytypes = false From e15c2d21d1e6f327ff4631d51ac949d1adc921c8 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 23:34:59 -0700 Subject: [PATCH 11/12] fix build --- .../azure/ai/agentserver/core/server/_base.py | 2 -- .../azure/ai/agentserver/core/tools/client/_models.py | 1 + sdk/agentserver/azure-ai-agentserver-core/pyproject.toml | 2 +- sdk/agentserver/azure-ai-agentserver-core/pyrightconfig.json | 2 ++ 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py index e1ce45188c34..79036dea1ee0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_base.py @@ -284,8 +284,6 @@ async def _create_openai_client(self) -> AsyncOpenAI: :return: Configured AsyncOpenAI client scoped to the Foundry project endpoint. :rtype: AsyncOpenAI """ - from openai import AsyncOpenAI - token_provider = get_bearer_token_provider( self.credentials, "https://ai.azure.com/.default" ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py index b3a505ae37ae..73a938904ea2 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py @@ -343,6 +343,7 @@ def _extract(cls, # Find the first matching key in the datasource value_found = False + value = None for key in keys_to_check: if key in datasource: value = datasource[key] diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index ae989950dedc..e981f148f68d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -10,8 +10,8 @@ license = "MIT" classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyrightconfig.json b/sdk/agentserver/azure-ai-agentserver-core/pyrightconfig.json index b7490ae2b8c7..d8f97b5746c7 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyrightconfig.json +++ b/sdk/agentserver/azure-ai-agentserver-core/pyrightconfig.json @@ -5,6 +5,8 @@ "reportMissingImports": "warning", "reportGeneralTypeIssues": "warning", "reportReturnType": "warning", + "reportPossiblyUnboundVariable":"warning", + "reportCallIssue":"warning", "exclude": [ "**/azure/ai/agentserver/core/models/**", From c79d00f18c513f509847ac6905b07314a53fbeb1 Mon Sep 17 00:00:00 2001 From: Lu Sun Date: Sun, 15 Mar 2026 23:36:38 -0700 Subject: [PATCH 12/12] fix changelog --- .../azure-ai-agentserver-agentframework/CHANGELOG.md | 6 +++--- sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md | 2 +- sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md index ef89ac2ae33d..c8bba88d91b5 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md @@ -1,11 +1,11 @@ # Release History -## 1.0.0b2 (2026-03-18) +## 1.0.0b2 (Unreleased) -### Features Added +### Other Changes -DUMMY +- To be updated ## 1.0.0b1 (2025-11-07) diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index fcbff2d16b74..4a66abdc4c08 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,7 +1,7 @@ # Release History -## 1.0.0b16 (2026-03-10) +## 1.0.0b18 (Unreleased) ### Other Changes diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md index ef89ac2ae33d..c8bba88d91b5 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md @@ -1,11 +1,11 @@ # Release History -## 1.0.0b2 (2026-03-18) +## 1.0.0b2 (Unreleased) -### Features Added +### Other Changes -DUMMY +- To be updated ## 1.0.0b1 (2025-11-07)