diff --git a/src/runpod_flash/core/resources/constants.py b/src/runpod_flash/core/resources/constants.py index 7d826534..c78600bb 100644 --- a/src/runpod_flash/core/resources/constants.py +++ b/src/runpod_flash/core/resources/constants.py @@ -142,6 +142,20 @@ def get_image_name( f"runpod/flash-lb-cpu:py{DEFAULT_PYTHON_VERSION}-{_RESOLVED_TAG}", ) +# Base images for process injection (no flash-worker baked in) +FLASH_GPU_BASE_IMAGE = os.environ.get( + "FLASH_GPU_BASE_IMAGE", "pytorch/pytorch:2.9.1-cuda12.8-cudnn9-runtime" +) +FLASH_CPU_BASE_IMAGE = os.environ.get("FLASH_CPU_BASE_IMAGE", "python:3.11-slim") + +# Worker tarball for process injection +FLASH_WORKER_VERSION = os.environ.get("FLASH_WORKER_VERSION", "1.1.1") +FLASH_WORKER_TARBALL_URL_TEMPLATE = os.environ.get( + "FLASH_WORKER_TARBALL_URL", + "https://github.com/runpod-workers/flash/releases/download/" + "v{version}/flash-worker-v{version}-py3.11-linux-x86_64.tar.gz", +) + # Worker configuration defaults DEFAULT_WORKERS_MIN = 0 DEFAULT_WORKERS_MAX = 1 diff --git a/src/runpod_flash/core/resources/injection.py b/src/runpod_flash/core/resources/injection.py new file mode 100644 index 00000000..9ce72ec2 --- /dev/null +++ b/src/runpod_flash/core/resources/injection.py @@ -0,0 +1,54 @@ +"""Process injection utilities for flash-worker tarball delivery.""" + +from .constants import FLASH_WORKER_TARBALL_URL_TEMPLATE, FLASH_WORKER_VERSION + + +def build_injection_cmd( + worker_version: str = FLASH_WORKER_VERSION, + tarball_url: str | None = None, +) -> str: + """Build the dockerArgs command that downloads, extracts, and runs flash-worker. + + Supports remote URLs (curl/wget) and local file paths (file://) for testing. + Includes version-based caching to skip re-extraction on warm workers. + Network volume caching stores extracted tarball at /runpod-volume/.flash-worker/v{version}. + """ + if tarball_url is None: + tarball_url = FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=worker_version) + + if tarball_url.startswith("file://"): + local_path = tarball_url[7:] + return ( + "bash -c '" + "set -e; FW_DIR=/opt/flash-worker; " + "mkdir -p $FW_DIR; " + f"tar xzf {local_path} -C $FW_DIR --strip-components=1; " + "exec $FW_DIR/bootstrap.sh'" + ) + + return ( + "bash -c '" + f"set -e; FW_DIR=/opt/flash-worker; FW_VER={worker_version}; " + # Network volume cache check + 'NV_CACHE="/runpod-volume/.flash-worker/v$FW_VER"; ' + 'if [ -d "$NV_CACHE" ] && [ -f "$NV_CACHE/.version" ]; then ' + 'cp -r "$NV_CACHE" "$FW_DIR"; ' + # Local cache check (container disk persistence between restarts) + 'elif [ -f "$FW_DIR/.version" ] && [ "$(cat $FW_DIR/.version)" = "$FW_VER" ]; then ' + "true; " + "else " + "mkdir -p $FW_DIR; " + f'DL_URL="{tarball_url}"; ' + "dl() { " + '(command -v curl >/dev/null 2>&1 && curl -sSL "$1" || ' + 'command -v wget >/dev/null 2>&1 && wget -qO- "$1" || ' + 'python3 -c "import urllib.request,sys;sys.stdout.buffer.write(urllib.request.urlopen(sys.argv[1]).read())" "$1"); ' + "}; " + 'dl "$DL_URL" ' + "| tar xz -C $FW_DIR --strip-components=1; " + # Cache to network volume if available + "if [ -d /runpod-volume ]; then " + 'mkdir -p "$NV_CACHE" && cp -r "$FW_DIR"/* "$NV_CACHE/" 2>/dev/null || true; fi; ' + "fi; " + "exec $FW_DIR/bootstrap.sh'" + ) diff --git a/src/runpod_flash/core/resources/live_serverless.py b/src/runpod_flash/core/resources/live_serverless.py index a7e1930b..71c5ee30 100644 --- a/src/runpod_flash/core/resources/live_serverless.py +++ b/src/runpod_flash/core/resources/live_serverless.py @@ -1,6 +1,4 @@ # Ship serverless code as you write it. No builds, no deploys — just run. -from typing import ClassVar - from pydantic import model_validator from .constants import ( @@ -8,92 +6,85 @@ get_image_name, local_python_version, ) +from .injection import build_injection_cmd from .load_balancer_sls_resource import ( CpuLoadBalancerSlsResource, LoadBalancerSlsResource, ) from .serverless import ServerlessEndpoint from .serverless_cpu import CpuServerlessEndpoint +from .template import PodTemplate class LiveServerlessMixin: - """Common mixin for live serverless endpoints that locks the image.""" - - _image_type: ClassVar[str] = ( - "" # Override in subclasses: 'gpu', 'cpu', 'lb', 'lb-cpu' - ) - _GPU_IMAGE_TYPES: ClassVar[frozenset[str]] = frozenset({"gpu", "lb"}) + """Configures process injection via dockerArgs for any base image. - @property - def _live_image(self) -> str: - python_version = getattr(self, "python_version", None) - if not python_version: - if self._image_type in self._GPU_IMAGE_TYPES: - python_version = GPU_BASE_IMAGE_PYTHON_VERSION - else: - python_version = local_python_version() - return get_image_name(self._image_type, python_version) + Sets a default base image (user can override via imageName) and generates + dockerArgs to download, extract, and run the flash-worker tarball at container + start time. QB vs LB mode is determined by FLASH_ENDPOINT_TYPE env var at + runtime, not by the Docker image. + """ - @property - def imageName(self): - return self._live_image + def _create_new_template(self) -> PodTemplate: + """Create template with dockerArgs for process injection.""" + template = super()._create_new_template() # type: ignore[misc] + template.dockerArgs = build_injection_cmd() + return template - @imageName.setter - def imageName(self, value): - pass + def _configure_existing_template(self) -> None: + """Configure existing template, adding dockerArgs for injection if not user-set.""" + super()._configure_existing_template() # type: ignore[misc] + if self.template is not None and not self.template.dockerArgs: # type: ignore[attr-defined] + self.template.dockerArgs = build_injection_cmd() # type: ignore[attr-defined] class LiveServerless(LiveServerlessMixin, ServerlessEndpoint): """GPU-only live serverless endpoint.""" - _image_type: ClassVar[str] = "gpu" - @model_validator(mode="before") @classmethod def set_live_serverless_template(cls, data: dict): """Set default GPU image for Live Serverless.""" - python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION - data["imageName"] = get_image_name("gpu", python_version) + if "imageName" not in data: + python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION + data["imageName"] = get_image_name("gpu", python_version) return data class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint): """CPU-only live serverless endpoint with automatic disk sizing.""" - _image_type: ClassVar[str] = "cpu" - @model_validator(mode="before") @classmethod def set_live_serverless_template(cls, data: dict): """Set default CPU image for Live Serverless.""" - python_version = data.get("python_version") or local_python_version() - data["imageName"] = get_image_name("cpu", python_version) + if "imageName" not in data: + python_version = data.get("python_version") or local_python_version() + data["imageName"] = get_image_name("cpu", python_version) return data class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource): """Live load-balanced endpoint.""" - _image_type: ClassVar[str] = "lb" - @model_validator(mode="before") @classmethod def set_live_lb_template(cls, data: dict): """Set default image for Live Load-Balanced endpoint.""" - python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION - data["imageName"] = get_image_name("lb", python_version) + if "imageName" not in data: + python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION + data["imageName"] = get_image_name("lb", python_version) return data class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource): """CPU-only live load-balanced endpoint.""" - _image_type: ClassVar[str] = "lb-cpu" - @model_validator(mode="before") @classmethod def set_live_cpu_lb_template(cls, data: dict): """Set default CPU image for Live Load-Balanced endpoint.""" - python_version = data.get("python_version") or local_python_version() - data["imageName"] = get_image_name("lb-cpu", python_version) + if "imageName" not in data: + python_version = data.get("python_version") or local_python_version() + data["imageName"] = get_image_name("lb-cpu", python_version) return data diff --git a/tests/integration/test_cpu_disk_sizing.py b/tests/integration/test_cpu_disk_sizing.py index e9032850..3b7d7e4f 100644 --- a/tests/integration/test_cpu_disk_sizing.py +++ b/tests/integration/test_cpu_disk_sizing.py @@ -125,11 +125,11 @@ def test_live_serverless_cpu_integration(self): ) # Verify integration: - # 1. Uses CPU image (locked) + # 1. Uses CPU base image (default) # 2. CPU utilities calculate minimum disk size # 3. Template creation with auto-sizing # 4. Validation passes - assert "flash-cpu:" in live_serverless.imageName + assert "runpod/flash-cpu:" in live_serverless.imageName assert live_serverless.instanceIds == [ CpuInstanceType.CPU5C_1_2, CpuInstanceType.CPU5C_2_4, @@ -244,28 +244,24 @@ def test_mixed_cpu_generations_integration(self): assert "cpu5c-1-2: max 15GB" in error_msg -class TestLiveServerlessImageLockingIntegration: - """Test image locking integration in live serverless variants.""" +class TestLiveServerlessImageDefaultsIntegration: + """Test image defaults in live serverless variants.""" - def test_live_serverless_image_consistency(self): - """Test that LiveServerless variants maintain image consistency.""" + def test_live_serverless_image_defaults(self): + """Test that LiveServerless variants use correct base images.""" gpu_live = LiveServerless(name="gpu-live") cpu_live = CpuLiveServerless(name="cpu-live") - # Verify different images are used + # Verify different base images are used assert gpu_live.imageName != cpu_live.imageName - assert "flash:" in gpu_live.imageName - assert "flash-cpu:" in cpu_live.imageName + assert "runpod/flash:" in gpu_live.imageName + assert "runpod/flash-cpu:" in cpu_live.imageName - # Verify images remain locked despite attempts to change - original_gpu_image = gpu_live.imageName - original_cpu_image = cpu_live.imageName - - gpu_live.imageName = "custom/image:latest" - cpu_live.imageName = "custom/image:latest" - - assert gpu_live.imageName == original_gpu_image - assert cpu_live.imageName == original_cpu_image + # Verify images can be overridden (BYOI) + custom_gpu = LiveServerless( + name="custom-gpu", imageName="nvidia/cuda:12.8.0-runtime" + ) + assert custom_gpu.imageName == "nvidia/cuda:12.8.0-runtime" def test_live_serverless_template_integration(self): """Test live serverless template integration with disk sizing.""" diff --git a/tests/integration/test_lb_remote_execution.py b/tests/integration/test_lb_remote_execution.py index 11a3d14a..5426c42f 100644 --- a/tests/integration/test_lb_remote_execution.py +++ b/tests/integration/test_lb_remote_execution.py @@ -114,22 +114,18 @@ async def echo(message: str): # Verify resource is correctly configured # Note: name may have "-fb" appended by flash boot validator assert "test-live-api" in lb.name - assert "flash-lb" in lb.imageName + assert "runpod/flash-lb:" in lb.imageName # GPU LB base image assert echo.__remote_config__["method"] == "POST" - def test_live_load_balancer_image_locked(self): - """Test that LiveLoadBalancer locks the image to Flash LB image.""" + def test_live_load_balancer_default_image(self): + """Test that LiveLoadBalancer uses GPU LB base image by default.""" lb = LiveLoadBalancer(name="test-api") + assert "runpod/flash-lb:" in lb.imageName - # Verify image is locked and cannot be overridden - original_image = lb.imageName - assert "flash-lb" in original_image - - # Try to set a different image (should be ignored due to property) - lb.imageName = "custom-image:latest" - - # Image should still be locked to Flash - assert lb.imageName == original_image + def test_live_load_balancer_allows_custom_image(self): + """Test that LiveLoadBalancer allows user to set custom image (BYOI).""" + lb = LiveLoadBalancer(name="test-api", imageName="custom-image:latest") + assert lb.imageName == "custom-image:latest" def test_load_balancer_vs_queue_based_endpoints(self): """Test that LB and QB endpoints have different characteristics.""" diff --git a/tests/unit/resources/test_injection.py b/tests/unit/resources/test_injection.py new file mode 100644 index 00000000..d8cf05de --- /dev/null +++ b/tests/unit/resources/test_injection.py @@ -0,0 +1,83 @@ +"""Unit tests for process injection utilities.""" + +from runpod_flash.core.resources.injection import build_injection_cmd + + +class TestBuildInjectionCmd: + """Test build_injection_cmd() output format.""" + + def test_default_remote_url(self): + """Test default remote URL generation.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + assert cmd.startswith("bash -c '") + assert "FW_VER=1.1.1" in cmd + assert "runpod-workers/flash/releases/download/v1.1.1/" in cmd + assert "bootstrap.sh'" in cmd + + def test_custom_tarball_url(self): + """Test custom tarball URL.""" + url = "https://example.com/worker.tar.gz" + cmd = build_injection_cmd(worker_version="2.0.0", tarball_url=url) + + assert "FW_VER=2.0.0" in cmd + assert url in cmd + + def test_file_url_for_local_testing(self): + """Test file:// URL generates local extraction command.""" + cmd = build_injection_cmd( + worker_version="1.0.0", + tarball_url="file:///tmp/flash-worker.tar.gz", + ) + + assert "tar xzf /tmp/flash-worker.tar.gz" in cmd + assert "curl" not in cmd + assert "wget" not in cmd + assert "bootstrap.sh'" in cmd + + def test_version_caching_logic(self): + """Test that version-based cache check is included.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + # Should check .version file + assert ".version" in cmd + assert "FW_VER" in cmd + + def test_network_volume_caching(self): + """Test network volume cache path is included.""" + cmd = build_injection_cmd(worker_version="1.1.1") + + assert "/runpod-volume/.flash-worker/" in cmd + assert "NV_CACHE" in cmd + + def test_curl_wget_python_fallback(self): + """Test curl/wget/python3 fallback chain.""" + cmd = build_injection_cmd(worker_version="1.0.0") + + assert "curl -sSL" in cmd + assert "wget -qO-" in cmd + assert "urllib.request" in cmd + + def test_default_uses_constants(self): + """Test that calling with no args uses module-level constants.""" + from runpod_flash.core.resources.constants import FLASH_WORKER_VERSION + + cmd = build_injection_cmd() + + assert f"FW_VER={FLASH_WORKER_VERSION}" in cmd + assert f"v{FLASH_WORKER_VERSION}" in cmd + + def test_strip_components_in_remote_extraction(self): + """Test tar uses --strip-components=1 for remote downloads.""" + cmd = build_injection_cmd(worker_version="1.0.0") + + assert "--strip-components=1" in cmd + + def test_strip_components_in_local_extraction(self): + """Test tar uses --strip-components=1 for local file extraction.""" + cmd = build_injection_cmd( + worker_version="1.0.0", + tarball_url="file:///tmp/fw.tar.gz", + ) + + assert "--strip-components=1" in cmd diff --git a/tests/unit/resources/test_live_load_balancer.py b/tests/unit/resources/test_live_load_balancer.py index b12e13ed..63a924b0 100644 --- a/tests/unit/resources/test_live_load_balancer.py +++ b/tests/unit/resources/test_live_load_balancer.py @@ -1,11 +1,13 @@ -""" -Unit tests for LiveLoadBalancer class and template serialization. -""" +"""Unit tests for LiveLoadBalancer class and template serialization.""" +import importlib import os import pytest - +from runpod_flash.core.resources.constants import ( + GPU_BASE_IMAGE_PYTHON_VERSION, + local_python_version, +) from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.live_serverless import ( CpuLiveLoadBalancer, @@ -23,7 +25,6 @@ def test_live_load_balancer_creation_with_local_tag(self, monkeypatch): """Test LiveLoadBalancer creates with local image tag.""" monkeypatch.setenv("FLASH_IMAGE_TAG", "local") # Need to reload modules to pick up new env var - import importlib import runpod_flash.core.resources.constants as const_module import runpod_flash.core.resources.live_serverless as ls_module @@ -42,22 +43,30 @@ def test_live_load_balancer_default_image_tag(self): os.environ.pop("FLASH_IMAGE_TAG", None) lb = LiveLoadBalancer(name="test-lb") - - assert "runpod/flash-lb:" in lb.imageName + assert f"py{GPU_BASE_IMAGE_PYTHON_VERSION}" in lb.imageName assert lb.template is not None assert lb.template.imageName == lb.imageName + def test_live_load_balancer_user_can_override_image(self): + """Test user can set custom imageName (BYOI).""" + lb = LiveLoadBalancer(name="test-lb", imageName="custom/image:v1") + assert lb.imageName == "custom/image:v1" + def test_live_load_balancer_template_creation(self): """Test LiveLoadBalancer creates proper template from imageName.""" lb = LiveLoadBalancer(name="cpu_processor") - # Should have a template created from imageName assert lb.template is not None assert lb.template.imageName == lb.imageName - # Template name uses resource IDs, not the original name assert "LiveLoadBalancer" in lb.template.name assert "PodTemplate" in lb.template.name + def test_live_load_balancer_template_has_docker_args(self): + """Test LiveLoadBalancer template has process injection dockerArgs.""" + lb = LiveLoadBalancer(name="test-lb") + assert lb.template.dockerArgs + assert "bootstrap.sh" in lb.template.dockerArgs + def test_live_load_balancer_template_env_variables(self): """Test LiveLoadBalancer template includes environment variables.""" lb = LiveLoadBalancer( @@ -69,7 +78,6 @@ def test_live_load_balancer_template_env_variables(self): assert lb.template.env is not None assert len(lb.template.env) > 0 - # Check for custom env var custom_vars = [kv for kv in lb.template.env if kv.key == "CUSTOM_VAR"] assert len(custom_vars) == 1 assert custom_vars[0].value == "custom_value" @@ -78,14 +86,11 @@ def test_live_load_balancer_payload_serialization(self): """Test LiveLoadBalancer serializes correctly for GraphQL deployment.""" lb = LiveLoadBalancer(name="data_processor") - # Generate payload as would be sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Template must be in payload (not imageName since that's in _input_only) assert "template" in payload assert "imageName" not in payload - # Template must have all required fields template = payload["template"] assert "imageName" in template assert "name" in template @@ -94,14 +99,11 @@ def test_live_load_balancer_payload_serialization(self): def test_live_load_balancer_type_is_lb(self): """Test LiveLoadBalancer has type=LB.""" lb = LiveLoadBalancer(name="test-lb") - assert lb.type.value == "LB" - assert str(lb.type) == "ServerlessType.LB" def test_live_load_balancer_scaler_is_request_count(self): """Test LiveLoadBalancer uses REQUEST_COUNT scaler.""" lb = LiveLoadBalancer(name="test-lb") - assert lb.scalerType.value == "REQUEST_COUNT" @@ -147,21 +149,15 @@ def test_live_load_balancer_serialization_roundtrip(self): env={"API_KEY": "secret123"}, ) - # Simulate what gets sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Verify GraphQL payload has template assert "template" in payload, "Template must be in GraphQL payload" assert payload["template"]["imageName"] is not None assert payload["template"]["name"] is not None - - # Verify imageName is NOT in payload (it's in _input_only) assert "imageName" not in payload - # Verify the template has the correct image - assert "flash-lb:" in payload["template"]["imageName"], ( - "Must have load-balancer image" - ) + # dockerArgs must contain injection command + assert "bootstrap.sh" in payload["template"]["dockerArgs"] def test_template_env_serialization(self): """Test template environment variables serialize correctly.""" @@ -176,7 +172,6 @@ def test_template_env_serialization(self): assert isinstance(template_env, list) assert len(template_env) >= 2 - # Check env vars are serialized as {key, value} objects var_keys = {kv["key"] for kv in template_env} assert "VAR1" in var_keys assert "VAR2" in var_keys @@ -189,7 +184,6 @@ def test_cpu_live_load_balancer_creation_with_local_tag(self, monkeypatch): """Test CpuLiveLoadBalancer creates with local image tag.""" monkeypatch.setenv("FLASH_IMAGE_TAG", "local") # Need to reload modules to pick up new env var - import importlib import runpod_flash.core.resources.constants as const_module import runpod_flash.core.resources.live_serverless as ls_module @@ -209,16 +203,18 @@ def test_cpu_live_load_balancer_default_image_tag(self): os.environ.pop("FLASH_IMAGE_TAG", None) lb = CpuLiveLoadBalancer(name="test-lb") - - assert "runpod/flash-lb-cpu:" in lb.imageName + assert f"py{local_python_version()}" in lb.imageName assert lb.template is not None assert lb.template.imageName == lb.imageName + def test_cpu_live_load_balancer_user_can_override_image(self): + """Test CpuLiveLoadBalancer allows user image override.""" + lb = CpuLiveLoadBalancer(name="test-lb", imageName="python:3.11-slim") + assert lb.imageName == "python:3.11-slim" + def test_cpu_live_load_balancer_defaults(self): """Test CpuLiveLoadBalancer defaults to CPU3G_2_8.""" lb = CpuLiveLoadBalancer(name="test-lb") - - # Should default to CPU3G_2_8 assert lb.instanceIds == [CpuInstanceType.CPU3G_2_8] def test_cpu_live_load_balancer_with_specific_cpu_instances(self): @@ -227,34 +223,27 @@ def test_cpu_live_load_balancer_with_specific_cpu_instances(self): name="test-lb", instanceIds=[CpuInstanceType.CPU3G_1_4], ) - assert lb.instanceIds == [CpuInstanceType.CPU3G_1_4] def test_cpu_live_load_balancer_type_is_lb(self): """Test CpuLiveLoadBalancer has type=LB.""" lb = CpuLiveLoadBalancer(name="test-lb") - assert lb.type.value == "LB" - assert str(lb.type) == "ServerlessType.LB" def test_cpu_live_load_balancer_scaler_is_request_count(self): """Test CpuLiveLoadBalancer uses REQUEST_COUNT scaler.""" lb = CpuLiveLoadBalancer(name="test-lb") - assert lb.scalerType.value == "REQUEST_COUNT" def test_cpu_live_load_balancer_payload_serialization(self): """Test CpuLiveLoadBalancer serializes correctly for GraphQL deployment.""" lb = CpuLiveLoadBalancer(name="data_processor") - # Generate payload as would be sent to RunPod payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # Template must be in payload (not imageName since that's in _input_only) assert "template" in payload assert "imageName" not in payload - # Template must have all required fields template = payload["template"] assert "imageName" in template assert "name" in template @@ -266,7 +255,12 @@ def test_cpu_live_load_balancer_excludes_gpu_fields(self): payload = lb.model_dump(exclude=lb._input_only, exclude_none=True, mode="json") - # GPU-specific fields should not be in payload assert "gpus" not in payload assert "gpuIds" not in payload assert "cudaVersions" not in payload + + def test_cpu_live_load_balancer_template_has_docker_args(self): + """Test CpuLiveLoadBalancer template has process injection dockerArgs.""" + lb = CpuLiveLoadBalancer(name="test-lb") + assert lb.template.dockerArgs + assert "bootstrap.sh" in lb.template.dockerArgs diff --git a/tests/unit/resources/test_live_serverless.py b/tests/unit/resources/test_live_serverless.py index eab01492..5f10513c 100644 --- a/tests/unit/resources/test_live_serverless.py +++ b/tests/unit/resources/test_live_serverless.py @@ -1,6 +1,4 @@ -""" -Unit tests for LiveServerless and CpuLiveServerless classes. -""" +"""Unit tests for LiveServerless, CpuLiveServerless, and LiveServerlessMixin.""" import pytest from runpod_flash.core.resources.constants import ( @@ -9,10 +7,10 @@ ) from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.live_serverless import ( - LiveServerless, + CpuLiveLoadBalancer, CpuLiveServerless, LiveLoadBalancer, - CpuLiveLoadBalancer, + LiveServerless, ) from runpod_flash.core.resources.template import PodTemplate @@ -28,30 +26,19 @@ def test_live_serverless_workers_min_cannot_exceed_workers_max(self): LiveServerless(name="broken", workersMin=5, workersMax=1) def test_live_serverless_gpu_defaults(self): - """Test LiveServerless uses GPU image and defaults.""" - live_serverless = LiveServerless( - name="example_gpu_live_serverless", - ) + """Test LiveServerless uses GPU base image and defaults.""" + live_serverless = LiveServerless(name="example_gpu_live_serverless") - # Should not have CPU instances, uses default 64GB assert live_serverless.instanceIds is None assert live_serverless.template is not None assert live_serverless.template.containerDiskInGb == 64 - assert "flash:" in live_serverless.imageName # GPU image - def test_live_serverless_image_locked(self): - """Test LiveServerless imageName is locked to GPU image.""" + def test_live_serverless_user_can_override_image(self): + """Test user can set custom imageName (BYOI).""" live_serverless = LiveServerless( - name="example_gpu_live_serverless", + name="test", imageName="nvidia/cuda:12.8.0-runtime-ubuntu22.04" ) - - original_image = live_serverless.imageName - - # Attempt to change imageName - should be ignored - live_serverless.imageName = "custom/image:latest" - - assert live_serverless.imageName == original_image - assert "flash:" in live_serverless.imageName # Still GPU image + assert live_serverless.imageName == "nvidia/cuda:12.8.0-runtime-ubuntu22.04" def test_live_serverless_with_custom_template(self): """Test LiveServerless with custom template.""" @@ -60,31 +47,30 @@ def test_live_serverless_with_custom_template(self): imageName="test/image:v1", containerDiskInGb=100, ) - live_serverless = LiveServerless( name="example_gpu_live_serverless", template=template, ) - - # Should preserve custom template settings assert live_serverless.template.containerDiskInGb == 100 + def test_live_serverless_template_has_docker_args(self): + """Test that the template includes dockerArgs for process injection.""" + live_serverless = LiveServerless(name="test") + assert live_serverless.template is not None + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs + class TestCpuLiveServerless: """Test CpuLiveServerless class behavior.""" def test_cpu_live_serverless_defaults(self): """Test CpuLiveServerless uses CPU image and auto-sizing.""" - live_serverless = CpuLiveServerless( - name="example_cpu_live_serverless", - ) + live_serverless = CpuLiveServerless(name="example_cpu_live_serverless") - # Should default to CPU3G_2_8 assert live_serverless.instanceIds == [CpuInstanceType.CPU3G_2_8] assert live_serverless.template is not None - # Default disk size should be 20GB for CPU3G_2_8 assert live_serverless.template.containerDiskInGb == 20 - assert "flash-cpu:" in live_serverless.imageName # CPU image def test_cpu_live_serverless_custom_instances(self): """Test CpuLiveServerless with custom CPU instances.""" @@ -92,7 +78,6 @@ def test_cpu_live_serverless_custom_instances(self): name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], ) - assert live_serverless.instanceIds == [CpuInstanceType.CPU3G_1_4] assert live_serverless.template is not None assert live_serverless.template.containerDiskInGb == 10 @@ -103,33 +88,21 @@ def test_cpu_live_serverless_multiple_instances(self): name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4, CpuInstanceType.CPU5C_2_4], ) - assert live_serverless.template is not None - assert live_serverless.template.containerDiskInGb == 10 # Min of 10 and 30 - - def test_cpu_live_serverless_image_locked(self): - """Test CpuLiveServerless imageName is locked to CPU image.""" - live_serverless = CpuLiveServerless( - name="example_cpu_live_serverless", - instanceIds=[CpuInstanceType.CPU3G_1_4], - ) - - original_image = live_serverless.imageName - - # Attempt to change imageName - should be ignored - live_serverless.imageName = "custom/image:latest" + assert live_serverless.template.containerDiskInGb == 10 - assert live_serverless.imageName == original_image - assert "flash-cpu:" in live_serverless.imageName # Still CPU image + def test_cpu_live_serverless_user_can_override_image(self): + """Test CpuLiveServerless allows user to set custom image.""" + live_serverless = CpuLiveServerless(name="test", imageName="python:3.11-slim") + assert live_serverless.imageName == "python:3.11-slim" def test_cpu_live_serverless_validation_failure(self): """Test CpuLiveServerless validation fails with excessive disk size.""" template = PodTemplate( name="custom", imageName="test/image:v1", - containerDiskInGb=50, # Exceeds 10GB limit + containerDiskInGb=50, ) - with pytest.raises(ValueError, match="Container disk size 50GB exceeds"): CpuLiveServerless( name="example_cpu_live_serverless", @@ -140,76 +113,93 @@ def test_cpu_live_serverless_validation_failure(self): def test_cpu_live_serverless_with_existing_template_default_size(self): """Test CpuLiveServerless auto-sizes existing template with default disk size.""" template = PodTemplate(name="existing", imageName="test/image:v1") - # Template uses default size - live_serverless = CpuLiveServerless( name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], template=template, ) - - assert live_serverless.template.containerDiskInGb == 10 # Should be auto-sized + assert live_serverless.template.containerDiskInGb == 10 def test_cpu_live_serverless_preserves_custom_disk_size(self): """Test CpuLiveServerless preserves custom disk size in template.""" template = PodTemplate( name="existing", imageName="test/image:v1", - containerDiskInGb=5, # Custom size within limits + containerDiskInGb=5, ) - live_serverless = CpuLiveServerless( name="example_cpu_live_serverless", instanceIds=[CpuInstanceType.CPU3G_1_4], template=template, ) + assert live_serverless.template.containerDiskInGb == 5 - assert ( - live_serverless.template.containerDiskInGb == 5 - ) # Should preserve custom size + def test_cpu_live_serverless_template_has_docker_args(self): + """Test CpuLiveServerless template includes dockerArgs.""" + live_serverless = CpuLiveServerless(name="test") + assert live_serverless.template is not None + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs class TestLiveServerlessMixin: """Test LiveServerlessMixin functionality.""" - def test_live_image_property_gpu(self): - """Test LiveServerless _live_image property.""" - live_serverless = LiveServerless(name="test") - assert "flash:" in live_serverless._live_image - assert "cpu" not in live_serverless._live_image - - def test_live_image_property_cpu(self): - """Test CpuLiveServerless _live_image property.""" - live_serverless = CpuLiveServerless(name="test") - assert "flash-cpu:" in live_serverless._live_image - - def test_image_name_property_gpu(self): - """Test LiveServerless imageName property returns locked image.""" + def test_docker_args_set_on_new_template(self): + """Test dockerArgs is set when creating a new template.""" live_serverless = LiveServerless(name="test") - assert live_serverless.imageName == live_serverless._live_image + assert live_serverless.template.dockerArgs + assert "bash -c" in live_serverless.template.dockerArgs - def test_image_name_property_cpu(self): - """Test CpuLiveServerless imageName property returns locked image.""" - live_serverless = CpuLiveServerless(name="test") - assert live_serverless.imageName == live_serverless._live_image - - def test_image_name_setter_ignored_gpu(self): - """Test LiveServerless imageName setter is ignored.""" - live_serverless = LiveServerless(name="test") - original_image = live_serverless.imageName - - live_serverless.imageName = "should-be-ignored" - - assert live_serverless.imageName == original_image - - def test_image_name_setter_ignored_cpu(self): - """Test CpuLiveServerless imageName setter is ignored.""" - live_serverless = CpuLiveServerless(name="test") - original_image = live_serverless.imageName - - live_serverless.imageName = "should-be-ignored" + def test_docker_args_set_on_existing_template(self): + """Test dockerArgs is set when configuring an existing template.""" + template = PodTemplate( + name="existing", + imageName="test/image:v1", + ) + live_serverless = LiveServerless(name="test", template=template) + assert live_serverless.template.dockerArgs + assert "bootstrap.sh" in live_serverless.template.dockerArgs + + def test_all_live_classes_have_docker_args(self): + """Test all Live* classes set dockerArgs on their templates.""" + classes_and_kwargs = [ + (LiveServerless, {}), + (CpuLiveServerless, {}), + (LiveLoadBalancer, {}), + (CpuLiveLoadBalancer, {}), + ] + for cls, extra_kwargs in classes_and_kwargs: + resource = cls(name=f"test-{cls.__name__}", **extra_kwargs) + assert resource.template is not None, f"{cls.__name__} has no template" + assert resource.template.dockerArgs, f"{cls.__name__} has no dockerArgs" + assert "bootstrap.sh" in resource.template.dockerArgs, ( + f"{cls.__name__} missing bootstrap.sh in dockerArgs" + ) - assert live_serverless.imageName == original_image + def test_live_load_balancer_defaults(self): + """Test LiveLoadBalancer uses GPU image.""" + lb = LiveLoadBalancer(name="test-lb") + assert lb.imageName is not None + assert lb.template is not None + assert lb.template.dockerArgs + + def test_cpu_live_load_balancer_defaults(self): + """Test CpuLiveLoadBalancer uses CPU image.""" + lb = CpuLiveLoadBalancer(name="test-lb-cpu") + assert lb.imageName is not None + assert lb.template is not None + assert lb.template.dockerArgs + + def test_live_serverless_byoi_gpu(self): + """Test LiveServerless respects user-provided imageName.""" + live_serverless = LiveServerless(name="test", imageName="custom/gpu:v1") + assert live_serverless.imageName == "custom/gpu:v1" + + def test_live_serverless_byoi_cpu(self): + """Test CpuLiveServerless respects user-provided imageName.""" + live_serverless = CpuLiveServerless(name="test", imageName="custom/cpu:v1") + assert live_serverless.imageName == "custom/cpu:v1" class TestLiveServerlessPythonVersion: