Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,61 @@ def test_on_error(self):
self.assertTrue(self.client.server_error)
self.assertEqual(self.client.error_message, error_message)

def test_on_message_transcript_reset_ack(self):
"""Server ACK for reset should be handled gracefully without errors."""
# First make the client think it's connected to a backend
ready_msg = json.dumps({
"uid": self.client.uid,
"message": "SERVER_READY",
"backend": "faster_whisper"
})
self.client.on_message(self.mock_ws_app, ready_msg)

reset_ack_msg = json.dumps({
"uid": self.client.uid,
"message": "TRANSCRIPT_RESET"
})
# Should not raise; just log/print the confirmation
self.client.on_message(self.mock_ws_app, reset_ack_msg)


class TestResetTranscript(BaseTestCase):
def test_reset_clears_local_state(self):
"""reset_transcript() must clear transcript and related fields."""
# Simulate accumulated state
self.client.server_backend = "faster_whisper"
self.client.recording = True
self.client.transcript = [
{"start": "0.000", "end": "1.000", "text": "Hello", "completed": True}
]
self.client.translated_transcript = [
{"start": "0.000", "end": "1.000", "text": "Bonjour", "completed": True}
]
self.client.last_segment = {"text": "Hello"}
self.client.last_received_segment = "Hello"

self.client.reset_transcript()

self.assertEqual(self.client.transcript, [])
self.assertEqual(self.client.translated_transcript, [])
self.assertIsNone(self.client.last_segment)
self.assertIsNone(self.client.last_received_segment)

def test_reset_sends_control_frame_to_server(self):
"""reset_transcript() must send the correct JSON action to the server."""
self.client.reset_transcript()

sent_calls = self.client.client_socket.send.call_args_list
# Find the call that contains RESET_TRANSCRIPT
reset_calls = [
call for call in sent_calls
if "RESET_TRANSCRIPT" in str(call)
]
self.assertTrue(reset_calls, "Expected a RESET_TRANSCRIPT frame to be sent")
payload = json.loads(reset_calls[-1][0][0])
self.assertEqual(payload["action"], "RESET_TRANSCRIPT")
self.assertEqual(payload["uid"], self.client.uid)


class TestAudioResampling(unittest.TestCase):
def test_resample_audio(self):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import subprocess
import time
import json
import threading
import unittest
from unittest import mock

Expand All @@ -10,8 +11,57 @@
from websockets.exceptions import ConnectionClosed
from whisper_live.server import TranscriptionServer, BackendType, ClientManager
from whisper_live.client import Client, TranscriptionClient, TranscriptionTeeClient
from whisper_live.backend.base import ServeClientBase
from whisper.normalizers import EnglishTextNormalizer

class TestResetTranscriptServer(unittest.TestCase):
"""Unit tests for ServeClientBase.reset_transcript()."""

def _make_client(self):
"""Return a minimal ServeClientBase with a mocked websocket."""
mock_ws = mock.MagicMock()
client = ServeClientBase.__new__(ServeClientBase)
client.client_uid = "test-uid"
client.websocket = mock_ws
client.lock = threading.Lock()
client.transcript = [
{"start": "0.000", "end": "1.000", "text": "Hello", "completed": True}
]
client.text = ["Hello"]
client.current_out = "Hello"
client.prev_out = "Hello"
client.same_output_count = 3
client.end_time_for_same_output = 1.0
client.frames_np = np.zeros(16000, dtype=np.float32)
client.frames_offset = 5.0
client.timestamp_offset = 6.0
client.translation_queue = None
return client, mock_ws

def test_reset_clears_transcript_and_state(self):
"""reset_transcript() must zero-out all accumulated state."""
client, _ = self._make_client()
client.reset_transcript()

self.assertEqual(client.transcript, [])
self.assertEqual(client.text, [])
self.assertEqual(client.current_out, '')
self.assertEqual(client.prev_out, '')
self.assertEqual(client.same_output_count, 0)
self.assertIsNone(client.end_time_for_same_output)
self.assertIsNone(client.frames_np)
self.assertEqual(client.frames_offset, 0.0)
self.assertEqual(client.timestamp_offset, 0.0)

def test_reset_sends_ack_to_client(self):
"""reset_transcript() must send the TRANSCRIPT_RESET ACK over WebSocket."""
client, mock_ws = self._make_client()
client.reset_transcript()

mock_ws.send.assert_called_once()
payload = json.loads(mock_ws.send.call_args[0][0])
self.assertEqual(payload["uid"], "test-uid")
self.assertEqual(payload["message"], "TRANSCRIPT_RESET")

