diff --git a/examples/wait_for_ready.py b/examples/wait_for_ready.py new file mode 100644 index 0000000..731d590 --- /dev/null +++ b/examples/wait_for_ready.py @@ -0,0 +1,134 @@ +"""Example: Using wait_for_ready for cold start handling. + +This example demonstrates how to use the `wait_for_ready` parameter to +automatically handle model cold starts. When a model is not loaded (cold), +the SDK will trigger a warmup request and poll until the model becomes +healthy before making the actual request. + +This is useful for: +- Serverless deployments where models may be unloaded after idle time +- First requests to a model that hasn't been used recently +- Ensuring reliable request completion without manual status checking + +Usage: + export KAFEIDO_API_KEY="your-api-key" + python examples/wait_for_ready.py +""" + +import asyncio + +from kafeido import OpenAI, AsyncOpenAI, WarmupTimeoutError + + +def sync_example(): + """Synchronous example of wait_for_ready.""" + print("=== Synchronous Example ===\n") + + client = OpenAI() + + # Basic usage - wait for model to be ready before making request + print("1. Basic chat completion with wait_for_ready:") + try: + response = client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "Hello! What is 2+2?"}], + wait_for_ready=True, # Automatically handle cold start + ) + print(f" Response: {response.choices[0].message.content}\n") + except WarmupTimeoutError as e: + print(f" Model {e.model} didn't warm up in {e.waited_seconds:.1f}s\n") + + # With custom timeout - useful for models that take longer to load + print("2. Chat completion with custom timeout (2 minutes):") + try: + response = client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "Tell me a short joke."}], + wait_for_ready=True, + warmup_timeout=120.0, # Wait up to 2 minutes + ) + print(f" Response: {response.choices[0].message.content}\n") + except WarmupTimeoutError as e: + print(f" Model {e.model} didn't warm up in {e.waited_seconds:.1f}s\n") + + # Streaming also works with wait_for_ready + print("3. Streaming with wait_for_ready:") + try: + stream = client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "Count from 1 to 5."}], + stream=True, + wait_for_ready=True, + ) + print(" Response: ", end="") + for chunk in stream: + if chunk.choices[0].delta.content: + print(chunk.choices[0].delta.content, end="", flush=True) + print("\n") + except WarmupTimeoutError as e: + print(f" Model {e.model} didn't warm up in {e.waited_seconds:.1f}s\n") + + +async def async_example(): + """Asynchronous example of wait_for_ready.""" + print("=== Asynchronous Example ===\n") + + async with AsyncOpenAI() as client: + # Basic async usage + print("1. Async chat completion with wait_for_ready:") + try: + response = await client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "What is the capital of France?"}], + wait_for_ready=True, + ) + print(f" Response: {response.choices[0].message.content}\n") + except WarmupTimeoutError as e: + print(f" Model {e.model} didn't warm up in {e.waited_seconds:.1f}s\n") + + +def audio_example(): + """Example with audio transcription.""" + print("=== Audio Transcription Example ===\n") + + client = OpenAI() + + # Note: You need an actual audio file for this example + # This is just to show the API usage + print("Audio transcription with wait_for_ready:") + print(" (Requires an audio file to run)") + print(""" + # Example code: + with open("audio.mp3", "rb") as f: + transcript = client.audio.transcriptions.create( + file=f, + model="whisper-large-v3", + wait_for_ready=True, # Wait for whisper model to load + warmup_timeout=180.0, # ASR models may take longer + ) + print(transcript.text) + """) + + +def main(): + """Run all examples.""" + print("=" * 60) + print("Kafeido SDK - wait_for_ready Example") + print("=" * 60) + print() + + # Run sync example + sync_example() + + # Run async example + asyncio.run(async_example()) + + # Show audio example code + audio_example() + + print("=" * 60) + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/kafeido/__init__.py b/kafeido/__init__.py index c30b4f6..ac96e84 100644 --- a/kafeido/__init__.py +++ b/kafeido/__init__.py @@ -18,6 +18,7 @@ from kafeido.version import __version__ from kafeido.client import OpenAI from kafeido._async_client import AsyncOpenAI +from kafeido._warmup import WarmupTimeoutError from kafeido.types import ( # Errors OpenAIError, @@ -95,6 +96,7 @@ "UnprocessableEntityError", "RateLimitError", "InternalServerError", + "WarmupTimeoutError", # Chat "ChatCompletion", "ChatCompletionChunk", diff --git a/kafeido/_async_client.py b/kafeido/_async_client.py index a3625c6..9a62be1 100644 --- a/kafeido/_async_client.py +++ b/kafeido/_async_client.py @@ -7,6 +7,7 @@ from kafeido._auth import get_api_key from kafeido._http_client import AsyncHTTPClient +from kafeido._warmup import AsyncWarmupHelper from kafeido.resources._async_chat import AsyncChat from kafeido.resources._async_audio import AsyncAudio from kafeido.resources._async_models import AsyncModels @@ -76,13 +77,21 @@ def __init__( max_retries=max_retries, ) - # Initialize resources - self._chat = AsyncChat(self._http_client) - self._audio = AsyncAudio(self._http_client) + # Initialize models resource first (needed for warmup helper) self._models = AsyncModels(self._http_client) + + # Initialize warmup helper for cold start handling + self._warmup_helper = AsyncWarmupHelper( + status_fn=self._models.status, + warmup_fn=lambda m: self._models.warmup(model=m), + ) + + # Initialize resources with warmup helper + self._chat = AsyncChat(self._http_client, self._warmup_helper) + self._audio = AsyncAudio(self._http_client, self._warmup_helper) self._files = AsyncFiles(self._http_client) - self._ocr = AsyncOCR(self._http_client) - self._vision = AsyncVision(self._http_client) + self._ocr = AsyncOCR(self._http_client, self._warmup_helper) + self._vision = AsyncVision(self._http_client, self._warmup_helper) self._jobs = AsyncJobs(self._http_client) @property diff --git a/kafeido/_warmup.py b/kafeido/_warmup.py new file mode 100644 index 0000000..756b693 --- /dev/null +++ b/kafeido/_warmup.py @@ -0,0 +1,181 @@ +"""Cold start waiting / warmup helpers. + +This module provides helpers for handling model cold starts by automatically +triggering warmup and polling until the model is ready before making requests. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING, Awaitable, Callable, Optional + +if TYPE_CHECKING: + from kafeido.types.models import ModelStatus, WarmupResponse + + +# Default configuration +DEFAULT_POLL_INTERVAL = 2.0 # seconds between status checks +DEFAULT_MAX_WAIT_TIME = 300.0 # 5 minutes max wait +HEALTHY_STATUS = "healthy" + + +class WarmupTimeoutError(Exception): + """Raised when model warmup times out. + + Attributes: + model: The model ID that timed out. + waited_seconds: How long we waited before timing out. + """ + + def __init__(self, model: str, waited_seconds: float) -> None: + super().__init__( + f"Model '{model}' did not become ready within {waited_seconds:.1f}s" + ) + self.model = model + self.waited_seconds = waited_seconds + + +class WarmupHelper: + """Synchronous warmup helper for cold start waiting. + + This helper triggers model warmup and polls until the model becomes healthy. + """ + + def __init__( + self, + status_fn: Callable[[str], "ModelStatus"], + warmup_fn: Callable[[str], "WarmupResponse"], + poll_interval: float = DEFAULT_POLL_INTERVAL, + max_wait_time: float = DEFAULT_MAX_WAIT_TIME, + ) -> None: + """Initialize warmup helper. + + Args: + status_fn: Function to get model status (typically models.status). + warmup_fn: Function to trigger warmup (typically models.warmup). + poll_interval: Seconds between status checks. + max_wait_time: Maximum seconds to wait before timeout. + """ + self._status_fn = status_fn + self._warmup_fn = warmup_fn + self._poll_interval = poll_interval + self._max_wait_time = max_wait_time + + def wait_for_ready( + self, model: str, timeout: Optional[float] = None + ) -> None: + """Wait for model to be ready, triggering warmup if needed. + + This method will: + 1. Trigger a warmup request to start loading the model + 2. If model is already warm, return immediately + 3. Otherwise, poll the status endpoint until the model is healthy + 4. Raise WarmupTimeoutError if the model doesn't become ready in time + + Args: + model: The model ID to wait for. + timeout: Optional timeout override in seconds. If None, uses + the default max_wait_time from initialization. + + Raises: + WarmupTimeoutError: If model doesn't become ready within timeout. + """ + max_wait = timeout if timeout is not None else self._max_wait_time + + # First, trigger warmup + warmup_response = self._warmup_fn(model) + + if warmup_response.already_warm: + return # Model is already ready + + # Poll until ready or timeout + start_time = time.monotonic() + + while True: + elapsed = time.monotonic() - start_time + + if elapsed >= max_wait: + raise WarmupTimeoutError(model, elapsed) + + # Check status + status = self._status_fn(model) + + if status.status and status.status.status == HEALTHY_STATUS: + return # Model is ready + + # Wait before next poll + time.sleep(self._poll_interval) + + +class AsyncWarmupHelper: + """Asynchronous warmup helper for cold start waiting. + + This helper triggers model warmup and polls until the model becomes healthy, + using async/await for non-blocking operation. + """ + + def __init__( + self, + status_fn: Callable[[str], Awaitable["ModelStatus"]], + warmup_fn: Callable[[str], Awaitable["WarmupResponse"]], + poll_interval: float = DEFAULT_POLL_INTERVAL, + max_wait_time: float = DEFAULT_MAX_WAIT_TIME, + ) -> None: + """Initialize async warmup helper. + + Args: + status_fn: Async function to get model status. + warmup_fn: Async function to trigger warmup. + poll_interval: Seconds between status checks. + max_wait_time: Maximum seconds to wait before timeout. + """ + self._status_fn = status_fn + self._warmup_fn = warmup_fn + self._poll_interval = poll_interval + self._max_wait_time = max_wait_time + + async def wait_for_ready( + self, model: str, timeout: Optional[float] = None + ) -> None: + """Wait for model to be ready asynchronously. + + This method will: + 1. Trigger a warmup request to start loading the model + 2. If model is already warm, return immediately + 3. Otherwise, poll the status endpoint until the model is healthy + 4. Raise WarmupTimeoutError if the model doesn't become ready in time + + Args: + model: The model ID to wait for. + timeout: Optional timeout override in seconds. If None, uses + the default max_wait_time from initialization. + + Raises: + WarmupTimeoutError: If model doesn't become ready within timeout. + """ + max_wait = timeout if timeout is not None else self._max_wait_time + + # First, trigger warmup + warmup_response = await self._warmup_fn(model) + + if warmup_response.already_warm: + return # Model is already ready + + # Poll until ready or timeout + start_time = time.monotonic() + + while True: + elapsed = time.monotonic() - start_time + + if elapsed >= max_wait: + raise WarmupTimeoutError(model, elapsed) + + # Check status + status = await self._status_fn(model) + + if status.status and status.status.status == HEALTHY_STATUS: + return # Model is ready + + # Wait before next poll + await asyncio.sleep(self._poll_interval) diff --git a/kafeido/client.py b/kafeido/client.py index 9a008b6..78885e2 100644 --- a/kafeido/client.py +++ b/kafeido/client.py @@ -5,6 +5,7 @@ from kafeido._auth import get_api_key from kafeido._http_client import HTTPClient +from kafeido._warmup import WarmupHelper from kafeido.resources.chat import Chat from kafeido.resources.audio import Audio from kafeido.resources.models import Models @@ -79,13 +80,21 @@ def __init__( max_retries=max_retries, ) - # Initialize resources - self._chat = Chat(self._http_client) - self._audio = Audio(self._http_client) + # Initialize models resource first (needed for warmup helper) self._models = Models(self._http_client) + + # Initialize warmup helper for cold start handling + self._warmup_helper = WarmupHelper( + status_fn=self._models.status, + warmup_fn=lambda m: self._models.warmup(model=m), + ) + + # Initialize resources with warmup helper + self._chat = Chat(self._http_client, self._warmup_helper) + self._audio = Audio(self._http_client, self._warmup_helper) self._files = Files(self._http_client) - self._ocr = OCR(self._http_client) - self._vision = Vision(self._http_client) + self._ocr = OCR(self._http_client, self._warmup_helper) + self._vision = Vision(self._http_client, self._warmup_helper) self._jobs = Jobs(self._http_client) @property diff --git a/kafeido/resources/_async_audio.py b/kafeido/resources/_async_audio.py index aed82ab..7a4e53c 100644 --- a/kafeido/resources/_async_audio.py +++ b/kafeido/resources/_async_audio.py @@ -1,6 +1,6 @@ """Async audio transcription and translation resources.""" -from typing import BinaryIO, Literal, Optional, Union +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, Union from kafeido._http_client import AsyncHTTPClient from kafeido.types.audio import ( @@ -11,6 +11,9 @@ ) from kafeido.types.tts import CreateSpeechAsyncResponse, GetSpeechResultResponse +if TYPE_CHECKING: + from kafeido._warmup import AsyncWarmupHelper + # Type alias for file inputs FileTypes = Union[BinaryIO, bytes] @@ -19,13 +22,19 @@ class AsyncTranscriptions: """Async audio transcriptions endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: """Initialize async transcriptions resource. Args: http_client: The async HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -37,6 +46,8 @@ async def create( response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", temperature: Optional[float] = None, timestamp_granularities: Optional[list[str]] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Transcription: """Transcribe audio to text asynchronously. @@ -48,19 +59,31 @@ async def create( response_format: Format of the response. temperature: Sampling temperature (0-1). timestamp_granularities: Granularity of timestamps. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: Transcription with text and optional segments. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = AsyncOpenAI(api_key="sk-...") >>> with open("audio.mp3", "rb") as f: ... transcript = await client.audio.transcriptions.create( ... file=f, - ... model="whisper-large-v3" + ... model="whisper-large-v3", + ... wait_for_ready=True, ... ) >>> print(transcript.text) """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Prepare multipart upload files = {"file": file} data = { @@ -127,13 +150,19 @@ async def get_result(self, *, job_id: str) -> AsyncTranscriptionResult: class AsyncTranslations: """Async audio translations endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: """Initialize async translations resource. Args: http_client: The async HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -143,6 +172,8 @@ async def create( prompt: Optional[str] = None, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", temperature: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Translation: """Translate audio to English asynchronously. @@ -152,19 +183,31 @@ async def create( prompt: Optional text to guide the model's style. response_format: Format of the response. temperature: Sampling temperature (0-1). + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: Translation with English text. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = AsyncOpenAI(api_key="sk-...") >>> with open("audio_spanish.mp3", "rb") as f: ... translation = await client.audio.translations.create( ... file=f, - ... model="whisper-large-v3" + ... model="whisper-large-v3", + ... wait_for_ready=True, ... ) >>> print(translation.text) # English translation """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Prepare multipart upload files = {"file": file} data = { @@ -188,8 +231,13 @@ async def create( class AsyncSpeech: """Async text-to-speech endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -207,8 +255,39 @@ async def create( top_p: Optional[float] = None, top_k: Optional[int] = None, max_tokens: Optional[int] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateSpeechAsyncResponse: - """Create a text-to-speech job asynchronously.""" + """Create a text-to-speech job asynchronously. + + Args: + model: TTS model ID. + input: Text to synthesize. + voice: Voice preset. + response_format: Audio format. + speed: Speech speed. + reference_audio_id: File ID for voice cloning. + reference_audio_key: Storage key for voice cloning. + language: Language code. + system_prompt: Scene description. + temperature: Sampling temperature. + top_p: Top-p sampling. + top_k: Top-k sampling. + max_tokens: Maximum tokens. + wait_for_ready: If True, wait for the model to be ready. + warmup_timeout: Maximum seconds to wait for warmup. + + Returns: + CreateSpeechAsyncResponse with job_id. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + body: dict = { "model": model, "input": input, @@ -248,16 +327,21 @@ async def get_result(self, *, job_id: str) -> GetSpeechResultResponse: class AsyncAudio: """Async audio resource.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: """Initialize async audio resource. Args: http_client: The async HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client - self._transcriptions = AsyncTranscriptions(http_client) - self._translations = AsyncTranslations(http_client) - self._speech = AsyncSpeech(http_client) + self._transcriptions = AsyncTranscriptions(http_client, warmup_helper) + self._translations = AsyncTranslations(http_client, warmup_helper) + self._speech = AsyncSpeech(http_client, warmup_helper) @property def transcriptions(self) -> AsyncTranscriptions: diff --git a/kafeido/resources/_async_chat.py b/kafeido/resources/_async_chat.py index 6618eeb..4f8f226 100644 --- a/kafeido/resources/_async_chat.py +++ b/kafeido/resources/_async_chat.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx @@ -14,17 +14,26 @@ ChatCompletionMessageParam, ) +if TYPE_CHECKING: + from kafeido._warmup import AsyncWarmupHelper + class AsyncCompletions: """Async chat completions endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: """Initialize async completions resource. Args: http_client: The async HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -47,6 +56,8 @@ async def create( tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, user: Optional[str] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Union[ChatCompletion, AsyncStream[ChatCompletionChunk]]: """Create a chat completion asynchronously. @@ -69,18 +80,30 @@ async def create( tools: List of tools the model can call. tool_choice: Controls which tool is called. user: Unique identifier for the end-user. + wait_for_ready: If True, wait for the model to be ready before + making the request. This handles cold starts automatically. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: ChatCompletion or AsyncStream[ChatCompletionChunk] if streaming. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = AsyncOpenAI(api_key="sk-...") >>> response = await client.chat.completions.create( ... model="gpt-oss-20b", - ... messages=[{"role": "user", "content": "Hello!"}] + ... messages=[{"role": "user", "content": "Hello!"}], + ... wait_for_ready=True, ... ) >>> print(response.choices[0].message.content) """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Build request body body: Dict[str, Any] = { "model": model, @@ -143,14 +166,19 @@ async def create( class AsyncChat: """Async chat resource.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: """Initialize async chat resource. Args: http_client: The async HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client - self._completions = AsyncCompletions(http_client) + self._completions = AsyncCompletions(http_client, warmup_helper) @property def completions(self) -> AsyncCompletions: diff --git a/kafeido/resources/_async_ocr.py b/kafeido/resources/_async_ocr.py index 23d68b3..c64c9a5 100644 --- a/kafeido/resources/_async_ocr.py +++ b/kafeido/resources/_async_ocr.py @@ -1,6 +1,6 @@ """Async OCR resource.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional from kafeido._http_client import AsyncHTTPClient from kafeido.types.ocr import ( @@ -9,12 +9,20 @@ GetOCRResultResponse, ) +if TYPE_CHECKING: + from kafeido._warmup import AsyncWarmupHelper + class AsyncOCRExtractions: """Async OCR extraction endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -27,8 +35,34 @@ async def create( language: Optional[str] = None, custom_prompt: Optional[str] = None, max_tokens: Optional[int] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateOCRResponse: - """Extract text from an image asynchronously.""" + """Extract text from an image asynchronously. + + Args: + model_id: OCR model ID. + file_id: ID of a previously uploaded file. + storage_key: Storage key from upload service. + mode: OCR mode. + resolution: Resolution setting. + language: Language hint. + custom_prompt: Custom prompt. + max_tokens: Maximum tokens. + wait_for_ready: If True, wait for the model to be ready. + warmup_timeout: Maximum seconds to wait for warmup. + + Returns: + CreateOCRResponse with extracted text. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body = {"model_id": model_id} if file_id is not None: @@ -91,9 +125,13 @@ async def get_result(self, *, job_id: str) -> GetOCRResultResponse: class AsyncOCR: """Async OCR resource.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client - self._extractions = AsyncOCRExtractions(http_client) + self._extractions = AsyncOCRExtractions(http_client, warmup_helper) @property def extractions(self) -> AsyncOCRExtractions: diff --git a/kafeido/resources/_async_vision.py b/kafeido/resources/_async_vision.py index a9890b5..18d134c 100644 --- a/kafeido/resources/_async_vision.py +++ b/kafeido/resources/_async_vision.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from kafeido._http_client import AsyncHTTPClient from kafeido._streaming import AsyncStream @@ -13,12 +13,20 @@ GetVisionResultResponse, ) +if TYPE_CHECKING: + from kafeido._warmup import AsyncWarmupHelper + class AsyncVisionAnalysis: """Async vision analysis endpoint.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -34,8 +42,37 @@ async def create( top_p: Optional[float] = None, top_k: Optional[int] = None, repetition_penalty: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateVisionResponse: - """Analyze an image asynchronously.""" + """Analyze an image asynchronously. + + Args: + model_id: Vision model ID. + storage_key: Storage key of the image. + image_base64: Base64-encoded image. + image_url: URL of the image. + prompt: Analysis prompt. + mode: Analysis mode. + temperature: Sampling temperature. + max_tokens: Maximum tokens. + top_p: Top-p sampling. + top_k: Top-k sampling. + repetition_penalty: Repetition penalty. + wait_for_ready: If True, wait for the model to be ready. + warmup_timeout: Maximum seconds to wait for warmup. + + Returns: + CreateVisionResponse with analysis text. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body: Dict[str, Any] = {"model_id": model_id} if storage_key is not None: @@ -113,8 +150,13 @@ async def get_result(self, *, job_id: str) -> GetVisionResultResponse: class AsyncVisionChat: """Async vision chat endpoint with streaming support.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper async def create( self, @@ -128,8 +170,35 @@ async def create( top_p: Optional[float] = None, top_k: Optional[int] = None, repetition_penalty: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Union[CreateVisionChatResponse, AsyncStream[CreateVisionChatResponse]]: - """Chat with images asynchronously.""" + """Chat with images asynchronously. + + Args: + messages: List of vision chat messages. + model_id: Vision model ID. + stream: Whether to stream the response. + conversation_id: Conversation ID for multi-turn. + temperature: Sampling temperature. + max_tokens: Maximum tokens. + top_p: Top-p sampling. + top_k: Top-k sampling. + repetition_penalty: Repetition penalty. + wait_for_ready: If True, wait for the model to be ready. + warmup_timeout: Maximum seconds to wait for warmup. + + Returns: + Stream or CreateVisionChatResponse. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + await self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body: Dict[str, Any] = { "messages": messages, "model_id": model_id, @@ -162,10 +231,14 @@ async def create( class AsyncVision: """Async vision resource.""" - def __init__(self, http_client: AsyncHTTPClient) -> None: + def __init__( + self, + http_client: AsyncHTTPClient, + warmup_helper: Optional["AsyncWarmupHelper"] = None, + ) -> None: self._client = http_client - self._analyze = AsyncVisionAnalysis(http_client) - self._chat = AsyncVisionChat(http_client) + self._analyze = AsyncVisionAnalysis(http_client, warmup_helper) + self._chat = AsyncVisionChat(http_client, warmup_helper) @property def analyze(self) -> AsyncVisionAnalysis: diff --git a/kafeido/resources/audio.py b/kafeido/resources/audio.py index 09b6dd1..6ee89af 100644 --- a/kafeido/resources/audio.py +++ b/kafeido/resources/audio.py @@ -1,6 +1,6 @@ """Audio transcription and translation resources.""" -from typing import BinaryIO, Literal, Optional, Union +from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, Union from kafeido._http_client import HTTPClient from kafeido.types.audio import ( @@ -11,6 +11,9 @@ ) from kafeido.types.tts import CreateSpeechAsyncResponse, GetSpeechResultResponse +if TYPE_CHECKING: + from kafeido._warmup import WarmupHelper + # Type alias for file inputs FileTypes = Union[BinaryIO, bytes] @@ -19,13 +22,19 @@ class Transcriptions: """Audio transcriptions endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: """Initialize transcriptions resource. Args: http_client: The HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -37,6 +46,8 @@ def create( response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", temperature: Optional[float] = None, timestamp_granularities: Optional[list[str]] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Transcription: """Transcribe audio to text. @@ -48,19 +59,31 @@ def create( response_format: Format of the response. temperature: Sampling temperature (0-1). timestamp_granularities: Granularity of timestamps. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: Transcription with text and optional segments. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = OpenAI(api_key="sk-...") >>> with open("audio.mp3", "rb") as f: ... transcript = client.audio.transcriptions.create( ... file=f, - ... model="whisper-large-v3" + ... model="whisper-large-v3", + ... wait_for_ready=True, ... ) >>> print(transcript.text) """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Prepare multipart upload files = {"file": file} data = { @@ -147,13 +170,19 @@ def get_result(self, *, job_id: str) -> AsyncTranscriptionResult: class Translations: """Audio translations endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: """Initialize translations resource. Args: http_client: The HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -163,6 +192,8 @@ def create( prompt: Optional[str] = None, response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", temperature: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Translation: """Translate audio to English. @@ -172,19 +203,31 @@ def create( prompt: Optional text to guide the model's style. response_format: Format of the response. temperature: Sampling temperature (0-1). + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: Translation with English text. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = OpenAI(api_key="sk-...") >>> with open("audio_spanish.mp3", "rb") as f: ... translation = client.audio.translations.create( ... file=f, - ... model="whisper-large-v3" + ... model="whisper-large-v3", + ... wait_for_ready=True, ... ) >>> print(translation.text) # English translation """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Prepare multipart upload files = {"file": file} data = { @@ -208,8 +251,13 @@ def create( class Speech: """Text-to-speech endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -227,6 +275,8 @@ def create( top_p: Optional[float] = None, top_k: Optional[int] = None, max_tokens: Optional[int] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateSpeechAsyncResponse: """Create a text-to-speech job. @@ -244,10 +294,21 @@ def create( top_p: Top-p sampling. top_k: Top-k sampling. max_tokens: Maximum tokens. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: CreateSpeechAsyncResponse with job_id for polling. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + body: dict = { "model": model, "input": input, @@ -294,16 +355,21 @@ def get_result(self, *, job_id: str) -> GetSpeechResultResponse: class Audio: """Audio resource.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: """Initialize audio resource. Args: http_client: The HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client - self._transcriptions = Transcriptions(http_client) - self._translations = Translations(http_client) - self._speech = Speech(http_client) + self._transcriptions = Transcriptions(http_client, warmup_helper) + self._translations = Translations(http_client, warmup_helper) + self._speech = Speech(http_client, warmup_helper) @property def transcriptions(self) -> Transcriptions: diff --git a/kafeido/resources/chat.py b/kafeido/resources/chat.py index 1d80ebb..82324d4 100644 --- a/kafeido/resources/chat.py +++ b/kafeido/resources/chat.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union import httpx @@ -14,17 +14,26 @@ ChatCompletionMessageParam, ) +if TYPE_CHECKING: + from kafeido._warmup import WarmupHelper + class Completions: """Chat completions endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: """Initialize completions resource. Args: http_client: The HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -47,6 +56,8 @@ def create( tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[Union[str, Dict[str, Any]]] = None, user: Optional[str] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: """Create a chat completion. @@ -69,18 +80,32 @@ def create( tools: List of tools the model can call. tool_choice: Controls which tool is called. user: Unique identifier for the end-user. + wait_for_ready: If True, wait for the model to be ready before + making the request. This handles cold starts automatically + by triggering warmup and polling until the model is healthy. + warmup_timeout: Maximum seconds to wait for model warmup. + Defaults to 300 seconds (5 minutes) if not specified. Returns: ChatCompletion or Stream[ChatCompletionChunk] if streaming. + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. + Example: >>> client = OpenAI(api_key="sk-...") >>> response = client.chat.completions.create( ... model="gpt-oss-20b", - ... messages=[{"role": "user", "content": "Hello!"}] + ... messages=[{"role": "user", "content": "Hello!"}], + ... wait_for_ready=True, # Handle cold start automatically ... ) >>> print(response.choices[0].message.content) """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model, timeout=warmup_timeout) + # Build request body body: Dict[str, Any] = { "model": model, @@ -143,14 +168,19 @@ def create( class Chat: """Chat resource.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: """Initialize chat resource. Args: http_client: The HTTP client to use for requests. + warmup_helper: Optional warmup helper for cold start handling. """ self._client = http_client - self._completions = Completions(http_client) + self._completions = Completions(http_client, warmup_helper) @property def completions(self) -> Completions: diff --git a/kafeido/resources/ocr.py b/kafeido/resources/ocr.py index 984996b..c395708 100644 --- a/kafeido/resources/ocr.py +++ b/kafeido/resources/ocr.py @@ -1,6 +1,6 @@ """OCR resource.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional from kafeido._http_client import HTTPClient from kafeido.types.ocr import ( @@ -9,12 +9,20 @@ GetOCRResultResponse, ) +if TYPE_CHECKING: + from kafeido._warmup import WarmupHelper + class OCRExtractions: """OCR extraction endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -27,6 +35,8 @@ def create( language: Optional[str] = None, custom_prompt: Optional[str] = None, max_tokens: Optional[int] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateOCRResponse: """Extract text from an image. @@ -39,10 +49,21 @@ def create( language: Language hint (ISO-639-1 code). custom_prompt: Custom prompt for the OCR model. max_tokens: Maximum tokens in the response. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: CreateOCRResponse with extracted text and optional regions. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body = {"model_id": model_id} if file_id is not None: @@ -126,9 +147,13 @@ def get_result(self, *, job_id: str) -> GetOCRResultResponse: class OCR: """OCR resource.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client - self._extractions = OCRExtractions(http_client) + self._extractions = OCRExtractions(http_client, warmup_helper) @property def extractions(self) -> OCRExtractions: diff --git a/kafeido/resources/vision.py b/kafeido/resources/vision.py index 95b2d06..67afea5 100644 --- a/kafeido/resources/vision.py +++ b/kafeido/resources/vision.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from kafeido._http_client import HTTPClient from kafeido._streaming import Stream @@ -14,12 +14,20 @@ VisionChatMessage, ) +if TYPE_CHECKING: + from kafeido._warmup import WarmupHelper + class VisionAnalysis: """Vision analysis endpoint.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -35,6 +43,8 @@ def create( top_p: Optional[float] = None, top_k: Optional[int] = None, repetition_penalty: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> CreateVisionResponse: """Analyze an image. @@ -50,10 +60,21 @@ def create( top_p: Top-p sampling. top_k: Top-k sampling. repetition_penalty: Repetition penalty. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: CreateVisionResponse with analysis text. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body: Dict[str, Any] = {"model_id": model_id} if storage_key is not None: @@ -155,8 +176,13 @@ def get_result(self, *, job_id: str) -> GetVisionResultResponse: class VisionChat: """Vision chat endpoint with streaming support.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client + self._warmup_helper = warmup_helper def create( self, @@ -170,6 +196,8 @@ def create( top_p: Optional[float] = None, top_k: Optional[int] = None, repetition_penalty: Optional[float] = None, + wait_for_ready: bool = False, + warmup_timeout: Optional[float] = None, ) -> Union[CreateVisionChatResponse, Stream[CreateVisionChatResponse]]: """Chat with images. @@ -183,11 +211,22 @@ def create( top_p: Top-p sampling. top_k: Top-k sampling. repetition_penalty: Repetition penalty. + wait_for_ready: If True, wait for the model to be ready before + making the request. + warmup_timeout: Maximum seconds to wait for model warmup. Returns: Stream of CreateVisionChatResponse chunks if streaming, or a single CreateVisionChatResponse if not. + + Raises: + WarmupTimeoutError: If wait_for_ready is True and the model + doesn't become ready within the timeout period. """ + # Handle cold start waiting if enabled + if wait_for_ready and self._warmup_helper: + self._warmup_helper.wait_for_ready(model_id, timeout=warmup_timeout) + body: Dict[str, Any] = { "messages": messages, "model_id": model_id, @@ -220,10 +259,14 @@ def create( class Vision: """Vision resource.""" - def __init__(self, http_client: HTTPClient) -> None: + def __init__( + self, + http_client: HTTPClient, + warmup_helper: Optional["WarmupHelper"] = None, + ) -> None: self._client = http_client - self._analyze = VisionAnalysis(http_client) - self._chat = VisionChat(http_client) + self._analyze = VisionAnalysis(http_client, warmup_helper) + self._chat = VisionChat(http_client, warmup_helper) @property def analyze(self) -> VisionAnalysis: diff --git a/tests/test_warmup.py b/tests/test_warmup.py new file mode 100644 index 0000000..14bdb28 --- /dev/null +++ b/tests/test_warmup.py @@ -0,0 +1,266 @@ +"""Tests for cold start waiting / warmup helpers.""" + +import pytest +import httpx +import respx +from unittest.mock import Mock, AsyncMock + +from kafeido import OpenAI, AsyncOpenAI, WarmupTimeoutError +from kafeido._warmup import ( + WarmupHelper, + AsyncWarmupHelper, + DEFAULT_POLL_INTERVAL, + DEFAULT_MAX_WAIT_TIME, + HEALTHY_STATUS, +) +from kafeido.types.models import ModelStatus, ModelStatusInfo, WarmupResponse + + +class TestWarmupHelper: + """Tests for synchronous WarmupHelper.""" + + def test_already_warm_returns_immediately(self): + """When model is already warm, should return immediately.""" + warmup_fn = Mock(return_value=WarmupResponse(already_warm=True)) + status_fn = Mock() + + helper = WarmupHelper(status_fn, warmup_fn) + helper.wait_for_ready("test-model") + + warmup_fn.assert_called_once_with("test-model") + status_fn.assert_not_called() + + def test_polls_until_healthy(self): + """Should poll until model becomes healthy.""" + warmup_fn = Mock( + return_value=WarmupResponse(already_warm=False, estimated_seconds=10.0) + ) + + call_count = [0] + + def status_fn(model): + call_count[0] += 1 + if call_count[0] >= 3: + return ModelStatus( + model_id=model, status=ModelStatusInfo(status=HEALTHY_STATUS) + ) + return ModelStatus( + model_id=model, status=ModelStatusInfo(status="loading") + ) + + helper = WarmupHelper( + status_fn, warmup_fn, poll_interval=0.01 # Fast for testing + ) + helper.wait_for_ready("test-model") + + assert call_count[0] == 3 + + def test_timeout_raises_error(self): + """Should raise WarmupTimeoutError after max_wait_time.""" + warmup_fn = Mock( + return_value=WarmupResponse(already_warm=False, estimated_seconds=60.0) + ) + status_fn = Mock( + return_value=ModelStatus( + model_id="test", status=ModelStatusInfo(status="loading") + ) + ) + + helper = WarmupHelper( + status_fn, + warmup_fn, + poll_interval=0.01, + max_wait_time=0.05, + ) + + with pytest.raises(WarmupTimeoutError) as exc_info: + helper.wait_for_ready("test-model") + + assert exc_info.value.model == "test-model" + assert exc_info.value.waited_seconds >= 0.05 + + def test_custom_timeout_override(self): + """Should respect custom timeout parameter.""" + warmup_fn = Mock(return_value=WarmupResponse(already_warm=False)) + status_fn = Mock( + return_value=ModelStatus( + model_id="test", status=ModelStatusInfo(status="loading") + ) + ) + + helper = WarmupHelper( + status_fn, + warmup_fn, + poll_interval=0.01, + max_wait_time=10.0, # Default is 10s + ) + + with pytest.raises(WarmupTimeoutError) as exc_info: + helper.wait_for_ready("test-model", timeout=0.03) # Override to 0.03s + + assert exc_info.value.waited_seconds < 0.1 # Should timeout quickly + + +class TestAsyncWarmupHelper: + """Tests for asynchronous AsyncWarmupHelper.""" + + @pytest.mark.asyncio + async def test_already_warm_returns_immediately(self): + """Async: When model is already warm, should return immediately.""" + warmup_fn = AsyncMock(return_value=WarmupResponse(already_warm=True)) + status_fn = AsyncMock() + + helper = AsyncWarmupHelper(status_fn, warmup_fn) + await helper.wait_for_ready("test-model") + + warmup_fn.assert_called_once_with("test-model") + status_fn.assert_not_called() + + @pytest.mark.asyncio + async def test_polls_until_healthy(self): + """Async: Should poll until model becomes healthy.""" + warmup_fn = AsyncMock( + return_value=WarmupResponse(already_warm=False, estimated_seconds=10.0) + ) + + call_count = [0] + + async def status_fn(model): + call_count[0] += 1 + if call_count[0] >= 3: + return ModelStatus( + model_id=model, status=ModelStatusInfo(status=HEALTHY_STATUS) + ) + return ModelStatus( + model_id=model, status=ModelStatusInfo(status="loading") + ) + + helper = AsyncWarmupHelper( + status_fn, warmup_fn, poll_interval=0.01 # Fast for testing + ) + await helper.wait_for_ready("test-model") + + assert call_count[0] == 3 + + @pytest.mark.asyncio + async def test_timeout_raises_error(self): + """Async: Should raise WarmupTimeoutError after max_wait_time.""" + warmup_fn = AsyncMock( + return_value=WarmupResponse(already_warm=False, estimated_seconds=60.0) + ) + status_fn = AsyncMock( + return_value=ModelStatus( + model_id="test", status=ModelStatusInfo(status="loading") + ) + ) + + helper = AsyncWarmupHelper( + status_fn, + warmup_fn, + poll_interval=0.01, + max_wait_time=0.05, + ) + + with pytest.raises(WarmupTimeoutError) as exc_info: + await helper.wait_for_ready("test-model") + + assert exc_info.value.model == "test-model" + + +class TestChatCompletionWithWaitForReady: + """Integration tests for chat completion with wait_for_ready.""" + + @respx.mock + def test_chat_completion_with_warm_model(self, client, base_url, mock_chat_response): + """Test chat completion with wait_for_ready when model is already warm.""" + # Mock warmup endpoint - model already warm + respx.post(f"{base_url}/v1/models/warmup").mock( + return_value=httpx.Response(200, json={"already_warm": True}) + ) + + # Mock chat endpoint + respx.post(f"{base_url}/v1/chat/completions").mock( + return_value=httpx.Response(200, json=mock_chat_response) + ) + + response = client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "Hello!"}], + wait_for_ready=True, + ) + + assert response.choices[0].message.content == "Hello! How can I help you today?" + + @respx.mock + def test_chat_completion_warmup_polling(self, client, base_url, mock_chat_response): + """Test that warmup polls until ready.""" + status_call_count = [0] + + def status_response(request): + status_call_count[0] += 1 + if status_call_count[0] >= 2: + return httpx.Response( + 200, + json={"model_id": "gpt-oss-20b", "status": {"status": "healthy"}}, + ) + return httpx.Response( + 200, + json={ + "model_id": "gpt-oss-20b", + "status": { + "status": "loading", + "cold_start_progress": {"stage": "loading", "progress": 0.5}, + }, + }, + ) + + # Mock warmup - model not warm + respx.post(f"{base_url}/v1/models/warmup").mock( + return_value=httpx.Response( + 200, json={"already_warm": False, "estimated_seconds": 10.0} + ) + ) + + # Mock status endpoint with polling + respx.get(f"{base_url}/v1/models/gpt-oss-20b/status").mock( + side_effect=status_response + ) + + # Mock chat endpoint + respx.post(f"{base_url}/v1/chat/completions").mock( + return_value=httpx.Response(200, json=mock_chat_response) + ) + + # Create a client with fast polling for testing + from kafeido._warmup import WarmupHelper + + fast_client = OpenAI(api_key="sk-test123_dGVzdGtleQ==", base_url=base_url) + # Replace warmup helper with fast polling version + fast_client._warmup_helper = WarmupHelper( + status_fn=fast_client._models.status, + warmup_fn=lambda m: fast_client._models.warmup(model=m), + poll_interval=0.01, # Fast for testing + ) + + response = fast_client.chat.completions.create( + model="gpt-oss-20b", + messages=[{"role": "user", "content": "Hello!"}], + wait_for_ready=True, + ) + + assert status_call_count[0] >= 2 + assert response.choices[0].message.content is not None + + +class TestWarmupTimeoutErrorExport: + """Test that WarmupTimeoutError is properly exported.""" + + def test_warmup_timeout_error_exported(self): + """WarmupTimeoutError should be importable from kafeido.""" + from kafeido import WarmupTimeoutError + + error = WarmupTimeoutError("test-model", 30.5) + assert error.model == "test-model" + assert error.waited_seconds == 30.5 + assert "test-model" in str(error) + assert "30.5" in str(error)