Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 104 additions & 15 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import logging
import struct
from collections.abc import MutableSequence, Sequence
Expand All @@ -27,6 +28,7 @@
)
from pyrit.models import (
AzureBlobStorageIO,
ConversationStats,
MessagePiece,
)

Expand Down Expand Up @@ -386,37 +388,37 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
)
)

def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any:
def _get_attack_result_attack_type_condition(self, *, attack_type: str) -> Any:
"""
Azure SQL implementation for filtering AttackResults by attack type.
Uses JSON_VALUE() to match class_name in the attack_identifier JSON column.

Args:
attack_class (str): Exact attack class name to match.
attack_type (str): Exact attack type name to match.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON("AttackResultEntries".attack_identifier) = 1
AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_class"""
).bindparams(attack_class=attack_class)
AND JSON_VALUE("AttackResultEntries".attack_identifier, '$.class_name') = :attack_type"""
).bindparams(attack_type=attack_type)

def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[str]) -> Any:
def _get_attack_result_converter_types_condition(self, *, converter_types: Sequence[str]) -> Any:
"""
Azure SQL implementation for filtering AttackResults by converter classes.
Azure SQL implementation for filtering AttackResults by converter types.

When converter_classes is empty, matches attacks with no converters.
When non-empty, uses OPENJSON() to check all specified classes are present
When converter_types is empty, matches attacks with no converters.
When non-empty, uses OPENJSON() to check all specified types are present
(AND logic, case-insensitive).

Args:
converter_classes (Sequence[str]): List of converter class names. Empty list means no converters.
converter_types (Sequence[str]): List of converter type names. Empty list means no converters.

Returns:
Any: SQLAlchemy combined condition with bound parameters.
"""
if len(converter_classes) == 0:
if len(converter_types) == 0:
# Explicitly "no converters": match attacks where the converter list
# is absent, null, or empty in the stored JSON.
return text(
Expand All @@ -428,7 +430,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[

conditions = []
bindparams_dict: dict[str, str] = {}
for i, cls in enumerate(converter_classes):
for i, cls in enumerate(converter_types):
param_name = f"conv_cls_{i}"
conditions.append(
f'EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".attack_identifier, '
Expand All @@ -442,7 +444,7 @@ def _get_attack_result_converter_condition(self, *, converter_classes: Sequence[
**bindparams_dict
)

def get_unique_attack_class_names(self) -> list[str]:
def get_unique_attack_type_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique class_name values from attack_identifier JSON.

Expand All @@ -460,7 +462,7 @@ def get_unique_attack_class_names(self) -> list[str]:
).fetchall()
return sorted(row[0] for row in rows)

def get_unique_converter_class_names(self) -> list[str]:
def get_unique_converter_type_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique converter class_name values
from the request_converter_identifiers array in attack_identifier JSON.
Expand All @@ -481,6 +483,87 @@ def get_unique_converter_class_names(self) -> list[str]:
).fetchall()
return sorted(row[0] for row in rows)

def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]:
"""
Azure SQL implementation: lightweight aggregate stats per conversation.

Executes a single SQL query that returns message count (distinct
sequences), a truncated last-message preview, the first non-empty
labels dict, and the earliest timestamp for each conversation_id.

Args:
conversation_ids (Sequence[str]): The conversation IDs to query.

Returns:
Mapping from conversation_id to ConversationStats.
"""
if not conversation_ids:
return {}

placeholders = ", ".join(f":cid{i}" for i in range(len(conversation_ids)))
params = {f"cid{i}": cid for i, cid in enumerate(conversation_ids)}

max_len = ConversationStats.PREVIEW_MAX_LEN
sql = text(
f"""
SELECT
pme.conversation_id,
COUNT(DISTINCT pme.sequence) AS msg_count,
(
SELECT TOP 1 LEFT(p2.converted_value, {max_len + 3})
FROM "PromptMemoryEntries" p2
WHERE p2.conversation_id = pme.conversation_id
ORDER BY p2.sequence DESC, p2.id DESC
) AS last_preview,
(
SELECT TOP 1 p3.labels
FROM "PromptMemoryEntries" p3
WHERE p3.conversation_id = pme.conversation_id
AND p3.labels IS NOT NULL
AND p3.labels != '{{}}'
AND p3.labels != 'null'
) AS first_labels,
MIN(pme.timestamp) AS created_at
FROM "PromptMemoryEntries" pme
WHERE pme.conversation_id IN ({placeholders})
GROUP BY pme.conversation_id
"""
)

with closing(self.get_session()) as session:
rows = session.execute(sql, params).fetchall()

result: dict[str, ConversationStats] = {}
for row in rows:
conv_id, msg_count, last_preview, raw_labels, raw_created_at = row

preview = None
if last_preview:
preview = last_preview[:max_len] + "..." if len(last_preview) > max_len else last_preview

labels: dict[str, str] = {}
if raw_labels and raw_labels not in ("null", "{}"):
try:
labels = json.loads(raw_labels)
except (ValueError, TypeError):
pass

created_at = None
if raw_created_at is not None:
if isinstance(raw_created_at, str):
created_at = datetime.fromisoformat(raw_created_at)
else:
created_at = raw_created_at

result[conv_id] = ConversationStats(
message_count=msg_count,
last_message_preview=preview,
labels=labels,
created_at=created_at,
)

return result

def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any:
"""
Get the SQL Azure implementation for filtering ScenarioResults by labels.
Expand Down Expand Up @@ -673,8 +756,14 @@ def _update_entries(self, *, entries: MutableSequence[Base], update_fields: dict
with closing(self.get_session()) as session:
try:
for entry in entries:
# Ensure the entry is attached to the session. If it's detached, merge it.
entry_in_session = session.merge(entry) if not session.is_modified(entry) else entry
# Load a fresh copy by primary key so we only touch the
# requested fields. Using merge() would copy ALL
# attributes from the (potentially stale) detached object
# and silently overwrite concurrent updates to columns
# that are NOT in update_fields.
entry_in_session = session.get(type(entry), entry.id) # type: ignore[attr-defined]
if entry_in_session is None:
entry_in_session = session.merge(entry)
for field, value in update_fields.items():
if field in vars(entry_in_session):
setattr(entry_in_session, field, value)
Expand Down
Loading
Loading