Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/runpod_flash/core/resources/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ def get_image_name(
f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}",
)

# Base images for process injection (no flash-worker baked in)
FLASH_GPU_BASE_IMAGE = os.environ.get(
"FLASH_GPU_BASE_IMAGE", "pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime"
)
FLASH_CPU_BASE_IMAGE = os.environ.get("FLASH_CPU_BASE_IMAGE", "python:3.11-slim")

# Worker tarball for process injection
FLASH_WORKER_VERSION = os.environ.get("FLASH_WORKER_VERSION", "1.1.1")
FLASH_WORKER_TARBALL_URL_TEMPLATE = os.environ.get(
"FLASH_WORKER_TARBALL_URL",
"https://github.com/runpod-workers/flash/releases/download/"
"v{version}/flash-worker-v{version}-py3.11-linux-x86_64.tar.gz",
)

# Worker configuration defaults
DEFAULT_WORKERS_MIN = 0
DEFAULT_WORKERS_MAX = 1
Expand Down
54 changes: 54 additions & 0 deletions src/runpod_flash/core/resources/injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Process injection utilities for flash-worker tarball delivery."""

from .constants import FLASH_WORKER_TARBALL_URL_TEMPLATE, FLASH_WORKER_VERSION


def build_injection_cmd(
worker_version: str = FLASH_WORKER_VERSION,
tarball_url: str | None = None,
) -> str:
"""Build the dockerArgs command that downloads, extracts, and runs flash-worker.