class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
Expand Down
29 changes: 29 additions & 0 deletions whisper_live/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ServeClientBase(object):
RATE = 16000
SERVER_READY = "SERVER_READY"
DISCONNECT = "DISCONNECT"
TRANSCRIPT_RESET = "TRANSCRIPT_RESET"

client_uid: str
"""A unique identifier for the client."""
Expand Down Expand Up @@ -260,6 +261,34 @@ def disconnect(self):
"message": self.DISCONNECT
}))

def reset_transcript(self):
"""
Reset the accumulated transcript and audio state for this client.

Clears all stored segments, resets audio buffer pointers, and resets
repeated-output detection counters so that the next transcription
starts from a clean state. Sends a TRANSCRIPT_RESET acknowledgment
to the client over WebSocket.
"""
with self.lock:
self.transcript = []
self.text = []
self.current_out = ''
self.prev_out = ''
self.same_output_count = 0
self.end_time_for_same_output = None
self.frames_np = None
self.frames_offset = 0.0
self.timestamp_offset = 0.0
Comment on lines +273 to +282
logging.info(f"[{self.client_uid}] Transcript reset.")
try:
self.websocket.send(json.dumps({
"uid": self.client_uid,
"message": self.TRANSCRIPT_RESET
}))
except Exception as e:
logging.error(f"[ERROR]: Failed to send TRANSCRIPT_RESET ack: {e}")

def cleanup(self):
"""
Perform cleanup tasks before exiting the transcription service.
Expand Down
23 changes: 23 additions & 0 deletions whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ def on_message(self, ws, message):
if "translated_segments" in message.keys():
self.process_segments(message["translated_segments"], translated=True)

if message.get("message") == "TRANSCRIPT_RESET":
print("[INFO]: Server confirmed transcript reset.")

def on_error(self, ws, error):
print(f"[ERROR] WebSocket Error: {error}")
self.server_error = True
Expand Down Expand Up @@ -359,6 +362,26 @@ def write_srt_file(self, output_path="output.srt"):
if self.enable_translation:
utils.create_srt_file(self.translated_transcript, self.translation_srt_file_path)

def reset_transcript(self):
"""
Reset the local transcript and notify the server to reset its state.

Clears the local transcript, translated transcript, last segment pointers,
and sends a RESET_TRANSCRIPT control frame to the server so that both
sides are in sync.
"""
self.transcript = []
self.translated_transcript = []
self.last_segment = None
self.last_received_segment = None
print("[INFO]: Sending RESET_TRANSCRIPT to server...")
try:
self.client_socket.send(
json.dumps({"action": "RESET_TRANSCRIPT", "uid": self.uid})
)
except Exception as e:
print(f"[ERROR]: Failed to send RESET_TRANSCRIPT: {e}")

def wait_before_disconnect(self):
"""Waits a bit before disconnecting in order to process pending responses."""
assert self.last_response_received
Expand Down
23 changes: 22 additions & 1 deletion whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,32 @@ def get_audio_from_websocket(self, websocket):
"""
Receives audio buffer from websocket and creates a numpy array out of it.

Also handles JSON control frames sent by the client. Currently supports:
- b"END_OF_AUDIO" : signals end of the audio stream.
- {"action": "RESET_TRANSCRIPT"} : resets the accumulated transcript.

Args:
websocket: The websocket to receive audio from.

Returns:
A numpy array containing the audio.
A numpy array containing the audio, False on END_OF_AUDIO,
or the string "RESET" for a transcript-reset control frame.
"""
frame_data = websocket.recv()
if frame_data == b"END_OF_AUDIO":
return False
# Handle JSON control frames (sent as text or bytes that decode to JSON)
if isinstance(frame_data, (str, bytes)):
try:
decoded = frame_data if isinstance(frame_data, str) else frame_data.decode("utf-8")
control = json.loads(decoded)
if control.get("action") == "RESET_TRANSCRIPT":
client = self.client_manager.get_client(websocket)
if client:
client.reset_transcript()
return "RESET"
except (UnicodeDecodeError, json.JSONDecodeError, ValueError):
pass # Not a JSON control frame; treat as audio bytes
Comment on lines 309 to +322
return np.frombuffer(frame_data, dtype=np.float32)

def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
Expand Down Expand Up @@ -340,6 +357,10 @@ def process_audio_frames(self, websocket):
client.set_eos(True)
return False

# Control frame: transcript was reset server-side; continue the loop.
if isinstance(frame_np, str) and frame_np == "RESET":
return True

if self.backend.is_tensorrt():
voice_active = self.voice_activity(websocket, frame_np)
if voice_active:
Expand Down
Loading