diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index b131b3cf4f..b3ad2a770e 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -15,13 +15,18 @@ from __future__ import annotations +import argparse +import inspect import json import logging import sys from pathlib import Path from typing import TYPE_CHECKING, Any, Optional -from pyrit.setup import ConfigurationLoader +from pyrit.registry import InitializerRegistry, ScenarioRegistry +from pyrit.scenario import DatasetConfiguration +from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter +from pyrit.setup import ConfigurationLoader, initialize_pyrit_async, run_initializers_async from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP try: @@ -47,9 +52,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i from pyrit.models.scenario_result import ScenarioResult from pyrit.registry import ( InitializerMetadata, - InitializerRegistry, ScenarioMetadata, - ScenarioRegistry, ) logger = logging.getLogger(__name__) @@ -141,14 +144,17 @@ def __init__( logging.basicConfig(level=self._log_level) async def initialize_async(self) -> None: - """Initialize PyRIT and load registries (heavy operation).""" + """ + Initialize PyRIT and load registries (heavy operation). + + Sets up memory and loads scenario/initializer registries. + Initializers are NOT run here — they are run separately + (per-scenario in pyrit_scan, or up-front in pyrit_backend). + """ if self._initialized: return - from pyrit.registry import InitializerRegistry, ScenarioRegistry - from pyrit.setup import initialize_pyrit_async - - # Initialize PyRIT without initializers (they run per-scenario) + # Initialize PyRIT without initializers (they run separately) await initialize_pyrit_async( memory_db_type=self._database, initialization_scripts=None, @@ -167,6 +173,27 @@ async def initialize_async(self) -> None: self._initialized = True + async def run_initializers_async(self) -> None: + """ + Resolve and run all configured initializers and initialization scripts. + + Must be called after :meth:`initialize_async` so that registries are + available to resolve initializer names. This is the same pattern used + by :func:`run_scenario_async` before executing a scenario. + + If no initializers are configured this is a no-op. + """ + initializer_instances = None + if self._initializer_names: + print(f"Running {len(self._initializer_names)} initializer(s)...") + sys.stdout.flush() + initializer_instances = [self.initializer_registry.get_class(name)() for name in self._initializer_names] + + await run_initializers_async( + initializers=initializer_instances, + initialization_scripts=self._initialization_scripts, + ) + @property def scenario_registry(self) -> ScenarioRegistry: """ @@ -227,8 +254,6 @@ async def list_initializers_async( Sequence of initializer metadata dictionaries describing each initializer class. """ if discovery_path: - from pyrit.registry import InitializerRegistry - registry = InitializerRegistry(discovery_path=discovery_path) return registry.list_metadata() @@ -276,34 +301,13 @@ async def run_scenario_async( Note: Initializers from PyRITContext will be run before the scenario executes. """ - from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter - from pyrit.setup import initialize_pyrit_async - # Ensure context is initialized first (loads registries) # This must happen BEFORE we run initializers to avoid double-initialization if not context._initialized: await context.initialize_async() - # Run initializers before scenario - initializer_instances = None - if context._initializer_names: - print(f"Running {len(context._initializer_names)} initializer(s)...") - sys.stdout.flush() - - initializer_instances = [] - - for name in context._initializer_names: - initializer_class = context.initializer_registry.get_class(name) - initializer_instances.append(initializer_class()) - - # Re-initialize PyRIT with the scenario-specific initializers - # This resets memory and applies initializer defaults - await initialize_pyrit_async( - memory_db_type=context._database, - initialization_scripts=context._initialization_scripts, - initializers=initializer_instances, - env_files=context._env_files, - ) + # Resolve and run initializers + initialization scripts + await context.run_initializers_async() # Get scenario class scenario_class = context.scenario_registry.get_class(scenario_name) @@ -343,8 +347,6 @@ async def run_scenario_async( # - max_dataset_size only: default datasets with overridden limit if dataset_names: # User specified dataset names - create new config (fetches all unless max_dataset_size set) - from pyrit.scenario import DatasetConfiguration - init_kwargs["dataset_config"] = DatasetConfiguration( dataset_names=dataset_names, max_dataset_size=max_dataset_size, @@ -599,8 +601,6 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A Raises: ValueError: If validator_func has no parameters. """ - import inspect - # Get the first parameter name from the function signature sig = inspect.signature(validator_func) params = list(sig.parameters.keys()) @@ -609,13 +609,11 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A first_param = params[0] def wrapper(value: Any) -> Any: - import argparse as ap - try: # Call with keyword argument to support keyword-only parameters return validator_func(**{first_param: value}) except ValueError as e: - raise ap.ArgumentTypeError(str(e)) from e + raise argparse.ArgumentTypeError(str(e)) from e # Preserve function metadata for better debugging wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator") @@ -636,8 +634,6 @@ def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: Raises: FileNotFoundError: If a script path does not exist. """ - from pyrit.registry import InitializerRegistry - return InitializerRegistry.resolve_script_paths(script_paths=script_paths) diff --git a/pyrit/setup/__init__.py b/pyrit/setup/__init__.py index 2929a59ea3..2b0823e0f3 100644 --- a/pyrit/setup/__init__.py +++ b/pyrit/setup/__init__.py @@ -4,7 +4,14 @@ """Module containing initialization PyRIT.""" from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async -from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async +from pyrit.setup.initialization import ( + AZURE_SQL, + IN_MEMORY, + SQLITE, + MemoryDatabaseType, + initialize_pyrit_async, + run_initializers_async, +) __all__ = [ "AZURE_SQL", @@ -12,6 +19,7 @@ "IN_MEMORY", "initialize_pyrit_async", "initialize_from_config_async", + "run_initializers_async", "MemoryDatabaseType", "ConfigurationLoader", ] diff --git a/pyrit/setup/initialization.py b/pyrit/setup/initialization.py index 0aff8deafc..2f29c38c1a 100644 --- a/pyrit/setup/initialization.py +++ b/pyrit/setup/initialization.py @@ -284,14 +284,35 @@ async def initialize_pyrit_async( ) CentralMemory.set_memory_instance(memory) - # Combine directly provided initializers with those loaded from scripts + await run_initializers_async(initializers=initializers, initialization_scripts=initialization_scripts) + + +async def run_initializers_async( + *, + initializers: Optional[Sequence["PyRITInitializer"]] = None, + initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None, +) -> None: + """ + Run initializers and initialization scripts without re-initializing memory or environment. + + This is useful when memory and environment are already set up (e.g. via + :func:`initialize_pyrit_async`) and only the initializer step needs to run. + + Args: + initializers: Optional sequence of PyRITInitializer instances to execute directly. + initialization_scripts: Optional sequence of Python script paths containing + PyRITInitializer classes. + + Raises: + ValueError: If initializers are invalid or scripts cannot be loaded. + FileNotFoundError: If an initialization script path does not exist. + Exception: If an initialization script fails to import or execute. + """ all_initializers = list(initializers) if initializers else [] - # Load additional initializers from scripts if initialization_scripts: script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts) all_initializers.extend(script_initializers) - # Execute all initializers (sorted by execution_order) if all_initializers: await _execute_initializers_async(initializers=all_initializers) diff --git a/pyrit/setup/initializers/airt_targets.py b/pyrit/setup/initializers/airt_targets.py index 5f335eb052..f3a9b0daa2 100644 --- a/pyrit/setup/initializers/airt_targets.py +++ b/pyrit/setup/initializers/airt_targets.py @@ -14,7 +14,7 @@ import logging import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional from pyrit.auth import get_azure_openai_auth, get_azure_token_provider @@ -46,6 +46,7 @@ class TargetConfig: key_var: str = "" # Empty string means no auth required model_var: Optional[str] = None underlying_model_var: Optional[str] = None + extra_kwargs: dict[str, Any] = field(default_factory=dict) # Define all supported target configurations. @@ -169,6 +170,15 @@ class TargetConfig: model_var="AZURE_OPENAI_GPT5_MODEL", underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", ), + TargetConfig( + registry_name="azure_openai_gpt5_responses_high_reasoning", + target_class=OpenAIResponseTarget, + endpoint_var="AZURE_OPENAI_GPT5_RESPONSES_ENDPOINT", + key_var="AZURE_OPENAI_GPT5_KEY", + model_var="AZURE_OPENAI_GPT5_MODEL", + underlying_model_var="AZURE_OPENAI_GPT5_UNDERLYING_MODEL", + extra_kwargs={"extra_body_parameters": {"reasoning": {"effort": "high"}}}, + ), TargetConfig( registry_name="platform_openai_responses", target_class=OpenAIResponseTarget, @@ -249,7 +259,6 @@ class TargetConfig: endpoint_var="AZURE_OPENAI_VIDEO_ENDPOINT", key_var="AZURE_OPENAI_VIDEO_KEY", model_var="AZURE_OPENAI_VIDEO_MODEL", - underlying_model_var="AZURE_OPENAI_VIDEO_UNDERLYING_MODEL", ), # ============================================ # Completion Targets (OpenAICompletionTarget) @@ -311,6 +320,7 @@ class AIRTTargetInitializer(PyRITInitializer): **OpenAI Responses Targets (OpenAIResponseTarget):** - AZURE_OPENAI_GPT5_RESPONSES_* - Azure OpenAI GPT-5 Responses + - AZURE_OPENAI_GPT5_RESPONSES_* (high reasoning) - Azure OpenAI GPT-5 Responses with high reasoning effort - PLATFORM_OPENAI_RESPONSES_* - Platform OpenAI Responses - AZURE_OPENAI_RESPONSES_* - Azure OpenAI Responses @@ -426,6 +436,10 @@ def _register_target(self, config: TargetConfig) -> None: if underlying_model is not None: kwargs["underlying_model"] = underlying_model + # Add any extra constructor kwargs (e.g. extra_body_parameters for reasoning) + if config.extra_kwargs: + kwargs.update(config.extra_kwargs) + target = config.target_class(**kwargs) registry = TargetRegistry.get_registry_singleton() registry.register_instance(target, name=config.registry_name) diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 3bf264a505..4f9e3b87e8 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -18,8 +18,10 @@ class TestFrontendCore: """Tests for FrontendCore class.""" - def test_init_with_defaults(self): + @patch("pyrit.setup.configuration_loader.DEFAULT_CONFIG_PATH") + def test_init_with_defaults(self, mock_default_config_path): """Test initialization with default parameters.""" + mock_default_config_path.exists.return_value = False context = frontend_core.FrontendCore() assert context._database == frontend_core.SQLITE @@ -53,9 +55,9 @@ def test_init_with_invalid_database(self): with pytest.raises(ValueError, match="Invalid database type"): frontend_core.FrontendCore(database="InvalidDB") - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) def test_initialize_loads_registries( self, mock_init_pyrit: AsyncMock, @@ -73,9 +75,9 @@ def test_initialize_loads_registries( mock_scenario_registry.get_registry_singleton.assert_called_once() mock_init_registry.assert_called_once() - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) async def test_scenario_registry_property_initializes( self, mock_init_pyrit: AsyncMock, @@ -92,9 +94,9 @@ async def test_scenario_registry_property_initializes( assert context._initialized is True assert registry is not None - @patch("pyrit.registry.ScenarioRegistry") - @patch("pyrit.registry.InitializerRegistry") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ScenarioRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") + @patch("pyrit.cli.frontend_core.initialize_pyrit_async", new_callable=AsyncMock) async def test_initializer_registry_property_initializes( self, mock_init_pyrit: AsyncMock, @@ -249,7 +251,7 @@ def test_parse_memory_labels_non_string_key(self): class TestResolveInitializationScripts: """Tests for resolve_initialization_scripts function.""" - @patch("pyrit.registry.InitializerRegistry.resolve_script_paths") + @patch("pyrit.cli.frontend_core.InitializerRegistry.resolve_script_paths") def test_resolve_initialization_scripts(self, mock_resolve: MagicMock): """Test resolve_initialization_scripts calls InitializerRegistry.""" mock_resolve.return_value = [Path("/test/script.py")] @@ -302,7 +304,7 @@ async def test_list_initializers_without_discovery_path(self): assert result == [{"name": "test_init"}] mock_registry.list_metadata.assert_called_once() - @patch("pyrit.registry.InitializerRegistry") + @patch("pyrit.cli.frontend_core.InitializerRegistry") async def test_list_initializers_with_discovery_path(self, mock_init_registry_class: MagicMock): """Test list_initializers_async with discovery path.""" mock_registry = MagicMock() @@ -620,12 +622,12 @@ def test_parse_run_arguments_missing_value(self): class TestRunScenarioAsync: """Tests for run_scenario_async function.""" - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_basic( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running a basic scenario.""" # Mock context @@ -660,8 +662,8 @@ async def test_run_scenario_async_basic( mock_scenario_instance.run_async.assert_called_once() mock_printer.print_summary_async.assert_called_once_with(mock_result) - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - async def test_run_scenario_async_not_found(self, mock_init_pyrit: AsyncMock): + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + async def test_run_scenario_async_not_found(self, mock_run_init: AsyncMock): """Test running non-existent scenario raises ValueError.""" context = frontend_core.FrontendCore() mock_scenario_registry = MagicMock() @@ -678,12 +680,12 @@ async def test_run_scenario_async_not_found(self, mock_init_pyrit: AsyncMock): context=context, ) - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_strategies( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with strategies.""" context = frontend_core.FrontendCore() @@ -724,12 +726,12 @@ class MockStrategy(Enum): call_kwargs = mock_scenario_instance.initialize_async.call_args[1] assert "scenario_strategies" in call_kwargs - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_initializers( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with initializers.""" context = frontend_core.FrontendCore(initializer_names=["test_init"]) @@ -763,12 +765,12 @@ async def test_run_scenario_async_with_initializers( # Verify initializer was retrieved mock_initializer_registry.get_class.assert_called_once_with("test_init") - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_with_max_concurrency( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario with max_concurrency.""" context = frontend_core.FrontendCore() @@ -802,12 +804,12 @@ async def test_run_scenario_async_with_max_concurrency( call_kwargs = mock_scenario_instance.initialize_async.call_args[1] assert call_kwargs["max_concurrency"] == 5 - @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) - @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") + @patch("pyrit.cli.frontend_core.run_initializers_async", new_callable=AsyncMock) + @patch("pyrit.cli.frontend_core.ConsoleScenarioResultPrinter") async def test_run_scenario_async_without_print_summary( self, mock_printer_class: MagicMock, - mock_init_pyrit: AsyncMock, + mock_run_init: AsyncMock, ): """Test running scenario without printing summary.""" context = frontend_core.FrontendCore() diff --git a/tests/unit/cli/test_pyrit_backend.py b/tests/unit/cli/test_pyrit_backend.py new file mode 100644 index 0000000000..e1e54f0c06 --- /dev/null +++ b/tests/unit/cli/test_pyrit_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.cli import pyrit_backend + + +class TestParseArgs: + """Tests for pyrit_backend.parse_args.""" + + def test_parse_args_defaults(self) -> None: + """Should parse backend defaults correctly.""" + args = pyrit_backend.parse_args(args=[]) + + assert args.host == "0.0.0.0" + assert args.port == 8000 + assert args.config_file is None + + def test_parse_args_accepts_config_file(self) -> None: + """Should parse --config-file argument.""" + args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) + + assert args.config_file == "./custom_conf.yaml" + + +class TestInitializeAndRun: + """Tests for pyrit_backend.initialize_and_run.""" + + @pytest.mark.asyncio + async def test_initialize_and_run_passes_config_file_to_frontend_core(self) -> None: + """Should forward parsed config file path to FrontendCore.""" + parsed_args = pyrit_backend.parse_args(args=["--config-file", "./custom_conf.yaml"]) + + with ( + patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, + patch("pyrit.cli.pyrit_backend.uvicorn.Config") as mock_uvicorn_config, + patch("pyrit.cli.pyrit_backend.uvicorn.Server") as mock_uvicorn_server, + ): + mock_core = MagicMock() + mock_core.initialize_async = AsyncMock() + mock_core.run_initializers_async = AsyncMock() + mock_core_class.return_value = mock_core + + mock_server = MagicMock() + mock_server.serve = AsyncMock() + mock_uvicorn_server.return_value = mock_server + + result = await pyrit_backend.initialize_and_run(parsed_args=parsed_args) + + assert result == 0 + mock_core_class.assert_called_once() + assert mock_core_class.call_args.kwargs["config_file"] == Path("./custom_conf.yaml") + mock_uvicorn_config.assert_called_once() + mock_uvicorn_server.assert_called_once() + mock_server.serve.assert_awaited_once() diff --git a/tests/unit/setup/test_initialization.py b/tests/unit/setup/test_initialization.py index 9c386ba68b..8e36683f6b 100644 --- a/tests/unit/setup/test_initialization.py +++ b/tests/unit/setup/test_initialization.py @@ -13,6 +13,7 @@ from pyrit.setup.initialization import ( _load_environment_files, _load_initializers_from_scripts, + run_initializers_async, ) @@ -54,6 +55,31 @@ def test_script_not_found_raises_error(self): _load_initializers_from_scripts(script_paths=["nonexistent_script.py"]) +class TestRunInitializersAsync: + """Tests for run_initializers_async function.""" + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization._execute_initializers_async") + async def test_calls_execute_when_initializers_provided(self, mock_execute): + """Test that it calls _execute_initializers_async when initializers are provided.""" + mock_initializer = mock.MagicMock() + await run_initializers_async(initializers=[mock_initializer]) + mock_execute.assert_called_once() + assert mock_initializer in mock_execute.call_args.kwargs["initializers"] + + @pytest.mark.asyncio + @mock.patch("pyrit.setup.initialization._execute_initializers_async") + @mock.patch("pyrit.setup.initialization._load_initializers_from_scripts") + async def test_loads_scripts_when_scripts_provided(self, mock_load_scripts, mock_execute): + """Test that it loads scripts via _load_initializers_from_scripts when scripts are provided.""" + mock_script_init = mock.MagicMock() + mock_load_scripts.return_value = [mock_script_init] + await run_initializers_async(initialization_scripts=["test_script.py"]) + mock_load_scripts.assert_called_once_with(script_paths=["test_script.py"]) + mock_execute.assert_called_once() + assert mock_script_init in mock_execute.call_args.kwargs["initializers"] + + class TestInitializePyrit: """Tests for initialize_pyrit_async function - basic orchestration tests."""