Supports remote URLs (curl/wget) and local file paths (file://) for testing.
Includes version-based caching to skip re-extraction on warm workers.
Network volume caching stores extracted tarball at /runpod-volume/.flash-worker/v{version}.
"""
if tarball_url is None:
tarball_url = FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=worker_version)

if tarball_url.startswith("file://"):
local_path = tarball_url[7:]
return (
"bash -c '"
"set -e; FW_DIR=/opt/flash-worker; "
"mkdir -p $FW_DIR; "
f"tar xzf {local_path} -C $FW_DIR --strip-components=1; "
"exec $FW_DIR/bootstrap.sh'"
)

return (
"bash -c '"
f"set -e; FW_DIR=/opt/flash-worker; FW_VER={worker_version}; "
# Network volume cache check
'NV_CACHE="/runpod-volume/.flash-worker/v$FW_VER"; '
'if [ -d "$NV_CACHE" ] && [ -f "$NV_CACHE/.version" ]; then '
'cp -r "$NV_CACHE" "$FW_DIR"; '
# Local cache check (container disk persistence between restarts)
'elif [ -f "$FW_DIR/.version" ] && [ "$(cat $FW_DIR/.version)" = "$FW_VER" ]; then '
"true; "
"else "
"mkdir -p $FW_DIR; "
f'DL_URL="{tarball_url}"; '
"dl() { "
'(command -v curl >/dev/null 2>&1 && curl -sSL "$1" || '
'command -v wget >/dev/null 2>&1 && wget -qO- "$1" || '
'python3 -c "import urllib.request,sys;sys.stdout.buffer.write(urllib.request.urlopen(sys.argv[1]).read())" "$1"); '
"}; "
'dl "$DL_URL" '
"| tar xz -C $FW_DIR --strip-components=1; "
# Cache to network volume if available
"if [ -d /runpod-volume ]; then "
'mkdir -p "$NV_CACHE" && cp -r "$FW_DIR"/* "$NV_CACHE/" 2>/dev/null || true; fi; '
"fi; "
"exec $FW_DIR/bootstrap.sh'"
)
69 changes: 30 additions & 39 deletions src/runpod_flash/core/resources/live_serverless.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,90 @@
# Ship serverless code as you write it. No builds, no deploys — just run.
from typing import ClassVar

from pydantic import model_validator

from .constants import (
GPU_BASE_IMAGE_PYTHON_VERSION,
get_image_name,
local_python_version,
)
from .injection import build_injection_cmd
from .load_balancer_sls_resource import (
CpuLoadBalancerSlsResource,
LoadBalancerSlsResource,
)
from .serverless import ServerlessEndpoint
from .serverless_cpu import CpuServerlessEndpoint
from .template import PodTemplate


class LiveServerlessMixin:
"""Common mixin for live serverless endpoints that locks the image."""

_image_type: ClassVar[str] = (
"" # Override in subclasses: 'gpu', 'cpu', 'lb', 'lb-cpu'
)
_GPU_IMAGE_TYPES: ClassVar[frozenset[str]] = frozenset({"gpu", "lb"})
"""Configures process injection via dockerArgs for any base image.

@property
def _live_image(self) -> str:
python_version = getattr(self, "python_version", None)
if not python_version:
if self._image_type in self._GPU_IMAGE_TYPES:
python_version = GPU_BASE_IMAGE_PYTHON_VERSION
else:
python_version = local_python_version()
return get_image_name(self._image_type, python_version)
Sets a default base image (user can override via imageName) and generates
dockerArgs to download, extract, and run the flash-worker tarball at container
start time. QB vs LB mode is determined by FLASH_ENDPOINT_TYPE env var at
runtime, not by the Docker image.
"""

@property
def imageName(self):
return self._live_image
def _create_new_template(self) -> PodTemplate:
"""Create template with dockerArgs for process injection."""
template = super()._create_new_template() # type: ignore[misc]
template.dockerArgs = build_injection_cmd()
return template

@imageName.setter
def imageName(self, value):
pass
def _configure_existing_template(self) -> None:
"""Configure existing template, adding dockerArgs for injection if not user-set."""
super()._configure_existing_template() # type: ignore[misc]
if self.template is not None and not self.template.dockerArgs: # type: ignore[attr-defined]
self.template.dockerArgs = build_injection_cmd() # type: ignore[attr-defined]
Comment on lines +20 to +38
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These Live* resources now rely on template.dockerArgs to perform injection at container start, but local preview (flash deploy --preview) starts containers via docker run <image> and does not apply template dockerArgs. With the new default base images, preview containers may not start the flash worker at all unless the preview path explicitly uses the legacy images or executes the injection command. Please ensure preview mode uses _legacy_image (or otherwise applies the injection command) before switching Live* defaults to base images.

Copilot uses AI. Check for mistakes.


class LiveServerless(LiveServerlessMixin, ServerlessEndpoint):
"""GPU-only live serverless endpoint."""

_image_type: ClassVar[str] = "gpu"

@model_validator(mode="before")
@classmethod
def set_live_serverless_template(cls, data: dict):
"""Set default GPU image for Live Serverless."""
python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION
data["imageName"] = get_image_name("gpu", python_version)
if "imageName" not in data:
python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION
data["imageName"] = get_image_name("gpu", python_version)
return data


class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint):
"""CPU-only live serverless endpoint with automatic disk sizing."""

_image_type: ClassVar[str] = "cpu"

@model_validator(mode="before")
@classmethod
def set_live_serverless_template(cls, data: dict):
"""Set default CPU image for Live Serverless."""
python_version = data.get("python_version") or local_python_version()
data["imageName"] = get_image_name("cpu", python_version)
if "imageName" not in data:
python_version = data.get("python_version") or local_python_version()
data["imageName"] = get_image_name("cpu", python_version)
return data


class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource):
"""Live load-balanced endpoint."""

_image_type: ClassVar[str] = "lb"

@model_validator(mode="before")
@classmethod
def set_live_lb_template(cls, data: dict):
"""Set default image for Live Load-Balanced endpoint."""
python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION
data["imageName"] = get_image_name("lb", python_version)
if "imageName" not in data:
python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION
data["imageName"] = get_image_name("lb", python_version)
return data


class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource):
"""CPU-only live load-balanced endpoint."""

_image_type: ClassVar[str] = "lb-cpu"

@model_validator(mode="before")
@classmethod
def set_live_cpu_lb_template(cls, data: dict):
"""Set default CPU image for Live Load-Balanced endpoint."""
python_version = data.get("python_version") or local_python_version()
data["imageName"] = get_image_name("lb-cpu", python_version)
if "imageName" not in data:
python_version = data.get("python_version") or local_python_version()
data["imageName"] = get_image_name("lb-cpu", python_version)
return data
32 changes: 14 additions & 18 deletions tests/integration/test_cpu_disk_sizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def test_live_serverless_cpu_integration(self):
)

# Verify integration:
# 1. Uses CPU image (locked)
# 1. Uses CPU base image (default)
# 2. CPU utilities calculate minimum disk size
# 3. Template creation with auto-sizing
# 4. Validation passes
assert "flash-cpu:" in live_serverless.imageName
assert "runpod/flash-cpu:" in live_serverless.imageName
assert live_serverless.instanceIds == [
CpuInstanceType.CPU5C_1_2,
CpuInstanceType.CPU5C_2_4,
Expand Down Expand Up @@ -244,28 +244,24 @@ def test_mixed_cpu_generations_integration(self):
assert "cpu5c-1-2: max 15GB" in error_msg


class TestLiveServerlessImageLockingIntegration:
"""Test image locking integration in live serverless variants."""
class TestLiveServerlessImageDefaultsIntegration:
"""Test image defaults in live serverless variants."""

def test_live_serverless_image_consistency(self):
"""Test that LiveServerless variants maintain image consistency."""
def test_live_serverless_image_defaults(self):
"""Test that LiveServerless variants use correct base images."""
gpu_live = LiveServerless(name="gpu-live")
cpu_live = CpuLiveServerless(name="cpu-live")

# Verify different images are used
# Verify different base images are used
assert gpu_live.imageName != cpu_live.imageName
assert "flash:" in gpu_live.imageName
assert "flash-cpu:" in cpu_live.imageName
assert "runpod/flash:" in gpu_live.imageName
assert "runpod/flash-cpu:" in cpu_live.imageName

# Verify images remain locked despite attempts to change
original_gpu_image = gpu_live.imageName
original_cpu_image = cpu_live.imageName

gpu_live.imageName = "custom/image:latest"
cpu_live.imageName = "custom/image:latest"

assert gpu_live.imageName == original_gpu_image
assert cpu_live.imageName == original_cpu_image
# Verify images can be overridden (BYOI)
custom_gpu = LiveServerless(
name="custom-gpu", imageName="nvidia/cuda:12.8.0-runtime"
)
assert custom_gpu.imageName == "nvidia/cuda:12.8.0-runtime"

def test_live_serverless_template_integration(self):
"""Test live serverless template integration with disk sizing."""
Expand Down
20 changes: 8 additions & 12 deletions tests/integration/test_lb_remote_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,18 @@ async def echo(message: str):
# Verify resource is correctly configured
# Note: name may have "-fb" appended by flash boot validator
assert "test-live-api" in lb.name
assert "flash-lb" in lb.imageName
assert "runpod/flash-lb:" in lb.imageName # GPU LB base image
assert echo.__remote_config__["method"] == "POST"

def test_live_load_balancer_image_locked(self):
"""Test that LiveLoadBalancer locks the image to Flash LB image."""
def test_live_load_balancer_default_image(self):
"""Test that LiveLoadBalancer uses GPU LB base image by default."""
lb = LiveLoadBalancer(name="test-api")
assert "runpod/flash-lb:" in lb.imageName

# Verify image is locked and cannot be overridden
original_image = lb.imageName
assert "flash-lb" in original_image

# Try to set a different image (should be ignored due to property)
lb.imageName = "custom-image:latest"

# Image should still be locked to Flash
assert lb.imageName == original_image
def test_live_load_balancer_allows_custom_image(self):
"""Test that LiveLoadBalancer allows user to set custom image (BYOI)."""
lb = LiveLoadBalancer(name="test-api", imageName="custom-image:latest")
assert lb.imageName == "custom-image:latest"

def test_load_balancer_vs_queue_based_endpoints(self):
"""Test that LB and QB endpoints have different characteristics."""
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/resources/test_injection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Unit tests for process injection utilities."""

from runpod_flash.core.resources.injection import build_injection_cmd


class TestBuildInjectionCmd:
"""Test build_injection_cmd() output format."""

def test_default_remote_url(self):
"""Test default remote URL generation."""
cmd = build_injection_cmd(worker_version="1.1.1")

assert cmd.startswith("bash -c '")
assert "FW_VER=1.1.1" in cmd
assert "runpod-workers/flash/releases/download/v1.1.1/" in cmd
assert "bootstrap.sh'" in cmd
Comment on lines +9 to +16
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test assumes the default tarball URL template is the GitHub releases URL (asserting flash-worker/releases/download/...). Since FLASH_WORKER_TARBALL_URL is configurable via environment variable, the default template may differ in some test environments, causing a false failure. Consider asserting against FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=...) or passing an explicit tarball_url in the test.

