Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions src/runpod_flash/execute_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
prevent memory leaks through LRU eviction.
"""

import asyncio
import hashlib
import inspect
import logging
Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(self, *args, **kwargs):
f"{cls.__name__}_{uuid.uuid4().hex[:UUID_FALLBACK_LENGTH]}"
)
self._initialized = False
self._init_lock = asyncio.Lock()

# Generate cache key and get class code
self._cache_key = get_class_cache_key(cls, args, kwargs)
Expand All @@ -223,20 +225,23 @@ def __init__(self, *args, **kwargs):
)

async def _ensure_initialized(self):
"""Ensure the remote instance is created."""
"""Ensure the remote instance is created exactly once, even under concurrent calls."""
# Fast path: already initialized, no lock needed.
if self._initialized:
return

# Get remote resource
resource_manager = ResourceManager()
remote_resource = await resource_manager.get_or_deploy_resource(
self._resource_config
)
self._stub = stub_resource(remote_resource, **self._extra)
# Slow path: acquire lock and re-check to prevent double deployment
# when multiple coroutines race past the fast-path check.
async with self._init_lock:
if self._initialized:
return

# Create the remote instance by calling a method (which will trigger instance creation)
# We'll do this on first method call
self._initialized = True
resource_manager = ResourceManager()
remote_resource = await resource_manager.get_or_deploy_resource(
self._resource_config
)
self._stub = stub_resource(remote_resource, **self._extra)
self._initialized = True

def __getattr__(self, name):
"""Dynamically create method proxies for all class methods."""
Expand Down
Empty file added tests/bug_probes/__init__.py
Empty file.
145 changes: 145 additions & 0 deletions tests/bug_probes/test_class_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Bug probe tests for execute_class.py race conditions."""

import asyncio
from unittest.mock import AsyncMock, MagicMock, patch

import pytest


class TestNEW1_EnsureInitializedRace:
"""AE-2370: _ensure_initialized has no async lock — concurrent calls cause double deploy.

Without a lock, two concurrent calls to _ensure_initialized both pass
the `if not self._initialized` check and both call get_or_deploy_resource,
causing a double deploy and orphaning one stub.
"""
Comment on lines +10 to +15

@pytest.fixture
def wrapper_instance(self):
"""Create a RemoteClassWrapper instance with mocked dependencies."""

class FakeModel:
__name__ = "FakeModel"

def predict(self, x):
return x

resource_config = MagicMock()

with (
patch("runpod_flash.execute_class.get_class_cache_key", return_value="key"),
patch(
"runpod_flash.execute_class.get_or_cache_class_data",
return_value="code",
),
):
from runpod_flash.execute_class import create_remote_class

wrapper_cls = create_remote_class(
cls=FakeModel,
resource_config=resource_config,
dependencies=None,
system_dependencies=None,
accelerate_downloads=False,
extra={},
)
instance = wrapper_cls()

return instance

@pytest.mark.asyncio
async def test_concurrent_calls_deploy_only_once(self, wrapper_instance):
"""Two concurrent _ensure_initialized calls must call get_or_deploy_resource exactly once."""
deploy_call_count = 0
gate = asyncio.Event()

async def slow_deploy(config):
nonlocal deploy_call_count
deploy_call_count += 1
await gate.wait()
return MagicMock()

with (
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
):
mock_rm = MagicMock()
mock_rm.get_or_deploy_resource = slow_deploy
mock_rm_cls.return_value = mock_rm

task1 = asyncio.create_task(wrapper_instance._ensure_initialized())
task2 = asyncio.create_task(wrapper_instance._ensure_initialized())

await asyncio.sleep(0.05)
gate.set()

await asyncio.gather(task1, task2)
Comment on lines +70 to +76

assert deploy_call_count == 1, (
f"get_or_deploy_resource called {deploy_call_count} times, expected 1. "
"Race condition: concurrent calls both passed the initialized check."
)

@pytest.mark.asyncio
async def test_initialized_flag_set_after_deploy(self, wrapper_instance):
"""After _ensure_initialized completes, _initialized must be True."""
with (
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
):
mock_rm = MagicMock()
mock_rm.get_or_deploy_resource = AsyncMock(return_value=MagicMock())
mock_rm_cls.return_value = mock_rm

await wrapper_instance._ensure_initialized()

assert wrapper_instance._initialized is True

@pytest.mark.asyncio
async def test_second_call_skips_deploy(self, wrapper_instance):
"""Once initialized, subsequent calls must not call get_or_deploy_resource."""
with (
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
):
mock_rm = MagicMock()
mock_rm.get_or_deploy_resource = AsyncMock(return_value=MagicMock())
mock_rm_cls.return_value = mock_rm

await wrapper_instance._ensure_initialized()
mock_rm.get_or_deploy_resource.assert_awaited_once()

await wrapper_instance._ensure_initialized()
mock_rm.get_or_deploy_resource.assert_awaited_once()

@pytest.mark.asyncio
async def test_deploy_failure_releases_lock_and_allows_retry(
self, wrapper_instance
):
"""If deploy fails, the lock must be released and a subsequent call must retry."""
call_count = 0
Comment on lines +115 to +120

async def failing_then_succeeding_deploy(config):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ConnectionError("transient failure")
return MagicMock()

with (
patch("runpod_flash.execute_class.ResourceManager") as mock_rm_cls,
patch("runpod_flash.execute_class.stub_resource", return_value=MagicMock()),
):
mock_rm = MagicMock()
mock_rm.get_or_deploy_resource = failing_then_succeeding_deploy
mock_rm_cls.return_value = mock_rm

with pytest.raises(ConnectionError, match="transient failure"):
await wrapper_instance._ensure_initialized()

assert not wrapper_instance._initialized

# Retry should succeed
await wrapper_instance._ensure_initialized()
assert wrapper_instance._initialized
assert call_count == 2
Loading