-
Notifications
You must be signed in to change notification settings - Fork 8
feat: process injection via LiveServerlessMixin #260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
945645a
472037f
660efbd
6b223fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| """Process injection utilities for flash-worker tarball delivery.""" | ||
|
|
||
| from .constants import FLASH_WORKER_TARBALL_URL_TEMPLATE, FLASH_WORKER_VERSION | ||
|
|
||
|
|
||
| def build_injection_cmd( | ||
| worker_version: str = FLASH_WORKER_VERSION, | ||
| tarball_url: str | None = None, | ||
| ) -> str: | ||
| """Build the dockerArgs command that downloads, extracts, and runs flash-worker. | ||
|
|
||
| Supports remote URLs (curl/wget) and local file paths (file://) for testing. | ||
| Includes version-based caching to skip re-extraction on warm workers. | ||
| Network volume caching stores extracted tarball at /runpod-volume/.flash-worker/v{version}. | ||
| """ | ||
| if tarball_url is None: | ||
| tarball_url = FLASH_WORKER_TARBALL_URL_TEMPLATE.format(version=worker_version) | ||
|
|
||
| if tarball_url.startswith("file://"): | ||
| local_path = tarball_url[7:] | ||
| return ( | ||
| "bash -c '" | ||
| "set -e; FW_DIR=/opt/flash-worker; " | ||
| "mkdir -p $FW_DIR; " | ||
| f"tar xzf {local_path} -C $FW_DIR --strip-components=1; " | ||
| "exec $FW_DIR/bootstrap.sh'" | ||
| ) | ||
|
|
||
| return ( | ||
| "bash -c '" | ||
| f"set -e; FW_DIR=/opt/flash-worker; FW_VER={worker_version}; " | ||
| # Network volume cache check | ||
| 'NV_CACHE="/runpod-volume/.flash-worker/v$FW_VER"; ' | ||
| 'if [ -d "$NV_CACHE" ] && [ -f "$NV_CACHE/.version" ]; then ' | ||
| 'cp -r "$NV_CACHE" "$FW_DIR"; ' | ||
| # Local cache check (container disk persistence between restarts) | ||
| 'elif [ -f "$FW_DIR/.version" ] && [ "$(cat $FW_DIR/.version)" = "$FW_VER" ]; then ' | ||
| "true; " | ||
| "else " | ||
| "mkdir -p $FW_DIR; " | ||
| f'DL_URL="{tarball_url}"; ' | ||
| "dl() { " | ||
| '(command -v curl >/dev/null 2>&1 && curl -sSL "$1" || ' | ||
| 'command -v wget >/dev/null 2>&1 && wget -qO- "$1" || ' | ||
| 'python3 -c "import urllib.request,sys;sys.stdout.buffer.write(urllib.request.urlopen(sys.argv[1]).read())" "$1"); ' | ||
| "}; " | ||
| 'dl "$DL_URL" ' | ||
| "| tar xz -C $FW_DIR --strip-components=1; " | ||
| # Cache to network volume if available | ||
| "if [ -d /runpod-volume ]; then " | ||
| 'mkdir -p "$NV_CACHE" && cp -r "$FW_DIR"/* "$NV_CACHE/" 2>/dev/null || true; fi; ' | ||
deanq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "fi; " | ||
| "exec $FW_DIR/bootstrap.sh'" | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,99 +1,90 @@ | ||
| # Ship serverless code as you write it. No builds, no deploys — just run. | ||
| from typing import ClassVar | ||
|
|
||
| from pydantic import model_validator | ||
|
|
||
| from .constants import ( | ||
| GPU_BASE_IMAGE_PYTHON_VERSION, | ||
| get_image_name, | ||
| local_python_version, | ||
| ) | ||
| from .injection import build_injection_cmd | ||
| from .load_balancer_sls_resource import ( | ||
| CpuLoadBalancerSlsResource, | ||
| LoadBalancerSlsResource, | ||
| ) | ||
| from .serverless import ServerlessEndpoint | ||
| from .serverless_cpu import CpuServerlessEndpoint | ||
| from .template import PodTemplate | ||
|
|
||
|
|
||
| class LiveServerlessMixin: | ||
| """Common mixin for live serverless endpoints that locks the image.""" | ||
|
|
||
| _image_type: ClassVar[str] = ( | ||
| "" # Override in subclasses: 'gpu', 'cpu', 'lb', 'lb-cpu' | ||
| ) | ||
| _GPU_IMAGE_TYPES: ClassVar[frozenset[str]] = frozenset({"gpu", "lb"}) | ||
| """Configures process injection via dockerArgs for any base image. | ||
|
|
||
| @property | ||
| def _live_image(self) -> str: | ||
| python_version = getattr(self, "python_version", None) | ||
| if not python_version: | ||
| if self._image_type in self._GPU_IMAGE_TYPES: | ||
| python_version = GPU_BASE_IMAGE_PYTHON_VERSION | ||
| else: | ||
| python_version = local_python_version() | ||
| return get_image_name(self._image_type, python_version) | ||
| Sets a default base image (user can override via imageName) and generates | ||
| dockerArgs to download, extract, and run the flash-worker tarball at container | ||
| start time. QB vs LB mode is determined by FLASH_ENDPOINT_TYPE env var at | ||
| runtime, not by the Docker image. | ||
| """ | ||
|
|
||
| @property | ||
| def imageName(self): | ||
| return self._live_image | ||
| def _create_new_template(self) -> PodTemplate: | ||
| """Create template with dockerArgs for process injection.""" | ||
| template = super()._create_new_template() # type: ignore[misc] | ||
| template.dockerArgs = build_injection_cmd() | ||
| return template | ||
|
|
||
| @imageName.setter | ||
| def imageName(self, value): | ||
| pass | ||
| def _configure_existing_template(self) -> None: | ||
| """Configure existing template, adding dockerArgs for injection if not user-set.""" | ||
| super()._configure_existing_template() # type: ignore[misc] | ||
| if self.template is not None and not self.template.dockerArgs: # type: ignore[attr-defined] | ||
| self.template.dockerArgs = build_injection_cmd() # type: ignore[attr-defined] | ||
|
Comment on lines
+20
to
+38
|
||
|
|
||
|
|
||
| class LiveServerless(LiveServerlessMixin, ServerlessEndpoint): | ||
| """GPU-only live serverless endpoint.""" | ||
|
|
||
| _image_type: ClassVar[str] = "gpu" | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def set_live_serverless_template(cls, data: dict): | ||
| """Set default GPU image for Live Serverless.""" | ||
| python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION | ||
| data["imageName"] = get_image_name("gpu", python_version) | ||
| if "imageName" not in data: | ||
| python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION | ||
| data["imageName"] = get_image_name("gpu", python_version) | ||
| return data | ||
|
|
||
|
|
||
| class CpuLiveServerless(LiveServerlessMixin, CpuServerlessEndpoint): | ||
| """CPU-only live serverless endpoint with automatic disk sizing.""" | ||
|
|
||
| _image_type: ClassVar[str] = "cpu" | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def set_live_serverless_template(cls, data: dict): | ||
| """Set default CPU image for Live Serverless.""" | ||
| python_version = data.get("python_version") or local_python_version() | ||
| data["imageName"] = get_image_name("cpu", python_version) | ||
| if "imageName" not in data: | ||
| python_version = data.get("python_version") or local_python_version() | ||
| data["imageName"] = get_image_name("cpu", python_version) | ||
| return data | ||
|
|
||
|
|
||
| class LiveLoadBalancer(LiveServerlessMixin, LoadBalancerSlsResource): | ||
| """Live load-balanced endpoint.""" | ||
|
|
||
| _image_type: ClassVar[str] = "lb" | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def set_live_lb_template(cls, data: dict): | ||
| """Set default image for Live Load-Balanced endpoint.""" | ||
| python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION | ||
| data["imageName"] = get_image_name("lb", python_version) | ||
| if "imageName" not in data: | ||
| python_version = data.get("python_version") or GPU_BASE_IMAGE_PYTHON_VERSION | ||
| data["imageName"] = get_image_name("lb", python_version) | ||
| return data | ||
|
|
||
|
|
||
| class CpuLiveLoadBalancer(LiveServerlessMixin, CpuLoadBalancerSlsResource): | ||
| """CPU-only live load-balanced endpoint.""" | ||
|
|
||
| _image_type: ClassVar[str] = "lb-cpu" | ||
|
|
||
| @model_validator(mode="before") | ||
| @classmethod | ||
| def set_live_cpu_lb_template(cls, data: dict): | ||
| """Set default CPU image for Live Load-Balanced endpoint.""" | ||
| python_version = data.get("python_version") or local_python_version() | ||
| data["imageName"] = get_image_name("lb-cpu", python_version) | ||
| if "imageName" not in data: | ||
| python_version = data.get("python_version") or local_python_version() | ||
| data["imageName"] = get_image_name("lb-cpu", python_version) | ||
| return data | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| """Unit tests for process injection utilities.""" | ||
|
|
||
| from runpod_flash.core.resources.injection import build_injection_cmd | ||
|
|
||
|
|
||
| class TestBuildInjectionCmd: | ||
| """Test build_injection_cmd() output format.""" | ||
|
|
||
| def test_default_remote_url(self): | ||
| """Test default remote URL generation.""" | ||
| cmd = build_injection_cmd(worker_version="1.1.1") | ||
|
|
||
| assert cmd.startswith("bash -c '") | ||
| assert "FW_VER=1.1.1" in cmd | ||
| assert "runpod-workers/flash/releases/download/v1.1.1/" in cmd | ||
| assert "bootstrap.sh'" in cmd | ||
|
Comment on lines
+9
to
+16
|
||
|
|
||
| def test_custom_tarball_url(self): | ||
| """Test custom tarball URL.""" | ||
| url = "https://example.com/worker.tar.gz" | ||
| cmd = build_injection_cmd(worker_version="2.0.0", tarball_url=url) | ||
|
|
||
| assert "FW_VER=2.0.0" in cmd | ||
| assert url in cmd | ||
|
|
||
| def test_file_url_for_local_testing(self): | ||
| """Test file:// URL generates local extraction command.""" | ||
| cmd = build_injection_cmd( | ||
| worker_version="1.0.0", | ||
| tarball_url="file:///tmp/flash-worker.tar.gz", | ||
| ) | ||
|
|
||
| assert "tar xzf /tmp/flash-worker.tar.gz" in cmd | ||
| assert "curl" not in cmd | ||
| assert "wget" not in cmd | ||
| assert "bootstrap.sh'" in cmd | ||
|
|
||
| def test_version_caching_logic(self): | ||
| """Test that version-based cache check is included.""" | ||
| cmd = build_injection_cmd(worker_version="1.1.1") | ||
|
|
||
| # Should check .version file | ||
| assert ".version" in cmd | ||
| assert "FW_VER" in cmd | ||
|
|
||
| def test_network_volume_caching(self): | ||
| """Test network volume cache path is included.""" | ||
| cmd = build_injection_cmd(worker_version="1.1.1") | ||
|
|
||
| assert "/runpod-volume/.flash-worker/" in cmd | ||
| assert "NV_CACHE" in cmd | ||
|
|
||
| def test_curl_wget_python_fallback(self): | ||
| """Test curl/wget/python3 fallback chain.""" | ||
| cmd = build_injection_cmd(worker_version="1.0.0") | ||
|
|
||
| assert "curl -sSL" in cmd | ||
| assert "wget -qO-" in cmd | ||
| assert "urllib.request" in cmd | ||
|
|
||
| def test_default_uses_constants(self): | ||
| """Test that calling with no args uses module-level constants.""" | ||
| from runpod_flash.core.resources.constants import FLASH_WORKER_VERSION | ||
|
|
||
| cmd = build_injection_cmd() | ||
|
|
||
| assert f"FW_VER={FLASH_WORKER_VERSION}" in cmd | ||
| assert f"v{FLASH_WORKER_VERSION}" in cmd | ||
|
|
||
| def test_strip_components_in_remote_extraction(self): | ||
| """Test tar uses --strip-components=1 for remote downloads.""" | ||
| cmd = build_injection_cmd(worker_version="1.0.0") | ||
|
|
||
| assert "--strip-components=1" in cmd | ||
|
|
||
| def test_strip_components_in_local_extraction(self): | ||
| """Test tar uses --strip-components=1 for local file extraction.""" | ||
| cmd = build_injection_cmd( | ||
| worker_version="1.0.0", | ||
| tarball_url="file:///tmp/fw.tar.gz", | ||
| ) | ||
|
|
||
| assert "--strip-components=1" in cmd | ||
Uh oh!
There was an error while loading. Please reload this page.