Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions src/tetra_rp/core/resources/load_balancer_sls_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@

import asyncio
import logging
import os
from typing import List, Optional

import httpx
from pydantic import model_validator

from tetra_rp.core.utils.http import get_authenticated_httpx_client
from .cpu import CpuInstanceType
from .serverless import ServerlessResource, ServerlessType, ServerlessScalerType
from .serverless_cpu import CpuEndpointMixin
Expand Down Expand Up @@ -168,16 +167,10 @@ async def _check_ping_endpoint(self) -> bool:

ping_url = f"{self.endpoint_url}/ping"

# Add authentication header if API key is available
headers = {}
api_key = os.environ.get("RUNPOD_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

async with httpx.AsyncClient(
async with get_authenticated_httpx_client(
timeout=DEFAULT_PING_REQUEST_TIMEOUT
) as client:
response = await client.get(ping_url, headers=headers)
response = await client.get(ping_url)
return response.status_code in HEALTHY_STATUS_CODES
except Exception as e:
log.debug(f"Ping check failed for {self.name}: {e}")
Expand Down
18 changes: 11 additions & 7 deletions src/tetra_rp/core/resources/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
from typing import Dict, List, Optional, Any
from pydantic import BaseModel, model_validator
from tetra_rp.core.utils.http import get_authenticated_requests_session
from .base import BaseResource


Expand Down Expand Up @@ -38,7 +38,7 @@ def sync_input_fields(self):


def update_system_dependencies(
template_id, token, system_dependencies, base_entry_cmd=None
template_id, system_dependencies, base_entry_cmd=None, token=None
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature introduces a breaking change by reordering parameters. The token parameter was previously in position 2 but is now in position 4 (after base_entry_cmd). This will break any existing calls that use positional arguments like update_system_dependencies(template_id, token, deps). Consider deprecating this function and creating a new one, or keeping the original parameter order while adding a deprecation warning for the token parameter.

Copilot uses AI. Check for mistakes.
):
"""
Updates Runpod template with system dependencies installed via apt-get,
Expand Down Expand Up @@ -83,12 +83,16 @@ def update_system_dependencies(
"volumeMountPath": "/workspace",
}

headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}

url = f"https://rest.runpod.io/v1/templates/{template_id}/update"
response = requests.post(url, json=payload, headers=headers)

# Use centralized auth utility instead of manual header setup
# Note: token parameter is deprecated; uses RUNPOD_API_KEY environment variable
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment states that the token parameter is deprecated, but there is no deprecation warning raised when it's used. Consider adding a warnings.warn() call if the token parameter is provided to properly communicate the deprecation to users.

Copilot uses AI. Check for mistakes.
session = get_authenticated_requests_session()
try:
response = session.post(url, json=payload)
response.raise_for_status()
return response.json()
except Exception:
return {"error": "Invalid JSON response", "text": response.text}
except Exception as e:
return {"error": "Failed to update template", "details": str(e)}
finally:
session.close()
67 changes: 67 additions & 0 deletions src/tetra_rp/core/utils/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""HTTP utilities for RunPod API communication."""

import os
from typing import Optional

import httpx
import requests


def get_authenticated_httpx_client(
timeout: Optional[float] = None,
) -> httpx.AsyncClient:
"""Create httpx AsyncClient with RunPod authentication.

Automatically includes Authorization header if RUNPOD_API_KEY is set.
This provides a centralized place to manage authentication headers for
all RunPod HTTP requests, avoiding repetitive manual header addition.

Args:
timeout: Request timeout in seconds. Defaults to 30.0.

Returns:
Configured httpx.AsyncClient with Authorization header

Example:
async with get_authenticated_httpx_client() as client:
response = await client.post(url, json=data)

# With custom timeout
async with get_authenticated_httpx_client(timeout=60.0) as client:
response = await client.get(url)
"""
headers = {}
api_key = os.environ.get("RUNPOD_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

timeout_config = timeout if timeout is not None else 30.0
return httpx.AsyncClient(timeout=timeout_config, headers=headers)


def get_authenticated_requests_session() -> requests.Session:
"""Create requests Session with RunPod authentication.

Automatically includes Authorization header if RUNPOD_API_KEY is set.
Provides a centralized place to manage authentication headers for
synchronous RunPod HTTP requests.

Returns:
Configured requests.Session with Authorization header

Example:
session = get_authenticated_requests_session()
response = session.post(url, json=data, timeout=30.0)
# Remember to close: session.close()

# Or use as context manager
import contextlib
with contextlib.closing(get_authenticated_requests_session()) as session:
response = session.post(url, json=data)
"""
session = requests.Session()
api_key = os.environ.get("RUNPOD_API_KEY")
if api_key:
session.headers["Authorization"] = f"Bearer {api_key}"

return session
5 changes: 3 additions & 2 deletions src/tetra_rp/stubs/load_balancer_sls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import httpx
import cloudpickle

from tetra_rp.core.utils.http import get_authenticated_httpx_client
from .live_serverless import get_function_source

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -227,7 +228,7 @@ async def _execute_function(self, request: Dict[str, Any]) -> Dict[str, Any]:
execute_url = f"{self.server.endpoint_url}/execute"

try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with get_authenticated_httpx_client(timeout=self.timeout) as client:
response = await client.post(execute_url, json=request)
response.raise_for_status()
return response.json()
Expand Down Expand Up @@ -299,7 +300,7 @@ async def _execute_via_user_route(
log.debug(f"Executing via user route: {method} {url}")

try:
async with httpx.AsyncClient(timeout=self.timeout) as client:
async with get_authenticated_httpx_client(timeout=self.timeout) as client:
response = await client.request(method, url, json=body)
response.raise_for_status()
result = response.json()
Expand Down
125 changes: 125 additions & 0 deletions tests/unit/core/utils/test_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Tests for HTTP utilities for RunPod API communication."""

import requests
from tetra_rp.core.utils.http import (
get_authenticated_httpx_client,
get_authenticated_requests_session,
)


class TestGetAuthenticatedHttpxClient:
"""Test the get_authenticated_httpx_client utility function."""

def test_get_authenticated_httpx_client_with_api_key(self, monkeypatch):
"""Test client includes auth header when API key is set."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123")

client = get_authenticated_httpx_client()

assert client is not None
assert "Authorization" in client.headers
assert client.headers["Authorization"] == "Bearer test-api-key-123"

def test_get_authenticated_httpx_client_without_api_key(self, monkeypatch):
"""Test client works without API key (no auth header)."""
monkeypatch.delenv("RUNPOD_API_KEY", raising=False)

client = get_authenticated_httpx_client()

assert client is not None
assert "Authorization" not in client.headers

def test_get_authenticated_httpx_client_custom_timeout(self, monkeypatch):
"""Test client respects custom timeout."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

client = get_authenticated_httpx_client(timeout=60.0)

assert client is not None
assert client.timeout.read == 60.0

def test_get_authenticated_httpx_client_default_timeout(self, monkeypatch):
"""Test client uses default timeout when not specified."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

client = get_authenticated_httpx_client()

assert client is not None
assert client.timeout.read == 30.0

def test_get_authenticated_httpx_client_timeout_none_uses_default(
self, monkeypatch
):
"""Test client uses default timeout when explicitly passed None."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

client = get_authenticated_httpx_client(timeout=None)

assert client is not None
assert client.timeout.read == 30.0

def test_get_authenticated_httpx_client_empty_api_key_no_header(self, monkeypatch):
"""Test that empty API key doesn't add Authorization header."""
monkeypatch.setenv("RUNPOD_API_KEY", "")

client = get_authenticated_httpx_client()

assert client is not None
# Empty string is falsy, so no auth header should be added
assert "Authorization" not in client.headers

def test_get_authenticated_httpx_client_zero_timeout(self, monkeypatch):
"""Test client handles zero timeout correctly."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

client = get_authenticated_httpx_client(timeout=0.0)

assert client is not None
assert client.timeout.read == 0.0


class TestGetAuthenticatedRequestsSession:
"""Test the get_authenticated_requests_session utility function."""

def test_get_authenticated_requests_session_with_api_key(self, monkeypatch):
"""Test session includes auth header when API key is set."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123")

session = get_authenticated_requests_session()

assert session is not None
assert "Authorization" in session.headers
assert session.headers["Authorization"] == "Bearer test-api-key-123"
session.close()

def test_get_authenticated_requests_session_without_api_key(self, monkeypatch):
"""Test session works without API key (no auth header)."""
monkeypatch.delenv("RUNPOD_API_KEY", raising=False)

session = get_authenticated_requests_session()

assert session is not None
assert "Authorization" not in session.headers
session.close()

def test_get_authenticated_requests_session_empty_api_key_no_header(
self, monkeypatch
):
"""Test that empty API key doesn't add Authorization header."""
monkeypatch.setenv("RUNPOD_API_KEY", "")

session = get_authenticated_requests_session()

assert session is not None
# Empty string is falsy, so no auth header should be added
assert "Authorization" not in session.headers
session.close()

def test_get_authenticated_requests_session_is_valid_session(self, monkeypatch):
"""Test returned object is a valid requests.Session."""
monkeypatch.setenv("RUNPOD_API_KEY", "test-key")

session = get_authenticated_requests_session()

assert isinstance(session, requests.Session)
session.close()
Loading
Loading