From 0d10970b9fa4ea77c1f812ae72d299212906c7c1 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sat, 28 Feb 2026 14:49:37 +0000 Subject: [PATCH 1/2] Add run_initializers_async, Entra auth, and config-file support - Add run_initializers_async to pyrit.setup for programmatic initialization - Switch AIRTInitializer to Entra (Azure AD) auth, removing API key requirements - Add --config-file flag to pyrit_backend CLI - Use PyRIT configuration loader in FrontendCore and pyrit_backend - Update AIRTTargetInitializer with new target types Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/cli/frontend_core.py | 80 +++++++++---------- pyrit/cli/pyrit_backend.py | 80 ++++++++----------- pyrit/setup/__init__.py | 10 ++- pyrit/setup/initialization.py | 25 +++++- pyrit/setup/initializers/airt.py | 48 ++++++----- pyrit/setup/initializers/airt_targets.py | 28 +++++-- tests/unit/cli/test_frontend_core.py | 58 +++++++------- tests/unit/cli/test_pyrit_backend.py | 59 ++++++++++++++ tests/unit/setup/test_airt_initializer.py | 13 +-- .../setup/test_airt_targets_initializer.py | 17 ++++ 10 files changed, 258 insertions(+), 160 deletions(-) create mode 100644 tests/unit/cli/test_pyrit_backend.py diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index 5c49525ec5..211b74ce05 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -15,13 +15,18 @@ from __future__ import annotations +import argparse +import inspect import json import logging import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Optional -from pyrit.setup import ConfigurationLoader +from pyrit.registry import InitializerRegistry, ScenarioRegistry +from pyrit.scenario import DatasetConfiguration +from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter +from pyrit.setup import ConfigurationLoader, initialize_pyrit_async, run_initializers_async from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP try: @@ -47,9 +52,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i from pyrit.models.scenario_result import ScenarioResult from pyrit.registry import ( InitializerMetadata, - InitializerRegistry, ScenarioMetadata, - ScenarioRegistry, ) logger = logging.getLogger(__name__) @@ -141,14 +144,17 @@ def __init__( logging.basicConfig(level=self._log_level) async def initialize_async(self) -> None: - """Initialize PyRIT and load registries (heavy operation).""" + """ + Initialize PyRIT and load registries (heavy operation). + + Sets up memory and loads scenario/initializer registries. + Initializers are NOT run here — they are run separately + (per-scenario in pyrit_scan, or up-front in pyrit_backend). + """ if self._initialized: return - from pyrit.registry import InitializerRegistry, ScenarioRegistry - from pyrit.setup import initialize_pyrit_async - - # Initialize PyRIT without initializers (they run per-scenario) + # Initialize PyRIT without initializers (they run separately) await initialize_pyrit_async( memory_db_type=self._database, initialization_scripts=None, @@ -167,6 +173,27 @@ async def initialize_async(self) -> None: self._initialized = True + async def run_initializers_async(self) -> None: + """ + Resolve and run all configured initializers and initialization scripts. + + Must be called after :meth:`initialize_async` so that registries are + available to resolve initializer names. This is the same pattern used + by :func:`run_scenario_async` before executing a scenario. + + If no initializers are configured this is a no-op. + """ + initializer_instances = None + if self._initializer_names: + print(f"Running {len(self._initializer_names)} initializer(s)...") + sys.stdout.flush() + initializer_instances = [self.initializer_registry.get_class(name)() for name in self._initializer_names] + + await run_initializers_async( + initializers=initializer_instances, + initialization_scripts=self._initialization_scripts, + ) + @property def scenario_registry(self) -> ScenarioRegistry: """ @@ -227,8 +254,6 @@ async def list_initializers_async( Sequence of initializer metadata dictionaries describing each initializer class. """ if discovery_path: - from pyrit.registry import InitializerRegistry - registry = InitializerRegistry(discovery_path=discovery_path) return registry.list_metadata() @@ -276,34 +301,13 @@ async def run_scenario_async( Note: Initializers from PyRITContext will be run before the scenario executes. """ - from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter - from pyrit.setup import initialize_pyrit_async - # Ensure context is initialized first (loads registries) # This must happen BEFORE we run initializers to avoid double-initialization if not context._initialized: await context.initialize_async() - # Run initializers before scenario - initializer_instances = None - if context._initializer_names: - print(f"Running {len(context._initializer_names)} initializer(s)...") - sys.stdout.flush() - - initializer_instances = [] - - for name in context._initializer_names: - initializer_class = context.initializer_registry.get_class(name) - initializer_instances.append(initializer_class()) - - # Re-initialize PyRIT with the scenario-specific initializers - # This resets memory and applies initializer defaults - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - ) + # Resolve and run initializers + initialization scripts + await context.run_initializers_async() # Get scenario class scenario_class = context.scenario_registry.get_class(scenario_name) @@ -343,8 +347,6 @@ async def run_scenario_async( # - max_dataset_size only: default datasets with overridden limit if dataset_names: # User specified dataset names - create new config (fetches all unless max_dataset_size set) - from pyrit.scenario import DatasetConfiguration - init_kwargs["dataset_config"] = DatasetConfiguration( dataset_names=dataset_names, max_dataset_size=max_dataset_size, @@ -599,8 +601,6 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A Raises: ValueError: If validator_func has no parameters. """ - import inspect - # Get the first parameter name from the function signature sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) @@ -609,13 +609,11 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A first_param = params[0] def wrapper(value: Any) -> Any: - import argparse as ap - try: # Call with keyword argument to support keyword-only parameters return validator_func(**{first_param: value}) except ValueError as e: - raise ap.ArgumentTypeError(str(e)) from e + raise argparse.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") @@ -636,8 +634,6 @@ def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: Raises: FileNotFoundError: If a script path does not exist. """ - from pyrit.registry import InitializerRegistry - return InitializerRegistry.resolve_script_paths(script_paths=script_paths) diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index a3e3fe647f..20a01dce3d 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -10,6 +10,7 @@ import asyncio import sys from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from pathlib import Path from typing import Optional # Ensure emoji and other Unicode characters don't crash on Windows consoles @@ -17,6 +18,8 @@ sys.stdout.reconfigure(errors="replace") # type: ignore[union-attr] sys.stderr.reconfigure(errors="replace") # type: ignore[union-attr] +import uvicorn + from pyrit.cli import frontend_core @@ -41,6 +44,9 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace: # Start with custom initialization scripts pyrit_backend --initialization-scripts ./my_targets.py + # Start with explicit config file + pyrit_backend --config-file ./my_backend_conf.yaml + # Start with custom port and host pyrit_backend --host 0.0.0.0 --port 8080 @@ -80,13 +86,19 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace: parser.add_argument( "--database", type=frontend_core.validate_database_argparse, - default=frontend_core.SQLITE, + default=None, help=( f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, " - f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE})" + f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: from config file, or {frontend_core.SQLITE})" ), ) + parser.add_argument( + "--config-file", + type=str, + help=frontend_core.ARG_HELP["config_file"], + ) + parser.add_argument( "--initializers", type=str, @@ -124,55 +136,27 @@ async def initialize_and_run(*, parsed_args: Namespace) -> int: Returns: int: Exit code (0 for success, 1 for error). """ - from pyrit.setup import initialize_pyrit_async - - # Resolve initialization scripts if provided - initialization_scripts = None - if parsed_args.initialization_scripts: - try: - initialization_scripts = frontend_core.resolve_initialization_scripts( - script_paths=parsed_args.initialization_scripts - ) - except FileNotFoundError as e: - print(f"Error: {e}") - return 1 - - # Resolve env files if provided - env_files = None - if parsed_args.env_files: - try: - env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files) - except ValueError as e: - print(f"Error: {e}") - return 1 - - # Resolve initializer instances if names provided - initializer_instances = None - if parsed_args.initializers: - from pyrit.registry import InitializerRegistry - - registry = InitializerRegistry() - initializer_instances = [] - for name in parsed_args.initializers: - try: - initializer_class = registry.get_class(name) - initializer_instances.append(initializer_class()) - except Exception as e: - print(f"Error: Could not load initializer '{name}': {e}") - return 1 - - # Initialize PyRIT with the provided configuration + try: + core = frontend_core.FrontendCore( + config_file=Path(parsed_args.config_file) if parsed_args.config_file else None, + database=parsed_args.database, + initialization_scripts=( + [Path(p) for p in parsed_args.initialization_scripts] if parsed_args.initialization_scripts else None + ), + initializer_names=parsed_args.initializers, + env_files=([Path(p) for p in parsed_args.env_files] if parsed_args.env_files else None), + log_level=parsed_args.log_level, + ) + except (ValueError, FileNotFoundError) as e: + print(f"Error: {e}") + return 1 + + # Initialize memory, registries, and run initializers. print("🔧 Initializing PyRIT...") - await initialize_pyrit_async( - memory_db_type=parsed_args.database, - initialization_scripts=initialization_scripts, - initializers=initializer_instances, - env_files=env_files, - ) + await core.initialize_async() + await core.run_initializers_async() # Start uvicorn server - import uvicorn - print(f"🚀 Starting PyRIT backend on http://{parsed_args.host}:{parsed_args.port}") print(f" API Docs: http://{parsed_args.host}:{parsed_args.port}/docs") diff --git a/pyrit/setup/__init__.py b/pyrit/setup/__init__.py index 2929a59ea3..2b0823e0f3 100644 --- a/pyrit/setup/__init__.py +++ b/pyrit/setup/__init__.py @@ -4,7 +4,14 @@ """Module containing initialization PyRIT.""" from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async -from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async +from pyrit.setup.initialization import ( + AZURE_SQL, + IN_MEMORY, + SQLITE, + MemoryDatabaseType, + initialize_pyrit_async, + run_initializers_async, +) __all__ = [ "AZURE_SQL", @@ -12,6 +19,7 @@ "IN_MEMORY", "initialize_pyrit_async", "initialize_from_config_async", + "run_initializers_async", "MemoryDatabaseType", "ConfigurationLoader", ] diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 0aff8deafc..64127e9f8f 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -284,14 +284,33 @@ async def initialize_pyrit_async( ) CentralMemory.set_memory_instance(memory) - # Combine directly provided initializers with those loaded from scripts + await run_initializers_async(initializers=initializers, initialization_scripts=initialization_scripts) + + +async def run_initializers_async( + *, + initializers: Optional[Sequence["PyRITInitializer"]] = None, + initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, +) -> None: + """ + Run initializers and initialization scripts without re-initializing memory or environment. + + This is useful when memory and environment are already set up (e.g. via + :func:`initialize_pyrit_async`) and only the initializer step needs to run. + + Args: + initializers: Optional sequence of PyRITInitializer instances to execute directly. + initialization_scripts: Optional sequence of Python script paths containing + PyRITInitializer classes. + + Raises: + ValueError: If initializers are invalid or scripts cannot be loaded. + """ all_initializers = list(initializers) if initializers else [] - # Load additional initializers from scripts if initialization_scripts: script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts) all_initializers.extend(script_initializers) - # Execute all initializers (sorted by execution_order) if all_initializers: await _execute_initializers_async(initializers=all_initializers) diff --git a/pyrit/setup/initializers/airt.py b/pyrit/setup/initializers/airt.py index a0d81613df..cd990a87de 100644 --- a/pyrit/setup/initializers/airt.py +++ b/pyrit/setup/initializers/airt.py @@ -9,7 +9,9 @@ """ import os +from typing import Callable +from pyrit.auth import get_azure_openai_auth, get_azure_token_provider from pyrit.common.apply_defaults import set_default_value, set_global_variable from pyrit.executor.attack import ( AttackAdversarialConfig, @@ -44,11 +46,12 @@ class AIRTInitializer(PyRITInitializer): Required Environment Variables: - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT: Azure OpenAI endpoint for converters and targets - - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY: Azure OpenAI API key for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL: Azure OpenAI model name for converters and targets - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2: Azure OpenAI endpoint for scoring - - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2: Azure OpenAI API key for scoring - AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2: Azure OpenAI model name for scoring + - AZURE_CONTENT_SAFETY_API_ENDPOINT: Azure Content Safety endpoint + + Authentication is handled via Entra ID (Azure AD) using DefaultAzureCredential. This configuration is designed for full AI Red Team operations with: - Separate endpoints for attack execution vs scoring (security isolation) @@ -81,13 +84,10 @@ def required_env_vars(self) -> list[str]: """Get list of required environment variables.""" return [ "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", - "AZURE_CONTENT_SAFETY_API_KEY", ] async def initialize_async(self) -> None: @@ -102,37 +102,43 @@ async def initialize_async(self) -> None: """ # Get environment variables (validated by validate() method) converter_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT") - converter_api_key = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY") converter_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL") scorer_endpoint = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2") - scorer_api_key = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2") scorer_model_name = os.getenv("AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2") # Type assertions - safe because validate() already checked these assert converter_endpoint is not None - assert converter_api_key is not None assert scorer_endpoint is not None - assert scorer_api_key is not None # model name can be empty in certain cases (e.g., custom model deployments that don't need model name) + # Use Entra authentication via Azure token providers + converter_auth = get_azure_openai_auth(converter_endpoint) + scorer_auth = get_azure_openai_auth(scorer_endpoint) + content_safety_auth = get_azure_token_provider("https://cognitiveservices.azure.com/.default") + # 1. Setup converter target self._setup_converter_target( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, credential=converter_auth, model_name=converter_model_name ) # 2. Setup scorers - self._setup_scorers(endpoint=scorer_endpoint, api_key=scorer_api_key, model_name=scorer_model_name) + self._setup_scorers( + endpoint=scorer_endpoint, + credential=scorer_auth, + model_name=scorer_model_name, + content_safety_credential=content_safety_auth, + ) # 3. Setup adversarial targets self._setup_adversarial_targets( - endpoint=converter_endpoint, api_key=converter_api_key, model_name=converter_model_name + endpoint=converter_endpoint, credential=converter_auth, model_name=converter_model_name ) - def _setup_converter_target(self, *, endpoint: str, api_key: str, model_name: str) -> None: + def _setup_converter_target(self, *, endpoint: str, credential: Callable, model_name: str) -> None: """Set up the default converter target configuration.""" default_converter_target = OpenAIChatTarget( endpoint=endpoint, - api_key=api_key, + api_key=credential, model_name=model_name, temperature=1.1, ) @@ -144,11 +150,13 @@ def _setup_converter_target(self, *, endpoint: str, api_key: str, model_name: st value=default_converter_target, ) - def _setup_scorers(self, *, endpoint: str, api_key: str, model_name: str) -> None: + def _setup_scorers( + self, *, endpoint: str, credential: Callable, model_name: str, content_safety_credential: Callable + ) -> None: """Set up the composite harm and objective scorers.""" scorer_target = OpenAIChatTarget( endpoint=endpoint, - api_key=api_key, + api_key=credential, model_name=model_name, temperature=0.3, ) @@ -161,7 +169,9 @@ def _setup_scorers(self, *, endpoint: str, api_key: str, model_name: str) -> Non default_harm_scorer = TrueFalseCompositeScorer( aggregator=TrueFalseScoreAggregator.AND, scorers=[ - FloatScaleThresholdScorer(scorer=AzureContentFilterScorer(), threshold=0.5), + FloatScaleThresholdScorer( + scorer=AzureContentFilterScorer(api_key=content_safety_credential), threshold=0.5 + ), TrueFalseInverterScorer( scorer=SelfAskRefusalScorer(chat_target=scorer_target), ), @@ -205,12 +215,12 @@ def _setup_scorers(self, *, endpoint: str, api_key: str, model_name: str) -> Non value=default_objective_scorer_config, ) - def _setup_adversarial_targets(self, *, endpoint: str, api_key: str, model_name: str) -> None: + def _setup_adversarial_targets(self, *, endpoint: str, credential: Callable, model_name: str) -> None: """Set up the adversarial target configurations for attacks.""" adversarial_config = AttackAdversarialConfig( target=OpenAIChatTarget( endpoint=endpoint, - api_key=api_key, + api_key=credential, model_name=model_name, temperature=1.2, ) diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index be42a8c173..0f6f2f7c6f 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -14,8 +14,8 @@ import logging import os -from dataclasses import dataclass -from typing import Any, Optional +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Type from pyrit.prompt_target import ( AzureMLChatTarget, @@ -45,6 +45,7 @@ class TargetConfig: key_var: str = "" # Empty string means no auth required model_var: Optional[str] = None underlying_model_var: Optional[str] = None + extra_kwargs: Dict[str, Any] = field(default_factory=dict) # Define all supported target configurations. @@ -168,6 +169,15 @@ class TargetConfig: model_var="AZURE_OPENAI_GPT5_MODEL", underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", ), + TargetConfig( + registry_name="azure_openai_gpt5_responses_high_reasoning", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_GPT5_KEY", + model_var="AZURE_OPENAI_GPT5_MODEL", + underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", + extra_kwargs={"extra_body_parameters": {"reasoning": {"effort": "high"}}}, + ), TargetConfig( registry_name="platform_openai_responses", target_class=OpenAIResponseTarget, @@ -243,12 +253,11 @@ class TargetConfig: # Video Targets (OpenAIVideoTarget) # ============================================ TargetConfig( - registry_name="azure_openai_video", + registry_name="openai_video", target_class=OpenAIVideoTarget, - endpoint_var="AZURE_OPENAI_VIDEO_ENDPOINT", - key_var="AZURE_OPENAI_VIDEO_KEY", - model_var="AZURE_OPENAI_VIDEO_MODEL", - underlying_model_var="AZURE_OPENAI_VIDEO_UNDERLYING_MODEL", + endpoint_var="OPENAI_VIDEO_ENDPOINT", + key_var="OPENAI_VIDEO_KEY", + model_var="OPENAI_VIDEO_MODEL", ), # ============================================ # Completion Targets (OpenAICompletionTarget) @@ -310,6 +319,7 @@ class AIRTTargetInitializer(PyRITInitializer): **OpenAI Responses Targets (OpenAIResponseTarget):** - AZURE_OPENAI_GPT5_RESPONSES_* - Azure OpenAI GPT-5 Responses + - AZURE_OPENAI_GPT5_RESPONSES_* (high reasoning) - Azure OpenAI GPT-5 Responses with high reasoning effort - PLATFORM_OPENAI_RESPONSES_* - Platform OpenAI Responses - AZURE_OPENAI_RESPONSES_* - Azure OpenAI Responses @@ -416,6 +426,10 @@ def _register_target(self, config: TargetConfig) -> None: if underlying_model is not None: kwargs["underlying_model"] = underlying_model + # Add any extra constructor kwargs (e.g. extra_body_parameters for reasoning) + if config.extra_kwargs: + kwargs.update(config.extra_kwargs) + target = config.target_class(**kwargs) registry = TargetRegistry.get_registry_singleton() registry.register_instance(target, name=config.registry_name) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 3bf264a505..f26d87a3e3 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -24,7 +24,7 @@ def test_init_with_defaults(self): assert context._database == frontend_core.SQLITE assert context._initialization_scripts is None - assert context._initializer_names is None + assert context._initializer_names == ["airt", "airt_targets"] assert context._log_level == logging.WARNING assert context._initialized is False @@ -53,9 +53,9 @@ def test_init_with_invalid_database(self): with pytest.raises(ValueError, match="Invalid database type"): frontend_core.FrontendCore(database="InvalidDB") - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) def test_initialize_loads_registries( self, mock_init_pyrit: AsyncMock, @@ -73,9 +73,9 @@ def test_initialize_loads_registries( mock_scenario_registry.get_registry_singleton.assert_called_once() mock_init_registry.assert_called_once() - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) async def test_scenario_registry_property_initializes( self, mock_init_pyrit: AsyncMock, @@ -92,9 +92,9 @@ async def test_scenario_registry_property_initializes( assert context._initialized is True assert registry is not None - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) async def test_initializer_registry_property_initializes( self, mock_init_pyrit: AsyncMock, @@ -249,7 +249,7 @@ def test_parse_memory_labels_non_string_key(self): class TestResolveInitializationScripts: """Tests for resolve_initialization_scripts function.""" - @patch("pyrit.registry.InitializerRegistry.resolve_script_paths") + @patch("pyrit.cli.frontend_core.InitializerRegistry.resolve_script_paths") def test_resolve_initialization_scripts(self, mock_resolve: MagicMock): """Test resolve_initialization_scripts calls InitializerRegistry.""" mock_resolve.return_value = [Path("/test/script.py")] @@ -302,7 +302,7 @@ async def test_list_initializers_without_discovery_path(self): assert result == [{"name": "test_init"}] mock_registry.list_metadata.assert_called_once() - @patch("pyrit.registry.InitializerRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") async def test_list_initializers_with_discovery_path(self, mock_init_registry_class: MagicMock): """Test list_initializers_async with discovery path.""" mock_registry = MagicMock() @@ -620,12 +620,12 @@ def test_parse_run_arguments_missing_value(self): class TestRunScenarioAsync: """Tests for run_scenario_async function.""" - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_basic( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running a basic scenario.""" # Mock context @@ -660,8 +660,8 @@ async def test_run_scenario_async_basic( mock_scenario_instance.run_async.assert_called_once() mock_printer.print_summary_async.assert_called_once_with(mock_result) - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - async def test_run_scenario_async_not_found(self, mock_init_pyrit: AsyncMock): + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + async def test_run_scenario_async_not_found(self, mock_run_init: AsyncMock): """Test running non-existent scenario raises ValueError.""" context = frontend_core.FrontendCore() mock_scenario_registry = MagicMock() @@ -678,12 +678,12 @@ async def test_run_scenario_async_not_found(self, mock_init_pyrit: AsyncMock): context=context, ) - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_strategies( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with strategies.""" context = frontend_core.FrontendCore() @@ -724,12 +724,12 @@ class MockStrategy(Enum): call_kwargs = mock_scenario_instance.initialize_async.call_args[1] assert "scenario_strategies" in call_kwargs - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_initializers( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with initializers.""" context = frontend_core.FrontendCore(initializer_names=["test_init"]) @@ -763,12 +763,12 @@ async def test_run_scenario_async_with_initializers( # Verify initializer was retrieved mock_initializer_registry.get_class.assert_called_once_with("test_init") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_max_concurrency( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with max_concurrency.""" context = frontend_core.FrontendCore() @@ -802,12 +802,12 @@ async def test_run_scenario_async_with_max_concurrency( call_kwargs = mock_scenario_instance.initialize_async.call_args[1] assert call_kwargs["max_concurrency"] == 5 - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_without_print_summary( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario without printing summary.""" context = frontend_core.FrontendCore() diff --git a/tests/unit/cli/test_pyrit_backend.py b/tests/unit/cli/test_pyrit_backend.py new file mode 100644 index 0000000000..e1e54f0c06 --- /dev/null +++ b/tests/unit/cli/test_pyrit_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.cli import pyrit_backend + + +class TestParseArgs: + """Tests for pyrit_backend.parse_args.""" + + def test_parse_args_defaults(self) -> None: + """Should parse backend defaults correctly.""" + args = pyrit_backend.parse_args(args=[]) + + assert args.host == "0.0.0.0" + assert args.port == 8000 + assert args.config_file is None + + def test_parse_args_accepts_config_file(self) -> None: + """Should parse --config-file argument.""" + args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) + + assert args.config_file == "./custom_conf.yaml" + + +class TestInitializeAndRun: + """Tests for pyrit_backend.initialize_and_run.""" + + @pytest.mark.asyncio + async def test_initialize_and_run_passes_config_file_to_frontend_core(self) -> None: + """Should forward parsed config file path to FrontendCore.""" + parsed_args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) + + with ( + patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, + patch("pyrit.cli.pyrit_backend.uvicorn.Config") as mock_uvicorn_config, + patch("pyrit.cli.pyrit_backend.uvicorn.Server") as mock_uvicorn_server, + ): + mock_core = MagicMock() + mock_core.initialize_async = AsyncMock() + mock_core.run_initializers_async = AsyncMock() + mock_core_class.return_value = mock_core + + mock_server = MagicMock() + mock_server.serve = AsyncMock() + mock_uvicorn_server.return_value = mock_server + + result = await pyrit_backend.initialize_and_run(parsed_args=parsed_args) + + assert result == 0 + mock_core_class.assert_called_once() + assert mock_core_class.call_args.kwargs["config_file"] == Path("./custom_conf.yaml") + mock_uvicorn_config.assert_called_once() + mock_uvicorn_server.assert_called_once() + mock_server.serve.assert_awaited_once() diff --git a/tests/unit/setup/test_airt_initializer.py b/tests/unit/setup/test_airt_initializer.py index 434833681c..ebd121fa78 100644 --- a/tests/unit/setup/test_airt_initializer.py +++ b/tests/unit/setup/test_airt_initializer.py @@ -36,13 +36,10 @@ def setup_method(self) -> None: reset_default_values() # Set up required env vars for AIRT os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT"] = "https://test-converter.openai.azure.com" - os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] = "test_converter_key" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"] = "gpt-4" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2"] = "https://test-scorer.openai.azure.com" - os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2"] = "test_scorer_key" os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2"] = "gpt-4" os.environ["AZURE_CONTENT_SAFETY_API_ENDPOINT"] = "https://test-safety.cognitiveservices.azure.com" - os.environ["AZURE_CONTENT_SAFETY_API_KEY"] = "test_safety_key" # Clean up globals for attr in [ "default_converter_target", @@ -59,13 +56,10 @@ def teardown_method(self) -> None: # Clean up env vars for var in [ "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2", - "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2", "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL2", "AZURE_CONTENT_SAFETY_API_ENDPOINT", - "AZURE_CONTENT_SAFETY_API_KEY", ]: if var in os.environ: del os.environ[var] @@ -137,7 +131,7 @@ def test_validate_missing_multiple_env_vars_raises_error(self): """Test that validate raises error listing all missing env vars.""" # Remove multiple required env vars del os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT"] - del os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY"] + del os.environ["AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL"] init = AIRTInitializer() with pytest.raises(ValueError) as exc_info: @@ -145,7 +139,7 @@ def test_validate_missing_multiple_env_vars_raises_error(self): error_message = str(exc_info.value) assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in error_message - assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY" in error_message + assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_MODEL" in error_message class TestAIRTInitializerGetInfo: @@ -160,11 +154,8 @@ async def test_get_info_returns_expected_structure(self): assert info["class"] == "AIRTInitializer" assert "required_env_vars" in info assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT" in info["required_env_vars"] - assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY" in info["required_env_vars"] assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_ENDPOINT2" in info["required_env_vars"] - assert "AZURE_OPENAI_GPT4O_UNSAFE_CHAT_KEY2" in info["required_env_vars"] assert "AZURE_CONTENT_SAFETY_API_ENDPOINT" in info["required_env_vars"] - assert "AZURE_CONTENT_SAFETY_API_KEY" in info["required_env_vars"] async def test_get_info_includes_description(self): """Test that get_info_async includes the description field.""" diff --git a/tests/unit/setup/test_airt_targets_initializer.py b/tests/unit/setup/test_airt_targets_initializer.py index 356a6388d5..39571cfed8 100644 --- a/tests/unit/setup/test_airt_targets_initializer.py +++ b/tests/unit/setup/test_airt_targets_initializer.py @@ -165,6 +165,23 @@ async def test_registers_ollama_without_api_key(self): assert target is not None assert target._model_name == "llama2" + @pytest.mark.asyncio + async def test_registers_gpt5_high_reasoning_with_extra_body_parameters(self): + """Test that GPT-5 high-reasoning target has extra_body_parameters set.""" + os.environ["AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT"] = "https://gpt5.openai.azure.com" + os.environ["AZURE_OPENAI_GPT5_KEY"] = "test_key" + os.environ["AZURE_OPENAI_GPT5_MODEL"] = "gpt-5" + os.environ["AZURE_OPENAI_GPT5_UNDERLYING_MODEL"] = "gpt-5" + + init = AIRTTargetInitializer() + await init.initialize_async() + + registry = TargetRegistry.get_registry_singleton() + assert "azure_openai_gpt5_responses_high_reasoning" in registry + target = registry.get_instance_by_name("azure_openai_gpt5_responses_high_reasoning") + assert target is not None + assert target._extra_body_parameters == {"reasoning": {"effort": "high"}} + @pytest.mark.usefixtures("patch_central_database") class TestAIRTTargetInitializerTargetConfigs: From e453d4d5d22c4706aff014d352a98da5a0871aca Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 21:30:16 -0800 Subject: [PATCH 2/2] fix: address copilot comments - video target config, docstring, tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/setup/initialization.py | 2 ++ pyrit/setup/initializers/airt_targets.py | 8 ++++---- tests/unit/cli/test_frontend_core.py | 6 ++++-- tests/unit/setup/test_initialization.py | 26 ++++++++++++++++++++++++ 4 files changed, 36 insertions(+), 6 deletions(-) diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 64127e9f8f..2f29c38c1a 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -305,6 +305,8 @@ async def run_initializers_async( Raises: ValueError: If initializers are invalid or scripts cannot be loaded. + FileNotFoundError: If an initialization script path does not exist. + Exception: If an initialization script fails to import or execute. """ all_initializers = list(initializers) if initializers else [] diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index bc04ec05d0..f3a9b0daa2 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -254,11 +254,11 @@ class TargetConfig: # Video Targets (OpenAIVideoTarget) # ============================================ TargetConfig( - registry_name="openai_video", + registry_name="azure_openai_video", target_class=OpenAIVideoTarget, - endpoint_var="OPENAI_VIDEO_ENDPOINT", - key_var="OPENAI_VIDEO_KEY", - model_var="OPENAI_VIDEO_MODEL", + endpoint_var="AZURE_OPENAI_VIDEO_ENDPOINT", + key_var="AZURE_OPENAI_VIDEO_KEY", + model_var="AZURE_OPENAI_VIDEO_MODEL", ), # ============================================ # Completion Targets (OpenAICompletionTarget) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index f26d87a3e3..4f9e3b87e8 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -18,13 +18,15 @@ class TestFrontendCore: """Tests for FrontendCore class.""" - def test_init_with_defaults(self): + @patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_init_with_defaults(self, mock_default_config_path): """Test initialization with default parameters.""" + mock_default_config_path.exists.return_value = False context = frontend_core.FrontendCore() assert context._database == frontend_core.SQLITE assert context._initialization_scripts is None - assert context._initializer_names == ["airt", "airt_targets"] + assert context._initializer_names is None assert context._log_level == logging.WARNING assert context._initialized is False diff --git a/tests/unit/setup/test_initialization.py b/tests/unit/setup/test_initialization.py index 9c386ba68b..8e36683f6b 100644 --- a/tests/unit/setup/test_initialization.py +++ b/tests/unit/setup/test_initialization.py @@ -13,6 +13,7 @@ from pyrit.setup.initialization import ( _load_environment_files, _load_initializers_from_scripts, + run_initializers_async, ) @@ -54,6 +55,31 @@ def test_script_not_found_raises_error(self): _load_initializers_from_scripts(script_paths=["nonexistent_script.py"]) +class TestRunInitializersAsync: + """Tests for run_initializers_async function.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization._execute_initializers_async") + async def test_calls_execute_when_initializers_provided(self, mock_execute): + """Test that it calls _execute_initializers_async when initializers are provided.""" + mock_initializer = mock.MagicMock() + await run_initializers_async(initializers=[mock_initializer]) + mock_execute.assert_called_once() + assert mock_initializer in mock_execute.call_args.kwargs["initializers"] + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization._execute_initializers_async") + @mock.patch("pyrit.setup.initialization._load_initializers_from_scripts") + async def test_loads_scripts_when_scripts_provided(self, mock_load_scripts, mock_execute): + """Test that it loads scripts via _load_initializers_from_scripts when scripts are provided.""" + mock_script_init = mock.MagicMock() + mock_load_scripts.return_value = [mock_script_init] + await run_initializers_async(initialization_scripts=["test_script.py"]) + mock_load_scripts.assert_called_once_with(script_paths=["test_script.py"]) + mock_execute.assert_called_once() + assert mock_script_init in mock_execute.call_args.kwargs["initializers"] + + class TestInitializePyrit: """Tests for initialize_pyrit_async function - basic orchestration tests."""