Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def __init__(self, *, entry: AttackResult) -> None:

# Retry events
self.retry_events_json = (
json.dumps([evt.to_dict() for evt in entry.retry_events]) if entry.retry_events else None
json.dumps([evt.model_dump(mode="json") for evt in entry.retry_events]) if entry.retry_events else None
)
self.total_retries = entry.total_retries

Expand Down Expand Up @@ -923,7 +923,7 @@ def get_attack_result(self) -> AttackResult:
if self.retry_events_json:
from pyrit.models.retry_event import RetryEvent

retry_events = [RetryEvent.from_dict(evt_dict) for evt_dict in json.loads(self.retry_events_json)]
retry_events = [RetryEvent.model_validate(evt_dict) for evt_dict in json.loads(self.retry_events_json)]

return AttackResult(
conversation_id=self.conversation_id,
Expand Down
10 changes: 6 additions & 4 deletions pyrit/models/attack_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,15 @@ def to_dict(self) -> dict[str, Any]:
"outcome_reason": self.outcome_reason,
"timestamp": self.timestamp.isoformat(),
"related_conversations": sorted(
[ref.to_dict() for ref in self.related_conversations],
[ref.model_dump(mode="json") for ref in self.related_conversations],
key=lambda r: r["conversation_id"],
),
"metadata": self.metadata,
"labels": self.labels,
"error_message": self.error_message,
"error_type": self.error_type,
"error_traceback": self.error_traceback,
"retry_events": [e.to_dict() for e in self.retry_events],
"retry_events": [e.model_dump(mode="json") for e in self.retry_events],
"total_retries": self.total_retries,
}

Expand Down Expand Up @@ -291,13 +291,15 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult:
timestamp=(
datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc)
),
related_conversations={ConversationReference.from_dict(r) for r in data.get("related_conversations", [])},
related_conversations={
ConversationReference.model_validate(r) for r in data.get("related_conversations", [])
},
metadata=data.get("metadata", {}),
labels=data.get("labels", {}),
error_message=data.get("error_message"),
error_type=data.get("error_type"),
error_traceback=data.get("error_traceback"),
retry_events=[RetryEvent.from_dict(e) for e in data.get("retry_events", [])],
retry_events=[RetryEvent.model_validate(e) for e in data.get("retry_events", [])],
total_retries=data.get("total_retries", 0),
)

Expand Down
67 changes: 40 additions & 27 deletions pyrit/models/conversation_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Optional

from pydantic import BaseModel, ConfigDict

from pyrit.common.deprecation import print_deprecation_message


class ConversationType(Enum):
"""Types of conversations that can be associated with an attack."""
Expand All @@ -17,15 +20,15 @@ class ConversationType(Enum):
CONVERTER = "converter"


@dataclass(frozen=True)
class ConversationReference:
class ConversationReference(BaseModel):
"""Immutable reference to a conversation that played a role in the attack."""

model_config = ConfigDict(frozen=True)

conversation_id: str
conversation_type: ConversationType
description: Optional[str] = None

# Allow use in set / dict
def __hash__(self) -> int:
"""
Return a hash derived from conversation ID.
Expand All @@ -36,45 +39,55 @@ def __hash__(self) -> int:
"""
return hash(self.conversation_id)

def __eq__(self, other: object) -> bool:
"""
Compare two references by conversation ID.

Args:
other (object): Other object to compare.

Returns:
bool: True when the other object is a matching ConversationReference.

"""
return isinstance(other, ConversationReference) and self.conversation_id == other.conversation_id

def to_dict(self) -> dict[str, str | None]:
"""
Serialize to a JSON-compatible dictionary.

.. deprecated::
Use :meth:`model_dump` with ``mode="json"`` instead. This method
will be removed in version 0.16.0.

Returns:
dict[str, str | None]: Dictionary with conversation_id, conversation_type, and description.
"""
return {
"conversation_id": self.conversation_id,
"conversation_type": self.conversation_type.value,
"description": self.description,
}
print_deprecation_message(
old_item=ConversationReference.to_dict,
new_item='ConversationReference.model_dump(mode="json")',
removed_in="0.16.0",
)
return self.model_dump(mode="json")

