diff --git a/pyrit/memory/memory_models.py b/pyrit/memory/memory_models.py index 2afeb4533..c4ba80f40 100644 --- a/pyrit/memory/memory_models.py +++ b/pyrit/memory/memory_models.py @@ -824,7 +824,7 @@ def __init__(self, *, entry: AttackResult) -> None: # Retry events self.retry_events_json = ( - json.dumps([evt.to_dict() for evt in entry.retry_events]) if entry.retry_events else None + json.dumps([evt.model_dump(mode="json") for evt in entry.retry_events]) if entry.retry_events else None ) self.total_retries = entry.total_retries @@ -923,7 +923,7 @@ def get_attack_result(self) -> AttackResult: if self.retry_events_json: from pyrit.models.retry_event import RetryEvent - retry_events = [RetryEvent.from_dict(evt_dict) for evt_dict in json.loads(self.retry_events_json)] + retry_events = [RetryEvent.model_validate(evt_dict) for evt_dict in json.loads(self.retry_events_json)] return AttackResult( conversation_id=self.conversation_id, diff --git a/pyrit/models/attack_result.py b/pyrit/models/attack_result.py index 02fad197d..8fd8520c1 100644 --- a/pyrit/models/attack_result.py +++ b/pyrit/models/attack_result.py @@ -250,7 +250,7 @@ def to_dict(self) -> dict[str, Any]: "outcome_reason": self.outcome_reason, "timestamp": self.timestamp.isoformat(), "related_conversations": sorted( - [ref.to_dict() for ref in self.related_conversations], + [ref.model_dump(mode="json") for ref in self.related_conversations], key=lambda r: r["conversation_id"], ), "metadata": self.metadata, @@ -258,7 +258,7 @@ def to_dict(self) -> dict[str, Any]: "error_message": self.error_message, "error_type": self.error_type, "error_traceback": self.error_traceback, - "retry_events": [e.to_dict() for e in self.retry_events], + "retry_events": [e.model_dump(mode="json") for e in self.retry_events], "total_retries": self.total_retries, } @@ -291,13 +291,15 @@ def from_dict(cls, data: dict[str, Any]) -> AttackResult: timestamp=( datetime.fromisoformat(data["timestamp"]) if data.get("timestamp") else datetime.now(timezone.utc) ), - related_conversations={ConversationReference.from_dict(r) for r in data.get("related_conversations", [])}, + related_conversations={ + ConversationReference.model_validate(r) for r in data.get("related_conversations", []) + }, metadata=data.get("metadata", {}), labels=data.get("labels", {}), error_message=data.get("error_message"), error_type=data.get("error_type"), error_traceback=data.get("error_traceback"), - retry_events=[RetryEvent.from_dict(e) for e in data.get("retry_events", [])], + retry_events=[RetryEvent.model_validate(e) for e in data.get("retry_events", [])], total_retries=data.get("total_retries", 0), ) diff --git a/pyrit/models/conversation_reference.py b/pyrit/models/conversation_reference.py index 95c7b9d5e..33d5e2d88 100644 --- a/pyrit/models/conversation_reference.py +++ b/pyrit/models/conversation_reference.py @@ -3,10 +3,13 @@ from __future__ import annotations -from dataclasses import dataclass from enum import Enum from typing import Optional +from pydantic import BaseModel, ConfigDict + +from pyrit.common.deprecation import print_deprecation_message + class ConversationType(Enum): """Types of conversations that can be associated with an attack.""" @@ -17,15 +20,15 @@ class ConversationType(Enum): CONVERTER = "converter" -@dataclass(frozen=True) -class ConversationReference: +class ConversationReference(BaseModel): """Immutable reference to a conversation that played a role in the attack.""" + model_config = ConfigDict(frozen=True) + conversation_id: str conversation_type: ConversationType description: Optional[str] = None - # Allow use in set / dict def __hash__(self) -> int: """ Return a hash derived from conversation ID. @@ -36,45 +39,55 @@ def __hash__(self) -> int: """ return hash(self.conversation_id) + def __eq__(self, other: object) -> bool: + """ + Compare two references by conversation ID. + + Args: + other (object): Other object to compare. + + Returns: + bool: True when the other object is a matching ConversationReference. + + """ + return isinstance(other, ConversationReference) and self.conversation_id == other.conversation_id + def to_dict(self) -> dict[str, str | None]: """ Serialize to a JSON-compatible dictionary. + .. deprecated:: + Use :meth:`model_dump` with ``mode="json"`` instead. This method + will be removed in version 0.16.0. + Returns: dict[str, str | None]: Dictionary with conversation_id, conversation_type, and description. """ - return { - "conversation_id": self.conversation_id, - "conversation_type": self.conversation_type.value, - "description": self.description, - } + print_deprecation_message( + old_item=ConversationReference.to_dict, + new_item='ConversationReference.model_dump(mode="json")', + removed_in="0.16.0", + ) + return self.model_dump(mode="json") @classmethod def from_dict(cls, data: dict[str, str | None]) -> ConversationReference: """ Reconstruct a ConversationReference from a dictionary. + .. deprecated:: + Use :meth:`model_validate` instead. This method will be removed + in version 0.16.0. + Args: - data (dict[str, str | None]): Dictionary as produced by to_dict(). + data (dict[str, str | None]): Dictionary as produced by ``model_dump(mode="json")``. Returns: ConversationReference: Reconstructed instance. """ - return cls( - conversation_id=str(data["conversation_id"]), - conversation_type=ConversationType(data["conversation_type"]), - description=data.get("description"), + print_deprecation_message( + old_item=ConversationReference.from_dict, + new_item="ConversationReference.model_validate", + removed_in="0.16.0", ) - - def __eq__(self, other: object) -> bool: - """ - Compare two references by conversation ID. - - Args: - other (object): Other object to compare. - - Returns: - bool: True when the other object is a matching ConversationReference. - - """ - return isinstance(other, ConversationReference) and self.conversation_id == other.conversation_id + return cls.model_validate(data) diff --git a/pyrit/models/conversation_stats.py b/pyrit/models/conversation_stats.py index bb8283fcc..22dcefd6e 100644 --- a/pyrit/models/conversation_stats.py +++ b/pyrit/models/conversation_stats.py @@ -1,26 +1,24 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from __future__ import annotations +from datetime import datetime +from typing import ClassVar, Optional -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, ClassVar, Optional +from pydantic import BaseModel, ConfigDict, Field -if TYPE_CHECKING: - from datetime import datetime - -@dataclass(frozen=True) -class ConversationStats: +class ConversationStats(BaseModel): """ Lightweight aggregate statistics for a conversation. Used to build attack summaries without loading full message pieces. """ + model_config = ConfigDict(frozen=True) + PREVIEW_MAX_LEN: ClassVar[int] = 100 message_count: int = 0 last_message_preview: Optional[str] = None - labels: dict[str, str] = field(default_factory=dict) + labels: dict[str, str] = Field(default_factory=dict) created_at: Optional[datetime] = None diff --git a/pyrit/models/harm_definition.py b/pyrit/models/harm_definition.py index d8c836e97..9e739244a 100644 --- a/pyrit/models/harm_definition.py +++ b/pyrit/models/harm_definition.py @@ -9,19 +9,18 @@ import logging import re -from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Union import yaml +from pydantic import BaseModel, Field from pyrit.common.path import HARM_DEFINITION_PATH logger = logging.getLogger(__name__) -@dataclass -class ScaleDescription: +class ScaleDescription(BaseModel): """ A single scale description entry from a harm definition. @@ -35,8 +34,7 @@ class ScaleDescription: description: str -@dataclass -class HarmDefinition: +class HarmDefinition(BaseModel): """ A harm definition loaded from a YAML file. @@ -54,8 +52,8 @@ class HarmDefinition: version: str category: str - scale_descriptions: list[ScaleDescription] = field(default_factory=list) - source_path: Optional[str] = field(default=None, kw_only=True) + scale_descriptions: list[ScaleDescription] = Field(default_factory=list) + source_path: Optional[str] = None def get_scale_description(self, score_value: str) -> Optional[str]: """ @@ -92,7 +90,6 @@ def validate_category(category: str, *, check_exists: bool = False) -> bool: False otherwise. """ - # Check if category matches pattern: only lowercase letters and underscores if not re.match(r"^[a-z_]+$", category): return False @@ -127,7 +124,6 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": """ path = Path(harm_definition_path) - # If it's just a filename (no directory separators), look in the standard directory resolved_path = HARM_DEFINITION_PATH / path if path.parent == Path(".") else path if not resolved_path.exists(): @@ -145,7 +141,6 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": if not isinstance(data, dict): raise ValueError(f"Harm definition file {resolved_path} must contain a YAML mapping/dictionary.") - # Validate required fields if "version" not in data: raise ValueError(f"Harm definition file {resolved_path} is missing required 'version' field.") if "category" not in data: @@ -153,7 +148,6 @@ def from_yaml(cls, harm_definition_path: Union[str, Path]) -> "HarmDefinition": if "scale_descriptions" not in data: raise ValueError(f"Harm definition file {resolved_path} is missing required 'scale_descriptions' field.") - # Parse scale descriptions scale_descriptions = [] for item in data["scale_descriptions"]: if not isinstance(item, dict) or "score_value" not in item or "description" not in item: diff --git a/pyrit/models/json_response_config.py b/pyrit/models/json_response_config.py index 8485467cf..8c4c1b986 100644 --- a/pyrit/models/json_response_config.py +++ b/pyrit/models/json_response_config.py @@ -4,9 +4,10 @@ from __future__ import annotations import json -from dataclasses import dataclass from typing import Any, Optional +from pydantic import BaseModel, ConfigDict + # Would prefer StrEnum, but.... Python 3.10 _METADATAKEYS = { "RESPONSE_FORMAT": "response_format", @@ -16,8 +17,7 @@ } -@dataclass -class _JsonResponseConfig: +class _JsonResponseConfig(BaseModel): """ Configuration for JSON responses (with OpenAI). @@ -27,8 +27,10 @@ class _JsonResponseConfig: https://platform.openai.com/docs/api-reference/responses/create#responses_create-text """ + model_config = ConfigDict(extra="forbid") + enabled: bool = False - schema: Optional[dict[str, Any]] = None + json_schema: Optional[dict[str, Any]] = None schema_name: str = "CustomSchema" strict: bool = True @@ -53,7 +55,7 @@ def from_metadata(cls, *, metadata: Optional[dict[str, Any]]) -> _JsonResponseCo return cls( enabled=True, - schema=schema, + json_schema=schema, schema_name=metadata.get(_METADATAKEYS["JSON_SCHEMA_NAME"], "CustomSchema"), strict=metadata.get(_METADATAKEYS["JSON_SCHEMA_STRICT"], True), ) diff --git a/pyrit/models/retry_event.py b/pyrit/models/retry_event.py index 030bba350..79bb2bbb6 100644 --- a/pyrit/models/retry_event.py +++ b/pyrit/models/retry_event.py @@ -3,12 +3,17 @@ """Data model for capturing individual retry events during execution.""" -from dataclasses import dataclass, field +from __future__ import annotations + from datetime import datetime, timezone +from typing import Optional + +from pydantic import BaseModel, Field + +from pyrit.common.deprecation import print_deprecation_message -@dataclass -class RetryEvent: +class RetryEvent(BaseModel): """ A single retry attempt captured during attack execution. @@ -18,54 +23,52 @@ class RetryEvent: attached to AttackResult objects for persistence and REST API exposure. """ - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) attempt_number: int = 0 function_name: str = "" exception_type: str = "" exception_message: str = "" component_role: str = "" - component_name: str | None = None - endpoint: str | None = None + component_name: Optional[str] = None + endpoint: Optional[str] = None elapsed_seconds: float = 0.0 def to_dict(self) -> dict: """ Serialize to a dictionary suitable for JSON storage. + .. deprecated:: + Use :meth:`model_dump` with ``mode="json"`` instead. This method + will be removed in version 0.16.0. + Returns: dict: Dictionary representation of the retry event. """ - return { - "timestamp": self.timestamp.isoformat(), - "attempt_number": self.attempt_number, - "function_name": self.function_name, - "exception_type": self.exception_type, - "exception_message": self.exception_message, - "component_role": self.component_role, - "component_name": self.component_name, - "endpoint": self.endpoint, - "elapsed_seconds": self.elapsed_seconds, - } + print_deprecation_message( + old_item=RetryEvent.to_dict, + new_item='RetryEvent.model_dump(mode="json")', + removed_in="0.16.0", + ) + return self.model_dump(mode="json") @classmethod - def from_dict(cls, data: dict) -> "RetryEvent": + def from_dict(cls, data: dict) -> RetryEvent: """ Deserialize from a dictionary. + .. deprecated:: + Use :meth:`model_validate` instead. This method will be removed + in version 0.16.0. + Args: data: Dictionary representation of a retry event. Returns: RetryEvent: Deserialized retry event. """ - return cls( - timestamp=datetime.fromisoformat(data["timestamp"]), - attempt_number=data.get("attempt_number", 0), - function_name=data.get("function_name", ""), - exception_type=data.get("exception_type", ""), - exception_message=data.get("exception_message", ""), - component_role=data.get("component_role", ""), - component_name=data.get("component_name"), - endpoint=data.get("endpoint"), - elapsed_seconds=data.get("elapsed_seconds", 0.0), + print_deprecation_message( + old_item=RetryEvent.from_dict, + new_item="RetryEvent.model_validate", + removed_in="0.16.0", ) + return cls.model_validate(data) diff --git a/pyrit/prompt_target/openai/openai_chat_target.py b/pyrit/prompt_target/openai/openai_chat_target.py index bb5defde9..d0e4b1180 100644 --- a/pyrit/prompt_target/openai/openai_chat_target.py +++ b/pyrit/prompt_target/openai/openai_chat_target.py @@ -681,12 +681,12 @@ def _build_response_format(self, json_config: _JsonResponseConfig) -> Optional[d if not json_config.enabled: return None - if json_config.schema: + if json_config.json_schema: return { "type": "json_schema", "json_schema": { "name": json_config.schema_name, - "schema": json_config.schema, + "schema": json_config.json_schema, "strict": json_config.strict, }, } diff --git a/pyrit/prompt_target/openai/openai_response_target.py b/pyrit/prompt_target/openai/openai_response_target.py index 48b4d7ade..f2e4b19a7 100644 --- a/pyrit/prompt_target/openai/openai_response_target.py +++ b/pyrit/prompt_target/openai/openai_response_target.py @@ -418,12 +418,12 @@ def _build_text_format(self, json_config: _JsonResponseConfig) -> Optional[dict[ if not json_config.enabled: return None - if json_config.schema: + if json_config.json_schema: return { "format": { "type": "json_schema", "name": json_config.schema_name, - "schema": json_config.schema, + "schema": json_config.json_schema, "strict": json_config.strict, } } diff --git a/tests/unit/memory/memory_interface/test_interface_attack_results.py b/tests/unit/memory/memory_interface/test_interface_attack_results.py index d5729cf11..ca96b873c 100644 --- a/tests/unit/memory/memory_interface/test_interface_attack_results.py +++ b/tests/unit/memory/memory_interface/test_interface_attack_results.py @@ -557,8 +557,11 @@ def test_attack_result_with_attack_generation_conversation_ids(sqlite_instance: adversarial_ids = {"adv_conv_1", "adv_conv_2", "adv_conv_3"} related_conversations: set[ConversationReference] = { - *(ConversationReference(cid, ConversationType.PRUNED) for cid in pruned_ids), - *(ConversationReference(cid, ConversationType.ADVERSARIAL) for cid in adversarial_ids), + *(ConversationReference(conversation_id=cid, conversation_type=ConversationType.PRUNED) for cid in pruned_ids), + *( + ConversationReference(conversation_id=cid, conversation_type=ConversationType.ADVERSARIAL) + for cid in adversarial_ids + ), } attack_result = AttackResult( diff --git a/tests/unit/models/test_conversation_reference.py b/tests/unit/models/test_conversation_reference.py index 2f7a559ad..de6263cd9 100644 --- a/tests/unit/models/test_conversation_reference.py +++ b/tests/unit/models/test_conversation_reference.py @@ -2,6 +2,7 @@ # Licensed under the MIT license. import pytest +from pydantic import ValidationError from pyrit.models.conversation_reference import ConversationReference, ConversationType @@ -31,7 +32,7 @@ def test_conversation_reference_with_description(): def test_conversation_reference_is_frozen(): ref = ConversationReference(conversation_id="abc", conversation_type=ConversationType.SCORE) - with pytest.raises(AttributeError): + with pytest.raises(ValidationError): ref.conversation_id = "new_id" @@ -78,11 +79,25 @@ def test_conversation_reference_usable_as_dict_key(): assert d[lookup_ref] == "value" -def test_to_dict_from_dict_roundtrip(): +def test_model_dump_validate_roundtrip(): original = ConversationReference( conversation_id="conv-123", conversation_type=ConversationType.ADVERSARIAL, description="main adversarial conversation", ) - roundtripped = ConversationReference.from_dict(original.to_dict()) - assert original.to_dict() == roundtripped.to_dict() + payload = original.model_dump(mode="json") + roundtripped = ConversationReference.model_validate(payload) + assert original.model_dump(mode="json") == roundtripped.model_dump(mode="json") + + +def test_to_dict_from_dict_deprecated_wrappers_still_work(): + original = ConversationReference( + conversation_id="conv-123", + conversation_type=ConversationType.ADVERSARIAL, + description="main adversarial conversation", + ) + with pytest.warns(DeprecationWarning): + payload = original.to_dict() + with pytest.warns(DeprecationWarning): + roundtripped = ConversationReference.from_dict(payload) + assert original.model_dump(mode="json") == roundtripped.model_dump(mode="json") diff --git a/tests/unit/models/test_conversation_stats.py b/tests/unit/models/test_conversation_stats.py index adeabb174..44a58a82f 100644 --- a/tests/unit/models/test_conversation_stats.py +++ b/tests/unit/models/test_conversation_stats.py @@ -4,6 +4,7 @@ from datetime import datetime, timezone import pytest +from pydantic import ValidationError from pyrit.models.conversation_stats import ConversationStats @@ -32,7 +33,7 @@ def test_conversation_stats_with_values(): def test_conversation_stats_is_frozen(): stats = ConversationStats(message_count=3) - with pytest.raises(AttributeError): + with pytest.raises(ValidationError): stats.message_count = 10 diff --git a/tests/unit/models/test_json_response_config.py b/tests/unit/models/test_json_response_config.py index e555a0db4..0134db033 100644 --- a/tests/unit/models/test_json_response_config.py +++ b/tests/unit/models/test_json_response_config.py @@ -11,7 +11,7 @@ def test_with_none(): config = _JsonResponseConfig.from_metadata(metadata=None) assert config.enabled is False - assert config.schema is None + assert config.json_schema is None assert config.schema_name == "CustomSchema" assert config.strict is True @@ -22,7 +22,7 @@ def test_with_json_object(): } config = _JsonResponseConfig.from_metadata(metadata=metadata) assert config.enabled is True - assert config.schema is None + assert config.json_schema is None assert config.schema_name == "CustomSchema" assert config.strict is True @@ -37,7 +37,7 @@ def test_with_json_string_schema(): } config = _JsonResponseConfig.from_metadata(metadata=metadata) assert config.enabled is True - assert config.schema == schema + assert config.json_schema == schema assert config.schema_name == "TestSchema" assert config.strict is False @@ -50,7 +50,7 @@ def test_with_json_schema_object(): } config = _JsonResponseConfig.from_metadata(metadata=metadata) assert config.enabled is True - assert config.schema == schema + assert config.json_schema == schema assert config.schema_name == "CustomSchema" assert config.strict is True @@ -62,7 +62,7 @@ def test_with_empty_json_schema_object(): } config = _JsonResponseConfig.from_metadata(metadata=metadata) assert config.enabled is True - assert config.schema == {} + assert config.json_schema == {} assert config.schema_name == "CustomSchema" assert config.strict is True @@ -83,6 +83,6 @@ def test_other_response_format(): } config = _JsonResponseConfig.from_metadata(metadata=metadata) assert config.enabled is False - assert config.schema is None + assert config.json_schema is None assert config.schema_name == "CustomSchema" assert config.strict is True diff --git a/tests/unit/models/test_retry_event.py b/tests/unit/models/test_retry_event.py index 09f20dfca..af740d3cc 100644 --- a/tests/unit/models/test_retry_event.py +++ b/tests/unit/models/test_retry_event.py @@ -3,11 +3,13 @@ from datetime import datetime, timezone +import pytest + from pyrit.models.retry_event import RetryEvent class TestRetryEvent: - """Tests for the RetryEvent dataclass.""" + """Tests for the RetryEvent model.""" def test_defaults(self) -> None: """RetryEvent constructed with minimal args gets correct defaults.""" @@ -48,7 +50,7 @@ def test_full_construction(self) -> None: assert evt.timestamp == ts def test_to_dict(self) -> None: - """to_dict returns a JSON-serializable dictionary.""" + """model_dump(mode="json") returns a JSON-serializable dictionary.""" evt = RetryEvent( attempt_number=2, function_name="fn", @@ -59,7 +61,7 @@ def test_to_dict(self) -> None: endpoint="https://example.com", elapsed_seconds=1.5, ) - d = evt.to_dict() + d = evt.model_dump(mode="json") assert d["attempt_number"] == 2 assert d["function_name"] == "fn" assert d["exception_type"] == "ValueError" @@ -71,7 +73,7 @@ def test_to_dict(self) -> None: assert "timestamp" in d def test_from_dict_roundtrip(self) -> None: - """from_dict correctly reconstructs a RetryEvent from to_dict output.""" + """model_validate correctly reconstructs a RetryEvent from model_dump output.""" original = RetryEvent( attempt_number=1, function_name="call_target", @@ -82,8 +84,8 @@ def test_from_dict_roundtrip(self) -> None: endpoint="https://azure.openai.com", elapsed_seconds=10.0, ) - d = original.to_dict() - restored = RetryEvent.from_dict(d) + d = original.model_dump(mode="json") + restored = RetryEvent.model_validate(d) assert restored.attempt_number == original.attempt_number assert restored.function_name == original.function_name @@ -95,13 +97,13 @@ def test_from_dict_roundtrip(self) -> None: assert restored.elapsed_seconds == original.elapsed_seconds def test_from_dict_missing_optional_fields(self) -> None: - """from_dict handles missing optional fields gracefully.""" + """model_validate handles missing optional fields gracefully.""" d = { "attempt_number": 1, "function_name": "fn", "timestamp": "2026-05-07T12:00:00+00:00", } - evt = RetryEvent.from_dict(d) + evt = RetryEvent.model_validate(d) assert evt.attempt_number == 1 assert evt.function_name == "fn" assert evt.exception_type == "" @@ -110,14 +112,25 @@ def test_from_dict_missing_optional_fields(self) -> None: assert evt.elapsed_seconds == 0.0 def test_from_dict_timestamp_parsing(self) -> None: - """from_dict correctly parses ISO format timestamp.""" + """model_validate correctly parses ISO format timestamp.""" d = { "attempt_number": 1, "function_name": "fn", "timestamp": "2026-05-07T12:30:00+00:00", } - evt = RetryEvent.from_dict(d) + evt = RetryEvent.model_validate(d) assert evt.timestamp.year == 2026 assert evt.timestamp.month == 5 assert evt.timestamp.hour == 12 assert evt.timestamp.minute == 30 + + def test_to_dict_from_dict_deprecated_wrappers_still_work(self) -> None: + """Deprecated to_dict / from_dict wrappers still round-trip correctly.""" + original = RetryEvent(attempt_number=4, function_name="fn", exception_type="Boom") + with pytest.warns(DeprecationWarning): + payload = original.to_dict() + with pytest.warns(DeprecationWarning): + restored = RetryEvent.from_dict(payload) + assert restored.attempt_number == 4 + assert restored.function_name == "fn" + assert restored.exception_type == "Boom"