From aedcc98cc46433c70ff4035cb0db492870710c91 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Sat, 28 Feb 2026 14:52:53 +0000 Subject: [PATCH 1/6] Expand memory interface and models for attack results - Add conversation_stats model and attack_result extensions - Add get_attack_results with filtering by harm categories, labels, attack type, and converter types to memory interface - Implement SQLite-specific JSON filtering for attack results - Add memory_models field for targeted_harm_categories - Add prompt_metadata support to openai image/video/response targets - Fix missing return statements in SQLite harm_category and label filters Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/azure_sql_memory.py | 119 +++++++-- pyrit/memory/memory_interface.py | 170 ++++++++++--- pyrit/memory/memory_models.py | 1 + pyrit/memory/sqlite_memory.py | 142 ++++++++--- pyrit/models/__init__.py | 2 + pyrit/models/attack_result.py | 4 + pyrit/models/conversation_stats.py | 23 ++ .../openai/openai_image_target.py | 10 + .../openai/openai_realtime_target.py | 3 +- .../openai/openai_response_target.py | 6 + pyrit/prompt_target/openai/openai_target.py | 6 +- .../openai/openai_video_target.py | 10 + .../test_interface_attack_results.py | 240 +++++++++++++----- tests/unit/memory/test_sqlite_memory.py | 120 +++++++++ tests/unit/target/test_image_target.py | 18 ++ tests/unit/target/test_video_target.py | 17 ++ 16 files changed, 749 insertions(+), 142 deletions(-) create mode 100644 pyrit/models/conversation_stats.py diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 6702f2240..d44119749 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import logging import struct from collections.abc import MutableSequence, Sequence @@ -27,6 +28,7 @@ ) from pyrit.models import ( AzureBlobStorageIO, + ConversationStats, MessagePiece, ) @@ -386,37 +388,37 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: """ Azure SQL implementation for filtering AttackResults by attack type. Uses JSON_VALUE() to match class_name in the attack_identifier JSON column. Args: - attack_class (str): Exact attack class name to match. + attack_type (str): Exact attack type name to match. Returns: Any: SQLAlchemy text condition with bound parameter. """ return text( """ISJSON("AttackResultEntries".attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_class""" - ).bindparams(attack_class=attack_class) + AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_type""" + ).bindparams(attack_type=attack_type) - def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: """ - Azure SQL implementation for filtering AttackResults by converter classes. + Azure SQL implementation for filtering AttackResults by converter types. - When converter_classes is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified classes are present + When converter_types is empty, matches attacks with no converters. + When non-empty, uses OPENJSON() to check all specified types are present (AND logic, case-insensitive). Args: - converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. + converter_types (Sequence[str]): List of converter type names. Empty list means no converters. Returns: Any: SQLAlchemy combined condition with bound parameters. """ - if len(converter_classes) == 0: + if len(converter_types) == 0: # Explicitly "no converters": match attacks where the converter list # is absent, null, or empty in the stored JSON. return text( @@ -428,7 +430,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[ conditions = [] bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_classes): + for i, cls in enumerate(converter_types): param_name = f"conv_cls_{i}" conditions.append( f'EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".attack_identifier, ' @@ -442,7 +444,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[ **bindparams_dict ) - def get_unique_attack_class_names(self) -> list[str]: + def get_unique_attack_type_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from attack_identifier JSON. @@ -460,7 +462,7 @@ def get_unique_attack_class_names(self) -> list[str]: ).fetchall() return sorted(row[0] for row in rows) - def get_unique_converter_class_names(self) -> list[str]: + def get_unique_converter_type_names(self) -> list[str]: """ Azure SQL implementation: extract unique converter class_name values from the request_converter_identifiers array in attack_identifier JSON. @@ -481,6 +483,87 @@ def get_unique_converter_class_names(self) -> list[str]: ).fetchall() return sorted(row[0] for row in rows) + def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: + """ + Azure SQL implementation: lightweight aggregate stats per conversation. + + Executes a single SQL query that returns message count (distinct + sequences), a truncated last-message preview, the first non-empty + labels dict, and the earliest timestamp for each conversation_id. + + Args: + conversation_ids (Sequence[str]): The conversation IDs to query. + + Returns: + Mapping from conversation_id to ConversationStats. + """ + if not conversation_ids: + return {} + + placeholders = ", ".join(f":cid{i}" for i in range(len(conversation_ids))) + params = {f"cid{i}": cid for i, cid in enumerate(conversation_ids)} + + max_len = ConversationStats.PREVIEW_MAX_LEN + sql = text( + f""" + SELECT + pme.conversation_id, + COUNT(DISTINCT pme.sequence) AS msg_count, + ( + SELECT TOP 1 LEFT(p2.converted_value, {max_len + 3}) + FROM "PromptMemoryEntries" p2 + WHERE p2.conversation_id = pme.conversation_id + ORDER BY p2.sequence DESC, p2.id DESC + ) AS last_preview, + ( + SELECT TOP 1 p3.labels + FROM "PromptMemoryEntries" p3 + WHERE p3.conversation_id = pme.conversation_id + AND p3.labels IS NOT NULL + AND p3.labels != '{{}}' + AND p3.labels != 'null' + ) AS first_labels, + MIN(pme.timestamp) AS created_at + FROM "PromptMemoryEntries" pme + WHERE pme.conversation_id IN ({placeholders}) + GROUP BY pme.conversation_id + """ + ) + + with closing(self.get_session()) as session: + rows = session.execute(sql, params).fetchall() + + result: dict[str, ConversationStats] = {} + for row in rows: + conv_id, msg_count, last_preview, raw_labels, raw_created_at = row + + preview = None + if last_preview: + preview = last_preview[:max_len] + "..." if len(last_preview) > max_len else last_preview + + labels: dict[str, str] = {} + if raw_labels and raw_labels not in ("null", "{}"): + try: + labels = json.loads(raw_labels) + except (ValueError, TypeError): + pass + + created_at = None + if raw_created_at is not None: + if isinstance(raw_created_at, str): + created_at = datetime.fromisoformat(raw_created_at) + else: + created_at = raw_created_at + + result[conv_id] = ConversationStats( + message_count=msg_count, + last_message_preview=preview, + labels=labels, + created_at=created_at, + ) + + return result + def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ Get the SQL Azure implementation for filtering ScenarioResults by labels. @@ -673,8 +756,14 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict with closing(self.get_session()) as session: try: for entry in entries: - # Ensure the entry is attached to the session. If it's detached, merge it. - entry_in_session = session.merge(entry) if not session.is_modified(entry) else entry + # Load a fresh copy by primary key so we only touch the + # requested fields. Using merge() would copy ALL + # attributes from the (potentially stale) detached object + # and silently overwrite concurrent updates to columns + # that are NOT in update_fields. + entry_in_session = session.get(type(entry), entry.id) # type: ignore[attr-defined] + if entry_in_session is None: + entry_in_session = session.merge(entry) for field, value in update_fields.items(): if field in vars(entry_in_session): setattr(entry_in_session, field, value) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 67e6dcfb6..c9c3753da 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -34,6 +34,7 @@ ) from pyrit.models import ( AttackResult, + ConversationStats, DataTypeSerializer, Message, MessagePiece, @@ -290,33 +291,33 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @abc.abstractmethod - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: """ Return a database-specific condition for filtering AttackResults by attack type (class_name in the attack_identifier JSON column). Args: - attack_class: Exact attack class name to match. + attack_type: Exact attack type name to match. Returns: Database-specific SQLAlchemy condition. """ @abc.abstractmethod - def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: """ - Return a database-specific condition for filtering AttackResults by converter classes + Return a database-specific condition for filtering AttackResults by converter types in the request_converter_identifiers array within attack_identifier JSON column. - This method is only called when converter filtering is requested (converter_classes + This method is only called when converter filtering is requested (converter_types is not None). The caller handles the None-vs-list distinction: - - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter - class names to be present (AND logic, case-insensitive). + - ``len(converter_types) == 0``: return a condition matching attacks with NO converters. + - ``len(converter_types) > 0``: return a condition requiring ALL specified converter + type names to be present (AND logic, case-insensitive). Args: - converter_classes: Converter class names to require. An empty sequence means + converter_types: Converter type names to require. An empty sequence means "match only attacks that have no converters". Returns: @@ -324,27 +325,45 @@ class names to be present (AND logic, case-insensitive). """ @abc.abstractmethod - def get_unique_attack_class_names(self) -> list[str]: + def get_unique_attack_type_names(self) -> list[str]: """ - Return sorted unique attack class names from all stored attack results. + Return sorted unique attack type names from all stored attack results. Extracts class_name from the attack_identifier JSON column via a database-level DISTINCT query. Returns: - Sorted list of unique attack class name strings. + Sorted list of unique attack type name strings. """ @abc.abstractmethod - def get_unique_converter_class_names(self) -> list[str]: + def get_unique_converter_type_names(self) -> list[str]: """ - Return sorted unique converter class names used across all attack results. + Return sorted unique converter type names used across all attack results. Extracts class_name values from the request_converter_identifiers array within the attack_identifier JSON column via a database-level query. Returns: - Sorted list of unique converter class name strings. + Sorted list of unique converter type name strings. + """ + + @abc.abstractmethod + def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, "ConversationStats"]: + """ + Return lightweight aggregate statistics for one or more conversations. + + Computes per-conversation message count (distinct sequence numbers), + a truncated last-message preview, the first non-empty labels dict, + and the earliest message timestamp using efficient SQL aggregation + instead of loading full pieces. + + Args: + conversation_ids: The conversation IDs to query. + + Returns: + Mapping from conversation_id to ConversationStats. + Conversations with no pieces are omitted from the result. """ @abc.abstractmethod @@ -631,15 +650,18 @@ def get_message_pieces( logger.exception(f"Failed to retrieve prompts with error {e}") raise - def _duplicate_conversation(self, *, messages: Sequence[Message]) -> tuple[str, Sequence[MessagePiece]]: + def duplicate_messages(self, *, messages: Sequence[Message]) -> tuple[str, Sequence[MessagePiece]]: """ - Duplicate messages with new conversation ID. + Duplicate messages with a new conversation ID. + + Each duplicated piece gets a fresh ``id`` and ``timestamp`` while + preserving ``original_prompt_id`` for tracking lineage. Args: - messages (Sequence[Message]): The messages to duplicate. + messages: The messages to duplicate. Returns: - tuple[str, Sequence[MessagePiece]]: The new conversation ID and the duplicated message pieces. + Tuple of (new_conversation_id, duplicated_message_pieces). """ new_conversation_id = str(uuid.uuid4()) @@ -669,7 +691,7 @@ def duplicate_conversation(self, *, conversation_id: str) -> str: The uuid for the new conversation. """ messages = self.get_conversation(conversation_id=conversation_id) - new_conversation_id, all_pieces = self._duplicate_conversation(messages=messages) + new_conversation_id, all_pieces = self.duplicate_messages(messages=messages) self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id @@ -702,7 +724,7 @@ def duplicate_conversation_excluding_last_turn(self, *, conversation_id: str) -> message for message in messages if message.sequence <= last_message.sequence - length_of_sequence_to_remove ] - new_conversation_id, all_pieces = self._duplicate_conversation(messages=messages_to_duplicate) + new_conversation_id, all_pieces = self.duplicate_messages(messages=messages_to_duplicate) self.add_message_pieces_to_memory(message_pieces=all_pieces) return new_conversation_id @@ -1256,8 +1278,83 @@ def add_attack_results_to_memory(self, *, attack_results: Sequence[AttackResult] """ Insert a list of attack results into the memory storage. The database model automatically calculates objective_sha256 for consistency. + + Raises: + SQLAlchemyError: If the database transaction fails. + """ + entries = [AttackResultEntry(entry=attack_result) for attack_result in attack_results] + # Capture the DB-assigned IDs before insert (they'll be set after flush/commit). + # _insert_entries closes the session, so we must read `entry.id` *inside* + # the session. Since _insert_entries uses a context manager, we instead + # read the ids from the entries *before* the session closes by doing the + # insert inline. + from contextlib import closing + + with closing(self.get_session()) as session: + from sqlalchemy.exc import SQLAlchemyError + + try: + session.add_all(entries) + session.commit() + # Populate the attack_result_id back onto the domain objects so callers + # can reference the DB-assigned ID immediately after insert. + for ar, entry in zip(attack_results, entries): + ar.attack_result_id = str(entry.id) + except SQLAlchemyError: + session.rollback() + raise + + def update_attack_result(self, *, conversation_id: str, update_fields: dict[str, Any]) -> bool: """ - self._insert_entries(entries=[AttackResultEntry(entry=attack_result) for attack_result in attack_results]) + Update specific fields of an existing AttackResultEntry identified by conversation_id. + + This method queries for the raw database entry by conversation_id and updates + the specified fields in place, avoiding the creation of duplicate rows. + + Args: + conversation_id (str): The conversation ID of the attack result to update. + update_fields (dict[str, Any]): A dictionary of column names to new values. + Valid fields include 'adversarial_chat_conversation_ids', + 'pruned_conversation_ids', 'outcome', 'attack_metadata', etc. + + Returns: + bool: True if the update was successful, False if the entry was not found. + + Raises: + ValueError: If update_fields is empty. + """ + entries: MutableSequence[AttackResultEntry] = self._query_entries( + AttackResultEntry, + conditions=AttackResultEntry.conversation_id == conversation_id, + ) + if not entries: + return False + + # When duplicate rows exist for the same conversation_id (legacy bug), + # pick the newest entry — it has the most up-to-date data. + target_entry = max(entries, key=lambda e: e.timestamp) + self._update_entries(entries=[target_entry], update_fields=update_fields) + return True + + def update_attack_result_by_id(self, *, attack_result_id: str, update_fields: dict[str, Any]) -> bool: + """ + Update specific fields of an existing AttackResultEntry identified by its primary key. + + Args: + attack_result_id: The UUID primary key of the AttackResultEntry. + update_fields: Column names to new values. + + Returns: + True if the update was successful, False if the entry was not found. + """ + entries: MutableSequence[AttackResultEntry] = self._query_entries( + AttackResultEntry, + conditions=AttackResultEntry.id == attack_result_id, + ) + if not entries: + return False + self._update_entries(entries=[entries[0]], update_fields=update_fields) + return True def get_attack_results( self, @@ -1267,8 +1364,8 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, - attack_class: Optional[str] = None, - converter_classes: Optional[Sequence[str]] = None, + attack_type: Optional[str] = None, + converter_types: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, ) -> Sequence[AttackResult]: @@ -1283,9 +1380,9 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. - attack_class (Optional[str], optional): Filter by exact attack class_name in attack_identifier. + attack_type (Optional[str], optional): Filter by exact attack class_name in attack_identifier. Defaults to None. - converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. + converter_types (Optional[Sequence[str]], optional): Filter by converter type names. Returns only attacks that used ALL specified converters (AND logic, case-insensitive). Defaults to None. targeted_harm_categories (Optional[Sequence[str]], optional): @@ -1319,14 +1416,14 @@ def get_attack_results( if outcome: conditions.append(AttackResultEntry.outcome == outcome) - if attack_class: + if attack_type: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) + conditions.append(self._get_attack_result_attack_type_condition(attack_type=attack_type)) - if converter_classes is not None: - # converter_classes=[] means "only attacks with no converters" - # converter_classes=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_condition(converter_classes=converter_classes)) + if converter_types is not None: + # converter_types=[] means "only attacks with no converters" + # converter_types=["A","B"] means "must have all listed converters" + conditions.append(self._get_attack_result_converter_types_condition(converter_types=converter_types)) if targeted_harm_categories: # Use database-specific JSON query method @@ -1342,7 +1439,14 @@ def get_attack_results( entries: Sequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=and_(*conditions) if conditions else None ) - return [entry.get_attack_result() for entry in entries] + # Deduplicate by conversation_id — when duplicate rows exist + # (legacy bug), keep only the newest entry per conversation_id. + seen: dict[str, AttackResultEntry] = {} + for entry in entries: + prev = seen.get(entry.conversation_id) + if prev is None or entry.timestamp > prev.timestamp: + seen[entry.conversation_id] = entry + return [entry.get_attack_result() for entry in seen.values()] except Exception as e: logger.exception(f"Failed to retrieve attack results with error {e}") raise diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 04be633df..50ed0fb87 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -852,6 +852,7 @@ def get_attack_result(self) -> AttackResult: return AttackResult( conversation_id=self.conversation_id, + attack_result_id=str(self.id), objective=self.objective, attack_identifier=ComponentIdentifier.from_dict(self.attack_identifier) if self.attack_identifier else None, last_response=self.last_response.get_message_piece() if self.last_response else None, diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index cfce238fd..375c348e9 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import json import logging import uuid from collections.abc import MutableSequence, Sequence @@ -9,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, func, or_, text +from sqlalchemy import and_, create_engine, exists, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload, sessionmaker @@ -24,8 +25,9 @@ Base, EmbeddingDataEntry, PromptMemoryEntry, + ScenarioResultEntry, ) -from pyrit.models import DiskStorageIO, MessagePiece +from pyrit.models import ConversationStats, DiskStorageIO, MessagePiece logger = logging.getLogger(__name__) @@ -298,8 +300,14 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict with closing(self.get_session()) as session: try: for entry in entries: - # Ensure the entry is attached to the session. If it's detached, merge it. - entry_in_session = session.merge(entry) if not session.is_modified(entry) else entry + # Load a fresh copy by primary key so we only touch the + # requested fields. Using merge() would copy ALL + # attributes from the (potentially stale) detached object + # and silently overwrite concurrent updates to columns + # that are NOT in update_fields. + entry_in_session = session.get(type(entry), entry.id) # type: ignore[attr-defined] + if entry_in_session is None: + entry_in_session = session.merge(entry) for field, value in update_fields.items(): if field in vars(entry_in_session): setattr(entry_in_session, field, value) @@ -412,8 +420,6 @@ def export_conversations( # Export to JSON manually since the exporter expects objects but we have dicts with open(file_path, "w") as f: - import json - json.dump(merged_data, f, indent=4) return file_path @@ -462,7 +468,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - return exists().where( + targeted_harm_categories_subquery = exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, # Exclude empty strings, None, and empty lists @@ -477,6 +483,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories ), ) ) + return targeted_harm_categories_subquery def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -490,7 +497,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: from pyrit.memory.memory_models import AttackResultEntry, PromptMemoryEntry - return exists().where( + labels_subquery = exists().where( and_( PromptMemoryEntry.conversation_id == AttackResultEntry.conversation_id, PromptMemoryEntry.labels.isnot(None), @@ -499,8 +506,9 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ), ) ) + return labels_subquery - def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: + def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: """ SQLite implementation for filtering AttackResults by attack type. Uses json_extract() to match class_name in the attack_identifier JSON column. @@ -508,21 +516,21 @@ def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any Returns: Any: A SQLAlchemy condition for filtering by attack type. """ - return func.json_extract(AttackResultEntry.attack_identifier, "$.class_name") == attack_class + return func.json_extract(AttackResultEntry.attack_identifier, "$.class_name") == attack_type - def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any: + def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: """ - SQLite implementation for filtering AttackResults by converter classes. + SQLite implementation for filtering AttackResults by converter types. - When converter_classes is empty, matches attacks with no converters + When converter_types is empty, matches attacks with no converters (request_converter_identifiers is absent or null in the JSON). - When non-empty, uses json_each() to check all specified classes are present + When non-empty, uses json_each() to check all specified types are present (AND logic, case-insensitive). Returns: - Any: A SQLAlchemy condition for filtering by converter classes. + Any: A SQLAlchemy condition for filtering by converter types. """ - if len(converter_classes) == 0: + if len(converter_types) == 0: # Explicitly "no converters": match attacks where the converter list # is absent, null, or empty in the stored JSON. converter_json = func.json_extract(AttackResultEntry.attack_identifier, "$.request_converter_identifiers") @@ -534,7 +542,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[ ) conditions = [] - for i, cls in enumerate(converter_classes): + for i, cls in enumerate(converter_types): param_name = f"conv_cls_{i}" conditions.append( text( @@ -545,7 +553,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[ ) return and_(*conditions) - def get_unique_attack_class_names(self) -> list[str]: + def get_unique_attack_type_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from attack_identifier JSON. @@ -561,7 +569,7 @@ def get_unique_attack_class_names(self) -> list[str]: ) return sorted(row[0] for row in rows) - def get_unique_converter_class_names(self) -> list[str]: + def get_unique_converter_type_names(self) -> list[str]: """ SQLite implementation: extract unique converter class_name values from the request_converter_identifiers array in attack_identifier JSON. @@ -581,6 +589,89 @@ def get_unique_converter_class_names(self) -> list[str]: ).fetchall() return sorted(row[0] for row in rows) + def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]: + """ + SQLite implementation: lightweight aggregate stats per conversation. + + Executes a single SQL query that returns message count (distinct + sequences), a truncated last-message preview, the first non-empty + labels dict, and the earliest timestamp for each conversation_id. + + Args: + conversation_ids: The conversation IDs to query. + + Returns: + Mapping from conversation_id to ConversationStats. + """ + if not conversation_ids: + return {} + + placeholders = ", ".join(f":cid{i}" for i in range(len(conversation_ids))) + params = {f"cid{i}": cid for i, cid in enumerate(conversation_ids)} + + max_len = ConversationStats.PREVIEW_MAX_LEN + sql = text( + f""" + SELECT + pme.conversation_id, + COUNT(DISTINCT pme.sequence) AS msg_count, + ( + SELECT SUBSTR(p2.converted_value, 1, {max_len + 3}) + FROM "PromptMemoryEntries" p2 + WHERE p2.conversation_id = pme.conversation_id + ORDER BY p2.sequence DESC, p2.id DESC + LIMIT 1 + ) AS last_preview, + ( + SELECT p3.labels + FROM "PromptMemoryEntries" p3 + WHERE p3.conversation_id = pme.conversation_id + AND p3.labels IS NOT NULL + AND p3.labels != '{{}}' + AND p3.labels != 'null' + LIMIT 1 + ) AS first_labels, + MIN(pme.timestamp) AS created_at + FROM "PromptMemoryEntries" pme + WHERE pme.conversation_id IN ({placeholders}) + GROUP BY pme.conversation_id + """ + ) + + with closing(self.get_session()) as session: + rows = session.execute(sql, params).fetchall() + + result: dict[str, ConversationStats] = {} + for row in rows: + conv_id, msg_count, last_preview, raw_labels, raw_created_at = row + + preview = None + if last_preview: + preview = last_preview[:max_len] + "..." if len(last_preview) > max_len else last_preview + + labels: dict[str, str] = {} + if raw_labels and raw_labels not in ("null", "{}"): + try: + labels = json.loads(raw_labels) + except (ValueError, TypeError): + pass + + created_at = None + if raw_created_at is not None: + if isinstance(raw_created_at, str): + created_at = datetime.fromisoformat(raw_created_at) + else: + created_at = raw_created_at + + result[conv_id] = ConversationStats( + message_count=msg_count, + last_message_preview=preview, + labels=labels, + created_at=created_at, + ) + + return result + def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ SQLite implementation for filtering ScenarioResults by labels. @@ -589,11 +680,6 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any Returns: Any: A SQLAlchemy exists subquery condition. """ - from sqlalchemy import and_, func - - from pyrit.memory.memory_models import ScenarioResultEntry - - # Return a combined condition that checks ALL labels must be present return and_( *[func.json_extract(ScenarioResultEntry.labels, f"$.{key}") == value for key, value in labels.items()] ) @@ -606,10 +692,6 @@ def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> An Returns: Any: A SQLAlchemy subquery for filtering by target endpoint. """ - from sqlalchemy import func - - from pyrit.memory.memory_models import ScenarioResultEntry - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.endpoint")).like( f"%{endpoint.lower()}%" ) @@ -622,10 +704,6 @@ def _get_scenario_result_target_model_condition(self, *, model_name: str) -> Any Returns: Any: A SQLAlchemy subquery for filtering by target model name. """ - from sqlalchemy import func - - from pyrit.memory.memory_models import ScenarioResultEntry - return func.lower(func.json_extract(ScenarioResultEntry.objective_target_identifier, "$.model_name")).like( f"%{model_name.lower()}%" ) diff --git a/pyrit/models/__init__.py b/pyrit/models/__init__.py index 6f6734cb8..26eeae2d1 100644 --- a/pyrit/models/__init__.py +++ b/pyrit/models/__init__.py @@ -11,6 +11,7 @@ ChatMessagesDataset, ) from pyrit.models.conversation_reference import ConversationReference, ConversationType +from pyrit.models.conversation_stats import ConversationStats from pyrit.models.data_type_serializer import ( AllowedCategories, AudioPathDataTypeSerializer, @@ -70,6 +71,7 @@ "ChatMessageRole", "ChatMessageListDictContent", "ConversationReference", + "ConversationStats", "ConversationType", "construct_response_from_request", "DataTypeSerializer", diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index cd9efff5c..499e3f6de 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -47,6 +47,10 @@ class AttackResult(StrategyResult): # Natural-language description of the attacker's objective objective: str + # Database-assigned unique ID for this AttackResult row. + # ``None`` for newly-constructed results that haven't been persisted yet. + attack_result_id: Optional[str] = None + # Identifier of the attack strategy that produced this result attack_identifier: Optional[ComponentIdentifier] = None diff --git a/pyrit/models/conversation_stats.py b/pyrit/models/conversation_stats.py new file mode 100644 index 000000000..c67f3d842 --- /dev/null +++ b/pyrit/models/conversation_stats.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime +from typing import ClassVar, Dict, Optional + + +@dataclass(frozen=True) +class ConversationStats: + """Lightweight aggregate statistics for a conversation. + + Used to build attack summaries without loading full message pieces. + """ + + PREVIEW_MAX_LEN: ClassVar[int] = 100 + + message_count: int = 0 + last_message_preview: Optional[str] = None + labels: Dict[str, str] = field(default_factory=dict) + created_at: Optional[datetime] = None diff --git a/pyrit/prompt_target/openai/openai_image_target.py b/pyrit/prompt_target/openai/openai_image_target.py index 1c65ed603..ebdbdfe1a 100644 --- a/pyrit/prompt_target/openai/openai_image_target.py +++ b/pyrit/prompt_target/openai/openai_image_target.py @@ -314,6 +314,16 @@ def _validate_request(self, *, message: Message) -> None: other_types = [p.converted_value_data_type for p in other_pieces] raise ValueError(f"The message contains unsupported piece types. Unsupported types: {other_types}.") + request = text_pieces[0] + messages = self._memory.get_conversation(conversation_id=request.conversation_id) + + n_messages = len(messages) + if n_messages > 0: + raise ValueError( + "This target only supports a single turn conversation. " + f"Received: {n_messages} messages which indicates a prior turn." + ) + def is_json_response_supported(self) -> bool: """ Check if the target supports JSON as a response format. diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 774eeb773..c57de2156 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -21,6 +21,7 @@ construct_response_from_request, data_serializer_factory, ) +from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.openai.openai_target import OpenAITarget @@ -55,7 +56,7 @@ def flatten_transcripts(self) -> str: return "".join(self.transcripts) -class RealtimeTarget(OpenAITarget): +class RealtimeTarget(OpenAITarget, PromptChatTarget): """ A prompt target for Azure OpenAI Realtime API. diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 34ff23b70..5352573d7 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -171,6 +171,11 @@ def _build_identifier(self) -> ComponentIdentifier: Returns: ComponentIdentifier: The identifier for this target instance. """ + specific_params: dict[str, Any] = { + "max_output_tokens": self._max_output_tokens, + } + if self._extra_body_parameters: + specific_params["extra_body_parameters"] = self._extra_body_parameters return self._create_identifier( params={ "temperature": self._temperature, @@ -179,6 +184,7 @@ def _build_identifier(self) -> ComponentIdentifier: "reasoning_effort": self._reasoning_effort, "reasoning_summary": self._reasoning_summary, }, + target_specific_params=specific_params, ) def _set_openai_env_configuration_vars(self) -> None: diff --git a/pyrit/prompt_target/openai/openai_target.py b/pyrit/prompt_target/openai/openai_target.py index bf9c46bf6..e4fb8ecdd 100644 --- a/pyrit/prompt_target/openai/openai_target.py +++ b/pyrit/prompt_target/openai/openai_target.py @@ -29,7 +29,7 @@ handle_bad_request_exception, ) from pyrit.models import Message, MessagePiece -from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget +from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.openai.openai_error_handling import ( _extract_error_payload, _extract_request_id_from_exception, @@ -78,7 +78,7 @@ async def async_token_provider() -> str: return async_token_provider -class OpenAITarget(PromptChatTarget): +class OpenAITarget(PromptTarget): """ Abstract base class for OpenAI-based prompt targets. @@ -159,7 +159,7 @@ def __init__( ) # Initialize parent with endpoint and model_name - PromptChatTarget.__init__( + PromptTarget.__init__( self, max_requests_per_minute=max_requests_per_minute, endpoint=endpoint_value, diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index 276bbfc2c..b6b96956e 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -476,6 +476,16 @@ def _validate_request(self, *, message: Message) -> None: if remix_video_id and image_pieces: raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") + request = message.message_pieces[0] + messages = self._memory.get_conversation(conversation_id=request.conversation_id) + + n_messages = len(messages) + if n_messages > 0: + raise ValueError( + "This target only supports a single turn conversation. " + f"Received: {n_messages} messages which indicates a prior turn." + ) + def is_json_response_supported(self) -> bool: """ Check if the target supports JSON response data. diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index 3106409fd..b97753205 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -126,7 +126,11 @@ def test_get_attack_results_by_ids(sqlite_instance: MemoryInterface): def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface): - """Test retrieving attack results by conversation ID.""" + """Test retrieving attack results by conversation ID. + + When duplicate rows exist for the same conversation_id (legacy bug), + get_attack_results deduplicates and returns only the newest entry. + """ # Create and add attack results attack_result1 = AttackResult( conversation_id="conv_1", @@ -137,7 +141,7 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) ) attack_result2 = AttackResult( - conversation_id="conv_1", # Same conversation ID + conversation_id="conv_1", # Same conversation ID (simulates legacy duplicate) objective="Test objective 2", executed_turns=3, execution_time_ms=500, @@ -155,13 +159,11 @@ def test_get_attack_results_by_conversation_id(sqlite_instance: MemoryInterface) # Add all attack results to memory sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result1, attack_result2, attack_result3]) - # Retrieve attack results by conversation ID + # Retrieve attack results by conversation ID — deduplication keeps only the newest retrieved_results = sqlite_instance.get_attack_results(conversation_id="conv_1") - # Verify correct results were retrieved - assert len(retrieved_results) == 2 - for result in retrieved_results: - assert result.conversation_id == "conv_1" + assert len(retrieved_results) == 1 + assert retrieved_results[0].conversation_id == "conv_1" def test_get_attack_results_by_objective(sqlite_instance: MemoryInterface): @@ -593,6 +595,128 @@ def test_attack_result_without_attack_generation_conversation_ids(sqlite_instanc assert not retrieved_result.get_conversations_by_type(ConversationType.ADVERSARIAL) +def test_update_attack_result_adversarial_chat_conversation_ids_round_trip(sqlite_instance: MemoryInterface): + """Test that updating adversarial_chat_conversation_ids is reflected when reading back. + + This catches a regression where the conversation count in the attack history + was always showing 1 instead of the actual number of conversations. + """ + # Create attack with no related conversations + attack_result = AttackResult( + conversation_id="conv_1", + objective="Test conversation count", + outcome=AttackOutcome.UNDETERMINED, + metadata={"created_at": "2026-01-01T00:00:00", "updated_at": "2026-01-01T00:00:00"}, + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Verify initial state: no related conversations + results = sqlite_instance.get_attack_results(conversation_id="conv_1") + assert len(results) == 1 + assert len(results[0].related_conversations) == 0 + + # Add first related conversation + sqlite_instance.update_attack_result( + conversation_id="conv_1", + update_fields={"adversarial_chat_conversation_ids": ["branch-1"]}, + ) + + results = sqlite_instance.get_attack_results(conversation_id="conv_1") + assert len(results[0].related_conversations) == 1 + assert {r.conversation_id for r in results[0].related_conversations} == {"branch-1"} + + # Add second related conversation (preserving the first) + sqlite_instance.update_attack_result( + conversation_id="conv_1", + update_fields={"adversarial_chat_conversation_ids": ["branch-1", "branch-2"]}, + ) + + results = sqlite_instance.get_attack_results(conversation_id="conv_1") + assert len(results[0].related_conversations) == 2 + assert {r.conversation_id for r in results[0].related_conversations} == {"branch-1", "branch-2"} + + # Verify they are all ADVERSARIAL type + for ref in results[0].related_conversations: + assert ref.conversation_type == ConversationType.ADVERSARIAL + + +def test_update_attack_result_metadata_does_not_clobber_conversation_ids(sqlite_instance: MemoryInterface): + """Regression test: updating only attack_metadata must not erase adversarial_chat_conversation_ids. + + This was the root cause of the conversation-count bug. The old _update_entries + used session.merge() which copied ALL attributes from the (potentially stale) + detached entry, silently overwriting JSON columns that were not in update_fields. + """ + attack_result = AttackResult( + conversation_id="conv_1", + objective="Test metadata update preserves conversation ids", + outcome=AttackOutcome.UNDETERMINED, + metadata={"created_at": "2026-01-01T00:00:00"}, + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Step 1: add related conversations + sqlite_instance.update_attack_result( + conversation_id="conv_1", + update_fields={"adversarial_chat_conversation_ids": ["branch-1", "branch-2"]}, + ) + + # Step 2: update ONLY metadata (this is what add_message_async does) + sqlite_instance.update_attack_result( + conversation_id="conv_1", + update_fields={"attack_metadata": {"created_at": "2026-01-01T00:00:00", "updated_at": "2026-01-02T00:00:00"}}, + ) + + # Verify conversation ids are still present + results = sqlite_instance.get_attack_results(conversation_id="conv_1") + assert len(results[0].related_conversations) == 2, ( + "Updating attack_metadata must not erase adversarial_chat_conversation_ids" + ) + assert {r.conversation_id for r in results[0].related_conversations} == {"branch-1", "branch-2"} + + +def test_update_attack_result_stale_entry_does_not_overwrite(sqlite_instance: MemoryInterface): + """Regression test: merging a stale entry must not overwrite concurrent updates. + + Simulates the race condition where entry is loaded, then another update modifies + the DB, and finally the stale entry is used for an unrelated update. + """ + from pyrit.memory.memory_models import AttackResultEntry + + attack_result = AttackResult( + conversation_id="conv_1", + objective="Test stale merge", + outcome=AttackOutcome.UNDETERMINED, + metadata={"created_at": "2026-01-01T00:00:00"}, + ) + sqlite_instance.add_attack_results_to_memory(attack_results=[attack_result]) + + # Load entry (will become stale) + stale_entries = sqlite_instance._query_entries( + AttackResultEntry, conditions=AttackResultEntry.conversation_id == "conv_1" + ) + assert stale_entries[0].adversarial_chat_conversation_ids is None + + # Concurrent update adds conversation ids + sqlite_instance.update_attack_result( + conversation_id="conv_1", + update_fields={"adversarial_chat_conversation_ids": ["branch-1"]}, + ) + + # Now update with the stale entry (only metadata) + sqlite_instance._update_entries( + entries=[stale_entries[0]], + update_fields={"attack_metadata": {"updated_at": "2026-01-02T00:00:00"}}, + ) + + # Verify the concurrent update was NOT lost + results = sqlite_instance.get_attack_results(conversation_id="conv_1") + assert len(results[0].related_conversations) == 1, ( + "Stale entry merge must not overwrite concurrent adversarial_chat_conversation_ids update" + ) + assert results[0].related_conversations.pop().conversation_id == "branch-1" + + def test_get_attack_results_by_harm_category_single(sqlite_instance: MemoryInterface): """Test filtering attack results by a single harm category.""" @@ -1025,60 +1149,60 @@ def _make_attack_result_with_identifier( ) -def test_get_attack_results_by_attack_class(sqlite_instance: MemoryInterface): - """Test filtering attack results by attack_class matches class_name in JSON.""" +def test_get_attack_results_by_attack_type(sqlite_instance: MemoryInterface): + """Test filtering attack results by attack_type matches class_name in JSON.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack") assert len(results) == 2 assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} -def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInterface): - """Test that attack_class filter returns empty when nothing matches.""" +def test_get_attack_results_by_attack_type_no_match(sqlite_instance: MemoryInterface): + """Test that attack_type filter returns empty when nothing matches.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_class="NonExistentAttack") + results = sqlite_instance.get_attack_results(attack_type="NonExistentAttack") assert len(results) == 0 -def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_class filter is case-sensitive (exact match).""" +def test_get_attack_results_by_attack_type_case_sensitive(sqlite_instance: MemoryInterface): + """Test that attack_type filter is case-sensitive (exact match).""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_class="crescendoattack") + results = sqlite_instance.get_attack_results(attack_type="crescendoattack") assert len(results) == 0 -def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): - """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" +def test_get_attack_results_by_attack_type_no_identifier(sqlite_instance: MemoryInterface): + """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_type filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack") assert len(results) == 1 assert results[0].conversation_id == "conv_2" -def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: MemoryInterface): - """Test that converter_classes=None (omitted) returns all attacks unfiltered.""" +def test_get_attack_results_converter_types_none_returns_all(sqlite_instance: MemoryInterface): + """Test that converter_types=None (omitted) returns all attacks unfiltered.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack") # No converters (None) ar3 = create_attack_result("conv_3", 3) # No identifier at all sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(converter_classes=None) + results = sqlite_instance.get_attack_results(converter_types=None) assert len(results) == 3 -def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite_instance: MemoryInterface): - """Test that converter_classes=[] returns only attacks with no converters.""" +def test_get_attack_results_converter_types_empty_matches_no_converters(sqlite_instance: MemoryInterface): + """Test that converter_types=[] returns only attacks with no converters.""" ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] @@ -1087,7 +1211,7 @@ def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite attack_results=[ar_with_conv, ar_no_conv_none, ar_no_conv_empty, ar_no_identifier] ) - results = sqlite_instance.get_attack_results(converter_classes=[]) + results = sqlite_instance.get_attack_results(converter_types=[]) conv_ids = {r.conversation_id for r in results} # Should include attacks with no converters (None key, empty array, or empty identifier) assert "conv_1" not in conv_ids, "Should not include attacks that have converters" @@ -1096,130 +1220,130 @@ def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite assert "conv_4" in conv_ids, "Should include attacks with empty attack_identifier" -def test_get_attack_results_converter_classes_single_match(sqlite_instance: MemoryInterface): - """Test that converter_classes with one class returns attacks using that converter.""" +def test_get_attack_results_converter_types_single_match(sqlite_instance: MemoryInterface): + """Test that converter_types with one type returns attacks using that converter.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter"]) + results = sqlite_instance.get_attack_results(converter_types=["Base64Converter"]) conv_ids = {r.conversation_id for r in results} assert conv_ids == {"conv_1", "conv_3"} -def test_get_attack_results_converter_classes_and_logic(sqlite_instance: MemoryInterface): - """Test that multiple converter_classes use AND logic — all must be present.""" +def test_get_attack_results_converter_types_and_logic(sqlite_instance: MemoryInterface): + """Test that multiple converter_types use AND logic — all must be present.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) ar4 = _make_attack_result_with_identifier("conv_4", "Attack", ["Base64Converter", "ROT13Converter", "UrlConverter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) - results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter", "ROT13Converter"]) + results = sqlite_instance.get_attack_results(converter_types=["Base64Converter", "ROT13Converter"]) conv_ids = {r.conversation_id for r in results} # conv_3 and conv_4 have both; conv_1 and conv_2 have only one assert conv_ids == {"conv_3", "conv_4"} -def test_get_attack_results_converter_classes_case_insensitive(sqlite_instance: MemoryInterface): - """Test that converter class matching is case-insensitive.""" +def test_get_attack_results_converter_types_case_insensitive(sqlite_instance: MemoryInterface): + """Test that converter type matching is case-insensitive.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(converter_classes=["base64converter"]) + results = sqlite_instance.get_attack_results(converter_types=["base64converter"]) assert len(results) == 1 assert results[0].conversation_id == "conv_1" -def test_get_attack_results_converter_classes_no_match(sqlite_instance: MemoryInterface): - """Test that converter_classes filter returns empty when no attack has the converter.""" +def test_get_attack_results_converter_types_no_match(sqlite_instance: MemoryInterface): + """Test that converter_types filter returns empty when no attack has the converter.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(converter_classes=["NonExistentConverter"]) + results = sqlite_instance.get_attack_results(converter_types=["NonExistentConverter"]) assert len(results) == 0 -def test_get_attack_results_attack_class_and_converter_classes_combined(sqlite_instance: MemoryInterface): - """Test combining attack_class and converter_classes filters.""" +def test_get_attack_results_attack_type_and_converter_types_combined(sqlite_instance: MemoryInterface): + """Test combining attack_type and converter_types filters.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack", ["Base64Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack", ["ROT13Converter"]) ar4 = _make_attack_result_with_identifier("conv_4", "CrescendoAttack") # No converters sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=["Base64Converter"]) + results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack", converter_types=["Base64Converter"]) assert len(results) == 1 assert results[0].conversation_id == "conv_1" -def test_get_attack_results_attack_class_with_no_converters(sqlite_instance: MemoryInterface): - """Test combining attack_class with converter_classes=[] (no converters).""" +def test_get_attack_results_attack_type_with_no_converters(sqlite_instance: MemoryInterface): + """Test combining attack_type with converter_types=[] (no converters).""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # No converters sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=[]) + results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack", converter_types=[]) assert len(results) == 1 assert results[0].conversation_id == "conv_2" # ============================================================================ -# Unique attack class and converter class name tests +# Unique attack type and converter type name tests # ============================================================================ -def test_get_unique_attack_class_names_empty(sqlite_instance: MemoryInterface): +def test_get_unique_attack_type_names_empty(sqlite_instance: MemoryInterface): """Test that no attacks returns empty list.""" - result = sqlite_instance.get_unique_attack_class_names() + result = sqlite_instance.get_unique_attack_type_names() assert result == [] -def test_get_unique_attack_class_names_sorted_unique(sqlite_instance: MemoryInterface): - """Test that unique class names are returned sorted, with duplicates removed.""" +def test_get_unique_attack_type_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique type names are returned sorted, with duplicates removed.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - result = sqlite_instance.get_unique_attack_class_names() + result = sqlite_instance.get_unique_attack_type_names() assert result == ["CrescendoAttack", "ManualAttack"] -def test_get_unique_attack_class_names_skips_empty_identifier(sqlite_instance: MemoryInterface): +def test_get_unique_attack_type_names_skips_empty_identifier(sqlite_instance: MemoryInterface): """Test that attacks with empty attack_identifier (no class_name) are excluded.""" ar_no_id = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} ar_with_id = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_id, ar_with_id]) - result = sqlite_instance.get_unique_attack_class_names() + result = sqlite_instance.get_unique_attack_type_names() assert result == ["CrescendoAttack"] -def test_get_unique_converter_class_names_empty(sqlite_instance: MemoryInterface): +def test_get_unique_converter_type_names_empty(sqlite_instance: MemoryInterface): """Test that no attacks returns empty list.""" - result = sqlite_instance.get_unique_converter_class_names() + result = sqlite_instance.get_unique_converter_type_names() assert result == [] -def test_get_unique_converter_class_names_sorted_unique(sqlite_instance: MemoryInterface): - """Test that unique converter class names are returned sorted, with duplicates removed.""" +def test_get_unique_converter_type_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique converter type names are returned sorted, with duplicates removed.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter", "ROT13Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - result = sqlite_instance.get_unique_converter_class_names() + result = sqlite_instance.get_unique_converter_type_names() assert result == ["Base64Converter", "ROT13Converter"] -def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: MemoryInterface): +def test_get_unique_converter_type_names_skips_no_converters(sqlite_instance: MemoryInterface): """Test that attacks with no converters don't contribute names.""" ar_no_conv = _make_attack_result_with_identifier("conv_1", "Attack") # No converters ar_with_conv = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) ar_empty_id = create_attack_result("conv_3", 3) # Empty attack_identifier sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_conv, ar_with_conv, ar_empty_id]) - result = sqlite_instance.get_unique_converter_class_names() + result = sqlite_instance.get_unique_converter_type_names() assert result == ["Base64Converter"] diff --git a/tests/unit/memory/test_sqlite_memory.py b/tests/unit/memory/test_sqlite_memory.py index f99b72525..de404d233 100644 --- a/tests/unit/memory/test_sqlite_memory.py +++ b/tests/unit/memory/test_sqlite_memory.py @@ -547,3 +547,123 @@ def test_update_prompt_metadata_by_conversation_id(sqlite_instance, sample_conve # Verify that the entry with a different conversation_id was not updated other_entry = session.query(PromptMemoryEntry).filter_by(conversation_id="other_id").first() assert other_entry.prompt_metadata == original_metadata # Metadata should remain unchanged + + +def test_get_conversation_stats_returns_empty_for_no_ids(sqlite_instance): + """Test that get_conversation_stats returns empty dict for empty input.""" + result = sqlite_instance.get_conversation_stats(conversation_ids=[]) + assert result == {} + + +def test_get_conversation_stats_returns_empty_for_unknown_ids(sqlite_instance): + """Test that get_conversation_stats omits unknown conversation IDs.""" + result = sqlite_instance.get_conversation_stats(conversation_ids=["nonexistent"]) + assert result == {} + + +def test_get_conversation_stats_counts_distinct_sequences(sqlite_instance, sample_conversation_entries): + """Test that message_count reflects distinct sequence numbers, not raw rows.""" + # Extract conversation IDs and sequences before inserting (entries get detached after commit) + from pyrit.models import Message + from unit.mocks import get_sample_conversations + + conversations = get_sample_conversations() + pieces = Message.flatten_to_message_pieces(conversations) + expected: dict[str, set[int]] = {} + for p in pieces: + expected.setdefault(p.conversation_id, set()).add(p.sequence) + + sqlite_instance._insert_entries(entries=sample_conversation_entries) + + conv_ids = list(expected.keys()) + result = sqlite_instance.get_conversation_stats(conversation_ids=conv_ids) + + for conv_id in conv_ids: + if conv_id in result: + assert result[conv_id].message_count == len(expected[conv_id]), ( + f"Conv {conv_id}: expected {len(expected[conv_id])}, got {result[conv_id].message_count}" + ) + + +def test_get_conversation_stats_returns_labels(sqlite_instance): + """Test that labels from the first piece with non-empty labels are returned.""" + import uuid + + from pyrit.models import MessagePiece + + conv_id = str(uuid.uuid4()) + piece = MessagePiece( + role="user", + original_value="hello", + original_value_data_type="text", + converted_value="hello", + converted_value_data_type="text", + conversation_id=conv_id, + sequence=0, + labels={"env": "prod", "source": "gui"}, + ) + entry = PromptMemoryEntry(entry=piece) + sqlite_instance._insert_entry(entry) + + result = sqlite_instance.get_conversation_stats(conversation_ids=[conv_id]) + assert conv_id in result + assert result[conv_id].labels == {"env": "prod", "source": "gui"} + + +def test_get_conversation_stats_preview_truncates(sqlite_instance): + """Test that last_message_preview is truncated to 100 chars + ellipsis.""" + import uuid + + from pyrit.models import MessagePiece + + conv_id = str(uuid.uuid4()) + long_text = "x" * 200 + piece = MessagePiece( + role="assistant", + original_value=long_text, + original_value_data_type="text", + converted_value=long_text, + converted_value_data_type="text", + conversation_id=conv_id, + sequence=0, + ) + entry = PromptMemoryEntry(entry=piece) + sqlite_instance._insert_entry(entry) + + result = sqlite_instance.get_conversation_stats(conversation_ids=[conv_id]) + assert conv_id in result + preview = result[conv_id].last_message_preview + assert preview is not None + assert len(preview) == 103 # 100 chars + "..." + assert preview.endswith("...") + + +def test_get_conversation_stats_batches_multiple_conversations(sqlite_instance): + """Test that a single call returns stats for multiple conversations.""" + import uuid + + from pyrit.models import MessagePiece + + conv_ids = [str(uuid.uuid4()) for _ in range(3)] + entries = [] + for i, cid in enumerate(conv_ids): + for seq in range(i + 1): # conv 0: 1 msg, conv 1: 2 msgs, conv 2: 3 msgs + piece = MessagePiece( + role="user", + original_value=f"msg-{seq}", + original_value_data_type="text", + converted_value=f"msg-{seq}", + converted_value_data_type="text", + conversation_id=cid, + sequence=seq, + ) + entries.append(PromptMemoryEntry(entry=piece)) + + sqlite_instance._insert_entries(entries=entries) + + result = sqlite_instance.get_conversation_stats(conversation_ids=conv_ids) + + assert len(result) == 3 + assert result[conv_ids[0]].message_count == 1 + assert result[conv_ids[1]].message_count == 2 + assert result[conv_ids[2]].message_count == 3 diff --git a/tests/unit/target/test_image_target.py b/tests/unit/target/test_image_target.py index ffba3a756..4c5056e24 100644 --- a/tests/unit/target/test_image_target.py +++ b/tests/unit/target/test_image_target.py @@ -504,3 +504,21 @@ async def test_validate_piece_type(image_target: OpenAIImageTarget): finally: if os.path.isfile(audio_piece.original_value): os.remove(audio_piece.original_value) + + +@pytest.mark.asyncio +async def test_validate_previous_conversations( + image_target: OpenAIImageTarget, sample_conversations: MutableSequence[MessagePiece] +): + message_piece = sample_conversations[0] + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = sample_conversations + mock_memory.add_message_to_memory = AsyncMock() + + image_target._memory = mock_memory + + request = Message(message_pieces=[message_piece]) + + with pytest.raises(ValueError, match="This target only supports a single turn conversation."): + await image_target.send_prompt_async(message=request) diff --git a/tests/unit/target/test_video_target.py b/tests/unit/target/test_video_target.py index eab0d81ac..877bce7d6 100644 --- a/tests/unit/target/test_video_target.py +++ b/tests/unit/target/test_video_target.py @@ -910,3 +910,20 @@ def test_supported_durations(self, video_target: OpenAIVideoTarget): n_seconds=duration, ) assert target._n_seconds == duration + + +def test_video_validate_previous_conversations( + video_target: OpenAIVideoTarget, sample_conversations: MutableSequence[MessagePiece] +): + message_piece = sample_conversations[0] + + mock_memory = MagicMock() + mock_memory.get_conversation.return_value = sample_conversations + mock_memory.add_message_to_memory = AsyncMock() + + video_target._memory = mock_memory + + request = Message(message_pieces=[message_piece]) + + with pytest.raises(ValueError, match="This target only supports a single turn conversation."): + video_target._validate_request(message=request) From bb16333542589b8ac7bf175f2101ba45185ac2a3 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 12:54:31 -0800 Subject: [PATCH 2/6] Split type/class naming: API uses *_type, memory layer uses *_class Frontend-facing APIs (routes, service, response models) use attack_type and converter_types so the frontend doesn't need to know about classes. The memory layer (memory_interface, sqlite_memory, azure_sql_memory) keeps attack_class and converter_classes since they reference Python classes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/backend/models/attacks.py | 12 +-- pyrit/backend/routes/attacks.py | 26 ++--- pyrit/backend/services/attack_service.py | 12 +-- pyrit/memory/azure_sql_memory.py | 28 ++--- pyrit/memory/memory_interface.py | 52 ++++----- pyrit/memory/sqlite_memory.py | 28 ++--- tests/unit/backend/test_api_routes.py | 18 ++-- tests/unit/backend/test_attack_service.py | 26 ++--- .../test_interface_attack_results.py | 102 +++++++++--------- 9 files changed, 151 insertions(+), 153 deletions(-) diff --git a/pyrit/backend/models/attacks.py b/pyrit/backend/models/attacks.py index 9183d933c..fb40e0656 100644 --- a/pyrit/backend/models/attacks.py +++ b/pyrit/backend/models/attacks.py @@ -117,18 +117,16 @@ class AttackListResponse(BaseModel): class AttackOptionsResponse(BaseModel): - """Response containing unique attack class names used across attacks.""" + """Response containing unique attack type names used across attacks.""" - attack_classes: list[str] = Field( - ..., description="Sorted list of unique attack class names found in attack results" - ) + attack_types: list[str] = Field(..., description="Sorted list of unique attack type names found in attack results") class ConverterOptionsResponse(BaseModel): - """Response containing unique converter class names used across attacks.""" + """Response containing unique converter type names used across attacks.""" - converter_classes: list[str] = Field( - ..., description="Sorted list of unique converter class names found in attack results" + converter_types: list[str] = Field( + ..., description="Sorted list of unique converter type names found in attack results" ) diff --git a/pyrit/backend/routes/attacks.py b/pyrit/backend/routes/attacks.py index ed6fc4c02..6c08400e4 100644 --- a/pyrit/backend/routes/attacks.py +++ b/pyrit/backend/routes/attacks.py @@ -52,10 +52,10 @@ def _parse_labels(label_params: Optional[list[str]]) -> Optional[dict[str, str]] response_model=AttackListResponse, ) async def list_attacks( - attack_class: Optional[str] = Query(None, description="Filter by exact attack class name"), - converter_classes: Optional[list[str]] = Query( + attack_type: Optional[str] = Query(None, description="Filter by exact attack type name"), + converter_types: Optional[list[str]] = Query( None, - description="Filter by converter class names (repeatable, AND logic). Pass empty to match no-converter attacks.", + description="Filter by converter type names (repeatable, AND logic). Pass empty to match no-converter attacks.", ), outcome: Optional[Literal["undetermined", "success", "failure"]] = Query(None, description="Filter by outcome"), label: Optional[list[str]] = Query(None, description="Filter by labels (format: key:value, repeatable)"), @@ -76,8 +76,8 @@ async def list_attacks( service = get_attack_service() labels = _parse_labels(label) return await service.list_attacks_async( - attack_class=attack_class, - converter_classes=converter_classes, + attack_type=attack_type, + converter_types=converter_types, outcome=outcome, labels=labels, min_turns=min_turns, @@ -93,17 +93,17 @@ async def list_attacks( ) async def get_attack_options() -> AttackOptionsResponse: """ - Get unique attack class names used across all attacks. + Get unique attack type names used across all attacks. - Returns all attack class names found in stored attack results. + Returns all attack type names found in stored attack results. Useful for populating attack type filter dropdowns in the GUI. Returns: - AttackOptionsResponse: Sorted list of unique attack class names. + AttackOptionsResponse: Sorted list of unique attack type names. """ service = get_attack_service() class_names = await service.get_attack_options_async() - return AttackOptionsResponse(attack_classes=class_names) + return AttackOptionsResponse(attack_types=class_names) @router.get( @@ -112,17 +112,17 @@ async def get_attack_options() -> AttackOptionsResponse: ) async def get_converter_options() -> ConverterOptionsResponse: """ - Get unique converter class names used across all attacks. + Get unique converter type names used across all attacks. - Returns all converter class names found in stored attack results. + Returns all converter type names found in stored attack results. Useful for populating converter filter dropdowns in the GUI. Returns: - ConverterOptionsResponse: Sorted list of unique converter class names. + ConverterOptionsResponse: Sorted list of unique converter type names. """ service = get_attack_service() class_names = await service.get_converter_options_async() - return ConverterOptionsResponse(converter_classes=class_names) + return ConverterOptionsResponse(converter_types=class_names) @router.post( diff --git a/pyrit/backend/services/attack_service.py b/pyrit/backend/services/attack_service.py index 652cafee5..badc3ffe0 100644 --- a/pyrit/backend/services/attack_service.py +++ b/pyrit/backend/services/attack_service.py @@ -63,8 +63,8 @@ def __init__(self) -> None: async def list_attacks_async( self, *, - attack_class: Optional[str] = None, - converter_classes: Optional[list[str]] = None, + attack_type: Optional[str] = None, + converter_types: Optional[list[str]] = None, outcome: Optional[Literal["undetermined", "success", "failure"]] = None, labels: Optional[dict[str, str]] = None, min_turns: Optional[int] = None, @@ -78,8 +78,8 @@ async def list_attacks_async( Queries AttackResult entries from the database. Args: - attack_class: Filter by exact attack class_name (case-sensitive). - converter_classes: Filter by converter usage. + attack_type: Filter by exact attack class_name (case-sensitive). + converter_types: Filter by converter usage. None = no filter, [] = only attacks with no converters, ["A", "B"] = only attacks using ALL specified converters (AND logic, case-insensitive). outcome: Filter by attack outcome. @@ -96,8 +96,8 @@ async def list_attacks_async( attack_results = self._memory.get_attack_results( outcome=outcome, labels=labels, - attack_class=attack_class, - converter_classes=converter_classes, + attack_class=attack_type, + converter_classes=converter_types, ) filtered: list[AttackResult] = [] diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index d44119749..75e1b2b4a 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -388,37 +388,37 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) ) - def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ - Azure SQL implementation for filtering AttackResults by attack type. + Azure SQL implementation for filtering AttackResults by attack class. Uses JSON_VALUE() to match class_name in the attack_identifier JSON column. Args: - attack_type (str): Exact attack type name to match. + attack_class (str): Exact attack class name to match. Returns: Any: SQLAlchemy text condition with bound parameter. """ return text( """ISJSON("AttackResultEntries".attack_identifier) = 1 - AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_type""" - ).bindparams(attack_type=attack_type) + AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_class""" + ).bindparams(attack_class=attack_class) - def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: + def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: """ - Azure SQL implementation for filtering AttackResults by converter types. + Azure SQL implementation for filtering AttackResults by converter classes. - When converter_types is empty, matches attacks with no converters. - When non-empty, uses OPENJSON() to check all specified types are present + When converter_classes is empty, matches attacks with no converters. + When non-empty, uses OPENJSON() to check all specified classes are present (AND logic, case-insensitive). Args: - converter_types (Sequence[str]): List of converter type names. Empty list means no converters. + converter_classes (Sequence[str]): List of converter class names. Empty list means no converters. Returns: Any: SQLAlchemy combined condition with bound parameters. """ - if len(converter_types) == 0: + if len(converter_classes) == 0: # Explicitly "no converters": match attacks where the converter list # is absent, null, or empty in the stored JSON. return text( @@ -430,7 +430,7 @@ def _get_attack_result_converter_types_condition(self, *, converter_types: Seque conditions = [] bindparams_dict: dict[str, str] = {} - for i, cls in enumerate(converter_types): + for i, cls in enumerate(converter_classes): param_name = f"conv_cls_{i}" conditions.append( f'EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".attack_identifier, ' @@ -444,7 +444,7 @@ def _get_attack_result_converter_types_condition(self, *, converter_types: Seque **bindparams_dict ) - def get_unique_attack_type_names(self) -> list[str]: + def get_unique_attack_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique class_name values from attack_identifier JSON. @@ -462,7 +462,7 @@ def get_unique_attack_type_names(self) -> list[str]: ).fetchall() return sorted(row[0] for row in rows) - def get_unique_converter_type_names(self) -> list[str]: + def get_unique_converter_class_names(self) -> list[str]: """ Azure SQL implementation: extract unique converter class_name values from the request_converter_identifiers array in attack_identifier JSON. diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index c9c3753da..4e7a19aa1 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -291,33 +291,33 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @abc.abstractmethod - def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ - Return a database-specific condition for filtering AttackResults by attack type + Return a database-specific condition for filtering AttackResults by attack class (class_name in the attack_identifier JSON column). Args: - attack_type: Exact attack type name to match. + attack_class: Exact attack class name to match. Returns: Database-specific SQLAlchemy condition. """ @abc.abstractmethod - def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: + def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: """ - Return a database-specific condition for filtering AttackResults by converter types + Return a database-specific condition for filtering AttackResults by converter classes in the request_converter_identifiers array within attack_identifier JSON column. - This method is only called when converter filtering is requested (converter_types + This method is only called when converter filtering is requested (converter_classes is not None). The caller handles the None-vs-list distinction: - - ``len(converter_types) == 0``: return a condition matching attacks with NO converters. - - ``len(converter_types) > 0``: return a condition requiring ALL specified converter - type names to be present (AND logic, case-insensitive). + - ``len(converter_classes) == 0``: return a condition matching attacks with NO converters. + - ``len(converter_classes) > 0``: return a condition requiring ALL specified converter + class names to be present (AND logic, case-insensitive). Args: - converter_types: Converter type names to require. An empty sequence means + converter_classes: Converter class names to require. An empty sequence means "match only attacks that have no converters". Returns: @@ -325,27 +325,27 @@ def _get_attack_result_converter_types_condition(self, *, converter_types: Seque """ @abc.abstractmethod - def get_unique_attack_type_names(self) -> list[str]: + def get_unique_attack_class_names(self) -> list[str]: """ - Return sorted unique attack type names from all stored attack results. + Return sorted unique attack class names from all stored attack results. Extracts class_name from the attack_identifier JSON column via a database-level DISTINCT query. Returns: - Sorted list of unique attack type name strings. + Sorted list of unique attack class name strings. """ @abc.abstractmethod - def get_unique_converter_type_names(self) -> list[str]: + def get_unique_converter_class_names(self) -> list[str]: """ - Return sorted unique converter type names used across all attack results. + Return sorted unique converter class names used across all attack results. Extracts class_name values from the request_converter_identifiers array within the attack_identifier JSON column via a database-level query. Returns: - Sorted list of unique converter type name strings. + Sorted list of unique converter class name strings. """ @abc.abstractmethod @@ -1364,8 +1364,8 @@ def get_attack_results( objective: Optional[str] = None, objective_sha256: Optional[Sequence[str]] = None, outcome: Optional[str] = None, - attack_type: Optional[str] = None, - converter_types: Optional[Sequence[str]] = None, + attack_class: Optional[str] = None, + converter_classes: Optional[Sequence[str]] = None, targeted_harm_categories: Optional[Sequence[str]] = None, labels: Optional[dict[str, str]] = None, ) -> Sequence[AttackResult]: @@ -1380,9 +1380,9 @@ def get_attack_results( Defaults to None. outcome (Optional[str], optional): The outcome to filter by (success, failure, undetermined). Defaults to None. - attack_type (Optional[str], optional): Filter by exact attack class_name in attack_identifier. + attack_class (Optional[str], optional): Filter by exact attack class_name in attack_identifier. Defaults to None. - converter_types (Optional[Sequence[str]], optional): Filter by converter type names. + converter_classes (Optional[Sequence[str]], optional): Filter by converter class names. Returns only attacks that used ALL specified converters (AND logic, case-insensitive). Defaults to None. targeted_harm_categories (Optional[Sequence[str]], optional): @@ -1416,14 +1416,14 @@ def get_attack_results( if outcome: conditions.append(AttackResultEntry.outcome == outcome) - if attack_type: + if attack_class: # Use database-specific JSON query method - conditions.append(self._get_attack_result_attack_type_condition(attack_type=attack_type)) + conditions.append(self._get_attack_result_attack_class_condition(attack_class=attack_class)) - if converter_types is not None: - # converter_types=[] means "only attacks with no converters" - # converter_types=["A","B"] means "must have all listed converters" - conditions.append(self._get_attack_result_converter_types_condition(converter_types=converter_types)) + if converter_classes is not None: + # converter_classes=[] means "only attacks with no converters" + # converter_classes=["A","B"] means "must have all listed converters" + conditions.append(self._get_attack_result_converter_classes_condition(converter_classes=converter_classes)) if targeted_harm_categories: # Use database-specific JSON query method diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 375c348e9..15c89bf95 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Optional, TypeVar, Union -from sqlalchemy import and_, create_engine, exists, func, or_, text +from sqlalchemy import and_, create_engine, func, or_, text from sqlalchemy.engine.base import Engine from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload, sessionmaker @@ -508,29 +508,29 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ) return labels_subquery - def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any: + def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ - SQLite implementation for filtering AttackResults by attack type. + SQLite implementation for filtering AttackResults by attack class. Uses json_extract() to match class_name in the attack_identifier JSON column. Returns: - Any: A SQLAlchemy condition for filtering by attack type. + Any: A SQLAlchemy condition for filtering by attack class. """ - return func.json_extract(AttackResultEntry.attack_identifier, "$.class_name") == attack_type + return func.json_extract(AttackResultEntry.attack_identifier, "$.class_name") == attack_class - def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any: + def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any: """ - SQLite implementation for filtering AttackResults by converter types. + SQLite implementation for filtering AttackResults by converter classes. - When converter_types is empty, matches attacks with no converters + When converter_classes is empty, matches attacks with no converters (request_converter_identifiers is absent or null in the JSON). - When non-empty, uses json_each() to check all specified types are present + When non-empty, uses json_each() to check all specified classes are present (AND logic, case-insensitive). Returns: - Any: A SQLAlchemy condition for filtering by converter types. + Any: A SQLAlchemy condition for filtering by converter classes. """ - if len(converter_types) == 0: + if len(converter_classes) == 0: # Explicitly "no converters": match attacks where the converter list # is absent, null, or empty in the stored JSON. converter_json = func.json_extract(AttackResultEntry.attack_identifier, "$.request_converter_identifiers") @@ -542,7 +542,7 @@ def _get_attack_result_converter_types_condition(self, *, converter_types: Seque ) conditions = [] - for i, cls in enumerate(converter_types): + for i, cls in enumerate(converter_classes): param_name = f"conv_cls_{i}" conditions.append( text( @@ -553,7 +553,7 @@ def _get_attack_result_converter_types_condition(self, *, converter_types: Seque ) return and_(*conditions) - def get_unique_attack_type_names(self) -> list[str]: + def get_unique_attack_class_names(self) -> list[str]: """ SQLite implementation: extract unique class_name values from attack_identifier JSON. @@ -569,7 +569,7 @@ def get_unique_attack_type_names(self) -> list[str]: ) return sorted(row[0] for row in rows) - def get_unique_converter_type_names(self) -> list[str]: + def get_unique_converter_class_names(self) -> list[str]: """ SQLite implementation: extract unique converter class_name values from the request_converter_identifiers array in attack_identifier JSON. diff --git a/tests/unit/backend/test_api_routes.py b/tests/unit/backend/test_api_routes.py index 44ad3f34b..3725dc21c 100644 --- a/tests/unit/backend/test_api_routes.py +++ b/tests/unit/backend/test_api_routes.py @@ -86,13 +86,13 @@ def test_list_attacks_with_filters(self, client: TestClient) -> None: response = client.get( "/api/attacks", - params={"attack_class": "CrescendoAttack", "outcome": "success", "limit": 10}, + params={"attack_type": "CrescendoAttack", "outcome": "success", "limit": 10}, ) assert response.status_code == status.HTTP_200_OK mock_service.list_attacks_async.assert_called_once_with( - attack_class="CrescendoAttack", - converter_classes=None, + attack_type="CrescendoAttack", + converter_types=None, outcome="success", labels=None, min_turns=None, @@ -416,7 +416,7 @@ def test_get_attack_options(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["attack_classes"] == ["CrescendoAttack", "ManualAttack"] + assert data["attack_types"] == ["CrescendoAttack", "ManualAttack"] def test_get_converter_options(self, client: TestClient) -> None: """Test getting converter options from attack results.""" @@ -429,7 +429,7 @@ def test_get_converter_options(self, client: TestClient) -> None: assert response.status_code == status.HTTP_200_OK data = response.json() - assert data["converter_classes"] == ["Base64Converter", "ROT13Converter"] + assert data["converter_types"] == ["Base64Converter", "ROT13Converter"] def test_parse_labels_skips_param_without_colon(self, client: TestClient) -> None: """Test that _parse_labels skips label params that have no colon.""" @@ -486,8 +486,8 @@ def test_parse_labels_value_with_extra_colons(self, client: TestClient) -> None: call_kwargs = mock_service.list_attacks_async.call_args[1] assert call_kwargs["labels"] == {"url": "http://example.com:8080"} - def test_list_attacks_forwards_converter_classes_param(self, client: TestClient) -> None: - """Test that converter_classes query params are forwarded to service.""" + def test_list_attacks_forwards_converter_types_param(self, client: TestClient) -> None: + """Test that converter_types query params are forwarded to service.""" with patch("pyrit.backend.routes.attacks.get_attack_service") as mock_get_service: mock_service = MagicMock() mock_service.list_attacks_async = AsyncMock( @@ -498,11 +498,11 @@ def test_list_attacks_forwards_converter_classes_param(self, client: TestClient) ) mock_get_service.return_value = mock_service - response = client.get("/api/attacks?converter_classes=Base64&converter_classes=ROT13") + response = client.get("/api/attacks?converter_types=Base64&converter_types=ROT13") assert response.status_code == status.HTTP_200_OK call_kwargs = mock_service.list_attacks_async.call_args[1] - assert call_kwargs["converter_classes"] == ["Base64", "ROT13"] + assert call_kwargs["converter_types"] == ["Base64", "ROT13"] # ============================================================================ diff --git a/tests/unit/backend/test_attack_service.py b/tests/unit/backend/test_attack_service.py index e184f8ad4..7ed17497e 100644 --- a/tests/unit/backend/test_attack_service.py +++ b/tests/unit/backend/test_attack_service.py @@ -173,45 +173,45 @@ async def test_list_attacks_returns_attacks(self, attack_service, mock_memory) - assert result.items[0].attack_type == "Test Attack" @pytest.mark.asyncio - async def test_list_attacks_filters_by_attack_class_exact(self, attack_service, mock_memory) -> None: - """Test that list_attacks passes attack_class to memory layer.""" + async def test_list_attacks_filters_by_attack_type_exact(self, attack_service, mock_memory) -> None: + """Test that list_attacks passes attack_type to memory layer as attack_class.""" ar1 = make_attack_result(conversation_id="attack-1", name="CrescendoAttack") mock_memory.get_attack_results.return_value = [ar1] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks_async(attack_class="CrescendoAttack") + result = await attack_service.list_attacks_async(attack_type="CrescendoAttack") assert len(result.items) == 1 assert result.items[0].conversation_id == "attack-1" - # Verify attack_class was forwarded to the memory layer + # Verify attack_type was forwarded to the memory layer as attack_class call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["attack_class"] == "CrescendoAttack" @pytest.mark.asyncio - async def test_list_attacks_attack_class_passed_to_memory(self, attack_service, mock_memory) -> None: - """Test that attack_class is forwarded to memory for DB-level filtering.""" + async def test_list_attacks_attack_type_passed_to_memory(self, attack_service, mock_memory) -> None: + """Test that attack_type is forwarded to memory as attack_class for DB-level filtering.""" mock_memory.get_attack_results.return_value = [] mock_memory.get_message_pieces.return_value = [] - await attack_service.list_attacks_async(attack_class="Crescendo") + await attack_service.list_attacks_async(attack_type="Crescendo") call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["attack_class"] == "Crescendo" @pytest.mark.asyncio async def test_list_attacks_filters_by_no_converters(self, attack_service, mock_memory) -> None: - """Test that converter_classes=[] is forwarded to memory for DB-level filtering.""" + """Test that converter_types=[] is forwarded to memory for DB-level filtering.""" mock_memory.get_attack_results.return_value = [] mock_memory.get_message_pieces.return_value = [] - await attack_service.list_attacks_async(converter_classes=[]) + await attack_service.list_attacks_async(converter_types=[]) call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["converter_classes"] == [] @pytest.mark.asyncio - async def test_list_attacks_filters_by_converter_classes_and_logic(self, attack_service, mock_memory) -> None: - """Test that list_attacks passes converter_classes to memory layer.""" + async def test_list_attacks_filters_by_converter_types_and_logic(self, attack_service, mock_memory) -> None: + """Test that list_attacks passes converter_types to memory layer.""" ar1 = make_attack_result(conversation_id="attack-1", name="Attack One") ar1.attack_identifier = ComponentIdentifier( class_name="Attack One", @@ -240,11 +240,11 @@ async def test_list_attacks_filters_by_converter_classes_and_logic(self, attack_ mock_memory.get_attack_results.return_value = [ar1] mock_memory.get_message_pieces.return_value = [] - result = await attack_service.list_attacks_async(converter_classes=["Base64Converter", "ROT13Converter"]) + result = await attack_service.list_attacks_async(converter_types=["Base64Converter", "ROT13Converter"]) assert len(result.items) == 1 assert result.items[0].conversation_id == "attack-1" - # Verify converter_classes was forwarded to the memory layer + # Verify converter_types was forwarded to the memory layer as converter_classes call_kwargs = mock_memory.get_attack_results.call_args[1] assert call_kwargs["converter_classes"] == ["Base64Converter", "ROT13Converter"] diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index b97753205..fe560d429 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -1149,60 +1149,60 @@ def _make_attack_result_with_identifier( ) -def test_get_attack_results_by_attack_type(sqlite_instance: MemoryInterface): - """Test filtering attack results by attack_type matches class_name in JSON.""" +def test_get_attack_results_by_attack_class(sqlite_instance: MemoryInterface): + """Test filtering attack results by attack_class matches class_name in JSON.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") assert len(results) == 2 assert {r.conversation_id for r in results} == {"conv_1", "conv_3"} -def test_get_attack_results_by_attack_type_no_match(sqlite_instance: MemoryInterface): - """Test that attack_type filter returns empty when nothing matches.""" +def test_get_attack_results_by_attack_class_no_match(sqlite_instance: MemoryInterface): + """Test that attack_class filter returns empty when nothing matches.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_type="NonExistentAttack") + results = sqlite_instance.get_attack_results(attack_class="NonExistentAttack") assert len(results) == 0 -def test_get_attack_results_by_attack_type_case_sensitive(sqlite_instance: MemoryInterface): - """Test that attack_type filter is case-sensitive (exact match).""" +def test_get_attack_results_by_attack_class_case_sensitive(sqlite_instance: MemoryInterface): + """Test that attack_class filter is case-sensitive (exact match).""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(attack_type="crescendoattack") + results = sqlite_instance.get_attack_results(attack_class="crescendoattack") assert len(results) == 0 -def test_get_attack_results_by_attack_type_no_identifier(sqlite_instance: MemoryInterface): - """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_type filter.""" +def test_get_attack_results_by_attack_class_no_identifier(sqlite_instance: MemoryInterface): + """Test that attacks with no attack_identifier (empty JSON) are excluded by attack_class filter.""" ar1 = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack") + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack") assert len(results) == 1 assert results[0].conversation_id == "conv_2" -def test_get_attack_results_converter_types_none_returns_all(sqlite_instance: MemoryInterface): - """Test that converter_types=None (omitted) returns all attacks unfiltered.""" +def test_get_attack_results_converter_classes_none_returns_all(sqlite_instance: MemoryInterface): + """Test that converter_classes=None (omitted) returns all attacks unfiltered.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack") # No converters (None) ar3 = create_attack_result("conv_3", 3) # No identifier at all sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(converter_types=None) + results = sqlite_instance.get_attack_results(converter_classes=None) assert len(results) == 3 -def test_get_attack_results_converter_types_empty_matches_no_converters(sqlite_instance: MemoryInterface): - """Test that converter_types=[] returns only attacks with no converters.""" +def test_get_attack_results_converter_classes_empty_matches_no_converters(sqlite_instance: MemoryInterface): + """Test that converter_classes=[] returns only attacks with no converters.""" ar_with_conv = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar_no_conv_none = _make_attack_result_with_identifier("conv_2", "Attack") # converter_ids=None ar_no_conv_empty = _make_attack_result_with_identifier("conv_3", "Attack", []) # converter_ids=[] @@ -1211,7 +1211,7 @@ def test_get_attack_results_converter_types_empty_matches_no_converters(sqlite_i attack_results=[ar_with_conv, ar_no_conv_none, ar_no_conv_empty, ar_no_identifier] ) - results = sqlite_instance.get_attack_results(converter_types=[]) + results = sqlite_instance.get_attack_results(converter_classes=[]) conv_ids = {r.conversation_id for r in results} # Should include attacks with no converters (None key, empty array, or empty identifier) assert "conv_1" not in conv_ids, "Should not include attacks that have converters" @@ -1220,130 +1220,130 @@ def test_get_attack_results_converter_types_empty_matches_no_converters(sqlite_i assert "conv_4" in conv_ids, "Should include attacks with empty attack_identifier" -def test_get_attack_results_converter_types_single_match(sqlite_instance: MemoryInterface): - """Test that converter_types with one type returns attacks using that converter.""" +def test_get_attack_results_converter_classes_single_match(sqlite_instance: MemoryInterface): + """Test that converter_classes with one class returns attacks using that converter.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(converter_types=["Base64Converter"]) + results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter"]) conv_ids = {r.conversation_id for r in results} assert conv_ids == {"conv_1", "conv_3"} -def test_get_attack_results_converter_types_and_logic(sqlite_instance: MemoryInterface): - """Test that multiple converter_types use AND logic — all must be present.""" +def test_get_attack_results_converter_classes_and_logic(sqlite_instance: MemoryInterface): + """Test that multiple converter_classes use AND logic — all must be present.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["ROT13Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "Attack", ["Base64Converter", "ROT13Converter"]) ar4 = _make_attack_result_with_identifier("conv_4", "Attack", ["Base64Converter", "ROT13Converter", "UrlConverter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) - results = sqlite_instance.get_attack_results(converter_types=["Base64Converter", "ROT13Converter"]) + results = sqlite_instance.get_attack_results(converter_classes=["Base64Converter", "ROT13Converter"]) conv_ids = {r.conversation_id for r in results} # conv_3 and conv_4 have both; conv_1 and conv_2 have only one assert conv_ids == {"conv_3", "conv_4"} -def test_get_attack_results_converter_types_case_insensitive(sqlite_instance: MemoryInterface): - """Test that converter type matching is case-insensitive.""" +def test_get_attack_results_converter_classes_case_insensitive(sqlite_instance: MemoryInterface): + """Test that converter class matching is case-insensitive.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(converter_types=["base64converter"]) + results = sqlite_instance.get_attack_results(converter_classes=["base64converter"]) assert len(results) == 1 assert results[0].conversation_id == "conv_1" -def test_get_attack_results_converter_types_no_match(sqlite_instance: MemoryInterface): - """Test that converter_types filter returns empty when no attack has the converter.""" +def test_get_attack_results_converter_classes_no_match(sqlite_instance: MemoryInterface): + """Test that converter_classes filter returns empty when no attack has the converter.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1]) - results = sqlite_instance.get_attack_results(converter_types=["NonExistentConverter"]) + results = sqlite_instance.get_attack_results(converter_classes=["NonExistentConverter"]) assert len(results) == 0 -def test_get_attack_results_attack_type_and_converter_types_combined(sqlite_instance: MemoryInterface): - """Test combining attack_type and converter_types filters.""" +def test_get_attack_results_attack_class_and_converter_classes_combined(sqlite_instance: MemoryInterface): + """Test combining attack_class and converter_classes filters.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack", ["Base64Converter"]) ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack", ["ROT13Converter"]) ar4 = _make_attack_result_with_identifier("conv_4", "CrescendoAttack") # No converters sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3, ar4]) - results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack", converter_types=["Base64Converter"]) + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=["Base64Converter"]) assert len(results) == 1 assert results[0].conversation_id == "conv_1" -def test_get_attack_results_attack_type_with_no_converters(sqlite_instance: MemoryInterface): - """Test combining attack_type with converter_types=[] (no converters).""" +def test_get_attack_results_attack_class_with_no_converters(sqlite_instance: MemoryInterface): + """Test combining attack_class with converter_classes=[] (no converters).""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack", ["Base64Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") # No converters ar3 = _make_attack_result_with_identifier("conv_3", "ManualAttack") # No converters sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - results = sqlite_instance.get_attack_results(attack_type="CrescendoAttack", converter_types=[]) + results = sqlite_instance.get_attack_results(attack_class="CrescendoAttack", converter_classes=[]) assert len(results) == 1 assert results[0].conversation_id == "conv_2" # ============================================================================ -# Unique attack type and converter type name tests +# Unique attack class and converter class name tests # ============================================================================ -def test_get_unique_attack_type_names_empty(sqlite_instance: MemoryInterface): +def test_get_unique_attack_class_names_empty(sqlite_instance: MemoryInterface): """Test that no attacks returns empty list.""" - result = sqlite_instance.get_unique_attack_type_names() + result = sqlite_instance.get_unique_attack_class_names() assert result == [] -def test_get_unique_attack_type_names_sorted_unique(sqlite_instance: MemoryInterface): - """Test that unique type names are returned sorted, with duplicates removed.""" +def test_get_unique_attack_class_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique class names are returned sorted, with duplicates removed.""" ar1 = _make_attack_result_with_identifier("conv_1", "CrescendoAttack") ar2 = _make_attack_result_with_identifier("conv_2", "ManualAttack") ar3 = _make_attack_result_with_identifier("conv_3", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2, ar3]) - result = sqlite_instance.get_unique_attack_type_names() + result = sqlite_instance.get_unique_attack_class_names() assert result == ["CrescendoAttack", "ManualAttack"] -def test_get_unique_attack_type_names_skips_empty_identifier(sqlite_instance: MemoryInterface): +def test_get_unique_attack_class_names_skips_empty_identifier(sqlite_instance: MemoryInterface): """Test that attacks with empty attack_identifier (no class_name) are excluded.""" ar_no_id = create_attack_result("conv_1", 1) # No attack_identifier → stored as {} ar_with_id = _make_attack_result_with_identifier("conv_2", "CrescendoAttack") sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_id, ar_with_id]) - result = sqlite_instance.get_unique_attack_type_names() + result = sqlite_instance.get_unique_attack_class_names() assert result == ["CrescendoAttack"] -def test_get_unique_converter_type_names_empty(sqlite_instance: MemoryInterface): +def test_get_unique_converter_class_names_empty(sqlite_instance: MemoryInterface): """Test that no attacks returns empty list.""" - result = sqlite_instance.get_unique_converter_type_names() + result = sqlite_instance.get_unique_converter_class_names() assert result == [] -def test_get_unique_converter_type_names_sorted_unique(sqlite_instance: MemoryInterface): - """Test that unique converter type names are returned sorted, with duplicates removed.""" +def test_get_unique_converter_class_names_sorted_unique(sqlite_instance: MemoryInterface): + """Test that unique converter class names are returned sorted, with duplicates removed.""" ar1 = _make_attack_result_with_identifier("conv_1", "Attack", ["Base64Converter", "ROT13Converter"]) ar2 = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) sqlite_instance.add_attack_results_to_memory(attack_results=[ar1, ar2]) - result = sqlite_instance.get_unique_converter_type_names() + result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter", "ROT13Converter"] -def test_get_unique_converter_type_names_skips_no_converters(sqlite_instance: MemoryInterface): +def test_get_unique_converter_class_names_skips_no_converters(sqlite_instance: MemoryInterface): """Test that attacks with no converters don't contribute names.""" ar_no_conv = _make_attack_result_with_identifier("conv_1", "Attack") # No converters ar_with_conv = _make_attack_result_with_identifier("conv_2", "Attack", ["Base64Converter"]) ar_empty_id = create_attack_result("conv_3", 3) # Empty attack_identifier sqlite_instance.add_attack_results_to_memory(attack_results=[ar_no_conv, ar_with_conv, ar_empty_id]) - result = sqlite_instance.get_unique_converter_type_names() + result = sqlite_instance.get_unique_converter_class_names() assert result == ["Base64Converter"] From 13d8dff1bc41fcf247cf36784baacb00f5b556b4 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 13:06:27 -0800 Subject: [PATCH 3/6] Fix update_attack_result: validate empty fields, convert ID to UUID - Add explicit ValueError raise when update_fields is empty (matches docstring) - Convert attack_result_id string to uuid.UUID before querying to avoid type mismatch with the UUID column in SQLAlchemy Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/memory_interface.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 4e7a19aa1..b5fee6651 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -1323,6 +1323,9 @@ def update_attack_result(self, *, conversation_id: str, update_fields: dict[str, Raises: ValueError: If update_fields is empty. """ + if not update_fields: + raise ValueError("update_fields must not be empty") + entries: MutableSequence[AttackResultEntry] = self._query_entries( AttackResultEntry, conditions=AttackResultEntry.conversation_id == conversation_id, @@ -1347,9 +1350,18 @@ def update_attack_result_by_id(self, *, attack_result_id: str, update_fields: di Returns: True if the update was successful, False if the entry was not found. """ + try: + attack_result_uuid = uuid.UUID(attack_result_id) + except (ValueError, TypeError): + logger.warning( + "Invalid attack_result_id '%s' passed to update_attack_result_by_id", + attack_result_id, + ) + return False + entries: MutableSequence[AttackResultEntry] = self._query_entries( AttackResultEntry, - conditions=AttackResultEntry.id == attack_result_id, + conditions=AttackResultEntry.id == attack_result_uuid, ) if not entries: return False From 904c8167bd348df0429aadf97e1d9df28f5ba648 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 13:10:53 -0800 Subject: [PATCH 4/6] Address remaining review comments - Move inline imports (closing, SQLAlchemyError) to module top in memory_interface.py - Fix _build_identifier() in openai_response_target.py: merge target_specific_params into params to match _create_identifier() signature - Use text_piece.conversation_id instead of message.message_pieces[0] in openai_video_target.py for consistency Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/memory_interface.py | 11 +---------- pyrit/prompt_target/openai/openai_response_target.py | 7 +------ pyrit/prompt_target/openai/openai_video_target.py | 3 +-- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index b5fee6651..928d1ac81 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -15,6 +15,7 @@ from sqlalchemy import MetaData, and_, or_ from sqlalchemy.engine.base import Engine +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.attributes import InstrumentedAttribute from pyrit.common.path import DB_DATA_PATH @@ -243,7 +244,6 @@ def _update_entry(self, entry: Base) -> None: Raises: SQLAlchemyError: If there's an error during the database operation. """ - from sqlalchemy.exc import SQLAlchemyError with closing(self.get_session()) as session: try: @@ -1283,16 +1283,7 @@ def add_attack_results_to_memory(self, *, attack_results: Sequence[AttackResult] SQLAlchemyError: If the database transaction fails. """ entries = [AttackResultEntry(entry=attack_result) for attack_result in attack_results] - # Capture the DB-assigned IDs before insert (they'll be set after flush/commit). - # _insert_entries closes the session, so we must read `entry.id` *inside* - # the session. Since _insert_entries uses a context manager, we instead - # read the ids from the entries *before* the session closes by doing the - # insert inline. - from contextlib import closing - with closing(self.get_session()) as session: - from sqlalchemy.exc import SQLAlchemyError - try: session.add_all(entries) session.commit() diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 5352573d7..01584989e 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -171,11 +171,6 @@ def _build_identifier(self) -> ComponentIdentifier: Returns: ComponentIdentifier: The identifier for this target instance. """ - specific_params: dict[str, Any] = { - "max_output_tokens": self._max_output_tokens, - } - if self._extra_body_parameters: - specific_params["extra_body_parameters"] = self._extra_body_parameters return self._create_identifier( params={ "temperature": self._temperature, @@ -183,8 +178,8 @@ def _build_identifier(self) -> ComponentIdentifier: "max_output_tokens": self._max_output_tokens, "reasoning_effort": self._reasoning_effort, "reasoning_summary": self._reasoning_summary, + "extra_body_parameters": self._extra_body_parameters, }, - target_specific_params=specific_params, ) def _set_openai_env_configuration_vars(self) -> None: diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index b6b96956e..77e3510cf 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -476,8 +476,7 @@ def _validate_request(self, *, message: Message) -> None: if remix_video_id and image_pieces: raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") - request = message.message_pieces[0] - messages = self._memory.get_conversation(conversation_id=request.conversation_id) + messages = self._memory.get_conversation(conversation_id=text_piece.conversation_id) n_messages = len(messages) if n_messages > 0: From 493abc17656c75a8b34e6fa12624b35ae7dff8da Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 13:18:19 -0800 Subject: [PATCH 5/6] Fix lint: suppress, zip strict, RET504 noqa, TC003, ORDER BY for LIMIT 1 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pyrit/memory/azure_sql_memory.py | 7 +++---- pyrit/memory/memory_interface.py | 3 +-- pyrit/memory/sqlite_memory.py | 11 +++++------ pyrit/models/conversation_stats.py | 11 +++++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pyrit/memory/azure_sql_memory.py b/pyrit/memory/azure_sql_memory.py index 75e1b2b4a..c7e6235ee 100644 --- a/pyrit/memory/azure_sql_memory.py +++ b/pyrit/memory/azure_sql_memory.py @@ -5,7 +5,7 @@ import logging import struct from collections.abc import MutableSequence, Sequence -from contextlib import closing +from contextlib import closing, suppress from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union @@ -522,6 +522,7 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str AND p3.labels IS NOT NULL AND p3.labels != '{{}}' AND p3.labels != 'null' + ORDER BY p3.sequence ASC, p3.id ASC ) AS first_labels, MIN(pme.timestamp) AS created_at FROM "PromptMemoryEntries" pme @@ -543,10 +544,8 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str labels: dict[str, str] = {} if raw_labels and raw_labels not in ("null", "{}"): - try: + with suppress(ValueError, TypeError): labels = json.loads(raw_labels) - except (ValueError, TypeError): - pass created_at = None if raw_created_at is not None: diff --git a/pyrit/memory/memory_interface.py b/pyrit/memory/memory_interface.py index 928d1ac81..e66e239e8 100644 --- a/pyrit/memory/memory_interface.py +++ b/pyrit/memory/memory_interface.py @@ -244,7 +244,6 @@ def _update_entry(self, entry: Base) -> None: Raises: SQLAlchemyError: If there's an error during the database operation. """ - with closing(self.get_session()) as session: try: session.merge(entry) @@ -1289,7 +1288,7 @@ def add_attack_results_to_memory(self, *, attack_results: Sequence[AttackResult] session.commit() # Populate the attack_result_id back onto the domain objects so callers # can reference the DB-assigned ID immediately after insert. - for ar, entry in zip(attack_results, entries): + for ar, entry in zip(attack_results, entries, strict=False): ar.attack_result_id = str(entry.id) except SQLAlchemyError: session.rollback() diff --git a/pyrit/memory/sqlite_memory.py b/pyrit/memory/sqlite_memory.py index 15c89bf95..552953771 100644 --- a/pyrit/memory/sqlite_memory.py +++ b/pyrit/memory/sqlite_memory.py @@ -5,7 +5,7 @@ import logging import uuid from collections.abc import MutableSequence, Sequence -from contextlib import closing +from contextlib import closing, suppress from datetime import datetime from pathlib import Path from typing import Any, Optional, TypeVar, Union @@ -483,7 +483,7 @@ def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories ), ) ) - return targeted_harm_categories_subquery + return targeted_harm_categories_subquery # noqa: RET504 def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: """ @@ -506,7 +506,7 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any: ), ) ) - return labels_subquery + return labels_subquery # noqa: RET504 def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any: """ @@ -629,6 +629,7 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str AND p3.labels IS NOT NULL AND p3.labels != '{{}}' AND p3.labels != 'null' + ORDER BY p3.sequence ASC, p3.id ASC LIMIT 1 ) AS first_labels, MIN(pme.timestamp) AS created_at @@ -651,10 +652,8 @@ def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str labels: dict[str, str] = {} if raw_labels and raw_labels not in ("null", "{}"): - try: + with suppress(ValueError, TypeError): labels = json.loads(raw_labels) - except (ValueError, TypeError): - pass created_at = None if raw_created_at is not None: diff --git a/pyrit/models/conversation_stats.py b/pyrit/models/conversation_stats.py index c67f3d842..bb8283fcc 100644 --- a/pyrit/models/conversation_stats.py +++ b/pyrit/models/conversation_stats.py @@ -4,13 +4,16 @@ from __future__ import annotations from dataclasses import dataclass, field -from datetime import datetime -from typing import ClassVar, Dict, Optional +from typing import TYPE_CHECKING, ClassVar, Optional + +if TYPE_CHECKING: + from datetime import datetime @dataclass(frozen=True) class ConversationStats: - """Lightweight aggregate statistics for a conversation. + """ + Lightweight aggregate statistics for a conversation. Used to build attack summaries without loading full message pieces. """ @@ -19,5 +22,5 @@ class ConversationStats: message_count: int = 0 last_message_preview: Optional[str] = None - labels: Dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = field(default_factory=dict) created_at: Optional[datetime] = None From 75b5255066794f2183b248549fd5657fd89f1a87 Mon Sep 17 00:00:00 2001 From: Roman Lutz Date: Mon, 2 Mar 2026 13:21:12 -0800 Subject: [PATCH 6/6] Add ConversationStats to api.rst documentation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 8cc7a73ce..48de65a5e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -338,6 +338,7 @@ API Reference ChatMessageListDictContent construct_response_from_request ConversationReference + ConversationStats ConversationType DataTypeSerializer data_serializer_factory