@classmethod
def from_dict(cls, data: dict[str, str | None]) -> ConversationReference:
"""
Reconstruct a ConversationReference from a dictionary.

.. deprecated::
Use :meth:`model_validate` instead. This method will be removed
in version 0.16.0.

Args:
data (dict[str, str | None]): Dictionary as produced by to_dict().
data (dict[str, str | None]): Dictionary as produced by ``model_dump(mode="json")``.

Returns:
ConversationReference: Reconstructed instance.
"""
return cls(
conversation_id=str(data["conversation_id"]),
conversation_type=ConversationType(data["conversation_type"]),
description=data.get("description"),
print_deprecation_message(
old_item=ConversationReference.from_dict,
new_item="ConversationReference.model_validate",
removed_in="0.16.0",
)

def __eq__(self, other: object) -> bool:
"""
Compare two references by conversation ID.

Args:
other (object): Other object to compare.

Returns:
bool: True when the other object is a matching ConversationReference.

"""
return isinstance(other, ConversationReference) and self.conversation_id == other.conversation_id
return cls.model_validate(data)
16 changes: 7 additions & 9 deletions pyrit/models/conversation_stats.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations
from datetime import datetime
from typing import ClassVar, Optional

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, ClassVar, Optional
from pydantic import BaseModel, ConfigDict, Field

if TYPE_CHECKING:
from datetime import datetime


@dataclass(frozen=True)
class ConversationStats:
class ConversationStats(BaseModel):
"""
Lightweight aggregate statistics for a conversation.

Used to build attack summaries without loading full message pieces.
"""

model_config = ConfigDict(frozen=True)

PREVIEW_MAX_LEN: ClassVar[int] = 100

message_count: int = 0
last_message_preview: Optional[str] = None
labels: dict[str, str] = field(default_factory=dict)
labels: dict[str, str] = Field(default_factory=dict)
created_at: Optional[datetime] = None
16 changes: 5 additions & 11 deletions pyrit/models/harm_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,18 @@

import logging
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union

import yaml
from pydantic import BaseModel, Field

from pyrit.common.path import HARM_DEFINITION_PATH

logger = logging.getLogger(__name__)


@dataclass
class ScaleDescription:
class ScaleDescription(BaseModel):
"""
A single scale description entry from a harm definition.

Expand All @@ -35,8 +34,7 @@ class ScaleDescription:
description: str


@dataclass
class HarmDefinition:
class HarmDefinition(BaseModel):
"""
A harm definition loaded from a YAML file.

Expand All @@ -54,8 +52,8 @@ class HarmDefinition:

version: str
category: str
scale_descriptions: list[ScaleDescription] = field(default_factory=list)
source_path: Optional[str] = field(default=None, kw_only=True)
scale_descriptions: list[ScaleDescription] = Field(default_factory=list)
source_path: Optional[str] = None

def get_scale_description(self, score_value: str) -> Optional[str]:
"""
Expand Down Expand Up @@ -92,7 +90,6 @@ def validate_category(category: str, *, check_exists: bool = False) -> bool:
False otherwise.

"""
# Check if category matches pattern: only lowercase letters and underscores
if not re.match(r"^[a-z_]+$", category):
return False

Expand Down Expand Up @@ -127,7 +124,6 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition":
"""
path = Path(harm_definition_path)

# If it's just a filename (no directory separators), look in the standard directory
resolved_path = HARM_DEFINITION_PATH / path if path.parent == Path(".") else path

if not resolved_path.exists():
Expand All @@ -145,15 +141,13 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition":
if not isinstance(data, dict):
raise ValueError(f"Harm definition file {resolved_path} must contain a YAML mapping/dictionary.")

# Validate required fields
if "version" not in data:
raise ValueError(f"Harm definition file {resolved_path} is missing required 'version' field.")
if "category" not in data:
raise ValueError(f"Harm definition file {resolved_path} is missing required 'category' field.")
if "scale_descriptions" not in data:
raise ValueError(f"Harm definition file {resolved_path} is missing required 'scale_descriptions' field.")

# Parse scale descriptions
scale_descriptions = []
for item in data["scale_descriptions"]:
if not isinstance(item, dict) or "score_value" not in item or "description" not in item:
Expand Down
12 changes: 7 additions & 5 deletions pyrit/models/json_response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from typing import Any, Optional

from pydantic import BaseModel, ConfigDict

# Would prefer StrEnum, but.... Python 3.10
_METADATAKEYS = {
"RESPONSE_FORMAT": "response_format",
Expand All @@ -16,8 +17,7 @@
}


@dataclass
class _JsonResponseConfig:
class _JsonResponseConfig(BaseModel):
Comment thread
rlundeen2 marked this conversation as resolved.
"""
Configuration for JSON responses (with OpenAI).

Expand All @@ -27,8 +27,10 @@ class _JsonResponseConfig:
https://platform.openai.com/docs/api-reference/responses/create#responses_create-text
"""

model_config = ConfigDict(extra="forbid")

enabled: bool = False
schema: Optional[dict[str, Any]] = None
json_schema: Optional[dict[str, Any]] = None
schema_name: str = "CustomSchema"
strict: bool = True

Expand All @@ -53,7 +55,7 @@ def from_metadata(cls, *, metadata: Optional[dict[str, Any]]) -> _JsonResponseCo

return cls(
enabled=True,
schema=schema,
json_schema=schema,
schema_name=metadata.get(_METADATAKEYS["JSON_SCHEMA_NAME"], "CustomSchema"),
strict=metadata.get(_METADATAKEYS["JSON_SCHEMA_STRICT"], True),
)
Expand Down
Loading
Loading