Copilot uses AI. Check for mistakes.

def test_custom_tarball_url(self):
"""Test custom tarball URL."""
url = "https://example.com/worker.tar.gz"
cmd = build_injection_cmd(worker_version="2.0.0", tarball_url=url)

assert "FW_VER=2.0.0" in cmd
assert url in cmd

def test_file_url_for_local_testing(self):
"""Test file:// URL generates local extraction command."""
cmd = build_injection_cmd(
worker_version="1.0.0",
tarball_url="file:///tmp/flash-worker.tar.gz",
)

assert "tar xzf /tmp/flash-worker.tar.gz" in cmd
assert "curl" not in cmd
assert "wget" not in cmd
assert "bootstrap.sh'" in cmd

def test_version_caching_logic(self):
"""Test that version-based cache check is included."""
cmd = build_injection_cmd(worker_version="1.1.1")

# Should check .version file
assert ".version" in cmd
assert "FW_VER" in cmd

def test_network_volume_caching(self):
"""Test network volume cache path is included."""
cmd = build_injection_cmd(worker_version="1.1.1")

assert "/runpod-volume/.flash-worker/" in cmd
assert "NV_CACHE" in cmd

def test_curl_wget_python_fallback(self):
"""Test curl/wget/python3 fallback chain."""
cmd = build_injection_cmd(worker_version="1.0.0")

assert "curl -sSL" in cmd
assert "wget -qO-" in cmd
assert "urllib.request" in cmd

def test_default_uses_constants(self):
"""Test that calling with no args uses module-level constants."""
from runpod_flash.core.resources.constants import FLASH_WORKER_VERSION

cmd = build_injection_cmd()

assert f"FW_VER={FLASH_WORKER_VERSION}" in cmd
assert f"v{FLASH_WORKER_VERSION}" in cmd

def test_strip_components_in_remote_extraction(self):
"""Test tar uses --strip-components=1 for remote downloads."""
cmd = build_injection_cmd(worker_version="1.0.0")

assert "--strip-components=1" in cmd

def test_strip_components_in_local_extraction(self):
"""Test tar uses --strip-components=1 for local file extraction."""
cmd = build_injection_cmd(
worker_version="1.0.0",
tarball_url="file:///tmp/fw.tar.gz",
)

assert "--strip-components=1" in cmd
Loading
Loading