diff --git a/assets/photosynthesis_question.wav b/assets/photosynthesis_question.wav new file mode 100644 index 0000000000..44f9e54141 Binary files /dev/null and b/assets/photosynthesis_question.wav differ diff --git a/doc/code/executor/attack/barge_in_attack.ipynb b/doc/code/executor/attack/barge_in_attack.ipynb new file mode 100644 index 0000000000..f9a92ba9dc --- /dev/null +++ b/doc/code/executor/attack/barge_in_attack.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Barge-In Attack (Streaming Audio)\n", + "\n", + "`BargeInAttack` streams user audio to a `RealtimeTarget` and uses server-side voice-activity\n", + "detection (VAD) to detect turn boundaries. When the user speaks while the assistant is still\n", + "responding, server VAD cancels the in-flight response (barge-in). Interrupted turns are\n", + "persisted with `prompt_metadata[\"interrupted\"] = True`.\n", + "\n", + "Audio converters are applied per turn after VAD commits. The raw audio drives interruption\n", + "timing while the model responds to the converted version.\n", + "\n", + "> **Note:** Memory must be initialized via `initialize_pyrit_async`. See the\n", + "> [Memory Configuration Guide](../../memory/0_memory.md)." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "`BargeInAttack` requires a `RealtimeTarget` with `server_vad=True` (or a `ServerVadConfig`\n", + "for custom tuning)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['./.pyrit/.env']\n", + "Loaded environment file: ./.pyrit/.env\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No new upgrade operations detected.\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import wave\n", + "from pathlib import Path\n", + "\n", + "from pyrit.executor.attack import (\n", + " AttackConverterConfig,\n", + " BargeInAttack,\n", + " BargeInAttackContext,\n", + " ConsoleAttackResultPrinter,\n", + ")\n", + "from pyrit.executor.attack.core import AttackParameters\n", + "from pyrit.memory import CentralMemory\n", + "from pyrit.prompt_converter import AudioFrequencyConverter\n", + "from pyrit.prompt_normalizer import PromptConverterConfiguration\n", + "from pyrit.prompt_target import RealtimeTarget\n", + "from pyrit.setup import IN_MEMORY, initialize_pyrit_async\n", + "\n", + "await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Shared setup\n", + "\n", + "Both sections use a pre-recorded 24 kHz mono PCM16 question about photosynthesis. The\n", + "format matches what the OpenAI Realtime API expects. Any async generator yielding 24 kHz\n", + "PCM16 bytes works as a chunk source (live mic, TTS, etc.)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded question: 3.94s @ 24 kHz\n" + ] + } + ], + "source": [ + "CHUNK_MS = 100\n", + "CHUNK_SIZE = CHUNK_MS * 48 # PCM16 @ 24 kHz mono = 48 bytes per millisecond.\n", + "SILENCE_CHUNK = b\"\\x00\" * CHUNK_SIZE\n", + "audio_path = Path(\"../../../../assets/photosynthesis_question.wav\").resolve()\n", + "\n", + "\n", + "def _load_pcm(path: Path) -> bytes:\n", + " \"\"\"Read a WAV at 24 kHz / mono / PCM16 into raw PCM bytes.\"\"\"\n", + " with wave.open(str(path), \"rb\") as wav:\n", + " assert wav.getframerate() == 24000 and wav.getnchannels() == 1 and wav.getsampwidth() == 2\n", + " return wav.readframes(wav.getnframes())\n", + "\n", + "\n", + "async def _yield_chunks(pcm: bytes, real_time: bool = True):\n", + " \"\"\"Yield PCM in 100ms slices, optionally pacing at real-time.\"\"\"\n", + " for offset in range(0, len(pcm), CHUNK_SIZE):\n", + " yield pcm[offset : offset + CHUNK_SIZE]\n", + " if real_time:\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "question_pcm_24k = _load_pcm(audio_path)\n", + "print(f\"Loaded question: {len(question_pcm_24k) / 48 / 1000:.2f}s @ 24 kHz\")\n", + "\n", + "converters = PromptConverterConfiguration.from_converters(converters=[AudioFrequencyConverter(shift_value=200)])" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## Section 1: Single-turn streaming with a converter\n", + "\n", + "Streams one user statement, applies a frequency-shift converter after VAD commits the turn,\n", + "and gets the model's response. Exercises the full pipeline (chunk push, convert-on-commit,\n", + "item swap, response trigger, memory persistence) without barge-in." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "executed_turns: 1\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294332341158.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Sure! Photosynthesis is the process plants use to convert light energy into chemical energy, which they store as sugars. It mainly takes place in the chloroplasts of leaf cells. Here's how it works:\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 1. Light absorption: Chlorophyll, the green pigment, captures sunlight. This energy excites electrons within the chlorophyll.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 2. Water splitting: The plant takes in water (H₂O) from the roots and transfers it to the leaves. The light energy splits the water molecules into oxygen, protons, and electrons. The oxygen is\u001b[0m\n", + "\u001b[33m released as a byproduct.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 3. Conversion of energy: The excited electrons move through a chain of proteins, creating ATP and NADPH, which are energy carriers.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 4. Carbon fixation: Using that stored energy, the plant takes in carbon dioxide (CO₂) from the air. Through the Calvin cycle, it combines the CO₂ with the energy carriers to form glucose.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m This glucose feeds the plant and can be stored as starch. In essence, photosynthesis fuels plant growth and provides oxygen for us.\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294332344158.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" + ] + } + ], + "source": [ + "async def single_turn_source():\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " # Trailing silence helps server VAD recognize end-of-turn.\n", + " for _ in range(25): # 2.5s trailing silence, above the 1.5s VAD threshold\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "target = RealtimeTarget(server_vad=True)\n", + "attack = BargeInAttack(\n", + " objective_target=target,\n", + " attack_converter_config=AttackConverterConfig(request_converters=converters),\n", + ")\n", + "\n", + "context = BargeInAttackContext(\n", + " params=AttackParameters(objective=\"Observe a single converted user turn end-to-end\"),\n", + " audio_chunks=single_turn_source(),\n", + ")\n", + "\n", + "result = await attack.execute_with_context_async(context=context) # type: ignore\n", + "print(f\"executed_turns: {result.executed_turns}\")\n", + "await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=result) # type: ignore\n", + "await target.cleanup_target() # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Section 2: Barge-in (interrupting the assistant mid-response)\n", + "\n", + "Plays the question twice with timing arranged so turn 2's speech arrives during turn 1's\n", + "response. Server VAD detects the new speech, cancels turn 1's response, and resolves it\n", + "with `interrupted=True`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "executed_turns: 2\n", + "\n", + "Persisted pieces (4 messages):\n", + " user audio_path: 1779294342848770.mp3\n", + " assistant text [INTERRUPTED]: Sure! Photosynthesis is the process plants use to convert light energy into chem...\n", + " assistant audio_path [INTERRUPTED]: 1779294342850774.mp3\n", + " user audio_path: 1779294366566679.mp3\n", + " assistant text: Absolutely! Let’s break it down step by step.\n", + "\n", + "1. **Where it happens**: Photosyn...\n", + " assistant audio_path: 1779294366569687.mp3\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294342848770.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Sure! Photosynthesis is the process plants use to convert light energy into chemical energy they can use as\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294342850774.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[34m🔹 Turn 2 - USER\u001b[0m\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[34m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294366566679.mp3\u001b[0m\n", + "\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n", + "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n", + "\u001b[33m Absolutely! Let’s break it down step by step.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 1. **Where it happens**: Photosynthesis takes place in chloroplasts, which are specialized structures inside plant cells. These contain chlorophyll, the green pigment that captures light energy from\u001b[0m\n", + "\u001b[33m the sun.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 2. **The raw materials**: Plants use carbon dioxide from the air (taken in through tiny pores called stomata) and water from the soil (absorbed through their roots).\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 3. **The light-dependent reactions**: Inside the chloroplasts, chlorophyll absorbs sunlight, which excites electrons. This energy splits water molecules into oxygen, protons, and electrons. Oxygen\u001b[0m\n", + "\u001b[33m is released as a byproduct (that’s the oxygen we breathe!). The electrons and protons help generate energy-rich molecules called ATP and NADPH.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 4. **The Calvin cycle (light-independent reactions)**: Using the ATP and NADPH, plants convert carbon dioxide into glucose through a series of enzyme-driven steps. Glucose is a simple sugar that\u001b[0m\n", + "\u001b[33m plants use to build more complex carbohydrates like starch and cellulose, fueling growth and development.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m 5. **Energy storage and use**: The glucose can be used immediately for energy, or it can be stored as starch. This stored energy supports the plant’s metabolism, growth, and reproduction.\u001b[0m\n", + "\u001b[33m \u001b[0m\n", + "\u001b[33m In short, plants take in sunlight, water, and carbon dioxide, and through photosynthesis they produce oxygen and energy-rich sugars that sustain both themselves and, ultimately, life on Earth.\u001b[0m\n", + "\u001b[33m ./repos/PyRIT-internal/PyRIT/dbdata/prompt-memory-entries/audio/1779294366569687.mp3\u001b[0m\n", + "\n", + "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n" + ] + } + ], + "source": [ + "TURN1_RESPONSE_WAIT_S = 0.2 # how long to let the model start speaking before barging in\n", + "\n", + "\n", + "async def barge_in_source():\n", + " # Turn 1: speak the question, then 1.5s of silence so VAD commits.\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " for _ in range(25): # 2.5s trailing silence\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + " # Let the model get partway into its response before we interrupt.\n", + " for _ in range(int(TURN1_RESPONSE_WAIT_S * 10)):\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + " # Turn 2: speak the question again. VAD's speech_started fires while turn 1's response\n", + " # is still streaming → server cancels + truncates turn 1.\n", + " async for chunk in _yield_chunks(question_pcm_24k):\n", + " yield chunk\n", + " for _ in range(25): # 2.5s trailing silence\n", + " yield SILENCE_CHUNK\n", + " await asyncio.sleep(CHUNK_MS / 1000)\n", + "\n", + "\n", + "target2 = RealtimeTarget(server_vad=True)\n", + "attack2 = BargeInAttack(\n", + " objective_target=target2,\n", + " attack_converter_config=AttackConverterConfig(request_converters=converters),\n", + ")\n", + "\n", + "barge_in_context = BargeInAttackContext(\n", + " params=AttackParameters(objective=\"Demonstrate barge-in by interrupting a benign answer\"),\n", + " audio_chunks=barge_in_source(),\n", + ")\n", + "\n", + "barge_in_result = await attack2.execute_with_context_async(context=barge_in_context) # type: ignore\n", + "print(f\"executed_turns: {barge_in_result.executed_turns}\")\n", + "\n", + "# Inspect memory to verify the barge-in landed in metadata.\n", + "memory = CentralMemory.get_memory_instance()\n", + "turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id)\n", + "print(f\"\\nPersisted pieces ({len(turns)} messages):\")\n", + "for message in turns:\n", + " for piece in message.message_pieces:\n", + " interrupted = piece.prompt_metadata.get(\"interrupted\")\n", + " marker = \" [INTERRUPTED]\" if interrupted else \"\"\n", + " val = piece.converted_value\n", + " if piece.converted_value_data_type == \"audio_path\":\n", + " val = Path(val).name\n", + " value_preview = (val[:80] + \"...\") if len(val) > 80 else val\n", + " print(f\" {piece._role} {piece.converted_value_data_type}{marker}: {value_preview}\")\n", + "\n", + "await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=barge_in_result) # type: ignore\n", + "await target2.cleanup_target() # type: ignore" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "### Reading the barge-in output\n", + "\n", + "If barge-in fired successfully:\n", + "- `executed_turns: 2` (two VAD-detected user turns)\n", + "- First assistant turn shows `[INTERRUPTED]` with a truncated transcript\n", + "- Second assistant turn completes normally\n", + "\n", + "If you don't see `[INTERRUPTED]`, decrease `TURN1_RESPONSE_WAIT_S` so turn 2's audio\n", + "arrives earlier in turn 1's response window." + ] + }, + { + "cell_type": "markdown", + "id": "10", + "metadata": {}, + "source": [ + "## Alternate chunk sources\n", + "\n", + "The chunk source is the main strategy hook:\n", + "\n", + "- **Pre-recorded WAV** (this notebook): most common starting point\n", + "- **TTS converter**: generate audio from text prompts dynamically\n", + "- **Live microphone**: use `sounddevice` or similar; yield what the mic produces\n", + "\n", + "For feedback-driven attacks — for example, scoring each assistant turn and choosing\n", + "to barge in with follow-up audio only when the response shows incomplete refusal —\n", + "subclass `BargeInAttack` and override `_perform_async` to interleave turn observation\n", + "with chunk generation." + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "main_language": "python" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/executor/attack/barge_in_attack.py b/doc/code/executor/attack/barge_in_attack.py new file mode 100644 index 0000000000..b3dff07b60 --- /dev/null +++ b/doc/code/executor/attack/barge_in_attack.py @@ -0,0 +1,208 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# --- + +# %% [markdown] +# # Barge-In Attack (Streaming Audio) +# +# `BargeInAttack` streams user audio to a `RealtimeTarget` and uses server-side voice-activity +# detection (VAD) to detect turn boundaries. When the user speaks while the assistant is still +# responding, server VAD cancels the in-flight response (barge-in). Interrupted turns are +# persisted with `prompt_metadata["interrupted"] = True`. +# +# Audio converters are applied per turn after VAD commits. The raw audio drives interruption +# timing while the model responds to the converted version. +# +# > **Note:** Memory must be initialized via `initialize_pyrit_async`. See the +# > [Memory Configuration Guide](../../memory/0_memory.md). + +# %% [markdown] +# ## Setup +# +# `BargeInAttack` requires a `RealtimeTarget` with `server_vad=True` (or a `ServerVadConfig` +# for custom tuning). + +# %% +import asyncio +import wave +from pathlib import Path + +from pyrit.executor.attack import ( + AttackConverterConfig, + BargeInAttack, + BargeInAttackContext, + ConsoleAttackResultPrinter, +) +from pyrit.executor.attack.core import AttackParameters +from pyrit.memory import CentralMemory +from pyrit.prompt_converter import AudioFrequencyConverter +from pyrit.prompt_normalizer import PromptConverterConfiguration +from pyrit.prompt_target import RealtimeTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + +await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore + +# %% [markdown] +# ## Shared setup +# +# Both sections use a pre-recorded 24 kHz mono PCM16 question about photosynthesis. The +# format matches what the OpenAI Realtime API expects. Any async generator yielding 24 kHz +# PCM16 bytes works as a chunk source (live mic, TTS, etc.). + +# %% +CHUNK_MS = 100 +CHUNK_SIZE = CHUNK_MS * 48 # PCM16 @ 24 kHz mono = 48 bytes per millisecond. +SILENCE_CHUNK = b"\x00" * CHUNK_SIZE +audio_path = Path("../../../../assets/photosynthesis_question.wav").resolve() + + +def _load_pcm(path: Path) -> bytes: + """Read a WAV at 24 kHz / mono / PCM16 into raw PCM bytes.""" + with wave.open(str(path), "rb") as wav: + assert wav.getframerate() == 24000 and wav.getnchannels() == 1 and wav.getsampwidth() == 2 + return wav.readframes(wav.getnframes()) + + +async def _yield_chunks(pcm: bytes, real_time: bool = True): + """Yield PCM in 100ms slices, optionally pacing at real-time.""" + for offset in range(0, len(pcm), CHUNK_SIZE): + yield pcm[offset : offset + CHUNK_SIZE] + if real_time: + await asyncio.sleep(CHUNK_MS / 1000) + + +question_pcm_24k = _load_pcm(audio_path) +print(f"Loaded question: {len(question_pcm_24k) / 48 / 1000:.2f}s @ 24 kHz") + +converters = PromptConverterConfiguration.from_converters(converters=[AudioFrequencyConverter(shift_value=200)]) + + +# %% [markdown] +# ## Section 1: Single-turn streaming with a converter +# +# Streams one user statement, applies a frequency-shift converter after VAD commits the turn, +# and gets the model's response. Exercises the full pipeline (chunk push, convert-on-commit, +# item swap, response trigger, memory persistence) without barge-in. + + +# %% +async def single_turn_source(): + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + # Trailing silence helps server VAD recognize end-of-turn. + for _ in range(25): # 2.5s trailing silence, above the 1.5s VAD threshold + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + +target = RealtimeTarget(server_vad=True) +attack = BargeInAttack( + objective_target=target, + attack_converter_config=AttackConverterConfig(request_converters=converters), +) + +context = BargeInAttackContext( + params=AttackParameters(objective="Observe a single converted user turn end-to-end"), + audio_chunks=single_turn_source(), +) + +result = await attack.execute_with_context_async(context=context) # type: ignore +print(f"executed_turns: {result.executed_turns}") +await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=result) # type: ignore +await target.cleanup_target() # type: ignore + +# %% [markdown] +# ## Section 2: Barge-in (interrupting the assistant mid-response) +# +# Plays the question twice with timing arranged so turn 2's speech arrives during turn 1's +# response. Server VAD detects the new speech, cancels turn 1's response, and resolves it +# with `interrupted=True`. + +# %% +TURN1_RESPONSE_WAIT_S = 0.2 # how long to let the model start speaking before barging in + + +async def barge_in_source(): + # Turn 1: speak the question, then 1.5s of silence so VAD commits. + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + for _ in range(25): # 2.5s trailing silence + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + # Let the model get partway into its response before we interrupt. + for _ in range(int(TURN1_RESPONSE_WAIT_S * 10)): + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + # Turn 2: speak the question again. VAD's speech_started fires while turn 1's response + # is still streaming → server cancels + truncates turn 1. + async for chunk in _yield_chunks(question_pcm_24k): + yield chunk + for _ in range(25): # 2.5s trailing silence + yield SILENCE_CHUNK + await asyncio.sleep(CHUNK_MS / 1000) + + +target2 = RealtimeTarget(server_vad=True) +attack2 = BargeInAttack( + objective_target=target2, + attack_converter_config=AttackConverterConfig(request_converters=converters), +) + +barge_in_context = BargeInAttackContext( + params=AttackParameters(objective="Demonstrate barge-in by interrupting a benign answer"), + audio_chunks=barge_in_source(), +) + +barge_in_result = await attack2.execute_with_context_async(context=barge_in_context) # type: ignore +print(f"executed_turns: {barge_in_result.executed_turns}") + +# Inspect memory to verify the barge-in landed in metadata. +memory = CentralMemory.get_memory_instance() +turns = memory.get_conversation(conversation_id=barge_in_result.conversation_id) +print(f"\nPersisted pieces ({len(turns)} messages):") +for message in turns: + for piece in message.message_pieces: + interrupted = piece.prompt_metadata.get("interrupted") + marker = " [INTERRUPTED]" if interrupted else "" + val = piece.converted_value + if piece.converted_value_data_type == "audio_path": + val = Path(val).name + value_preview = (val[:80] + "...") if len(val) > 80 else val + print(f" {piece._role} {piece.converted_value_data_type}{marker}: {value_preview}") + +await ConsoleAttackResultPrinter(width=200).print_conversation_async(result=barge_in_result) # type: ignore +await target2.cleanup_target() # type: ignore + +# %% [markdown] +# ### Reading the barge-in output +# +# If barge-in fired successfully: +# - `executed_turns: 2` (two VAD-detected user turns) +# - First assistant turn shows `[INTERRUPTED]` with a truncated transcript +# - Second assistant turn completes normally +# +# If you don't see `[INTERRUPTED]`, decrease `TURN1_RESPONSE_WAIT_S` so turn 2's audio +# arrives earlier in turn 1's response window. + +# %% [markdown] +# ## Alternate chunk sources +# +# The chunk source is the main strategy hook: +# +# - **Pre-recorded WAV** (this notebook): most common starting point +# - **TTS converter**: generate audio from text prompts dynamically +# - **Live microphone**: use `sounddevice` or similar; yield what the mic produces +# +# For feedback-driven attacks — for example, scoring each assistant turn and choosing +# to barge in with follow-up audio only when the response shows incomplete refusal — +# subclass `BargeInAttack` and override `_perform_async` to interleave turn observation +# with chunk generation. diff --git a/doc/myst.yml b/doc/myst.yml index 580b36a33a..5468672fbf 100644 --- a/doc/myst.yml +++ b/doc/myst.yml @@ -112,6 +112,7 @@ project: - file: code/executor/attack/role_play_attack.ipynb - file: code/executor/attack/skeleton_key_attack.ipynb - file: code/executor/attack/tap_attack.ipynb + - file: code/executor/attack/barge_in_attack.ipynb - file: code/executor/attack/violent_durian_attack.ipynb - file: code/executor/workflow/0_workflow.md children: diff --git a/pyrit/executor/attack/__init__.py b/pyrit/executor/attack/__init__.py index e0c4f44fc6..dc1589a1da 100644 --- a/pyrit/executor/attack/__init__.py +++ b/pyrit/executor/attack/__init__.py @@ -49,6 +49,7 @@ SingleTurnAttackStrategy, SkeletonKeyAttack, ) +from pyrit.executor.attack.streaming import BargeInAttack, BargeInAttackContext # Backward-compatibility aliases — import from pyrit.output.attack_result directly. # TODO: Remove these re-exports in two releases (target removal: 0.16.0). @@ -96,6 +97,8 @@ "ConversationState", "AttackExecutor", "AttackExecutorResult", + "BargeInAttack", + "BargeInAttackContext", "PrependedConversationConfig", "generate_simulated_conversation_async", ] diff --git a/pyrit/executor/attack/streaming/__init__.py b/pyrit/executor/attack/streaming/__init__.py new file mode 100644 index 0000000000..b743ea7961 --- /dev/null +++ b/pyrit/executor/attack/streaming/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Streaming attack strategies (barge-in over realtime audio targets).""" + +from pyrit.executor.attack.streaming.barge_in import BargeInAttack, BargeInAttackContext + +__all__ = [ + "BargeInAttack", + "BargeInAttackContext", +] diff --git a/pyrit/executor/attack/streaming/barge_in.py b/pyrit/executor/attack/streaming/barge_in.py new file mode 100644 index 0000000000..af4a7deb3b --- /dev/null +++ b/pyrit/executor/attack/streaming/barge_in.py @@ -0,0 +1,435 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Streaming barge-in attack over realtime audio targets.""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from pyrit.common.apply_defaults import REQUIRED_VALUE, apply_defaults +from pyrit.executor.attack.component.conversation_manager import ConversationManager +from pyrit.executor.attack.core.attack_config import AttackConverterConfig +from pyrit.executor.attack.core.attack_parameters import AttackParameters, AttackParamsT +from pyrit.executor.attack.core.attack_strategy import AttackContext, AttackStrategy +from pyrit.identifiers.atomic_attack_identifier import build_atomic_attack_identifier +from pyrit.models import ( + AttackOutcome, + AttackResult, + Message, + MessagePiece, +) +from pyrit.prompt_normalizer import PromptNormalizer +from pyrit.prompt_target.common.realtime_audio import ( + REALTIME_COMMITTED_ITEM_ID_KEY, + StreamingBargeInTarget, +) +from pyrit.prompt_target.common.target_capabilities import CapabilityName +from pyrit.prompt_target.common.target_requirements import TargetRequirements + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from pyrit.prompt_target import PromptTarget + from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + ) + +logger = logging.getLogger(__name__) + + +@dataclass +class BargeInAttackContext(AttackContext[AttackParamsT]): + """ + Context for a streaming barge-in attack with an audio chunk source. + + ``prepended_conversation`` (inherited from ``AttackContext``) is persisted to memory + on setup, but only the leading system message is propagated to the live realtime + session as session instructions. User / assistant turns from the prepended history + are not (yet) pushed through ``conversation.item.create``, so the model conditions + only on the system prompt plus live audio chunks. See follow-up issue for full + realtime-session injection. + """ + + conversation_id: str = field(default_factory=lambda: str(uuid.uuid4())) + audio_chunks: AsyncIterator[bytes] | None = None + + +@dataclass +class _BargeInRunState: + """Mutable per-session state shared between ``_perform_async`` and ``on_committed``.""" + + raw_buffer: bytearray = field(default_factory=bytearray) + turn_tasks: list[asyncio.Task[None]] = field(default_factory=list) + # Session-time (in ms) at which the current buffer started accumulating. Used to + # convert the server's session-relative ``audio_start_ms`` into a buffer-relative + # offset for trimming. 0 at session start; advances by ``audio_end_ms`` of each + # commit, but since the server omits ``audio_end_ms`` we approximate it as + # ``audio_start_ms + buffer_speech_duration``. In practice we just track the most + # recent commit's reported start so the next turn's trim is relative to it. + buffer_start_session_ms: int = 0 + + +def _trim_snapshot_to_speech( + *, + raw_buffer: bytes, + sample_rate_hz: int, + audio_start_ms: int | None, + prefix_padding_ms: int, + sample_width_bytes: int = 2, + channels: int = 1, +) -> bytes: + """ + Trim leading pre-speech silence from a raw mic snapshot. + + Server VAD reports where speech began via ``audio_start_ms``. The local + accumulator captures every chunk pushed since the last commit — including + seconds of pre-speech silence — so without a trim the converted audio that + gets swapped into the server's committed item would be much longer than + what the server actually committed, causing the model to hear leading silence. + + Args: + raw_buffer: PCM16 mono audio for the current buffer (all bytes pushed since the last commit). + sample_rate_hz: PCM sample rate in Hz. + audio_start_ms: Server's ``audio_start_ms`` offset, or None when unknown. + prefix_padding_ms: Bytes to keep before ``audio_start_ms`` so we don't chop the speech onset + (typically matches server VAD's ``prefix_padding_ms``). + sample_width_bytes: Bytes per sample (2 for PCM16). + channels: Audio channels (1 for mono). + + Returns: + The trimmed buffer; returns ``raw_buffer`` unchanged when ``audio_start_ms`` + is None or 0, or when the computed trim would leave nothing. + + Raises: + ValueError: If ``audio_start_ms`` is negative, or if ``sample_rate_hz``, + ``sample_width_bytes``, or ``channels`` is not positive. + """ + if sample_rate_hz <= 0 or sample_width_bytes <= 0 or channels <= 0: + raise ValueError( + f"sample_rate_hz, sample_width_bytes, and channels must all be positive; " + f"got sample_rate_hz={sample_rate_hz}, sample_width_bytes={sample_width_bytes}, channels={channels}" + ) + if audio_start_ms is None or audio_start_ms == 0: + return raw_buffer + if audio_start_ms < 0: + raise ValueError(f"audio_start_ms must be >= 0, got {audio_start_ms}") + bytes_per_ms = sample_rate_hz * sample_width_bytes * channels // 1000 + start_ms = max(0, audio_start_ms - prefix_padding_ms) + start_byte = start_ms * bytes_per_ms + # Align to sample frame boundary so the trimmed buffer doesn't start mid-sample. + frame_bytes = sample_width_bytes * channels + start_byte -= start_byte % frame_bytes + if start_byte >= len(raw_buffer): + return raw_buffer + return raw_buffer[start_byte:] + + +class BargeInAttack(AttackStrategy["BargeInAttackContext[Any]", AttackResult]): + """ + Streaming attack that drives a Realtime API session with server VAD + barge-in. + + The attack pushes user audio chunks through the target, lets server VAD detect + turn boundaries, manually fires ``response.create`` after each commit, and + observes assistant turns (including interrupted ones) via per-turn futures + returned by the target's ``request_response_async``. + """ + + TARGET_REQUIREMENTS: ClassVar[TargetRequirements] = TargetRequirements( + required=frozenset({CapabilityName.STREAMING_BARGE_IN}), + ) + + #: Default maximum time to wait after the chunk source exhausts for any in-flight + #: VAD-committed turn to finish (commit → convert → response.create → response.done + #: → persist). Acts as a safety cap; the attack returns as soon as the last turn + #: actually completes. Overridable per-instance via ``max_post_stream_wait_seconds``. + DEFAULT_MAX_POST_STREAM_WAIT_SECONDS: ClassVar[float] = 60.0 + + @apply_defaults + def __init__( + self, + *, + objective_target: PromptTarget = REQUIRED_VALUE, # type: ignore[ty:invalid-parameter-default] + attack_converter_config: AttackConverterConfig | None = None, + prompt_normalizer: PromptNormalizer | None = None, + max_post_stream_wait_seconds: float = DEFAULT_MAX_POST_STREAM_WAIT_SECONDS, + params_type: type[AttackParamsT] = AttackParameters, # type: ignore[ty:invalid-parameter-default] + ) -> None: + """ + Initialize the streaming barge-in attack. + + Args: + objective_target: Target to attack. Must support ``STREAMING_BARGE_IN`` capability. + attack_converter_config: Converters applied to each committed user turn. + prompt_normalizer: Normalizer used to apply converters and persist messages. + Defaults to a fresh ``PromptNormalizer``. + max_post_stream_wait_seconds: Safety cap on the wait between the chunk source + exhausting and the last in-flight turn finishing. Defaults to 60 seconds. + Bump if a long realtime response is being cancelled at teardown. + params_type: Attack parameter dataclass type. + """ + super().__init__( + objective_target=objective_target, + context_type=BargeInAttackContext, + params_type=params_type, + logger=logger, + ) + attack_converter_config = attack_converter_config or AttackConverterConfig() + self._request_converters = attack_converter_config.request_converters + self._response_converters = attack_converter_config.response_converters + self._prompt_normalizer = prompt_normalizer or PromptNormalizer() + self._conversation_manager = ConversationManager( + attack_identifier=self.get_identifier(), + prompt_normalizer=self._prompt_normalizer, + ) + self._max_post_stream_wait_seconds = max_post_stream_wait_seconds + + def _validate_context(self, *, context: BargeInAttackContext[Any]) -> None: + """ + Validate the context before executing. + + Args: + context: The streaming attack context. + + Raises: + ValueError: If the context is missing required fields. + """ + if not context.objective or context.objective.isspace(): + raise ValueError("Attack objective must be provided and non-empty in the context") + if context.audio_chunks is None: + raise ValueError("BargeInAttackContext.audio_chunks must be set to an async iterator of PCM bytes") + + async def _setup_async(self, *, context: BargeInAttackContext[Any]) -> None: + """ + Set up the attack: ensure a conversation id and initialize prepended conversation. + + Merges memory labels and persists ``context.prepended_conversation`` to memory via + ``ConversationManager`` so streaming attacks share the same memory contract as + non-streaming attacks. Note: prepended messages are recorded in memory but are NOT + pushed into the live realtime session beyond the system prompt — the model only + conditions on the system message and live audio chunks. Pushing prepended user / + assistant turns into the websocket session via ``conversation.item.create`` is + tracked as a follow-up. + """ + if not context.conversation_id: + context.conversation_id = str(uuid.uuid4()) + await self._conversation_manager.initialize_context_async( + context=context, + target=self._objective_target, + conversation_id=context.conversation_id, + request_converters=self._request_converters, + ) + + async def _teardown_async(self, *, context: BargeInAttackContext[Any]) -> None: + """No-op teardown — connection / dispatcher are closed inside ``_perform_async``.""" + return + + async def _perform_async(self, *, context: BargeInAttackContext[Any]) -> AttackResult: + """ + Run the streaming session: connect, subscribe, push chunks, await final turn, tear down. + + Args: + context: Streaming attack context with ``audio_chunks`` source. + + Returns: + An ``AttackResult`` capturing the last assistant turn (if any) and the + number of completed turns. + + Raises: + ValueError: If ``context.audio_chunks`` is ``None``. + """ + target = cast("StreamingBargeInTarget", self._objective_target) + if context.audio_chunks is None: + raise ValueError("BargeInAttackContext.audio_chunks must be set before executing the attack.") + + connection = await target.connect_async(conversation_id=context.conversation_id) + state = _BargeInRunState() + last_response: Message | None = None + executed_turns = 0 + + async def on_committed(event: CommittedEvent) -> None: + nonlocal last_response, executed_turns + current_task = asyncio.current_task() + if current_task is not None: + state.turn_tasks.append(current_task) + try: + response = await self._handle_committed_turn_async( + event=event, + context=context, + state=state, + target=target, + ) + last_response = response + executed_turns += 1 + except Exception: + logger.exception("BargeInAttack turn failed in convert-on-commit handler.") + + await target.subscribe_events_async( + connection=connection, + conversation_id=context.conversation_id, + on_user_audio_committed=on_committed, + ) + + try: + await target.send_streaming_session_config_async( + connection=connection, conversation=context.prepended_conversation + ) + + async for chunk in context.audio_chunks: + if chunk: + state.raw_buffer.extend(chunk) + await target.push_audio_chunk_async(connection=connection, pcm_bytes=chunk) + + # Wait for any in-flight committed-turn tasks to finish, capped by a safety timeout. + # The chunk source must end with enough trailing silence for server VAD's silence + # threshold to fire commit — otherwise the last turn never enters the pipeline. + await self._wait_for_pending_turns_async(state.turn_tasks) + finally: + await target.cleanup_conversation(context.conversation_id) + + return self._build_result( + last_response=last_response, + executed_turns=executed_turns, + context=context, + ) + + async def _handle_committed_turn_async( + self, + *, + event: CommittedEvent, + context: BargeInAttackContext[Any], + state: _BargeInRunState, + target: StreamingBargeInTarget, + ) -> Message: + """ + Run one convert-and-respond turn for a VAD-committed user audio buffer. + + Snapshots the locally-accumulated raw PCM, persists it as a durable WAV, + wraps it in a Message with the server's committed item id stashed in + ``prompt_metadata`` so the target's streaming branch can swap raw audio + for converter-transformed audio, then drives ``send_prompt_async``. + + Returns: + The assistant Message returned by ``send_prompt_async`` for this turn. + """ + # Snapshot the locally-accumulated raw PCM and reset for the next turn. + snapshot = bytes(state.raw_buffer) + state.raw_buffer.clear() + + # Convert the server's session-relative audio_start_ms into a buffer-relative + # offset, then trim leading pre-speech silence. Without this, the converted + # audio that gets swapped into the server's committed item is several seconds + # longer than what server VAD actually committed, and the model hears the + # leading silence (often dominant) when converters are active. + bytes_per_ms = target.SAMPLE_RATE_HZ * 2 // 1000 # PCM16 mono + original_buffer_duration_ms = len(snapshot) // bytes_per_ms if bytes_per_ms else 0 + + buffer_relative_audio_start_ms: int | None = None + if event.audio_start_ms is not None: + buffer_relative_audio_start_ms = event.audio_start_ms - state.buffer_start_session_ms + + server_vad = target.server_vad_config + prefix_padding_ms = server_vad.prefix_padding_ms if server_vad is not None else 0 + snapshot = _trim_snapshot_to_speech( + raw_buffer=snapshot, + sample_rate_hz=target.SAMPLE_RATE_HZ, + audio_start_ms=buffer_relative_audio_start_ms, + prefix_padding_ms=prefix_padding_ms, + ) + + # Advance session-time bookkeeping for the next turn. Uses the ORIGINAL (pre-trim) + # buffer duration since the server saw every byte we pushed. + state.buffer_start_session_ms += original_buffer_duration_ms + + # PromptNormalizer.send_prompt_async needs an audio_path-shaped Message, + # so persist the snapshot to a durable WAV before wrapping. + snapshot_path = await target.save_audio( + snapshot, + num_channels=1, + sample_width=2, + sample_rate=target.SAMPLE_RATE_HZ, + ) + # Stash the server-assigned item id so the target's streaming branch + # can swap the raw buffer for converter-transformed audio. + piece = MessagePiece( + role="user", + original_value=snapshot_path, + original_value_data_type="audio_path", + converted_value=snapshot_path, + converted_value_data_type="audio_path", + conversation_id=context.conversation_id, + prompt_metadata={REALTIME_COMMITTED_ITEM_ID_KEY: event.item_id}, + ) + message = Message(message_pieces=[piece]) + + return await self._prompt_normalizer.send_prompt_async( + message=message, + target=self._objective_target, + request_converter_configurations=self._request_converters, + response_converter_configurations=self._response_converters, + conversation_id=context.conversation_id, + attack_identifier=self.get_identifier(), + ) + + def _build_result( + self, + *, + last_response: Message | None, + executed_turns: int, + context: BargeInAttackContext[Any], + ) -> AttackResult: + """ + Assemble the final ``AttackResult`` from accumulated turn outcomes. + + Returns: + ``AttackResult`` with the last assistant message, executed turn count, + and outcome reason. + """ + if executed_turns == 0: + outcome_reason: str | None = "No assistant turns completed (server VAD did not commit any user audio)" + else: + outcome_reason = f"{executed_turns} assistant turn(s) completed; no scorer configured" + + return AttackResult( + conversation_id=context.conversation_id, + objective=context.objective, + atomic_attack_identifier=build_atomic_attack_identifier(attack_identifier=self.get_identifier()), + last_response=(last_response.message_pieces[0] if last_response else None), + last_score=None, + related_conversations=context.related_conversations, + outcome=AttackOutcome.UNDETERMINED, + outcome_reason=outcome_reason, + executed_turns=executed_turns, + labels=context.memory_labels, + ) + + async def _wait_for_pending_turns_async(self, turn_tasks: list[asyncio.Task[None]]) -> None: + """ + Wait for any in-flight VAD-committed turn tasks to finish, with a safety timeout. + + Returns as soon as all known turn tasks complete (or the cap elapses, whichever + comes first). The timeout is a safety net for stuck turns; the common case is to + return immediately once the last turn's persistence finishes. + + Args: + turn_tasks: Task handles for every ``on_committed`` invocation launched so far. + Tasks added after this method starts are not waited on; the dispatcher + callback machinery makes this race vanishingly unlikely in practice. + """ + if not turn_tasks: + return + try: + await asyncio.wait_for( + asyncio.gather(*turn_tasks, return_exceptions=True), + timeout=self._max_post_stream_wait_seconds, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timed out after {self._max_post_stream_wait_seconds}s waiting for in-flight turn tasks to " + "finish; teardown will cancel them. Raise max_post_stream_wait_seconds on the attack " + "constructor if responses regularly take longer." + ) diff --git a/pyrit/prompt_normalizer/__init__.py b/pyrit/prompt_normalizer/__init__.py index fa030605f7..04980a08d7 100644 --- a/pyrit/prompt_normalizer/__init__.py +++ b/pyrit/prompt_normalizer/__init__.py @@ -13,7 +13,7 @@ from pyrit.prompt_normalizer.prompt_normalizer import PromptNormalizer __all__ = [ - "PromptNormalizer", - "PromptConverterConfiguration", "NormalizerRequest", + "PromptConverterConfiguration", + "PromptNormalizer", ] diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index b4dbc0b0a8..b0d42c9a76 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -19,6 +19,7 @@ ) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.realtime_audio import ServerVadConfig from pyrit.prompt_target.common.target_capabilities import ( CapabilityHandlingPolicy, CapabilityName, @@ -102,6 +103,7 @@ def __getattr__(name: str) -> object: "PromptShieldTarget", "PromptTarget", "RealtimeTarget", + "ServerVadConfig", "RoundRobinTarget", "TargetCapabilities", "TargetConfiguration", diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py index 45600e6009..f872e2b880 100644 --- a/pyrit/prompt_target/common/discover_target_capabilities.py +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -149,6 +149,7 @@ def _permissive_configuration( supports_json_output=True, supports_editable_history=True, supports_system_prompt=True, + supports_streaming_barge_in=True, input_modalities=merged_modalities, ) # Rebuild a fresh configuration from the instance's native capabilities so diff --git a/pyrit/prompt_target/common/realtime_audio.py b/pyrit/prompt_target/common/realtime_audio.py new file mode 100644 index 0000000000..e99d1061bb --- /dev/null +++ b/pyrit/prompt_target/common/realtime_audio.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Shared types for realtime audio prompt targets.""" + +import asyncio +import contextlib +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +logger = logging.getLogger(__name__) + + +#: Key under which streaming attacks stash the server-assigned item id of the most +#: recently committed user audio buffer (in ``MessagePiece.prompt_metadata``). Realtime +#: targets read this key in their streaming branch to identify which committed item +#: to delete when swapping in converter-transformed audio. Exposed as a public +#: constant so attacks can reference it without reaching into target internals. +REALTIME_COMMITTED_ITEM_ID_KEY = "_realtime_committed_item_id" + + +@dataclass(frozen=True) +class ServerVadConfig: + """Server-side voice activity detection (VAD) tuning for realtime audio targets.""" + + threshold: float = 0.4 + prefix_padding_ms: int = 200 + silence_duration_ms: int = 1500 + + def __post_init__(self) -> None: + """ + Validate VAD tuning values. + + Raises: + ValueError: If any field is outside its valid range. + """ + if not 0.0 <= self.threshold <= 1.0: + raise ValueError(f"threshold must be in [0.0, 1.0], got {self.threshold}") + if self.prefix_padding_ms < 0: + raise ValueError(f"prefix_padding_ms must be non-negative, got {self.prefix_padding_ms}") + if self.silence_duration_ms < 0: + raise ValueError(f"silence_duration_ms must be non-negative, got {self.silence_duration_ms}") + + +@dataclass +class RealtimeTargetResult: + """Result of a Realtime API turn: delivered audio, transcripts, and interruption status.""" + + audio_bytes: bytes = b"" + transcripts: list[str] = field(default_factory=list) + interrupted: bool = False + + def flatten_transcripts(self) -> str: + """Return all transcript deltas concatenated into a single string.""" + return "".join(self.transcripts) + + +@dataclass +class RealtimeTurnState: + """Mutable per-turn state assembled by the dispatcher from incoming events.""" + + completion: asyncio.Future[RealtimeTargetResult] + is_responding: bool = False + delivered_audio: bytearray = field(default_factory=bytearray) + delivered_transcripts: list[str] = field(default_factory=list) + current_item_id: str | None = None + last_response_id: str | None = None + interrupted: bool = False + + +@dataclass(frozen=True) +class CommittedEvent: + """Payload passed to ``on_user_audio_committed`` callbacks when server VAD commits.""" + + item_id: str + audio_start_ms: int | None = None + + +class RealtimeEventDispatcher(ABC): + """ + Owns a realtime connection's event stream and routes events to the active turn. + + Provider-specific event routing and cancel logic are isolated to the abstract methods. + """ + + def __init__( + self, + *, + connection: Any, + on_user_audio_committed: Callable[[CommittedEvent], Coroutine[Any, Any, None]] | None = None, + ) -> None: + """ + Args: + connection: An open realtime connection exposing an async iterator + of server events. The dispatcher owns reading from it. + on_user_audio_committed: Optional callback fired when the server + commits a user audio buffer (e.g. server VAD finalizing a turn). + Invoked as a background task so converter work in the callback + does not block the dispatch loop. Default None disables it. + """ + self._connection = connection + self._on_user_audio_committed = on_user_audio_committed + self._current_turn: RealtimeTurnState | None = None + self._task: asyncio.Task[None] | None = None + self._callback_tasks: set[asyncio.Task[None]] = set() + self._failure: BaseException | None = None + # Server VAD reports audio_start_ms on speech_started but omits it from + # input_audio_buffer.committed. Concrete subclasses capture it here when + # speech_started fires and read it back on commit. + self._pending_speech_start_ms: int | None = None + + @property + def failure(self) -> BaseException | None: + """ + The exception that killed the dispatch loop, or None if it is still healthy. + + Set when the outer event iterator raises. Callers (e.g. ``BargeInAttack``) + poll this between operations to detect a dead connection without needing a + callback. Once set, ``stop()`` should be called and the attack torn down. + """ + return self._failure + + async def start(self) -> None: + """Start the background dispatch task. Idempotent.""" + if self._task is None: + self._task = asyncio.create_task(self._dispatch_loop()) + + async def stop(self) -> None: + """ + Cancel the background dispatch task and release the reference. + + In-flight callback tasks are cancelled and awaited (with exception + suppression) so they don't deadlock waiting on the turn future that the + now-dead dispatch loop would have resolved. + """ + if self._task is not None: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError, Exception): + await self._task + self._task = None + if self._callback_tasks: + pending = list(self._callback_tasks) + self._callback_tasks.clear() + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + def register_turn(self, state: RealtimeTurnState) -> None: + """ + Bind a new turn as the active turn. + + Args: + state (RealtimeTurnState): The turn whose completion future will be + resolved when this turn ends. + + Raises: + RuntimeError: If another turn is already active on this dispatcher. + """ + if self._current_turn is not None and not self._current_turn.completion.done(): + raise RuntimeError("Another turn is already active on this dispatcher") + self._current_turn = state + + async def _dispatch_loop(self) -> None: + """ + Consume events from the connection and route each to the active turn. + + The router is called for every event with the current turn (which may + be None during the gap between turns). Concrete routers are expected to + handle ``state is None`` for input-side events that need no turn state + and return early on output-side events when no turn is registered. + + Raises: + asyncio.CancelledError: Propagated when ``stop()`` cancels the task. + """ + try: + async for event in self._connection: + turn = self._current_turn + if turn is not None and turn.completion.done(): + turn = None + try: + await self._route_event(event=event, state=turn) + except Exception as e: + logger.exception(f"Realtime event router raised: {e}") + if turn is not None and not turn.completion.done(): + turn.completion.set_exception(e) + except asyncio.CancelledError: + raise + except Exception as e: + logger.exception(f"Realtime dispatch loop crashed: {e}") + self._failure = e + turn = self._current_turn + if turn is not None and not turn.completion.done(): + turn.completion.set_exception(e) + + def _fire_committed_callback(self, event: CommittedEvent) -> None: + """ + Schedule the ``on_user_audio_committed`` callback as a background task. + + Tracks the resulting task so ``stop()`` can wait for it to finish. + """ + if self._on_user_audio_committed is None: + return + task = asyncio.create_task(self._on_user_audio_committed(event)) + self._callback_tasks.add(task) + task.add_done_callback(self._callback_tasks.discard) + + @abstractmethod + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + """ + Route a single provider-specific event. + + Concrete implementations: + - When the event is output-side (response lifecycle, audio/transcript + deltas, etc.) and ``state`` is non-None, mutate ``state`` and resolve + ``state.completion`` at end-of-turn or on interruption. + - When ``state`` is None (no active turn) or + ``state.completion.done()``, output-side events should be dropped. + - When the event is input-side (e.g. ``input_audio_buffer.committed``), + fire any subscribed callback via ``self._fire_committed_callback(...)``. + These callbacks may run regardless of ``state``. + - On error events, resolve ``state.completion`` via ``set_exception`` + when a turn is active. + + Args: + event: A single provider-specific event from the connection iterator. + state (RealtimeTurnState | None): The currently-active turn, or None + if no turn is registered (e.g. between turns in a streaming + session). + """ + + @abstractmethod + async def _cancel(self, *, state: RealtimeTurnState) -> None: + """ + Send provider-specific cancel and truncate events for the in-flight response. + + Must set ``state.interrupted = True`` even on wire-call failure so callers + can tell the turn was cut short. Must not resolve ``state.completion``; + that is the dispatcher's responsibility. + + Args: + state (RealtimeTurnState): The turn whose response should be cancelled. + """ + + +@runtime_checkable +class StreamingBargeInTarget(Protocol): + """ + Provider-agnostic surface a streaming barge-in attack requires of its target. + + Captures the methods and attributes ``BargeInAttack`` reads from its objective + target so the attack can be typed against this Protocol rather than a concrete + provider class (e.g. ``RealtimeTarget``). A second realtime provider could + implement this Protocol without subclassing ``RealtimeTarget``. + """ + + #: PCM sample rate in Hz negotiated by the provider's realtime protocol. + SAMPLE_RATE_HZ: int + + @property + def server_vad_config(self) -> "ServerVadConfig | None": + """Server VAD configuration in effect, or None if server VAD is disabled.""" + ... + + async def connect_async(self, conversation_id: str) -> Any: + """Open the realtime connection for ``conversation_id`` and return the connection handle.""" + ... + + async def subscribe_events_async( + self, + *, + connection: Any, + conversation_id: str, + on_user_audio_committed: Callable[[CommittedEvent], Coroutine[Any, Any, None]] | None = None, + ) -> "RealtimeEventDispatcher": + """Spawn a background reader that routes server events and returns the dispatcher.""" + ... + + async def send_streaming_session_config_async( + self, *, connection: Any, conversation: list[Any] | None = None + ) -> None: + """Send the initial ``session.update`` over the wire (system prompt, VAD config, etc.).""" + ... + + async def push_audio_chunk_async(self, *, connection: Any, pcm_bytes: bytes) -> None: + """Push a PCM16 audio chunk into the server's input buffer.""" + ... + + async def save_audio( + self, + audio_bytes: bytes, + num_channels: int = 1, + sample_width: int = 2, + sample_rate: int = 16000, + output_filename: str | None = None, + ) -> str: + """Persist a PCM buffer to disk and return the file path.""" + ... + + async def cleanup_conversation(self, conversation_id: str) -> None: + """Tear down any per-conversation state held by the target.""" + ... diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index 6ae9ed69e2..b578d6eefd 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -24,6 +24,7 @@ class CapabilityName(str, Enum): JSON_OUTPUT = "supports_json_output" EDITABLE_HISTORY = "supports_editable_history" SYSTEM_PROMPT = "supports_system_prompt" + STREAMING_BARGE_IN = "supports_streaming_barge_in" class UnsupportedCapabilityBehavior(str, Enum): @@ -138,6 +139,13 @@ class attribute. Users can override individual capabilities per instance # Whether the target natively supports system prompts. supports_system_prompt: bool = False + # Whether the target supports the streaming barge-in API: pushing user audio chunks + # via ``push_audio_chunk_async``, subscribing to user-audio-committed events via + # ``subscribe_events_async``, swapping committed items via + # ``delete_conversation_item_async`` + ``insert_user_audio_async``, and triggering + # responses via ``request_response_async``. Required by ``BargeInAttack``. + supports_streaming_barge_in: bool = False + # The input modalities supported by the target (e.g., "text", "image"). input_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) diff --git a/pyrit/prompt_target/openai/openai_realtime_target.py b/pyrit/prompt_target/openai/openai_realtime_target.py index 3deffe6287..b88dceb11e 100644 --- a/pyrit/prompt_target/openai/openai_realtime_target.py +++ b/pyrit/prompt_target/openai/openai_realtime_target.py @@ -6,8 +6,9 @@ import logging import re import wave +from collections.abc import Callable, Coroutine from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, ClassVar, Literal, Optional from openai import AsyncOpenAI @@ -22,6 +23,14 @@ data_serializer_factory, ) from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.realtime_audio import ( + REALTIME_COMMITTED_ITEM_ID_KEY, + CommittedEvent, + RealtimeEventDispatcher, + RealtimeTargetResult, + RealtimeTurnState, + ServerVadConfig, +) from pyrit.prompt_target.common.target_capabilities import TargetCapabilities from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute @@ -36,26 +45,19 @@ @dataclass -class RealtimeTargetResult: +class _StreamingConversationState: """ - Represents the result of a Realtime API request, containing audio data and transcripts. + Per-conversation streaming-mode bookkeeping for :class:`RealtimeTarget`. - Attributes: - audio_bytes: Raw audio data returned by the API - transcripts: List of text transcripts generated from the audio + Presence in :attr:`RealtimeTarget._streaming_state` is the signal that a + conversation should take the streaming swap-and-respond path inside + :meth:`RealtimeTarget._send_prompt_to_target_async` rather than the atomic + send_audio / send_text path. The lock serializes per-turn work so back-to-back + VAD commits cannot race on the dispatcher's single active-turn slot. """ - audio_bytes: bytes = field(default_factory=lambda: b"") - transcripts: list[str] = field(default_factory=list) - - def flatten_transcripts(self) -> str: - """ - Flattens the list of transcripts into a single string. - - Returns: - A single string containing all transcripts concatenated together. - """ - return "".join(self.transcripts) + dispatcher: RealtimeEventDispatcher + turn_lock: asyncio.Lock = field(default_factory=asyncio.Lock) class RealtimeTarget(OpenAITarget, PromptTarget): @@ -75,6 +77,7 @@ class RealtimeTarget(OpenAITarget, PromptTarget): supports_editable_history=True, supports_multi_message_pieces=True, supports_system_prompt=True, + supports_streaming_barge_in=True, input_modalities=frozenset( { frozenset(["text"]), @@ -91,12 +94,18 @@ class RealtimeTarget(OpenAITarget, PromptTarget): ) ) + #: Sample rate (Hz) for all PCM16 audio exchanged with the Realtime API. + #: The Realtime API negotiates 24 kHz; callers (streaming attacks, audio + #: helpers, normalizers) should read this rather than hard-coding 24000. + SAMPLE_RATE_HZ: ClassVar[int] = 24000 + def __init__( self, *, voice: Optional[RealTimeVoice] = None, existing_convo: Optional[dict[str, Any]] = None, custom_configuration: Optional[TargetConfiguration] = None, + server_vad: bool | ServerVadConfig = False, **kwargs: Any, ) -> None: """ @@ -120,6 +129,11 @@ def __init__( existing_convo (dict[str, websockets.WebSocketClientProtocol], Optional): Existing conversations. custom_configuration (TargetConfiguration, Optional): Override the default configuration for this target instance. Defaults to None. + server_vad (bool | ServerVadConfig): Server-side voice activity detection (VAD). + ``False`` (default) keeps the existing atomic send/receive behavior. + ``True`` enables VAD with default tuning. + Pass a ``ServerVadConfig`` to enable with custom tuning. Streaming/interruption plumbing + arrives in subsequent changes; this currently only affects the emitted session config. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` @@ -130,6 +144,24 @@ def __init__( self._existing_conversation = existing_convo if existing_convo is not None else {} self._realtime_client: Optional[AsyncOpenAI] = None + if isinstance(server_vad, ServerVadConfig): + self._server_vad: Optional[ServerVadConfig] = server_vad + elif server_vad: + self._server_vad = ServerVadConfig() + else: + self._server_vad = None + + # Streaming-mode bookkeeping. Entries are added by ``subscribe_events_async`` and + # consumed by the streaming branch of ``_send_prompt_to_target_async``. The + # presence of a conversation_id key signals "this conversation is in streaming + # mode" so the target can route requests to the swap-and-respond path. + self._streaming_state: dict[str, _StreamingConversationState] = {} + + @property + def server_vad_config(self) -> ServerVadConfig | None: + """Server VAD configuration in effect for this target, or None if server VAD is disabled.""" + return self._server_vad + def _set_openai_env_configuration_vars(self) -> None: self.model_name_environment_variable = "OPENAI_REALTIME_MODEL" self.endpoint_environment_variable = "OPENAI_REALTIME_ENDPOINT" @@ -241,7 +273,7 @@ def _get_openai_client(self) -> AsyncOpenAI: return self._realtime_client - async def connect(self, conversation_id: str) -> Any: + async def connect_async(self, conversation_id: str) -> Any: """ Connect to Realtime API using AsyncOpenAI client and return the realtime connection. @@ -278,18 +310,28 @@ def _set_system_prompt_and_config_vars(self, system_prompt: str) -> dict[str, An }, "format": { "type": "audio/pcm", - "rate": 24000, + "rate": self.SAMPLE_RATE_HZ, }, }, "output": { "format": { "type": "audio/pcm", - "rate": 24000, + "rate": self.SAMPLE_RATE_HZ, } }, }, } + if self._server_vad is not None: + session_config["audio"]["input"]["turn_detection"] = { # type: ignore[ty:invalid-assignment] + "type": "server_vad", + "threshold": self._server_vad.threshold, + "prefix_padding_ms": self._server_vad.prefix_padding_ms, + "silence_duration_ms": self._server_vad.silence_duration_ms, + "create_response": True, + "interrupt_response": True, + } + if self.voice: session_config["audio"]["output"]["voice"] = self.voice # type: ignore[ty:invalid-assignment] @@ -345,6 +387,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me """ Asynchronously send a message to the OpenAI realtime target. + Routes to the streaming swap-and-respond path when streaming state is + registered for the conversation; otherwise dispatches to the atomic + send_audio / send_text path. + Args: normalized_conversation (list[Message]): The full conversation (history + current message) after running the normalization @@ -358,32 +404,85 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me """ message = normalized_conversation[-1] conversation_id = message.message_pieces[0].conversation_id - if conversation_id not in self._existing_conversation: - connection = await self.connect(conversation_id=conversation_id) - self._existing_conversation[conversation_id] = connection - - # Only send config when creating a new connection - await self.send_config(conversation_id=conversation_id, conversation=normalized_conversation) - # Give the server a moment to process the session update - await asyncio.sleep(0.5) - request = message.message_pieces[0] - response_type = request.converted_value_data_type - # Order of messages sent varies based on the data format of the prompt - if response_type == "audio_path": - output_audio_path, result = await self.send_audio_async( - filename=request.converted_value, - conversation_id=conversation_id, - ) + streaming = self._streaming_state.get(conversation_id) + if streaming is not None: + # Streaming swap-and-respond path. The lock serializes per-turn work so + # back-to-back VAD commits cannot race on the dispatcher's single turn slot. + async with streaming.turn_lock: + if request.converted_value_data_type != "audio_path": + raise ValueError( + f"Streaming realtime requests must carry audio_path, got {request.converted_value_data_type!r}." + ) - elif response_type == "text": - output_audio_path, result = await self.send_text_async( - text=request.converted_value, - conversation_id=conversation_id, - ) + connection = self._existing_conversation[conversation_id] + + with wave.open(request.converted_value, "rb") as wav_in: + if ( + wav_in.getnchannels() != 1 + or wav_in.getsampwidth() != 2 + or wav_in.getframerate() != self.SAMPLE_RATE_HZ + ): + raise ValueError( + f"Streaming audio must be mono PCM16 at {self.SAMPLE_RATE_HZ} Hz, got " + f"channels={wav_in.getnchannels()} sampwidth={wav_in.getsampwidth()} " + f"rate={wav_in.getframerate()}." + ) + pcm_bytes = wav_in.readframes(wav_in.getnframes()) + + # Only swap when converters ran. Otherwise the server's raw committed + # buffer is what we want and a swap would be wasted work. + if request.converter_identifiers: + item_id = request.prompt_metadata.get(REALTIME_COMMITTED_ITEM_ID_KEY) + if not item_id: + raise ValueError( + "Streaming request with converters requires the server's committed " + f"item id in piece.prompt_metadata[{REALTIME_COMMITTED_ITEM_ID_KEY!r}]." + ) + await self.swap_user_audio_async( + connection=connection, + committed_event=CommittedEvent(item_id=str(item_id)), + converted_pcm=pcm_bytes, + ) + + turn_future = await self.request_response_async( + connection=connection, + dispatcher=streaming.dispatcher, + ) + result: RealtimeTargetResult = await turn_future + output_audio_path = await self.save_audio( + result.audio_bytes, + num_channels=1, + sample_width=2, + sample_rate=self.SAMPLE_RATE_HZ, + ) else: - raise ValueError(f"Unsupported response type: {response_type}") + if conversation_id not in self._existing_conversation: + connection = await self.connect_async(conversation_id=conversation_id) + self._existing_conversation[conversation_id] = connection + + # Only send config when creating a new connection + await self.send_config(conversation_id=conversation_id, conversation=normalized_conversation) + # Give the server a moment to process the session update + await asyncio.sleep(0.5) + + response_type = request.converted_value_data_type + + # Order of messages sent varies based on the data format of the prompt + if response_type == "audio_path": + output_audio_path, result = await self.send_audio_async( + filename=request.converted_value, + conversation_id=conversation_id, + ) + + elif response_type == "text": + output_audio_path, result = await self.send_text_async( + text=request.converted_value, + conversation_id=conversation_id, + ) + else: + raise ValueError(f"Unsupported response type: {response_type}") text_response_piece = construct_response_from_request( request=request, response_text_pieces=[result.flatten_transcripts()], response_type="text" @@ -393,6 +492,10 @@ async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Me request=request, response_text_pieces=[output_audio_path], response_type="audio_path" ).message_pieces[0] + if result.interrupted: + text_response_piece.prompt_metadata["interrupted"] = True + audio_response_piece.prompt_metadata["interrupted"] = True + response_entry = Message(message_pieces=[text_response_piece, audio_response_piece]) return [response_entry] @@ -432,7 +535,18 @@ async def save_audio( async def cleanup_target(self) -> None: """ Disconnects from the Realtime API connections. + + Stops any active streaming dispatchers before closing their underlying + websocket connections so the dispatch loops do not race with connection + shutdown. Safe to call multiple times. """ + for cid, streaming in list(self._streaming_state.items()): + try: + await streaming.dispatcher.stop() + except Exception as e: + logger.warning(f"Error stopping dispatcher for {cid}: {e}") + self._streaming_state = {} + for conversation_id, connection in list(self._existing_conversation.items()): if connection: try: @@ -453,10 +567,19 @@ async def cleanup_conversation(self, conversation_id: str) -> None: """ Disconnects from the Realtime API for a specific conversation. + Stops any active streaming dispatcher for the conversation before closing + the underlying connection. Safe to call when no streaming state exists. + Args: conversation_id (str): The conversation ID to disconnect from. - """ + streaming = self._streaming_state.pop(conversation_id, None) + if streaming is not None: + try: + await streaming.dispatcher.stop() + except Exception as e: + logger.warning(f"Error stopping dispatcher for {conversation_id}: {e}") + connection = self._existing_conversation.get(conversation_id) if connection: try: @@ -476,6 +599,245 @@ async def send_response_create(self, conversation_id: str) -> None: connection = self._get_connection(conversation_id=conversation_id) await connection.response.create() + async def push_audio_chunk_async(self, *, connection: Any, pcm_bytes: bytes) -> None: + """ + Append a single PCM16 mono @ 24 kHz audio chunk to the server's input buffer. + + Used by streaming-style callers (e.g. ``BargeInAttack``) that source chunks + from an iterator and want to control commit timing externally. Server VAD, + when enabled on the session, decides when to commit and fire response logic. + Empty buffers are accepted as no-ops. + + Args: + connection: Active Realtime API connection from ``self.connect()``. + pcm_bytes: Raw PCM16 mono audio for this chunk. + """ + if not pcm_bytes: + return + audio_b64 = base64.b64encode(pcm_bytes).decode("ascii") + await connection.input_audio_buffer.append(audio=audio_b64) + + async def insert_user_audio_async(self, *, connection: Any, pcm_bytes: bytes) -> None: + """ + Insert a user message containing the given PCM16 mono @ 24 kHz audio into the conversation. + + Use for the convert-on-commit dance — after deleting the server's raw user item, + the attack inserts the converted audio via this method before manually triggering + ``response.create``. + + Args: + connection: Active Realtime API connection. + pcm_bytes: Converted PCM16 mono audio. + """ + audio_b64 = base64.b64encode(pcm_bytes).decode("ascii") + await connection.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [{"type": "input_audio", "audio": audio_b64}], + } + ) + + async def insert_user_text_async(self, *, connection: Any, text: str) -> None: + """ + Insert a user message containing the given text into the conversation. + + Lets streaming attacks mix text turns into an otherwise audio-driven session. + The caller is responsible for triggering ``response.create`` after insertion. + + Args: + connection: Active Realtime API connection. + text: User-side text content. + """ + await connection.conversation.item.create( + item={ + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}], + } + ) + + async def delete_conversation_item_async(self, *, connection: Any, item_id: str) -> None: + """ + Delete a conversation item by id (e.g. the server's raw user audio item). + + Used during convert-on-commit to remove the raw audio item before replacing + it with a converted one. Errors are propagated; callers that want best-effort + deletion should wrap with ``contextlib.suppress``. + + Args: + connection: Active Realtime API connection. + item_id: Server-assigned item id to delete. + """ + await connection.conversation.item.delete(item_id=item_id) + + async def swap_user_audio_async( + self, + *, + connection: Any, + committed_event: CommittedEvent, + converted_pcm: bytes, + ) -> None: + """ + Replace the server's just-committed user audio with converted PCM. + + Inserts ``converted_pcm`` as a new user item and best-effort deletes the original + item identified by ``committed_event``. Hides OpenAI's item-id concept from + callers so streaming attacks can stay provider-agnostic. + + Args: + connection: Active Realtime API connection. + committed_event: Payload received in the on-committed callback. + converted_pcm: PCM16 mono @ 24 kHz audio to insert in place of the original. + """ + await self.insert_user_audio_async(connection=connection, pcm_bytes=converted_pcm) + try: + await self.delete_conversation_item_async(connection=connection, item_id=committed_event.item_id) + except Exception as e: + logger.warning(f"conversation.item.delete failed for {committed_event.item_id}: {e}") + + async def subscribe_events_async( + self, + *, + connection: Any, + conversation_id: str, + on_user_audio_committed: (Callable[[CommittedEvent], Coroutine[Any, Any, None]] | None) = None, + ) -> RealtimeEventDispatcher: + """ + Start consuming events from the connection and route them via the OpenAI dispatcher. + + Also registers per-conversation streaming state so requests routed through + ``send_prompt_async`` for ``conversation_id`` take the streaming swap-and-respond + path inside ``_send_prompt_to_target_async`` instead of the atomic send_audio / + send_text path. + + The returned dispatcher exposes ``stop()`` to tear down the background task and + drain in-flight callback tasks, and a ``failure`` property that callers can poll + between operations to detect a dead dispatch loop (e.g. websocket closed). + + Args: + connection: Active Realtime API connection from ``self.connect()``. + conversation_id: Conversation id for the realtime session. Used as the key + under which streaming state is registered. + on_user_audio_committed: Async callback fired when server VAD finalizes + a user audio buffer. Called as a background task. + + Returns: + The started dispatcher. Pass it to ``request_response_async`` for turn + futures, poll ``failure`` for dispatch-loop errors, and call ``stop()`` + (or ``cleanup_conversation``) to tear it down. + """ + dispatcher = _OpenAIRealtimeDispatcher( + connection=connection, + on_user_audio_committed=on_user_audio_committed, + ) + self._streaming_state[conversation_id] = _StreamingConversationState(dispatcher=dispatcher) + # Register the connection under the same key so cleanup_conversation / + # cleanup_target can find and close it without callers reaching into + # private state. + self._existing_conversation[conversation_id] = connection + await dispatcher.start() + return dispatcher + + async def request_response_async( + self, + *, + connection: Any, + dispatcher: RealtimeEventDispatcher, + ) -> asyncio.Future[RealtimeTargetResult]: + """ + Trigger ``response.create`` and return a future that resolves when the turn ends. + + Constructs a fresh ``RealtimeTurnState``, binds it to the dispatcher as the + active turn, then sends ``response.create``. The dispatcher resolves the + returned future via ``response.done`` (with ``interrupted=False``) or via + the barge-in cancel path (with ``interrupted=True``). + + Args: + connection: Active Realtime API connection. + dispatcher: Subscription handle previously returned by + ``subscribe_events_async``. Must not have another turn pending. + + Returns: + Future resolved with the assembled ``RealtimeTargetResult`` when this + turn ends (normally or via barge-in). + + Raises: + RuntimeError: If another turn is already pending on the dispatcher. + """ + state = RealtimeTurnState(completion=asyncio.get_running_loop().create_future()) + dispatcher.register_turn(state) + await connection.response.create() + return state.completion + + async def send_streaming_session_config_async( + self, *, connection: Any, conversation: list[Message] | None = None + ) -> None: + """ + Configure the realtime session for streaming use: server VAD with manual response creation. + + Emits the same session config as the atomic path except ``turn_detection.create_response`` + is forced to False so the streaming attack can swap the raw user audio item for converted + audio before triggering ``response.create``. + + Args: + connection: Active Realtime API connection. + conversation: Optional conversation history; if its first message is a system + message, its text becomes the session's instructions. Defaults to None, + in which case the default system prompt is used. + + Raises: + ValueError: If the target was constructed without server VAD. + """ + if self._server_vad is None: + raise ValueError( + "send_streaming_session_config_async requires server VAD; " + "construct RealtimeTarget(server_vad=True) or pass a ServerVadConfig." + ) + system_prompt = self._get_system_prompt_from_conversation(conversation=conversation or []) + config = self._set_system_prompt_and_config_vars(system_prompt=system_prompt) + turn_detection = config.get("audio", {}).get("input", {}).get("turn_detection") + if turn_detection is not None: + turn_detection["create_response"] = False + await connection.session.update(session=config) + + async def _stream_pcm_async( + self, + *, + connection: Any, + pcm_bytes: bytes, + commit: bool, + chunk_ms: int = 100, + sample_rate: int = 24000, + ) -> None: + """ + Stream raw PCM16 audio to the Realtime API as ``input_audio_buffer.append`` chunks. + + Operates on raw PCM bytes (not WAV) so this helper can back both the + WAV-file path and future per-frame streaming consumers (e.g. browser audio + forwarded by a GUI backend). Caller decides whether to manually commit; + server VAD commits automatically when enabled. + + Args: + connection: Active Realtime API connection from ``self.connect()``. + pcm_bytes (bytes): Raw PCM16 mono audio. Empty buffers are accepted + and result in zero appends. + commit (bool): When True, sends ``input_audio_buffer.commit`` after the + final chunk. Pass False when server VAD is committing automatically. + chunk_ms (int): Milliseconds of audio per chunk. Defaults to 100. + sample_rate (int): PCM sample rate in Hz. Defaults to 24000. + """ + bytes_per_sample = 2 # PCM16 + chunk_size = (chunk_ms * sample_rate * bytes_per_sample) // 1000 + + for offset in range(0, len(pcm_bytes), chunk_size): + chunk = pcm_bytes[offset : offset + chunk_size] + audio_b64 = base64.b64encode(chunk).decode("ascii") + await connection.input_audio_buffer.append(audio=audio_b64) + + if commit: + await connection.input_audio_buffer.commit() + async def receive_events(self, conversation_id: str) -> RealtimeTargetResult: """ Continuously receive events from the OpenAI Realtime API connection. @@ -806,3 +1168,134 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> This implementation exists to satisfy the abstract base class requirement. """ raise NotImplementedError("RealtimeTarget uses receive_events for message construction") + + +class _OpenAIRealtimeDispatcher(RealtimeEventDispatcher): + """ + Concrete ``RealtimeEventDispatcher`` for the OpenAI Realtime API. + + Routes OpenAI server events into the active ``RealtimeTurnState`` and issues + ``response.cancel`` plus ``conversation.item.truncate`` when interrupted. + """ + + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + """Route an OpenAI Realtime event to the active turn or to an input-side callback.""" + event_type = getattr(event, "type", "") + + # Capture audio_start_ms from speech_started for the next committed event. + # The server reports it reliably here but omits it from the commit event itself. + # Do not return — the downstream state-aware branch still needs to fire the + # barge-in cancel when speech starts mid-response. + if event_type == "input_audio_buffer.speech_started": + speech_start = getattr(event, "audio_start_ms", None) + if speech_start is not None: + self._pending_speech_start_ms = speech_start + + # Input-side events fire callbacks regardless of whether a turn is registered. + if event_type == "input_audio_buffer.committed": + item_id = getattr(event, "item_id", None) + if item_id is None: + return + audio_start_ms = getattr(event, "audio_start_ms", None) + if audio_start_ms is None: + audio_start_ms = self._pending_speech_start_ms + self._pending_speech_start_ms = None + self._fire_committed_callback( + CommittedEvent( + item_id=item_id, + audio_start_ms=audio_start_ms, + ) + ) + # Fall through: also include the bookkeeping below (none currently uses committed). + return + + # Remaining events are output-side and mutate per-turn state; drop if no turn. + if state is None or state.completion.done(): + return + + if event_type == "response.created": + state.is_responding = True + response = getattr(event, "response", None) + if response is not None: + state.last_response_id = getattr(response, "id", None) + return + + if event_type in ("response.output_item.added", "response.output_item.created"): + item = getattr(event, "item", None) + if item is not None: + state.current_item_id = getattr(item, "id", None) + return + + if event_type in ("response.audio.delta", "response.output_audio.delta"): + delta = getattr(event, "delta", "") + if delta: + state.delivered_audio.extend(base64.b64decode(delta)) + return + + if event_type in ("response.audio_transcript.delta", "response.output_audio_transcript.delta"): + delta = getattr(event, "delta", "") + if delta: + state.delivered_transcripts.append(delta) + return + + if event_type == "response.done": + response = getattr(event, "response", None) + done_response_id = getattr(response, "id", None) if response is not None else None + if state.last_response_id is not None and done_response_id != state.last_response_id: + # Stale event from a cancelled response; drop without resolving. + return + state.is_responding = False + state.completion.set_result( + RealtimeTargetResult( + audio_bytes=bytes(state.delivered_audio), + transcripts=list(state.delivered_transcripts), + ) + ) + return + + if event_type == "input_audio_buffer.speech_started" and state.is_responding: + await self._cancel(state=state) + state.is_responding = False + state.completion.set_result( + RealtimeTargetResult( + audio_bytes=bytes(state.delivered_audio), + transcripts=list(state.delivered_transcripts), + interrupted=True, + ) + ) + return + + if event_type == "error": + error = getattr(event, "error", None) + message = getattr(error, "message", "unknown") if error is not None else "unknown" + state.completion.set_exception(RuntimeError(f"Realtime API error: {message}")) + return + + async def _cancel(self, *, state: RealtimeTurnState) -> None: + """ + Truncate the in-flight response's conversation item to what was actually delivered. + + The server auto-cancels the response when it detects new speech, so we only need to + trim the conversation history to match the audio we received. + + Marks ``state.interrupted = True`` even when the truncate call fails. + Does not resolve ``state.completion``; the caller (``_route_event``) does that. + + Args: + state (RealtimeTurnState): The turn whose response should be cancelled. + """ + if state.current_item_id is not None: + # PCM16 @ 24 kHz: 48 bytes per millisecond. + audio_end_ms = len(state.delivered_audio) // 48 + try: + await self._connection.conversation.item.truncate( + item_id=state.current_item_id, + content_index=0, + audio_end_ms=audio_end_ms, + ) + except Exception as e: + logger.warning( + f"conversation.item.truncate failed for item {state.current_item_id} " + f"(audio_end_ms={audio_end_ms}): {e}" + ) + state.interrupted = True diff --git a/tests/unit/executor/attack/streaming/__init__.py b/tests/unit/executor/attack/streaming/__init__.py new file mode 100644 index 0000000000..9a0454564d --- /dev/null +++ b/tests/unit/executor/attack/streaming/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/executor/attack/streaming/test_barge_in.py b/tests/unit/executor/attack/streaming/test_barge_in.py new file mode 100644 index 0000000000..825ed952c9 --- /dev/null +++ b/tests/unit/executor/attack/streaming/test_barge_in.py @@ -0,0 +1,717 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Unit tests for ``BargeInAttack`` and supporting helpers.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from pyrit.executor.attack import BargeInAttack, BargeInAttackContext +from pyrit.executor.attack.core import AttackConverterConfig, AttackParameters +from pyrit.executor.attack.streaming.barge_in import _trim_snapshot_to_speech +from pyrit.models import AttackOutcome, Message, MessagePiece +from pyrit.prompt_normalizer import PromptConverterConfiguration +from pyrit.prompt_target import RealtimeTarget +from pyrit.prompt_target.common.realtime_audio import REALTIME_COMMITTED_ITEM_ID_KEY, CommittedEvent + +if TYPE_CHECKING: + from collections.abc import AsyncIterator + +_CLEAN_ENV = {"OPENAI_REALTIME_UNDERLYING_MODEL": ""} + + +@pytest.fixture +@patch.dict("os.environ", _CLEAN_ENV) +def vad_target(sqlite_instance): + return RealtimeTarget(api_key="test_key", endpoint="wss://test_url", model_name="test", server_vad=True) + + +async def _aiter(chunks: list[bytes]) -> AsyncIterator[bytes]: + for c in chunks: + yield c + + +def _attack_context(*, audio_chunks: AsyncIterator[bytes], objective: str = "obj") -> BargeInAttackContext[Any]: + return BargeInAttackContext( + params=AttackParameters(objective=objective), + audio_chunks=audio_chunks, + ) + + +def _mock_connection() -> AsyncMock: + connection = AsyncMock() + connection.input_audio_buffer.append = AsyncMock() + connection.conversation.item.create = AsyncMock() + connection.conversation.item.delete = AsyncMock() + connection.response.create = AsyncMock() + connection.session.update = AsyncMock() + connection.close = AsyncMock() + return connection + + +# ---- Construction validation ----------------------------------------------------------------- + + +@patch.dict("os.environ", _CLEAN_ENV) +def test_constructor_rejects_target_without_streaming_capability(sqlite_instance): + """A target whose capabilities lack STREAMING_BARGE_IN must be rejected at construction.""" + from pyrit.prompt_target import OpenAIChatTarget + + no_streaming = OpenAIChatTarget(api_key="k", endpoint="https://x", model_name="m") + with pytest.raises(Exception, match="streaming_barge_in"): + BargeInAttack(objective_target=no_streaming) + + +def test_constructor_succeeds_with_vad_target(vad_target): + """A RealtimeTarget declares STREAMING_BARGE_IN — construction succeeds.""" + attack = BargeInAttack(objective_target=vad_target) + assert attack.get_objective_target() is vad_target + + +def test_constructor_succeeds_even_without_server_vad_enabled(sqlite_instance): + """Capability check passes; server VAD is a runtime config concern surfaced when used.""" + with patch.dict("os.environ", _CLEAN_ENV): + no_vad = RealtimeTarget(api_key="k", endpoint="wss://test_url", model_name="test") + # Construction succeeds — capability is about the target type, not server_vad config. + attack = BargeInAttack(objective_target=no_vad) + assert attack.get_objective_target() is no_vad + + +def test_constructor_default_max_post_stream_wait_seconds(vad_target): + """When not passed, max_post_stream_wait_seconds takes the class default.""" + attack = BargeInAttack(objective_target=vad_target) + assert attack._max_post_stream_wait_seconds == BargeInAttack.DEFAULT_MAX_POST_STREAM_WAIT_SECONDS + + +def test_constructor_accepts_custom_max_post_stream_wait_seconds(vad_target): + """max_post_stream_wait_seconds is configurable per-instance.""" + attack = BargeInAttack(objective_target=vad_target, max_post_stream_wait_seconds=120.0) + assert attack._max_post_stream_wait_seconds == 120.0 + + +def test_realtime_target_satisfies_streaming_barge_in_protocol(vad_target): + """RealtimeTarget structurally implements StreamingBargeInTarget so the cast is safe.""" + from pyrit.prompt_target.common.realtime_audio import StreamingBargeInTarget + + assert isinstance(vad_target, StreamingBargeInTarget) + + +# ---- Context validation ---------------------------------------------------------------------- + + +async def test_validate_context_requires_objective(vad_target): + attack = BargeInAttack(objective_target=vad_target) + ctx = BargeInAttackContext( + params=AttackParameters(objective=""), + audio_chunks=_aiter([b"\x00" * 96]), + ) + with pytest.raises(ValueError, match="objective"): + attack._validate_context(context=ctx) + + +async def test_validate_context_requires_audio_chunks(vad_target): + attack = BargeInAttack(objective_target=vad_target) + ctx = BargeInAttackContext( + params=AttackParameters(objective="o"), + audio_chunks=None, + ) + with pytest.raises(ValueError, match="audio_chunks"): + attack._validate_context(context=ctx) + + +# ---- _setup_async + prepended_conversation persistence --------------------------------------- + + +async def test_setup_async_persists_prepended_conversation_to_memory(vad_target): + """Prepended_conversation messages must be written to memory on setup like other attacks do.""" + attack = BargeInAttack(objective_target=vad_target) + sys_msg = Message( + message_pieces=[ + MessagePiece( + role="system", + original_value="You are a strict assistant.", + original_value_data_type="text", + converted_value="You are a strict assistant.", + converted_value_data_type="text", + conversation_id="ignored-by-setup", + ) + ] + ) + user_msg = Message( + message_pieces=[ + MessagePiece( + role="user", + original_value="prior user turn", + original_value_data_type="text", + converted_value="prior user turn", + converted_value_data_type="text", + conversation_id="ignored-by-setup", + ) + ] + ) + assistant_msg = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value="prior assistant turn", + original_value_data_type="text", + converted_value="prior assistant turn", + converted_value_data_type="text", + conversation_id="ignored-by-setup", + ) + ] + ) + + ctx = BargeInAttackContext( + params=AttackParameters( + objective="o", + prepended_conversation=[sys_msg, user_msg, assistant_msg], + ), + audio_chunks=_aiter([b"\x00" * 96]), + ) + + add_calls: list[Any] = [] + with patch.object(attack._conversation_manager._memory, "add_message_to_memory") as mock_add: + mock_add.side_effect = lambda **kw: add_calls.append(kw["request"]) + await attack._setup_async(context=ctx) + + # All three prepended messages should have been written to memory under the + # attack's conversation_id; assistant role becomes simulated_assistant on storage. + assert len(add_calls) == 3 + storage_roles = [m.message_pieces[0].get_role_for_storage() for m in add_calls] + assert storage_roles == ["system", "user", "simulated_assistant"] + # All three messages share the context's conversation_id post-setup. + for m in add_calls: + assert m.message_pieces[0].conversation_id == ctx.conversation_id + + +async def test_setup_async_no_op_when_prepended_conversation_empty(vad_target): + """Empty prepended_conversation: no memory writes, no crash.""" + attack = BargeInAttack(objective_target=vad_target) + ctx = BargeInAttackContext( + params=AttackParameters(objective="o"), # no prepended_conversation + audio_chunks=_aiter([b"\x00" * 96]), + ) + + add_calls: list[Any] = [] + with patch.object(attack._conversation_manager._memory, "add_message_to_memory") as mock_add: + mock_add.side_effect = lambda **kw: add_calls.append(kw["request"]) + await attack._setup_async(context=ctx) + + assert add_calls == [] + + +# ---- Streaming loop end-to-end --------------------------------------------------------------- + + +def _setup_streaming_target(vad_target, *, future_response: Message | None = None) -> AsyncMock: + """ + Mock the streaming-mode surface on ``vad_target`` and return the connection mock. + + Stubs ``connect_async``, ``send_streaming_session_config_async``, ``push_audio_chunk_async``, + ``subscribe_events_async``, ``save_audio``, and ``cleanup_conversation`` so a callback can + be invoked mid-stream without exercising the real target machinery. + """ + connection = _mock_connection() + vad_target.connect_async = AsyncMock(return_value=connection) + vad_target.send_streaming_session_config_async = AsyncMock() + vad_target.push_audio_chunk_async = AsyncMock() + vad_target.save_audio = AsyncMock(return_value="/tmp/snapshot.wav") + vad_target.cleanup_conversation = AsyncMock() + return connection + + +def _capture_committed_callback(vad_target, captured: dict[str, Any]) -> None: + """Wire ``subscribe_events_async`` to capture the registered ``on_user_audio_committed``.""" + + async def fake_subscribe(*, connection, conversation_id, on_user_audio_committed): + captured["on_committed"] = on_user_audio_committed + return AsyncMock() + + vad_target.subscribe_events_async = AsyncMock(side_effect=fake_subscribe) + + +def _stub_send_prompt(attack: BargeInAttack, return_value: Message | None = None) -> AsyncMock: + """Replace the attack's prompt_normalizer.send_prompt_async with an AsyncMock and return it.""" + if return_value is None: + return_value = Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value="ok", + original_value_data_type="text", + converted_value="ok", + converted_value_data_type="text", + conversation_id="any", + ) + ] + ) + send_mock = AsyncMock(return_value=return_value) + attack._prompt_normalizer.send_prompt_async = send_mock + return send_mock + + +async def test_perform_async_streams_chunks_and_tears_down(vad_target): + """Happy path: connect, send config, subscribe, push chunks, then cleanup_conversation — no commits.""" + attack = BargeInAttack(objective_target=vad_target) + connection = _setup_streaming_target(vad_target) + dispatcher = AsyncMock() + vad_target.subscribe_events_async = AsyncMock(return_value=dispatcher) + + chunks = [b"\x11" * 480, b"\x22" * 480, b"\x33" * 240] + ctx = _attack_context(audio_chunks=_aiter(chunks)) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + result = await attack._perform_async(context=ctx) + + vad_target.connect_async.assert_awaited_once_with(conversation_id=ctx.conversation_id) + vad_target.send_streaming_session_config_async.assert_awaited_once() + vad_target.subscribe_events_async.assert_awaited_once() + assert vad_target.push_audio_chunk_async.await_count == len(chunks) + pushed = [call.kwargs["pcm_bytes"] for call in vad_target.push_audio_chunk_async.await_args_list] + assert pushed == chunks + vad_target.cleanup_conversation.assert_awaited_once_with(ctx.conversation_id) + assert result.executed_turns == 0 + assert result.outcome == AttackOutcome.UNDETERMINED + + +async def test_perform_async_calls_send_prompt_async_on_commit(vad_target): + """A commit must invoke prompt_normalizer.send_prompt_async with an audio_path Message.""" + bump = MagicMock() + bump.get_identifier = MagicMock(return_value=MagicMock()) + converter_config = AttackConverterConfig( + request_converters=PromptConverterConfiguration.from_converters(converters=[bump]), + ) + attack = BargeInAttack(objective_target=vad_target, attack_converter_config=converter_config) + send_mock = _stub_send_prompt(attack) + _setup_streaming_target(vad_target) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield b"\x05" * 480 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="item_42"))) + + ctx = _attack_context(audio_chunks=chunks_then_commit()) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + result = await attack._perform_async(context=ctx) + + send_mock.assert_awaited_once() + kwargs = send_mock.call_args.kwargs + sent_message = kwargs["message"] + assert sent_message.message_pieces[0].converted_value_data_type == "audio_path" + assert sent_message.message_pieces[0].conversation_id == ctx.conversation_id + assert sent_message.message_pieces[0].prompt_metadata[REALTIME_COMMITTED_ITEM_ID_KEY] == "item_42" + assert kwargs["target"] is vad_target + assert kwargs["request_converter_configurations"] == attack._request_converters + assert kwargs["conversation_id"] == ctx.conversation_id + assert result.executed_turns == 1 + + +async def test_perform_async_message_carries_snapshot_audio_path(vad_target): + """The audio_path on the user piece must point at the persisted snapshot WAV.""" + attack = BargeInAttack(objective_target=vad_target) + send_mock = _stub_send_prompt(attack) + connection = _setup_streaming_target(vad_target) + vad_target.save_audio = AsyncMock(return_value="/tmp/persisted_snapshot.wav") + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + raw_chunk = b"\x07" * 96 + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield raw_chunk + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="i"))) + + ctx = _attack_context(audio_chunks=chunks_then_commit()) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + await attack._perform_async(context=ctx) + + # save_audio called with the snapshot PCM; the resulting path lands on the message piece. + save_kwargs_or_args = vad_target.save_audio.call_args + saved_pcm = ( + save_kwargs_or_args.args[0] if save_kwargs_or_args.args else save_kwargs_or_args.kwargs.get("audio_bytes") + ) + assert saved_pcm == raw_chunk + piece = send_mock.call_args.kwargs["message"].message_pieces[0] + assert piece.original_value == "/tmp/persisted_snapshot.wav" + assert piece.converted_value == "/tmp/persisted_snapshot.wav" + + +async def test_perform_async_clears_raw_buffer_between_commits(vad_target): + """Each commit gets fresh PCM: the snapshot saved for turn 2 has no carryover from turn 1.""" + attack = BargeInAttack(objective_target=vad_target) + _stub_send_prompt(attack) + _setup_streaming_target(vad_target) + saved_pcm: list[bytes] = [] + + async def fake_save_audio(audio_bytes, **_): + saved_pcm.append(audio_bytes) + return f"/tmp/snap_{len(saved_pcm)}.wav" + + vad_target.save_audio = AsyncMock(side_effect=fake_save_audio) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + async def chunks_two_commits() -> AsyncIterator[bytes]: + yield b"\x01" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="i1"))) + yield b"\x02" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="i2"))) + + ctx = _attack_context(audio_chunks=chunks_two_commits()) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + await attack._perform_async(context=ctx) + + assert saved_pcm == [b"\x01" * 96, b"\x02" * 96] + + +async def test_perform_async_tracks_last_response_and_turn_count(vad_target): + """AttackResult.last_response is the last Message from send_prompt_async; count matches commits.""" + attack = BargeInAttack(objective_target=vad_target) + responses_in_order = [ + Message( + message_pieces=[ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + converted_value=text, + converted_value_data_type="text", + conversation_id="x", + ) + ] + ) + for text in ("first", "second", "final") + ] + send_mock = AsyncMock(side_effect=responses_in_order) + attack._prompt_normalizer.send_prompt_async = send_mock + _setup_streaming_target(vad_target) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + async def chunks_three_commits() -> AsyncIterator[bytes]: + for i in range(3): + yield bytes([i + 1]) * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id=f"i{i}"))) + + ctx = _attack_context(audio_chunks=chunks_three_commits()) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + result = await attack._perform_async(context=ctx) + + assert result.executed_turns == 3 + assert result.last_response is not None + assert result.last_response.converted_value == "final" + + +async def test_perform_async_cleans_up_even_on_exception(vad_target): + """If the chunk loop raises, cleanup_conversation still fires.""" + attack = BargeInAttack(objective_target=vad_target) + _setup_streaming_target(vad_target) + vad_target.push_audio_chunk_async = AsyncMock(side_effect=RuntimeError("push exploded")) + vad_target.subscribe_events_async = AsyncMock(return_value=AsyncMock()) + + ctx = _attack_context(audio_chunks=_aiter([b"\x00" * 96])) + + with pytest.raises(RuntimeError, match="push exploded"): + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + await attack._perform_async(context=ctx) + + vad_target.cleanup_conversation.assert_awaited_once_with(ctx.conversation_id) + + +async def test_perform_async_swallows_callback_exception(vad_target): + """If send_prompt_async raises mid-turn, the session keeps going (no executed turn).""" + attack = BargeInAttack(objective_target=vad_target) + attack._prompt_normalizer.send_prompt_async = AsyncMock(side_effect=RuntimeError("converter blew up")) + _setup_streaming_target(vad_target) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield b"\x00" * 96 + await asyncio.create_task(captured["on_committed"](CommittedEvent(item_id="i"))) + + ctx = _attack_context(audio_chunks=chunks_then_commit()) + + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + result = await attack._perform_async(context=ctx) + + # The callback caught the exception; no turn counted as successful. + assert result.executed_turns == 0 + + +# ---- send_streaming_session_config_async (target-side helper added in R4a) ------------------- + + +async def test_send_streaming_session_config_async_emits_create_response_false(vad_target): + """The streaming session config must flip create_response to False on turn_detection.""" + connection = _mock_connection() + await vad_target.send_streaming_session_config_async(connection=connection) + connection.session.update.assert_awaited_once() + config = connection.session.update.call_args.kwargs["session"] + assert config["audio"]["input"]["turn_detection"]["create_response"] is False + + +@patch.dict("os.environ", _CLEAN_ENV) +async def test_send_streaming_session_config_async_requires_server_vad(sqlite_instance): + """Without server VAD, sending streaming session config must raise.""" + no_vad = RealtimeTarget(api_key="k", endpoint="wss://test_url", model_name="test") + connection = _mock_connection() + with pytest.raises(ValueError, match="server VAD"): + await no_vad.send_streaming_session_config_async(connection=connection) + + +async def test_send_streaming_session_config_async_uses_system_message_from_conversation(vad_target): + """If the prepended conversation begins with a system message, it becomes session instructions.""" + connection = _mock_connection() + system_msg = Message( + message_pieces=[ + MessagePiece( + role="system", + original_value="You are a strict assistant.", + original_value_data_type="text", + converted_value="You are a strict assistant.", + converted_value_data_type="text", + conversation_id="x", + ) + ] + ) + await vad_target.send_streaming_session_config_async(connection=connection, conversation=[system_msg]) + config = connection.session.update.call_args.kwargs["session"] + assert config["instructions"] == "You are a strict assistant." + + +# ---- _trim_snapshot_to_speech (pre-speech silence trim) ------------------------------------- + + +def test_trim_drops_leading_silence_using_audio_start_ms(): + """When audio_start_ms is set, everything before (audio_start_ms - prefix_padding_ms) is trimmed.""" + # 24 kHz mono PCM16 → 48 bytes per ms. 1000 ms of silence + 100 ms of "speech". + silence = b"\x00" * (1000 * 48) + speech = b"\x11" * (100 * 48) + buffer = silence + speech + + trimmed = _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=1000, # speech starts at 1000 ms + prefix_padding_ms=200, # keep 200 ms before speech + ) + + # Expect: dropped 800 ms (1000 - 200) of silence; kept 200 ms silence + 100 ms speech. + assert len(trimmed) == (200 + 100) * 48 + assert trimmed[-len(speech) :] == speech + + +def test_trim_passes_through_when_audio_start_ms_missing(): + """If the server didn't report audio_start_ms, no trim happens.""" + buffer = b"\xff" * 480 + assert ( + _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=None, + prefix_padding_ms=300, + ) + is buffer + ) + + +def test_trim_passes_through_when_audio_start_ms_zero(): + """audio_start_ms == 0 means speech started immediately; no trim.""" + buffer = b"\xff" * 480 + assert ( + _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=0, + prefix_padding_ms=300, + ) + is buffer + ) + + +def test_trim_raises_on_negative_audio_start_ms(): + """A negative audio_start_ms is a server contract violation, not 'unknown'.""" + buffer = b"\xff" * 480 + with pytest.raises(ValueError, match="audio_start_ms must be >= 0"): + _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=-100, + prefix_padding_ms=300, + ) + + +@pytest.mark.parametrize( + "sample_rate_hz, sample_width_bytes, channels", + [(0, 2, 1), (24000, 0, 1), (24000, 2, 0), (-100, 2, 1), (24000, -1, 1)], +) +def test_trim_raises_on_nonpositive_format_args(sample_rate_hz, sample_width_bytes, channels): + """Non-positive sample rate, width, or channel count signals a misconfiguration; raise.""" + buffer = b"\xff" * 480 + with pytest.raises(ValueError, match="must all be positive"): + _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=sample_rate_hz, + audio_start_ms=100, + prefix_padding_ms=0, + sample_width_bytes=sample_width_bytes, + channels=channels, + ) + + +def test_trim_clamps_when_audio_start_ms_less_than_prefix_padding(): + """audio_start_ms - prefix_padding_ms shouldn't go negative.""" + buffer = b"\xab" * (500 * 48) + trimmed = _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=100, + prefix_padding_ms=300, + ) + # max(0, 100 - 300) = 0 → no bytes dropped. + assert trimmed == buffer + + +def test_trim_aligns_to_sample_boundary(): + """Trim must land on a sample-frame boundary (2 bytes for PCM16 mono) so playback isn't garbled.""" + # Sample rate 8000 Hz → 16 bytes/ms; audio_start_ms=3, prefix=0 → start_byte=48 (aligned). + buffer = bytes(range(256)) * 4 # arbitrary bytes + trimmed = _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=8000, + audio_start_ms=3, + prefix_padding_ms=0, + sample_width_bytes=2, + channels=1, + ) + # 48 bytes is already a frame boundary (48 % 2 == 0). + assert len(trimmed) == len(buffer) - 48 + # Sanity: the trim point is sample-aligned. + assert (len(buffer) - len(trimmed)) % 2 == 0 + + +def test_trim_passes_through_when_computed_start_exceeds_buffer(): + """Safety: if audio_start_ms points past the buffer, return the buffer unchanged.""" + buffer = b"\x00" * 480 # 10 ms at 24 kHz + trimmed = _trim_snapshot_to_speech( + raw_buffer=buffer, + sample_rate_hz=24000, + audio_start_ms=10_000, + prefix_padding_ms=0, + ) + assert trimmed is buffer + + +async def test_perform_async_trims_first_turn_using_audio_start_ms(vad_target): + """Turn 1: buffer_start_session_ms=0, so audio_start_ms is already buffer-relative.""" + from pyrit.prompt_target.common.realtime_audio import ServerVadConfig + + # Pin prefix_padding_ms to a known value so the expected byte count is unambiguous. + vad_target._server_vad = ServerVadConfig(prefix_padding_ms=300, silence_duration_ms=500) + + attack = BargeInAttack(objective_target=vad_target) + send_mock = _stub_send_prompt(attack) + _setup_streaming_target(vad_target) + saved_pcm: list[bytes] = [] + + async def fake_save_audio(audio_bytes, **_): + saved_pcm.append(audio_bytes) + return "/tmp/snap.wav" + + vad_target.save_audio = AsyncMock(side_effect=fake_save_audio) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + # 1000 ms of leading silence + 100 ms speech-like payload at 24 kHz mono PCM16 → 48 bytes/ms. + silence = b"\x00" * (1000 * 48) + speech = b"\x11" * (100 * 48) + + async def chunks_then_commit() -> AsyncIterator[bytes]: + yield silence + speech + # Server says speech started at 1000 ms (session-relative); with prefix_padding_ms=300, drop 700 ms. + await asyncio.create_task( + captured["on_committed"](CommittedEvent(item_id="i", audio_start_ms=1000)), + ) + + ctx = _attack_context(audio_chunks=chunks_then_commit()) + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + await attack._perform_async(context=ctx) + + # Expect save_audio to receive the trimmed snapshot: + # max(0, 1000 - 300) = 700 ms dropped; remaining = 300 ms silence + 100 ms speech = 400 ms. + assert len(saved_pcm) == 1 + assert len(saved_pcm[0]) == 400 * 48 + assert saved_pcm[0].endswith(speech) + send_mock.assert_awaited_once() + + +async def test_perform_async_trims_second_turn_with_session_relative_offset(vad_target): + """Turn 2: audio_start_ms is session-relative; the attack converts it to buffer-relative. + + Without the conversion, a session-relative audio_start_ms larger than the local buffer + would skip the trim (passthrough on out-of-range), letting silence reach the model. + """ + from pyrit.prompt_target.common.realtime_audio import ServerVadConfig + + vad_target._server_vad = ServerVadConfig(prefix_padding_ms=300, silence_duration_ms=500) + + attack = BargeInAttack(objective_target=vad_target) + _stub_send_prompt(attack) + _setup_streaming_target(vad_target) + saved_pcm: list[bytes] = [] + + async def fake_save_audio(audio_bytes, **_): + saved_pcm.append(audio_bytes) + return "/tmp/snap.wav" + + vad_target.save_audio = AsyncMock(side_effect=fake_save_audio) + captured: dict[str, Any] = {} + _capture_committed_callback(vad_target, captured) + + silence_500 = b"\x00" * (500 * 48) # 500 ms silence + speech_short = b"\x11" * (100 * 48) # 100 ms speech-like + silence_2000 = b"\x00" * (2000 * 48) # 2000 ms silence (between turns) + speech_long = b"\x22" * (300 * 48) # 300 ms speech-like (turn 2) + + async def two_turns() -> AsyncIterator[bytes]: + # Turn 1: 500 ms silence + 100 ms speech; total local buffer = 600 ms. + yield silence_500 + speech_short + # Server VAD fires commit at session_ms ≈ 600 with audio_start_ms = 500 (session-relative). + await asyncio.create_task( + captured["on_committed"](CommittedEvent(item_id="i1", audio_start_ms=500)), + ) + # Turn 2: 2000 ms silence (since turn 1's commit) + 300 ms speech. + # session_ms_at_speech_start ≈ 600 + 2000 = 2600. + yield silence_2000 + speech_long + await asyncio.create_task( + captured["on_committed"](CommittedEvent(item_id="i2", audio_start_ms=2600)), + ) + + ctx = _attack_context(audio_chunks=two_turns()) + with patch.object(attack, "_max_post_stream_wait_seconds", 0): + await attack._perform_async(context=ctx) + + assert len(saved_pcm) == 2 + + # Turn 1: buffer_relative_start = 500 - 0 = 500; trim = max(0, 500 - 300) = 200 ms; + # remaining = 300 ms pre-speech-padding + 100 ms speech = 400 ms. + assert len(saved_pcm[0]) == 400 * 48 + assert saved_pcm[0].endswith(speech_short) + + # Turn 2: buffer_start_session_ms advanced by 600 ms (turn 1's full buffer duration). + # buffer_relative_start = 2600 - 600 = 2000; trim = max(0, 2000 - 300) = 1700 ms; + # remaining = 300 ms pre-speech-padding + 300 ms speech = 600 ms. + assert len(saved_pcm[1]) == 600 * 48 + assert saved_pcm[1].endswith(speech_long) diff --git a/tests/unit/prompt_target/target/test_realtime_audio.py b/tests/unit/prompt_target/target/test_realtime_audio.py new file mode 100644 index 0000000000..005814a5e4 --- /dev/null +++ b/tests/unit/prompt_target/target/test_realtime_audio.py @@ -0,0 +1,234 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.prompt_target.common.realtime_audio import ( + CommittedEvent, + RealtimeEventDispatcher, + RealtimeTargetResult, + RealtimeTurnState, +) + + +async def test_realtime_turn_state_defaults(): + """Newly constructed turn state must be empty: no audio, no transcripts, not responding, not interrupted.""" + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + assert state.is_responding is False + assert state.interrupted is False + assert bytes(state.delivered_audio) == b"" + assert state.delivered_transcripts == [] + assert state.current_item_id is None + assert state.last_response_id is None + + +def test_realtime_target_result_interrupted_defaults_false(): + """RealtimeTargetResult must default interrupted=False so atomic callers see no change.""" + result = RealtimeTargetResult() + assert result.interrupted is False + assert result.audio_bytes == b"" + assert result.transcripts == [] + + +def test_realtime_target_result_carries_interrupted_when_set(): + """The interrupted flag round-trips through construction.""" + result = RealtimeTargetResult(audio_bytes=b"partial", transcripts=["hi"], interrupted=True) + assert result.interrupted is True + + +class _RecordingDispatcher(RealtimeEventDispatcher): + """Minimal concrete dispatcher for testing the generic base class behavior.""" + + def __init__(self, *, connection: Any) -> None: + super().__init__(connection=connection) + self.routed_events: list[Any] = [] + self.cancel_calls: int = 0 + + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + self.routed_events.append(event) + # End the turn on a sentinel event so tests can drain the loop. + if state is not None and getattr(event, "_finish", False): + state.completion.set_result(RealtimeTargetResult()) + + async def _cancel(self, *, state: RealtimeTurnState) -> None: + self.cancel_calls += 1 + state.interrupted = True + + +class _ScriptedConnection: + """Async-iterable connection that yields a fixed event list once registered.""" + + def __init__(self, events: list[Any]) -> None: + self._events = events + + async def __aiter__(self): + for event in self._events: + yield event + + +def _sentinel_event(*, finish: bool = False) -> AsyncMock: + event = AsyncMock() + event._finish = finish + return event + + +async def test_dispatcher_start_is_idempotent(): + """Calling start twice must not spawn two tasks.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + await dispatcher.start() + first_task = dispatcher._task + await dispatcher.start() + assert dispatcher._task is first_task + await dispatcher.stop() + + +async def test_dispatcher_stop_releases_task(): + """stop must cancel the task and clear the reference.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + await dispatcher.start() + await dispatcher.stop() + assert dispatcher._task is None + + +async def test_dispatcher_register_turn_rejects_concurrent_active_turn(): + """Registering a turn while another is active and unresolved must raise.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + first = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + second = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + dispatcher.register_turn(first) + with pytest.raises(RuntimeError, match="already active"): + dispatcher.register_turn(second) + + +async def test_dispatcher_register_turn_allows_replacement_after_completion(): + """Once the active turn's future is done, register_turn may bind a new turn.""" + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([])) + first = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + first.completion.set_result(RealtimeTargetResult()) + second = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + dispatcher.register_turn(first) + dispatcher.register_turn(second) + assert dispatcher._current_turn is second + + +async def test_dispatcher_loop_routes_events_to_active_turn(): + """The dispatch loop must forward events from the connection to _route_event.""" + finish = _sentinel_event(finish=True) + other = _sentinel_event() + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([other, finish])) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + dispatcher.register_turn(state) + + await dispatcher.start() + await asyncio.wait_for(state.completion, timeout=1.0) + await dispatcher.stop() + + assert dispatcher.routed_events == [other, finish] + + +async def test_dispatcher_loop_routes_events_with_no_turn_as_state_none(): + """When no turn is registered, events still reach _route_event so input callbacks can fire; state is None.""" + finish = _sentinel_event(finish=True) + other = _sentinel_event() + dispatcher = _RecordingDispatcher(connection=_ScriptedConnection([other, finish])) + + # No register_turn called. + await dispatcher.start() + await asyncio.sleep(0.05) + await dispatcher.stop() + + # Both events were routed but no turn was completed (state was None, sentinel branch skipped). + assert dispatcher.routed_events == [other, finish] + + +async def test_dispatcher_loop_sets_exception_on_router_failure(): + """A router exception must propagate to the active turn's completion future.""" + + class _ExplodingDispatcher(_RecordingDispatcher): + async def _route_event(self, *, event: Any, state: RealtimeTurnState | None) -> None: + raise ValueError("router boom") + + event = _sentinel_event() + dispatcher = _ExplodingDispatcher(connection=_ScriptedConnection([event])) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + dispatcher.register_turn(state) + + await dispatcher.start() + with pytest.raises(ValueError, match="router boom"): + await asyncio.wait_for(state.completion, timeout=1.0) + await dispatcher.stop() + + +async def test_dispatcher_fires_committed_callback_as_background_task(): + """The on_user_audio_committed callback must be invoked and awaited via background tasks.""" + + received: list[Any] = [] + blocked = asyncio.Event() + release = asyncio.Event() + + async def slow_callback(event): + received.append(event) + blocked.set() + # Block until the test releases us; this proves the dispatch loop did not wait. + await release.wait() + + class _CallbackDispatcher(RealtimeEventDispatcher): + async def _route_event(self, *, event, state): + # Synthesize a committed callback fire on every event for the test. + self._fire_committed_callback(event) + + async def _cancel(self, *, state): # pragma: no cover - not exercised here + return + + fake_event_1 = MagicMock(spec=CommittedEvent) + fake_event_2 = MagicMock(spec=CommittedEvent) + dispatcher = _CallbackDispatcher( + connection=_ScriptedConnection([fake_event_1, fake_event_2]), + on_user_audio_committed=slow_callback, + ) + + await dispatcher.start() + # Both events should reach the slow callback even though the first is "blocked" awaiting release. + await asyncio.wait_for(blocked.wait(), timeout=1.0) + # Give the loop a tick to process the second event despite the first callback still running. + await asyncio.sleep(0.05) + release.set() + await dispatcher.stop() + + # Both events fired the callback; the loop did not serialize behind the slow first call. + assert len(received) == 2 + + +async def test_dispatcher_records_failure_on_iterator_crash(): + """When the connection iterator raises, the dispatcher's failure property captures the exception.""" + + class _NoopDispatcher(RealtimeEventDispatcher): + async def _route_event(self, *, event, state): # pragma: no cover - never called + return + + async def _cancel(self, *, state): # pragma: no cover + return + + class _ExplodingConnection: + def __aiter__(self): + return self + + async def __anext__(self): + raise RuntimeError("iterator died") + + dispatcher = _NoopDispatcher(connection=_ExplodingConnection()) + await dispatcher.start() + for _ in range(50): + if dispatcher.failure is not None: + break + await asyncio.sleep(0.01) + await dispatcher.stop() + + assert isinstance(dispatcher.failure, RuntimeError) and str(dispatcher.failure) == "iterator died" diff --git a/tests/unit/prompt_target/target/test_realtime_target.py b/tests/unit/prompt_target/target/test_realtime_target.py index d0aa9cc5e2..fa01aa3995 100644 --- a/tests/unit/prompt_target/target/test_realtime_target.py +++ b/tests/unit/prompt_target/target/test_realtime_target.py @@ -1,14 +1,28 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from unittest.mock import AsyncMock, MagicMock, patch +import asyncio +import base64 +import wave +from typing import Any +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest from pyrit.exceptions.exception_classes import ServerErrorException +from pyrit.identifiers import ComponentIdentifier from pyrit.models import Message, MessagePiece -from pyrit.prompt_target import RealtimeTarget -from pyrit.prompt_target.openai.openai_realtime_target import RealtimeTargetResult +from pyrit.prompt_target import RealtimeTarget, ServerVadConfig +from pyrit.prompt_target.common.realtime_audio import ( + REALTIME_COMMITTED_ITEM_ID_KEY, + CommittedEvent, + RealtimeTargetResult, + RealtimeTurnState, +) +from pyrit.prompt_target.openai.openai_realtime_target import ( + _OpenAIRealtimeDispatcher, + _StreamingConversationState, +) # Env vars that may leak from .env files loaded by other tests in parallel workers. _CLEAN_UNDERLYING_MODEL_ENV = { @@ -29,7 +43,7 @@ async def test_connect_success(target): mock_client.realtime.connect.return_value.__aenter__ = AsyncMock(return_value=mock_connection) with patch.object(target, "_get_openai_client", return_value=mock_client): - connection = await target.connect(conversation_id="test_conv") + connection = await target.connect_async(conversation_id="test_conv") assert connection == mock_connection mock_client.realtime.connect.assert_called_once_with(model="test") await target.cleanup_target() @@ -37,7 +51,7 @@ async def test_connect_success(target): async def test_send_prompt_async(target): # Mock the necessary methods - target.connect = AsyncMock(return_value=AsyncMock()) + target.connect_async = AsyncMock(return_value=AsyncMock()) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"file", transcripts=["hello"]) target.send_text_async = AsyncMock(return_value=("output.wav", result)) @@ -70,6 +84,58 @@ async def test_send_prompt_async(target): await target.cleanup_target() +async def test_send_prompt_async_propagates_interrupted_to_metadata(target): + """When a turn result carries interrupted=True, both response pieces' metadata must reflect it.""" + target.connect_async = AsyncMock(return_value=AsyncMock()) + target.send_config = AsyncMock() + interrupted_result = RealtimeTargetResult(audio_bytes=b"partial", transcripts=["hi"], interrupted=True) + target.send_text_async = AsyncMock(return_value=("partial.wav", interrupted_result)) + + message_piece = MessagePiece( + original_value="Hello", + original_value_data_type="text", + converted_value="Hello", + converted_value_data_type="text", + role="user", + conversation_id="test_conv", + ) + message = Message(message_pieces=[message_piece]) + + response = await target.send_prompt_async(message=message) + + text_piece, audio_piece = response[0].message_pieces + assert text_piece.prompt_metadata.get("interrupted") is True + assert audio_piece.prompt_metadata.get("interrupted") is True + + await target.cleanup_target() + + +async def test_send_prompt_async_omits_interrupted_metadata_when_not_set(target): + """A non-interrupted result must not write an interrupted key to MessagePiece metadata.""" + target.connect_async = AsyncMock(return_value=AsyncMock()) + target.send_config = AsyncMock() + normal_result = RealtimeTargetResult(audio_bytes=b"full", transcripts=["hi"]) + target.send_text_async = AsyncMock(return_value=("full.wav", normal_result)) + + message_piece = MessagePiece( + original_value="Hello", + original_value_data_type="text", + converted_value="Hello", + converted_value_data_type="text", + role="user", + conversation_id="test_conv", + ) + message = Message(message_pieces=[message_piece]) + + response = await target.send_prompt_async(message=message) + + text_piece, audio_piece = response[0].message_pieces + assert "interrupted" not in text_piece.prompt_metadata + assert "interrupted" not in audio_piece.prompt_metadata + + await target.cleanup_target() + + async def test_get_system_prompt_from_conversation_with_system_message(target): """Test that system prompt is extracted from conversation history when present.""" @@ -123,7 +189,7 @@ async def test_get_system_prompt_empty_conversation(target): async def test_multiple_websockets_created_for_multiple_conversations(target): # Mock the necessary methods - target.connect = AsyncMock(return_value=AsyncMock()) + target.connect_async = AsyncMock(return_value=AsyncMock()) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"event1", transcripts=["event2"]) target.send_text_async = AsyncMock(return_value=("output_audio_path", result)) @@ -346,7 +412,7 @@ async def test_multi_turn_reuses_connection(target): This ensures that the server-side conversation context is preserved. """ mock_connection = AsyncMock() - target.connect = AsyncMock(return_value=mock_connection) + target.connect_async = AsyncMock(return_value=mock_connection) target.send_config = AsyncMock() result = RealtimeTargetResult(audio_bytes=b"audio", transcripts=["response"]) target.send_text_async = AsyncMock(return_value=("output.wav", result)) @@ -376,7 +442,7 @@ async def test_multi_turn_reuses_connection(target): await target.send_prompt_async(message=Message(message_pieces=[message_piece_2])) # Connection should only be created once for the conversation - target.connect.assert_called_once_with(conversation_id=conversation_id) + target.connect_async.assert_called_once_with(conversation_id=conversation_id) target.send_config.assert_called_once() # Both turns should use the same connection @@ -430,3 +496,990 @@ async def test_receive_events_skips_stale_response_done(target): # Should have processed through to the real response.done with actual audio assert result.audio_bytes == b"dummyaudio" assert result.transcripts == ["hello"] + + +# --------------------------------------------------------------------------- +# Chunk 1 — ServerVadConfig + session config +# --------------------------------------------------------------------------- + + +def test_session_config_omits_turn_detection_when_vad_disabled(target): + """Default construction must not emit a turn_detection block; pins atomic flow.""" + config = target._set_system_prompt_and_config_vars(system_prompt="test prompt") + + assert "turn_detection" not in config["audio"]["input"] + assert config["instructions"] == "test prompt" + + +@patch.dict("os.environ", _CLEAN_UNDERLYING_MODEL_ENV) +def test_session_config_emits_server_vad_block_with_defaults(sqlite_instance): + """server_vad=True must emit defaults.""" + vad_target = RealtimeTarget( + api_key="test_key", + endpoint="wss://test_url", + model_name="test", + server_vad=True, + ) + + config = vad_target._set_system_prompt_and_config_vars(system_prompt="test prompt") + + turn_detection = config["audio"]["input"]["turn_detection"] + assert turn_detection == { + "type": "server_vad", + "threshold": 0.4, + "prefix_padding_ms": 200, + "silence_duration_ms": 1500, + "create_response": True, + "interrupt_response": True, + } + + +@patch.dict("os.environ", _CLEAN_UNDERLYING_MODEL_ENV) +def test_session_config_honors_custom_vad_tuning(sqlite_instance): + """Passing a ServerVadConfig must flow through to the emitted turn_detection block.""" + vad_target = RealtimeTarget( + api_key="test_key", + endpoint="wss://test_url", + model_name="test", + server_vad=ServerVadConfig(threshold=0.7, prefix_padding_ms=350, silence_duration_ms=800), + ) + + turn_detection = vad_target._set_system_prompt_and_config_vars(system_prompt="x")["audio"]["input"][ + "turn_detection" + ] + + assert turn_detection["threshold"] == 0.7 + assert turn_detection["prefix_padding_ms"] == 350 + assert turn_detection["silence_duration_ms"] == 800 + + +@pytest.mark.parametrize( + "kwargs", + [ + {"threshold": -0.1}, + {"threshold": 1.5}, + {"prefix_padding_ms": -1}, + {"silence_duration_ms": -1}, + ], +) +def test_server_vad_config_rejects_invalid_values(kwargs): + """ServerVadConfig must reject out-of-range tuning values at construction.""" + with pytest.raises(ValueError): + ServerVadConfig(**kwargs) + + +# --------------------------------------------------------------------------- +# Chunk 2 — _stream_pcm_async helper +# --------------------------------------------------------------------------- + + +def _make_mock_connection(): + """Return an AsyncMock connection with input_audio_buffer wired up.""" + connection = AsyncMock() + connection.input_audio_buffer.append = AsyncMock() + connection.input_audio_buffer.commit = AsyncMock() + return connection + + +async def test_stream_pcm_even_split_no_commit(target): + """A buffer that divides evenly into chunks emits N appends and no commit when commit=False.""" + connection = _make_mock_connection() + # 100ms @ 24kHz @ 2 bytes/sample = 4800 bytes per chunk. 9600 bytes = 2 chunks. + pcm = b"\x00" * 9600 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + assert connection.input_audio_buffer.append.call_count == 2 + connection.input_audio_buffer.commit.assert_not_called() + + +async def test_stream_pcm_partial_final_chunk(target): + """A buffer not a clean multiple of chunk size sends the final partial chunk as-is.""" + connection = _make_mock_connection() + # 5000 bytes => one full 4800-byte chunk + one 200-byte tail. + pcm = b"\x01" * 5000 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + assert connection.input_audio_buffer.append.call_count == 2 + # Inspect the second call's chunk size by base64-decoding its audio kwarg. + second_call_audio_b64 = connection.input_audio_buffer.append.call_args_list[1].kwargs["audio"] + assert len(base64.b64decode(second_call_audio_b64)) == 200 + + +async def test_stream_pcm_empty_buffer(target): + """An empty buffer yields zero appends. commit=False produces no commit either.""" + connection = _make_mock_connection() + + await target._stream_pcm_async(connection=connection, pcm_bytes=b"", commit=False) + + connection.input_audio_buffer.append.assert_not_called() + connection.input_audio_buffer.commit.assert_not_called() + + +async def test_stream_pcm_commits_when_asked(target): + """commit=True triggers exactly one input_audio_buffer.commit after all appends.""" + connection = _make_mock_connection() + pcm = b"\x02" * 4800 + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=True) + + assert connection.input_audio_buffer.append.call_count == 1 + connection.input_audio_buffer.commit.assert_awaited_once_with() + + +async def test_stream_pcm_empty_buffer_still_commits_when_asked(target): + """commit=True with an empty buffer should still fire commit (e.g. to flush an existing buffer).""" + connection = _make_mock_connection() + + await target._stream_pcm_async(connection=connection, pcm_bytes=b"", commit=True) + + connection.input_audio_buffer.append.assert_not_called() + connection.input_audio_buffer.commit.assert_awaited_once_with() + + +async def test_stream_pcm_appends_base64_encoded_chunks(target): + """Each append's audio kwarg must be the base64 encoding of the corresponding PCM chunk.""" + connection = _make_mock_connection() + # Build a recognizable buffer: 4800 bytes of 0xAA then 4800 bytes of 0xBB. + pcm = (b"\xaa" * 4800) + (b"\xbb" * 4800) + + await target._stream_pcm_async(connection=connection, pcm_bytes=pcm, commit=False) + + first_audio = connection.input_audio_buffer.append.call_args_list[0].kwargs["audio"] + second_audio = connection.input_audio_buffer.append.call_args_list[1].kwargs["audio"] + assert base64.b64decode(first_audio) == b"\xaa" * 4800 + assert base64.b64decode(second_audio) == b"\xbb" * 4800 + + +# ---- Wire primitives for streaming attacks --------------------------------------------------- + + +async def test_push_audio_chunk_async_base64_encodes_and_appends(target): + connection = _make_mock_connection() + pcm = b"\x33" * 480 + + await target.push_audio_chunk_async(connection=connection, pcm_bytes=pcm) + + connection.input_audio_buffer.append.assert_awaited_once() + audio_b64 = connection.input_audio_buffer.append.call_args.kwargs["audio"] + assert base64.b64decode(audio_b64) == pcm + + +async def test_push_audio_chunk_async_empty_is_noop(target): + connection = _make_mock_connection() + await target.push_audio_chunk_async(connection=connection, pcm_bytes=b"") + connection.input_audio_buffer.append.assert_not_called() + + +async def test_insert_user_audio_async_creates_input_audio_item(target): + connection = AsyncMock() + pcm = b"\x44" * 480 + + await target.insert_user_audio_async(connection=connection, pcm_bytes=pcm) + + connection.conversation.item.create.assert_awaited_once() + item = connection.conversation.item.create.call_args.kwargs["item"] + assert item["type"] == "message" + assert item["role"] == "user" + assert item["content"][0]["type"] == "input_audio" + assert base64.b64decode(item["content"][0]["audio"]) == pcm + + +async def test_insert_user_text_async_creates_input_text_item(target): + connection = AsyncMock() + + await target.insert_user_text_async(connection=connection, text="hello model") + + connection.conversation.item.create.assert_awaited_once() + item = connection.conversation.item.create.call_args.kwargs["item"] + assert item["role"] == "user" + assert item["content"][0] == {"type": "input_text", "text": "hello model"} + + +async def test_delete_conversation_item_async_forwards_item_id(target): + connection = AsyncMock() + + await target.delete_conversation_item_async(connection=connection, item_id="raw_item_99") + + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_item_99") + + +async def test_swap_user_audio_async_inserts_converted_then_deletes_original(target): + """``swap_user_audio_async`` must insert the converted PCM then delete the original item.""" + connection = AsyncMock() + event = CommittedEvent(item_id="raw_swap_1") + + await target.swap_user_audio_async( + connection=connection, + committed_event=event, + converted_pcm=b"\xab" * 96, + ) + + connection.conversation.item.create.assert_awaited_once() + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_swap_1") + # Insert must precede delete: any future refactor that swaps the order or runs them + # concurrently would corrupt the streaming session — pin the ordering here. + create_index = connection.method_calls.index(call.conversation.item.create(item=ANY)) + delete_index = connection.method_calls.index(call.conversation.item.delete(item_id="raw_swap_1")) + assert create_index < delete_index + + +async def test_swap_user_audio_async_logs_and_swallows_delete_failure(target, caplog): + """Best-effort delete: if ``delete`` raises, ``swap`` logs a warning and returns normally.""" + connection = AsyncMock() + connection.conversation.item.delete.side_effect = RuntimeError("delete blew up") + event = CommittedEvent(item_id="raw_swap_fail") + + with caplog.at_level("WARNING"): + await target.swap_user_audio_async( + connection=connection, + committed_event=event, + converted_pcm=b"\x01" * 96, + ) + + connection.conversation.item.create.assert_awaited_once() + connection.conversation.item.delete.assert_awaited_once_with(item_id="raw_swap_fail") + # Even on delete failure, insert must have happened first. + create_index = connection.method_calls.index(call.conversation.item.create(item=ANY)) + delete_index = connection.method_calls.index(call.conversation.item.delete(item_id="raw_swap_fail")) + assert create_index < delete_index + assert any("delete failed for raw_swap_fail" in record.message for record in caplog.records) + + +def _turn_state(*, response_id: str | None = "resp_abc", item_id: str | None = "item_xyz") -> RealtimeTurnState: + """Build a turn state with the named ids preset; completion future is unused by cancel tests.""" + return RealtimeTurnState( + completion=asyncio.get_event_loop().create_future(), + is_responding=True, + last_response_id=response_id, + current_item_id=item_id, + ) + + +def _make_dispatcher(connection): + """Build an _OpenAIRealtimeDispatcher around the given mock connection.""" + return _OpenAIRealtimeDispatcher(connection=connection) + + +async def test_cancel_does_not_send_response_cancel(): + """_cancel must NOT send response.cancel (server auto-cancels on speech detection).""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(response_id="resp_42") + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + connection.response.cancel.assert_not_awaited() + + +async def test_cancel_truncates_to_delivered_audio_ms(): + """Truncate must be called with audio_end_ms computed from delivered_audio length.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(item_id="item_99") + # 4800 delivered bytes / 48 bytes-per-ms = 100ms + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + connection.conversation.item.truncate.assert_awaited_once_with( + item_id="item_99", + content_index=0, + audio_end_ms=100, + ) + assert state.interrupted is True + + +async def test_cancel_only_truncates_no_response_cancel(caplog): + """_cancel must only truncate, not send response.cancel (server handles cancellation).""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = _turn_state(item_id="item_1") + state.delivered_audio.extend(b"\x00" * 4800) + + await dispatcher._cancel(state=state) + + assert state.interrupted is True + connection.conversation.item.truncate.assert_awaited_once() + connection.response.cancel.assert_not_awaited() + + +async def test_cancel_marks_interrupted_when_truncate_raises(caplog): + """A failed conversation.item.truncate must log a warning and still flip state.interrupted.""" + connection = AsyncMock() + connection.conversation.item.truncate.side_effect = RuntimeError("boom") + dispatcher = _make_dispatcher(connection) + state = _turn_state() + + await dispatcher._cancel(state=state) + + assert state.interrupted is True + assert any( + "conversation.item.truncate failed" in record.message and record.levelname == "WARNING" + for record in caplog.records + ) + + +def _scripted_event(event_type, **fields): + """Build a MagicMock event with the named type plus any extra attribute paths.""" + event = MagicMock() + event.type = event_type + for path, value in fields.items(): + # Allow dotted attribute paths like "response.id" by walking nested MagicMocks. + parts = path.split(".") + target_attr = event + for part in parts[:-1]: + target_attr = getattr(target_attr, part) + setattr(target_attr, parts[-1], value) + return event + + +async def test_route_event_happy_path_resolves_completion_with_assembled_result(): + """response.created -> output_item.added -> audio.delta -> transcript.delta -> response.done.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("response.created", **{"response.id": "r1"}), state=state) + await dispatcher._route_event(event=_scripted_event("response.output_item.added", **{"item.id": "i1"}), state=state) + await dispatcher._route_event( + event=_scripted_event("response.audio.delta", delta=base64.b64encode(b"\xaa" * 4800).decode("ascii")), + state=state, + ) + await dispatcher._route_event(event=_scripted_event("response.audio_transcript.delta", delta="hello "), state=state) + await dispatcher._route_event(event=_scripted_event("response.audio_transcript.delta", delta="world"), state=state) + await dispatcher._route_event(event=_scripted_event("response.done", **{"response.id": "r1"}), state=state) + + assert state.completion.done() + result = state.completion.result() + assert result.audio_bytes == b"\xaa" * 4800 + assert result.transcripts == ["hello ", "world"] + assert state.interrupted is False + + +async def test_route_event_speech_started_while_responding_cancels_and_resolves_interrupted(): + """speech_started during a response triggers cancel and resolves with interrupted=True.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("response.created", **{"response.id": "r1"}), state=state) + await dispatcher._route_event(event=_scripted_event("response.output_item.added", **{"item.id": "i1"}), state=state) + await dispatcher._route_event( + event=_scripted_event("response.audio.delta", delta=base64.b64encode(b"\xbb" * 2400).decode("ascii")), + state=state, + ) + await dispatcher._route_event(event=_scripted_event("input_audio_buffer.speech_started"), state=state) + + connection.response.cancel.assert_not_awaited() + connection.conversation.item.truncate.assert_awaited_once_with( + item_id="i1", + content_index=0, + audio_end_ms=50, # 2400 / 48 + ) + result = state.completion.result() + assert result.audio_bytes == b"\xbb" * 2400 + assert result.interrupted is True + assert state.interrupted is True + + +async def test_route_event_stale_response_done_after_cancel_is_dropped(): + """A response.done with a stale response_id must not re-resolve a completed future.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + # Pretend a turn just resolved as interrupted on response_id r1. + state.last_response_id = "r1" + state.completion.set_result(RealtimeTargetResult()) + + # Late response.done for r1 arrives; router must not raise InvalidStateError. + await dispatcher._route_event(event=_scripted_event("response.done", **{"response.id": "r1"}), state=state) + + +async def test_route_event_error_resolves_with_exception(): + """error events resolve the completion future via set_exception.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("error", **{"error.message": "rate limited"}), state=state) + + with pytest.raises(RuntimeError, match="rate limited"): + state.completion.result() + + +async def test_route_event_speech_started_without_responding_is_noop(): + """speech_started before a response is in flight does not call cancel or resolve.""" + connection = AsyncMock() + dispatcher = _make_dispatcher(connection) + state = RealtimeTurnState(completion=asyncio.get_event_loop().create_future()) + + await dispatcher._route_event(event=_scripted_event("input_audio_buffer.speech_started"), state=state) + + connection.response.cancel.assert_not_awaited() + connection.conversation.item.truncate.assert_not_awaited() + assert not state.completion.done() + assert state.interrupted is False + + +async def test_route_event_committed_event_fires_user_audio_callback(): + """input_audio_buffer.committed must fire the registered on_user_audio_committed callback.""" + connection = AsyncMock() + received: list[Any] = [] + + async def on_committed(event): + received.append(event) + + dispatcher = _OpenAIRealtimeDispatcher(connection=connection, on_user_audio_committed=on_committed) + + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="raw_item_42", audio_start_ms=1234), + state=None, + ) + # Background callback task may not have run yet; yield until it does. + for _ in range(20): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].item_id == "raw_item_42" + assert received[0].audio_start_ms == 1234 + + +async def test_route_event_committed_event_without_callback_is_noop(): + """A committed event with no callback configured must be ignored quietly.""" + connection = AsyncMock() + dispatcher = _OpenAIRealtimeDispatcher(connection=connection) # no callback + + # Must not raise. + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="raw_item_99"), + state=None, + ) + + +async def test_route_event_speech_started_audio_start_propagates_to_commit(): + """speech_started's audio_start_ms is captured and attached to the next CommittedEvent. + + The OpenAI Realtime server omits audio_start_ms from the input_audio_buffer.committed + event but reports it on speech_started. The dispatcher bridges the two so callbacks + receive the value reliably. + """ + received: list[CommittedEvent] = [] + + async def on_committed(event: CommittedEvent) -> None: + received.append(event) + + connection = AsyncMock() + dispatcher = _OpenAIRealtimeDispatcher(connection=connection, on_user_audio_committed=on_committed) + + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.speech_started", audio_start_ms=8536), + state=None, + ) + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="raw_99", audio_start_ms=None), + state=None, + ) + for _ in range(20): + if received: + break + await asyncio.sleep(0.01) + + assert len(received) == 1 + assert received[0].item_id == "raw_99" + assert received[0].audio_start_ms == 8536 + + +async def test_route_event_pending_speech_start_resets_after_commit(): + """After commit fires, the dispatcher clears its captured speech_start so a later + commit (e.g. for a turn whose speech_started never fired) doesn't see stale data.""" + received: list[CommittedEvent] = [] + + async def on_committed(event: CommittedEvent) -> None: + received.append(event) + + connection = AsyncMock() + dispatcher = _OpenAIRealtimeDispatcher(connection=connection, on_user_audio_committed=on_committed) + + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.speech_started", audio_start_ms=500), + state=None, + ) + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="i1", audio_start_ms=None), + state=None, + ) + # Second commit without a prior speech_started: must NOT reuse the 500 captured above. + await dispatcher._route_event( + event=_scripted_event("input_audio_buffer.committed", item_id="i2", audio_start_ms=None), + state=None, + ) + for _ in range(20): + if len(received) >= 2: + break + await asyncio.sleep(0.01) + + assert len(received) == 2 + assert received[0].audio_start_ms == 500 + assert received[1].audio_start_ms is None + + +# Placeholder for R2 tests + + +# ---- subscribe_events_async + request_response_async (R2) ------------------------------------ + + +async def test_subscribe_events_async_returns_started_dispatcher(target): + """Subscription handle must be a started dispatcher; closing tears the task down.""" + events = [_scripted_event("input_audio_buffer.committed", item_id="i_1")] + + async def event_iter(): + for e in events: + yield e + # Keep the iterator alive briefly so the dispatch task can run. + await asyncio.sleep(0.01) + + connection = MagicMock() + connection.__aiter__ = lambda self_: event_iter() + + received: list[CommittedEvent] = [] + + async def on_committed(event): + received.append(event) + + dispatcher = await target.subscribe_events_async( + connection=connection, conversation_id="test_conv", on_user_audio_committed=on_committed + ) + try: + # Yield until the dispatch loop processes the scripted event. + for _ in range(20): + if received: + break + await asyncio.sleep(0.01) + assert len(received) == 1 and received[0].item_id == "i_1" + finally: + await dispatcher.stop() + + +async def test_subscribe_events_async_records_loop_failure_on_dispatcher(target): + """A dispatcher loop crash must be reachable via the dispatcher's ``failure`` property.""" + + async def boom_iter(): + raise RuntimeError("loop kaboom") + yield # pragma: no cover # makes it a generator + + connection = MagicMock() + connection.__aiter__ = lambda self_: boom_iter() + + dispatcher = await target.subscribe_events_async(connection=connection, conversation_id="test_conv") + try: + for _ in range(50): + if dispatcher.failure is not None: + break + await asyncio.sleep(0.01) + assert isinstance(dispatcher.failure, RuntimeError) + finally: + await dispatcher.stop() + + +async def test_request_response_async_registers_turn_and_sends_response_create(target): + """request_response_async must register a fresh turn and call response.create.""" + connection = AsyncMock() + dispatcher = MagicMock() + dispatcher.register_turn = MagicMock() + + future = await target.request_response_async(connection=connection, dispatcher=dispatcher) + + dispatcher.register_turn.assert_called_once() + registered_state = dispatcher.register_turn.call_args.args[0] + assert isinstance(registered_state, RealtimeTurnState) + assert registered_state.completion is future + connection.response.create.assert_awaited_once_with() + + +async def test_request_response_async_future_resolves_with_dispatcher_result(target): + """The future returned by request_response_async resolves when the turn ends.""" + connection = AsyncMock() + dispatcher = MagicMock() + expected_result = RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["ok"]) + + def _register(state): + state.completion.set_result(expected_result) + + dispatcher.register_turn = MagicMock(side_effect=_register) + + future = await target.request_response_async(connection=connection, dispatcher=dispatcher) + result = await future + assert result is expected_result + + +async def test_request_response_async_propagates_register_turn_failure(target): + """If another turn is already pending, register_turn raises and request_response_async surfaces it.""" + connection = AsyncMock() + dispatcher = MagicMock() + dispatcher.register_turn = MagicMock(side_effect=RuntimeError("turn already pending")) + + with pytest.raises(RuntimeError, match="turn already pending"): + await target.request_response_async(connection=connection, dispatcher=dispatcher) + + connection.response.create.assert_not_called() + + +# ---- streaming-mode state lifecycle --------------------------------------------- + + +def test_sample_rate_hz_class_constant(): + """SAMPLE_RATE_HZ is the single source of truth for the realtime PCM sample rate.""" + assert RealtimeTarget.SAMPLE_RATE_HZ == 24000 + + +def test_server_vad_config_returns_config_when_enabled(target): + """server_vad_config exposes the underlying ServerVadConfig when server VAD is enabled.""" + target._server_vad = ServerVadConfig(prefix_padding_ms=250, silence_duration_ms=400) + cfg = target.server_vad_config + assert cfg is not None + assert cfg.prefix_padding_ms == 250 + assert cfg.silence_duration_ms == 400 + + +def test_server_vad_config_returns_none_when_disabled(target): + """server_vad_config is None when server VAD is disabled.""" + target._server_vad = None + assert target.server_vad_config is None + + +async def test_subscribe_events_async_registers_streaming_state(target): + """Subscription must register per-conversation streaming state keyed by conversation_id.""" + + async def event_iter(): + await asyncio.sleep(0) + return + yield # pragma: no cover + + connection = MagicMock() + connection.__aiter__ = lambda self_: event_iter() + + assert "conv-A" not in target._streaming_state + + dispatcher = await target.subscribe_events_async(connection=connection, conversation_id="conv-A") + try: + assert "conv-A" in target._streaming_state + assert target._streaming_state["conv-A"].dispatcher is dispatcher + # The lock is created lazily-but-default; just verify it's an asyncio.Lock. + assert isinstance(target._streaming_state["conv-A"].turn_lock, asyncio.Lock) + finally: + await dispatcher.stop() + + +async def test_cleanup_conversation_clears_streaming_state(target): + """cleanup_conversation must stop the dispatcher and pop the streaming state entry.""" + dispatcher = AsyncMock() + dispatcher.stop = AsyncMock() + target._streaming_state["conv-B"] = _StreamingConversationState(dispatcher=dispatcher) + target._existing_conversation["conv-B"] = AsyncMock() + + await target.cleanup_conversation("conv-B") + + dispatcher.stop.assert_awaited_once() + assert "conv-B" not in target._streaming_state + assert "conv-B" not in target._existing_conversation + + +async def test_cleanup_conversation_is_safe_without_streaming_state(target): + """cleanup_conversation must not fail when no streaming state was registered.""" + target._existing_conversation["conv-C"] = AsyncMock() + + # Should not raise even though _streaming_state has no entry for conv-C. + await target.cleanup_conversation("conv-C") + assert "conv-C" not in target._existing_conversation + + +async def test_cleanup_target_clears_all_streaming_state(target): + """cleanup_target must stop every active dispatcher before closing connections.""" + dispatcher_a = AsyncMock() + dispatcher_a.stop = AsyncMock() + dispatcher_b = AsyncMock() + dispatcher_b.stop = AsyncMock() + target._streaming_state["conv-X"] = _StreamingConversationState(dispatcher=dispatcher_a) + target._streaming_state["conv-Y"] = _StreamingConversationState(dispatcher=dispatcher_b) + target._existing_conversation["conv-X"] = AsyncMock() + target._existing_conversation["conv-Y"] = AsyncMock() + + await target.cleanup_target() + + dispatcher_a.stop.assert_awaited_once() + dispatcher_b.stop.assert_awaited_once() + assert target._streaming_state == {} + assert target._existing_conversation == {} + + +async def test_cleanup_target_swallows_dispatcher_stop_errors(target): + """A failing dispatcher.stop() must not prevent cleanup_target from proceeding.""" + bad_dispatcher = AsyncMock() + bad_dispatcher.stop = AsyncMock(side_effect=RuntimeError("already stopped")) + target._streaming_state["conv-Z"] = _StreamingConversationState(dispatcher=bad_dispatcher) + connection = AsyncMock() + target._existing_conversation["conv-Z"] = connection + + await target.cleanup_target() # must not raise + + bad_dispatcher.stop.assert_awaited_once() + connection.close.assert_awaited_once() + assert target._streaming_state == {} + assert target._existing_conversation == {} + + +# ---- _send_streaming_turn_async / send_prompt routing ----------------------- + + +def _write_wav( + path: Any, + *, + rate: int = 24000, + channels: int = 1, + sampwidth: int = 2, + pcm: bytes = b"\x00" * 96, +) -> str: + """Write a small WAV file at ``path`` and return the path as a string.""" + with wave.open(str(path), "wb") as w: + w.setnchannels(channels) + w.setsampwidth(sampwidth) + w.setframerate(rate) + w.writeframes(pcm) + return str(path) + + +def _make_streaming_request( + *, + conversation_id: str, + wav_path: str, + converter_identifiers: list | None = None, + committed_item_id: str | None = None, +) -> Message: + """Construct a streaming-mode request Message matching the attack's contract.""" + metadata: dict[str, Any] = {} + if committed_item_id is not None: + metadata[REALTIME_COMMITTED_ITEM_ID_KEY] = committed_item_id + piece = MessagePiece( + role="user", + original_value=wav_path, + original_value_data_type="audio_path", + converted_value=wav_path, + converted_value_data_type="audio_path", + conversation_id=conversation_id, + prompt_metadata=metadata or None, + converter_identifiers=converter_identifiers or [], + ) + return Message(message_pieces=[piece]) + + +def _register_streaming(target, conversation_id: str) -> tuple[MagicMock, AsyncMock]: + """Register streaming state for a conversation; return (dispatcher, connection).""" + dispatcher = MagicMock() + connection = AsyncMock() + target._streaming_state[conversation_id] = _StreamingConversationState(dispatcher=dispatcher) + target._existing_conversation[conversation_id] = connection + return dispatcher, connection + + +async def test_send_prompt_routes_streaming_when_state_registered(target, tmp_path): + """When streaming state is registered, the streaming branch runs (no atomic send).""" + _register_streaming(target, "conv-R") + wav_path = _write_wav(tmp_path / "in.wav") + message = _make_streaming_request(conversation_id="conv-R", wav_path=wav_path) + + target.swap_user_audio_async = AsyncMock() + target.save_audio = AsyncMock(return_value="/tmp/resp.wav") + completed_future: asyncio.Future = asyncio.get_running_loop().create_future() + completed_future.set_result(RealtimeTargetResult(audio_bytes=b"", transcripts=["ok"])) + target.request_response_async = AsyncMock(return_value=completed_future) + target.send_audio_async = AsyncMock() + target.send_text_async = AsyncMock() + + responses = await target._send_prompt_to_target_async(normalized_conversation=[message]) + + target.request_response_async.assert_awaited_once() + target.send_audio_async.assert_not_called() + target.send_text_async.assert_not_called() + assert len(responses) == 1 + + +async def test_send_prompt_uses_atomic_path_when_no_streaming_state(target, tmp_path): + """Without streaming state, the atomic send_audio_async path runs as before.""" + wav_path = _write_wav(tmp_path / "in.wav") + piece = MessagePiece( + role="user", + original_value=wav_path, + original_value_data_type="audio_path", + converted_value=wav_path, + converted_value_data_type="audio_path", + conversation_id="conv-A", + ) + message = Message(message_pieces=[piece]) + + target.swap_user_audio_async = AsyncMock() + target.request_response_async = AsyncMock() + target.connect_async = AsyncMock(return_value=AsyncMock()) + target.send_config = AsyncMock() + target.send_audio_async = AsyncMock( + return_value=("/tmp/out.wav", RealtimeTargetResult(audio_bytes=b"", transcripts=["hi"])), + ) + + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + target.send_audio_async.assert_awaited_once() + target.swap_user_audio_async.assert_not_called() + target.request_response_async.assert_not_called() + + +async def test_streaming_send_swaps_when_converters_ran(target, tmp_path): + """When converter_identifiers is non-empty, the swap call fires with the converted PCM.""" + _, connection = _register_streaming(target, "conv-S") + converted_pcm = b"\x11\x22" * 48 + wav_path = _write_wav(tmp_path / "converted.wav", pcm=converted_pcm) + + converter_id = ComponentIdentifier(class_name="FakeConverter", class_module="tests.fake") + message = _make_streaming_request( + conversation_id="conv-S", + wav_path=wav_path, + converter_identifiers=[converter_id], + committed_item_id="item_xyz", + ) + + target.swap_user_audio_async = AsyncMock() + target.save_audio = AsyncMock(return_value="/tmp/resp.wav") + completed_future: asyncio.Future = asyncio.get_running_loop().create_future() + completed_future.set_result(RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["hello"])) + target.request_response_async = AsyncMock(return_value=completed_future) + + responses = await target._send_prompt_to_target_async(normalized_conversation=[message]) + + target.swap_user_audio_async.assert_awaited_once() + swap_kwargs = target.swap_user_audio_async.call_args.kwargs + assert swap_kwargs["connection"] is connection + assert swap_kwargs["committed_event"].item_id == "item_xyz" + assert swap_kwargs["converted_pcm"] == converted_pcm + assert responses[0].message_pieces[0].converted_value == "hello" + assert responses[0].message_pieces[1].converted_value == "/tmp/resp.wav" + + +async def test_streaming_send_skips_swap_when_no_converters(target, tmp_path): + """With no converters, the server's raw committed buffer is used: no swap.""" + _register_streaming(target, "conv-N") + wav_path = _write_wav(tmp_path / "raw.wav") + message = _make_streaming_request( + conversation_id="conv-N", + wav_path=wav_path, + converter_identifiers=[], + ) + + target.swap_user_audio_async = AsyncMock() + target.save_audio = AsyncMock(return_value="/tmp/resp.wav") + completed_future: asyncio.Future = asyncio.get_running_loop().create_future() + completed_future.set_result(RealtimeTargetResult(audio_bytes=b"\xaa" * 96, transcripts=["ok"])) + target.request_response_async = AsyncMock(return_value=completed_future) + + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + target.swap_user_audio_async.assert_not_called() + target.request_response_async.assert_awaited_once() + + +async def test_streaming_send_requires_item_id_when_converters_ran(target, tmp_path): + """A streaming request with converters but no committed item id is a contract violation.""" + _register_streaming(target, "conv-X") + wav_path = _write_wav(tmp_path / "in.wav") + converter_id = ComponentIdentifier(class_name="FakeConverter", class_module="tests.fake") + message = _make_streaming_request( + conversation_id="conv-X", + wav_path=wav_path, + converter_identifiers=[converter_id], + committed_item_id=None, + ) + + with pytest.raises(ValueError, match="committed item id"): + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + +async def test_streaming_send_rejects_wrong_audio_format(target, tmp_path): + """Converters must preserve mono PCM16 @ SAMPLE_RATE_HZ; mismatches raise.""" + _register_streaming(target, "conv-F") + wav_path = _write_wav(tmp_path / "bad.wav", rate=16000) # wrong rate + message = _make_streaming_request(conversation_id="conv-F", wav_path=wav_path) + + with pytest.raises(ValueError, match="mono PCM16"): + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + +async def test_send_prompt_propagates_interrupted_metadata_for_streaming(target, tmp_path): + """When the realtime turn future resolves with interrupted=True, both response pieces gain the flag.""" + _register_streaming(target, "conv-I") + wav_path = _write_wav(tmp_path / "in.wav") + message = _make_streaming_request(conversation_id="conv-I", wav_path=wav_path) + + target.save_audio = AsyncMock(return_value="/tmp/partial.wav") + completed_future: asyncio.Future = asyncio.get_running_loop().create_future() + completed_future.set_result( + RealtimeTargetResult(audio_bytes=b"\xaa" * 32, transcripts=["partial"], interrupted=True), + ) + target.request_response_async = AsyncMock(return_value=completed_future) + + responses = await target._send_prompt_to_target_async(normalized_conversation=[message]) + + assert len(responses) == 1 + text_piece, audio_piece = responses[0].message_pieces + assert text_piece.prompt_metadata.get("interrupted") is True + assert audio_piece.prompt_metadata.get("interrupted") is True + + +async def test_streaming_send_rejects_non_audio_piece(target, tmp_path): + """A text-typed piece routed to the streaming branch must surface a clear error.""" + _register_streaming(target, "conv-T") + piece = MessagePiece( + role="user", + original_value="hello", + original_value_data_type="text", + converted_value="hello", + converted_value_data_type="text", + conversation_id="conv-T", + ) + message = Message(message_pieces=[piece]) + + with pytest.raises(ValueError, match="audio_path"): + await target._send_prompt_to_target_async(normalized_conversation=[message]) + + +async def test_streaming_send_serializes_via_turn_lock(target, tmp_path): + """Two concurrent turns on the same conversation must run sequentially under the lock.""" + _register_streaming(target, "conv-L") + wav_path = _write_wav(tmp_path / "in.wav") + message = _make_streaming_request(conversation_id="conv-L", wav_path=wav_path) + + target.save_audio = AsyncMock(return_value="/tmp/r.wav") + active = 0 + max_concurrent = 0 + + async def fake_request_response(*, connection, dispatcher): + nonlocal active, max_concurrent + active += 1 + max_concurrent = max(max_concurrent, active) + # Yield control so a second turn would interleave if the lock weren't held. + await asyncio.sleep(0.01) + active -= 1 + fut: asyncio.Future = asyncio.get_running_loop().create_future() + fut.set_result(RealtimeTargetResult(audio_bytes=b"\xaa" * 32, transcripts=["ok"])) + return fut + + target.request_response_async = AsyncMock(side_effect=fake_request_response) + + await asyncio.gather( + target._send_prompt_to_target_async(normalized_conversation=[message]), + target._send_prompt_to_target_async(normalized_conversation=[message]), + ) + + assert max_concurrent == 1