Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 38 additions & 42 deletions pyrit/cli/frontend_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@

from __future__ import annotations

import argparse
import inspect
import json
import logging
import sys
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

from pyrit.setup import ConfigurationLoader
from pyrit.registry import InitializerRegistry, ScenarioRegistry
from pyrit.scenario import DatasetConfiguration
from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter
from pyrit.setup import ConfigurationLoader, initialize_pyrit_async, run_initializers_async
from pyrit.setup.configuration_loader import _MEMORY_DB_TYPE_MAP

try:
Expand All @@ -47,9 +52,7 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i
from pyrit.models.scenario_result import ScenarioResult
from pyrit.registry import (
InitializerMetadata,
InitializerRegistry,
ScenarioMetadata,
ScenarioRegistry,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,14 +144,17 @@ def __init__(
logging.basicConfig(level=self._log_level)

async def initialize_async(self) -> None:
"""Initialize PyRIT and load registries (heavy operation)."""
"""
Initialize PyRIT and load registries (heavy operation).

Sets up memory and loads scenario/initializer registries.
Initializers are NOT run here — they are run separately
(per-scenario in pyrit_scan, or up-front in pyrit_backend).
"""
if self._initialized:
return

from pyrit.registry import InitializerRegistry, ScenarioRegistry
from pyrit.setup import initialize_pyrit_async

# Initialize PyRIT without initializers (they run per-scenario)
# Initialize PyRIT without initializers (they run separately)
await initialize_pyrit_async(
memory_db_type=self._database,
initialization_scripts=None,
Expand All @@ -167,6 +173,27 @@ async def initialize_async(self) -> None:

self._initialized = True

async def run_initializers_async(self) -> None:
"""
Resolve and run all configured initializers and initialization scripts.

Must be called after :meth:`initialize_async` so that registries are
available to resolve initializer names. This is the same pattern used
by :func:`run_scenario_async` before executing a scenario.

If no initializers are configured this is a no-op.
"""
initializer_instances = None
if self._initializer_names:
print(f"Running {len(self._initializer_names)} initializer(s)...")
sys.stdout.flush()
initializer_instances = [self.initializer_registry.get_class(name)() for name in self._initializer_names]

await run_initializers_async(
initializers=initializer_instances,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably too big for here, but one thing on my mind is it'd be nice to pass constructor arguments to initializers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also probably execution order; rn it's built into the class but maybe it should actually be the order we call it or something.

initialization_scripts=self._initialization_scripts,
)

@property
def scenario_registry(self) -> ScenarioRegistry:
"""
Expand Down Expand Up @@ -227,8 +254,6 @@ async def list_initializers_async(
Sequence of initializer metadata dictionaries describing each initializer class.
"""
if discovery_path:
from pyrit.registry import InitializerRegistry

registry = InitializerRegistry(discovery_path=discovery_path)
return registry.list_metadata()

Expand Down Expand Up @@ -276,34 +301,13 @@ async def run_scenario_async(
Note:
Initializers from PyRITContext will be run before the scenario executes.
"""
from pyrit.scenario.printer.console_printer import ConsoleScenarioResultPrinter
from pyrit.setup import initialize_pyrit_async

# Ensure context is initialized first (loads registries)
# This must happen BEFORE we run initializers to avoid double-initialization
if not context._initialized:
await context.initialize_async()

# Run initializers before scenario
initializer_instances = None
if context._initializer_names:
print(f"Running {len(context._initializer_names)} initializer(s)...")
sys.stdout.flush()

initializer_instances = []

for name in context._initializer_names:
initializer_class = context.initializer_registry.get_class(name)
initializer_instances.append(initializer_class())

# Re-initialize PyRIT with the scenario-specific initializers
# This resets memory and applies initializer defaults
await initialize_pyrit_async(
memory_db_type=context._database,
initialization_scripts=context._initialization_scripts,
initializers=initializer_instances,
env_files=context._env_files,
)
# Resolve and run initializers + initialization scripts
await context.run_initializers_async()

# Get scenario class
scenario_class = context.scenario_registry.get_class(scenario_name)
Expand Down Expand Up @@ -343,8 +347,6 @@ async def run_scenario_async(
# - max_dataset_size only: default datasets with overridden limit
if dataset_names:
# User specified dataset names - create new config (fetches all unless max_dataset_size set)
from pyrit.scenario import DatasetConfiguration

init_kwargs["dataset_config"] = DatasetConfiguration(
dataset_names=dataset_names,
max_dataset_size=max_dataset_size,
Expand Down Expand Up @@ -599,8 +601,6 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A
Raises:
ValueError: If validator_func has no parameters.
"""
import inspect

# Get the first parameter name from the function signature
sig = inspect.signature(validator_func)
params = list(sig.parameters.keys())
Expand All @@ -609,13 +609,11 @@ def _argparse_validator(validator_func: Callable[..., Any]) -> Callable[[Any], A
first_param = params[0]

def wrapper(value: Any) -> Any:
import argparse as ap

try:
# Call with keyword argument to support keyword-only parameters
return validator_func(**{first_param: value})
except ValueError as e:
raise ap.ArgumentTypeError(str(e)) from e
raise argparse.ArgumentTypeError(str(e)) from e

# Preserve function metadata for better debugging
wrapper.__name__ = getattr(validator_func, "__name__", "argparse_validator")
Expand All @@ -636,8 +634,6 @@ def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]:
Raises:
FileNotFoundError: If a script path does not exist.
"""
from pyrit.registry import InitializerRegistry

return InitializerRegistry.resolve_script_paths(script_paths=script_paths)


Expand Down
80 changes: 32 additions & 48 deletions pyrit/cli/pyrit_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import asyncio
import sys
from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter
from pathlib import Path
from typing import Optional

# Ensure emoji and other Unicode characters don't crash on Windows consoles
# that use legacy encodings like cp1252.
sys.stdout.reconfigure(errors="replace") # type: ignore[union-attr]
sys.stderr.reconfigure(errors="replace") # type: ignore[union-attr]

import uvicorn

from pyrit.cli import frontend_core


Expand All @@ -41,6 +44,9 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace:
# Start with custom initialization scripts
pyrit_backend --initialization-scripts ./my_targets.py
# Start with explicit config file
pyrit_backend --config-file ./my_backend_conf.yaml
# Start with custom port and host
pyrit_backend --host 0.0.0.0 --port 8080
Expand Down Expand Up @@ -80,13 +86,19 @@ def parse_args(*, args: Optional[list[str]] = None) -> Namespace:
parser.add_argument(
"--database",
type=frontend_core.validate_database_argparse,
default=frontend_core.SQLITE,
default=None,
help=(
f"Database type to use for memory storage ({frontend_core.IN_MEMORY}, "
f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: {frontend_core.SQLITE})"
f"{frontend_core.SQLITE}, {frontend_core.AZURE_SQL}) (default: from config file, or {frontend_core.SQLITE})"
),
)

parser.add_argument(
"--config-file",
type=str,
help=frontend_core.ARG_HELP["config_file"],
)

parser.add_argument(
"--initializers",
type=str,
Expand Down Expand Up @@ -124,55 +136,27 @@ async def initialize_and_run(*, parsed_args: Namespace) -> int:
Returns:
int: Exit code (0 for success, 1 for error).
"""
from pyrit.setup import initialize_pyrit_async

# Resolve initialization scripts if provided
initialization_scripts = None
if parsed_args.initialization_scripts:
try:
initialization_scripts = frontend_core.resolve_initialization_scripts(
script_paths=parsed_args.initialization_scripts
)
except FileNotFoundError as e:
print(f"Error: {e}")
return 1

# Resolve env files if provided
env_files = None
if parsed_args.env_files:
try:
env_files = frontend_core.resolve_env_files(env_file_paths=parsed_args.env_files)
except ValueError as e:
print(f"Error: {e}")
return 1

# Resolve initializer instances if names provided
initializer_instances = None
if parsed_args.initializers:
from pyrit.registry import InitializerRegistry

registry = InitializerRegistry()
initializer_instances = []
for name in parsed_args.initializers:
try:
initializer_class = registry.get_class(name)
initializer_instances.append(initializer_class())
except Exception as e:
print(f"Error: Could not load initializer '{name}': {e}")
return 1

# Initialize PyRIT with the provided configuration
try:
core = frontend_core.FrontendCore(
config_file=Path(parsed_args.config_file) if parsed_args.config_file else None,
database=parsed_args.database,
initialization_scripts=(
[Path(p) for p in parsed_args.initialization_scripts] if parsed_args.initialization_scripts else None
),
initializer_names=parsed_args.initializers,
env_files=([Path(p) for p in parsed_args.env_files] if parsed_args.env_files else None),
log_level=parsed_args.log_level,
)
except (ValueError, FileNotFoundError) as e:
print(f"Error: {e}")
return 1

# Initialize memory, registries, and run initializers.
print("🔧 Initializing PyRIT...")
await initialize_pyrit_async(
memory_db_type=parsed_args.database,
initialization_scripts=initialization_scripts,
initializers=initializer_instances,
env_files=env_files,
)
await core.initialize_async()
await core.run_initializers_async()

# Start uvicorn server
import uvicorn

print(f"🚀 Starting PyRIT backend on http://{parsed_args.host}:{parsed_args.port}")
print(f" API Docs: http://{parsed_args.host}:{parsed_args.port}/docs")

Expand Down
10 changes: 9 additions & 1 deletion pyrit/setup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
"""Module containing initialization PyRIT."""

from pyrit.setup.configuration_loader import ConfigurationLoader, initialize_from_config_async
from pyrit.setup.initialization import AZURE_SQL, IN_MEMORY, SQLITE, MemoryDatabaseType, initialize_pyrit_async
from pyrit.setup.initialization import (
AZURE_SQL,
IN_MEMORY,
SQLITE,
MemoryDatabaseType,
initialize_pyrit_async,
run_initializers_async,
)

__all__ = [
"AZURE_SQL",
"SQLITE",
"IN_MEMORY",
"initialize_pyrit_async",
"initialize_from_config_async",
"run_initializers_async",
"MemoryDatabaseType",
"ConfigurationLoader",
]
25 changes: 22 additions & 3 deletions pyrit/setup/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,33 @@ async def initialize_pyrit_async(
)
CentralMemory.set_memory_instance(memory)

# Combine directly provided initializers with those loaded from scripts
await run_initializers_async(initializers=initializers, initialization_scripts=initialization_scripts)


async def run_initializers_async(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about this as is. I think there are three things we need to make sure we've thought through:

  1. Idempotency: There are initializers that can re-register or duplicate things if you re-run
  2. reset_default_values: Right now this is called on setup, but if we run initializers directly this is not wiped clean. E.g. if you run SimpleInitializer and then later AIRTInitializer they can both set the same converter_target on overlapping scopes. Without a reset, you get states that are tricky to debug.
  3. Precondition enforcement: We'd need to verify initialize_pyrit is called first, env files are loaded, etc.

Because these are all lumped together, I like keeping it bundled. There is some performance hit because we're nuking the state and resetting things, reloading env... but I don't think it's a huge deal. And in some ways it's nice (e.g. if we update the .env file). It's a bit awkard to re-call, but I think it makes it less error prone.

The biggest reason I think we'd want to separate is if we're doing additive initializer execution, but I don't think that's something we want near term?

*,
initializers: Optional[Sequence["PyRITInitializer"]] = None,
initialization_scripts: Optional[Sequence[Union[str, pathlib.Path]]] = None,
) -> None:
"""
Run initializers and initialization scripts without re-initializing memory or environment.

This is useful when memory and environment are already set up (e.g. via
:func:`initialize_pyrit_async`) and only the initializer step needs to run.

Args:
initializers: Optional sequence of PyRITInitializer instances to execute directly.
initialization_scripts: Optional sequence of Python script paths containing
PyRITInitializer classes.

Raises:
ValueError: If initializers are invalid or scripts cannot be loaded.
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The run_initializers_async docstring says it raises ValueError when scripts cannot be loaded, but _load_initializers_from_scripts() can also raise FileNotFoundError (and potentially other exceptions during import). Please update the Raises: section to reflect the actual exception types callers should handle.

Suggested change
ValueError: If initializers are invalid or scripts cannot be loaded.
ValueError: If one or more provided initializers are invalid.
FileNotFoundError: If an initialization script path does not exist.
Exception: If an error occurs while importing initialization scripts or executing initializers.

Copilot uses AI. Check for mistakes.
"""
all_initializers = list(initializers) if initializers else []

# Load additional initializers from scripts
if initialization_scripts:
script_initializers = _load_initializers_from_scripts(script_paths=initialization_scripts)
all_initializers.extend(script_initializers)

# Execute all initializers (sorted by execution_order)
if all_initializers:
await _execute_initializers_async(initializers=all_initializers)
Comment on lines +290 to 316
Copy link

Copilot AI Mar 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_initializers_async is a new public API surface but there are no unit tests covering its behavior (e.g., that it loads initializers from scripts, executes in execution_order, and does not reset defaults / reinitialize memory). Since initialize_pyrit_async already has orchestration tests, consider adding focused tests for this new function as well.

Copilot uses AI. Check for mistakes.
Loading
Loading