diff --git a/src/tetra_rp/core/resources/load_balancer_sls_resource.py b/src/tetra_rp/core/resources/load_balancer_sls_resource.py index cf93d7e4..1ea5085b 100644 --- a/src/tetra_rp/core/resources/load_balancer_sls_resource.py +++ b/src/tetra_rp/core/resources/load_balancer_sls_resource.py @@ -15,12 +15,11 @@ import asyncio import logging -import os from typing import List, Optional -import httpx from pydantic import model_validator +from tetra_rp.core.utils.http import get_authenticated_httpx_client from .cpu import CpuInstanceType from .serverless import ServerlessResource, ServerlessType, ServerlessScalerType from .serverless_cpu import CpuEndpointMixin @@ -168,16 +167,10 @@ async def _check_ping_endpoint(self) -> bool: ping_url = f"{self.endpoint_url}/ping" - # Add authentication header if API key is available - headers = {} - api_key = os.environ.get("RUNPOD_API_KEY") - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - async with httpx.AsyncClient( + async with get_authenticated_httpx_client( timeout=DEFAULT_PING_REQUEST_TIMEOUT ) as client: - response = await client.get(ping_url, headers=headers) + response = await client.get(ping_url) return response.status_code in HEALTHY_STATUS_CODES except Exception as e: log.debug(f"Ping check failed for {self.name}: {e}") diff --git a/src/tetra_rp/core/resources/template.py b/src/tetra_rp/core/resources/template.py index a4c0a254..8b9e9de5 100644 --- a/src/tetra_rp/core/resources/template.py +++ b/src/tetra_rp/core/resources/template.py @@ -1,6 +1,6 @@ -import requests from typing import Dict, List, Optional, Any from pydantic import BaseModel, model_validator +from tetra_rp.core.utils.http import get_authenticated_requests_session from .base import BaseResource @@ -38,7 +38,7 @@ def sync_input_fields(self): def update_system_dependencies( - template_id, token, system_dependencies, base_entry_cmd=None + template_id, system_dependencies, base_entry_cmd=None, token=None ): """ Updates Runpod template with system dependencies installed via apt-get, @@ -83,12 +83,16 @@ def update_system_dependencies( "volumeMountPath": "/workspace", } - headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - url = f"https://rest.runpod.io/v1/templates/{template_id}/update" - response = requests.post(url, json=payload, headers=headers) + # Use centralized auth utility instead of manual header setup + # Note: token parameter is deprecated; uses RUNPOD_API_KEY environment variable + session = get_authenticated_requests_session() try: + response = session.post(url, json=payload) + response.raise_for_status() return response.json() - except Exception: - return {"error": "Invalid JSON response", "text": response.text} + except Exception as e: + return {"error": "Failed to update template", "details": str(e)} + finally: + session.close() diff --git a/src/tetra_rp/core/utils/http.py b/src/tetra_rp/core/utils/http.py new file mode 100644 index 00000000..ac6ac01e --- /dev/null +++ b/src/tetra_rp/core/utils/http.py @@ -0,0 +1,67 @@ +"""HTTP utilities for RunPod API communication.""" + +import os +from typing import Optional + +import httpx +import requests + + +def get_authenticated_httpx_client( + timeout: Optional[float] = None, +) -> httpx.AsyncClient: + """Create httpx AsyncClient with RunPod authentication. + + Automatically includes Authorization header if RUNPOD_API_KEY is set. + This provides a centralized place to manage authentication headers for + all RunPod HTTP requests, avoiding repetitive manual header addition. + + Args: + timeout: Request timeout in seconds. Defaults to 30.0. + + Returns: + Configured httpx.AsyncClient with Authorization header + + Example: + async with get_authenticated_httpx_client() as client: + response = await client.post(url, json=data) + + # With custom timeout + async with get_authenticated_httpx_client(timeout=60.0) as client: + response = await client.get(url) + """ + headers = {} + api_key = os.environ.get("RUNPOD_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + timeout_config = timeout if timeout is not None else 30.0 + return httpx.AsyncClient(timeout=timeout_config, headers=headers) + + +def get_authenticated_requests_session() -> requests.Session: + """Create requests Session with RunPod authentication. + + Automatically includes Authorization header if RUNPOD_API_KEY is set. + Provides a centralized place to manage authentication headers for + synchronous RunPod HTTP requests. + + Returns: + Configured requests.Session with Authorization header + + Example: + session = get_authenticated_requests_session() + response = session.post(url, json=data, timeout=30.0) + # Remember to close: session.close() + + # Or use as context manager + import contextlib + with contextlib.closing(get_authenticated_requests_session()) as session: + response = session.post(url, json=data) + """ + session = requests.Session() + api_key = os.environ.get("RUNPOD_API_KEY") + if api_key: + session.headers["Authorization"] = f"Bearer {api_key}" + + return session diff --git a/src/tetra_rp/stubs/load_balancer_sls.py b/src/tetra_rp/stubs/load_balancer_sls.py index ee08e542..b9090e6c 100644 --- a/src/tetra_rp/stubs/load_balancer_sls.py +++ b/src/tetra_rp/stubs/load_balancer_sls.py @@ -12,6 +12,7 @@ import httpx import cloudpickle +from tetra_rp.core.utils.http import get_authenticated_httpx_client from .live_serverless import get_function_source log = logging.getLogger(__name__) @@ -227,7 +228,7 @@ async def _execute_function(self, request: Dict[str, Any]) -> Dict[str, Any]: execute_url = f"{self.server.endpoint_url}/execute" try: - async with httpx.AsyncClient(timeout=self.timeout) as client: + async with get_authenticated_httpx_client(timeout=self.timeout) as client: response = await client.post(execute_url, json=request) response.raise_for_status() return response.json() @@ -299,7 +300,7 @@ async def _execute_via_user_route( log.debug(f"Executing via user route: {method} {url}") try: - async with httpx.AsyncClient(timeout=self.timeout) as client: + async with get_authenticated_httpx_client(timeout=self.timeout) as client: response = await client.request(method, url, json=body) response.raise_for_status() result = response.json() diff --git a/tests/unit/core/utils/test_http.py b/tests/unit/core/utils/test_http.py new file mode 100644 index 00000000..d26c0954 --- /dev/null +++ b/tests/unit/core/utils/test_http.py @@ -0,0 +1,125 @@ +"""Tests for HTTP utilities for RunPod API communication.""" + +import requests +from tetra_rp.core.utils.http import ( + get_authenticated_httpx_client, + get_authenticated_requests_session, +) + + +class TestGetAuthenticatedHttpxClient: + """Test the get_authenticated_httpx_client utility function.""" + + def test_get_authenticated_httpx_client_with_api_key(self, monkeypatch): + """Test client includes auth header when API key is set.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123") + + client = get_authenticated_httpx_client() + + assert client is not None + assert "Authorization" in client.headers + assert client.headers["Authorization"] == "Bearer test-api-key-123" + + def test_get_authenticated_httpx_client_without_api_key(self, monkeypatch): + """Test client works without API key (no auth header).""" + monkeypatch.delenv("RUNPOD_API_KEY", raising=False) + + client = get_authenticated_httpx_client() + + assert client is not None + assert "Authorization" not in client.headers + + def test_get_authenticated_httpx_client_custom_timeout(self, monkeypatch): + """Test client respects custom timeout.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=60.0) + + assert client is not None + assert client.timeout.read == 60.0 + + def test_get_authenticated_httpx_client_default_timeout(self, monkeypatch): + """Test client uses default timeout when not specified.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client() + + assert client is not None + assert client.timeout.read == 30.0 + + def test_get_authenticated_httpx_client_timeout_none_uses_default( + self, monkeypatch + ): + """Test client uses default timeout when explicitly passed None.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=None) + + assert client is not None + assert client.timeout.read == 30.0 + + def test_get_authenticated_httpx_client_empty_api_key_no_header(self, monkeypatch): + """Test that empty API key doesn't add Authorization header.""" + monkeypatch.setenv("RUNPOD_API_KEY", "") + + client = get_authenticated_httpx_client() + + assert client is not None + # Empty string is falsy, so no auth header should be added + assert "Authorization" not in client.headers + + def test_get_authenticated_httpx_client_zero_timeout(self, monkeypatch): + """Test client handles zero timeout correctly.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + client = get_authenticated_httpx_client(timeout=0.0) + + assert client is not None + assert client.timeout.read == 0.0 + + +class TestGetAuthenticatedRequestsSession: + """Test the get_authenticated_requests_session utility function.""" + + def test_get_authenticated_requests_session_with_api_key(self, monkeypatch): + """Test session includes auth header when API key is set.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-api-key-123") + + session = get_authenticated_requests_session() + + assert session is not None + assert "Authorization" in session.headers + assert session.headers["Authorization"] == "Bearer test-api-key-123" + session.close() + + def test_get_authenticated_requests_session_without_api_key(self, monkeypatch): + """Test session works without API key (no auth header).""" + monkeypatch.delenv("RUNPOD_API_KEY", raising=False) + + session = get_authenticated_requests_session() + + assert session is not None + assert "Authorization" not in session.headers + session.close() + + def test_get_authenticated_requests_session_empty_api_key_no_header( + self, monkeypatch + ): + """Test that empty API key doesn't add Authorization header.""" + monkeypatch.setenv("RUNPOD_API_KEY", "") + + session = get_authenticated_requests_session() + + assert session is not None + # Empty string is falsy, so no auth header should be added + assert "Authorization" not in session.headers + session.close() + + def test_get_authenticated_requests_session_is_valid_session(self, monkeypatch): + """Test returned object is a valid requests.Session.""" + monkeypatch.setenv("RUNPOD_API_KEY", "test-key") + + session = get_authenticated_requests_session() + + assert isinstance(session, requests.Session) + session.close() diff --git a/tests/unit/test_load_balancer_sls_resource.py b/tests/unit/test_load_balancer_sls_resource.py index 709c2ed7..ab2fbacb 100644 --- a/tests/unit/test_load_balancer_sls_resource.py +++ b/tests/unit/test_load_balancer_sls_resource.py @@ -120,6 +120,22 @@ def test_endpoint_url_raises_without_id(self): class TestLoadBalancerSlsResourceHealthCheck: """Test health check functionality.""" + @staticmethod + def _create_mock_client( + status_code: int = 200, error: Exception = None + ) -> MagicMock: + """Create properly configured async context manager mock client.""" + mock_response = AsyncMock() + mock_response.status_code = status_code + mock_client = MagicMock() + if error: + mock_client.get = AsyncMock(side_effect=error) + else: + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + return mock_client + @pytest.mark.asyncio async def test_check_ping_endpoint_success(self): """Test successful ping endpoint check with ID set.""" @@ -129,6 +145,7 @@ async def test_check_ping_endpoint_success(self): id="test-endpoint-id", ) + mock_client = self._create_mock_client(200) with ( patch.object( LoadBalancerSlsResource, @@ -136,15 +153,10 @@ async def test_check_ping_endpoint_success(self): new_callable=lambda: property(lambda self: "https://test-endpoint.com"), ), patch( - "tetra_rp.core.resources.load_balancer_sls_resource.httpx.AsyncClient" - ) as mock_client, + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), ): - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_client.return_value.__aenter__.return_value.get = AsyncMock( - return_value=mock_response - ) - result = await resource._check_ping_endpoint() assert result is True @@ -158,6 +170,7 @@ async def test_check_ping_endpoint_initializing(self): id="test-endpoint-id", ) + mock_client = self._create_mock_client(204) with ( patch.object( LoadBalancerSlsResource, @@ -165,15 +178,10 @@ async def test_check_ping_endpoint_initializing(self): new_callable=lambda: property(lambda self: "https://test-endpoint.com"), ), patch( - "tetra_rp.core.resources.load_balancer_sls_resource.httpx.AsyncClient" - ) as mock_client, + "tetra_rp.core.utils.http.httpx.AsyncClient", + return_value=mock_client, + ), ): - mock_response = AsyncMock() - mock_response.status_code = 204 - mock_client.return_value.__aenter__.return_value.get = AsyncMock( - return_value=mock_response - ) - result = await resource._check_ping_endpoint() assert result is True @@ -194,15 +202,10 @@ async def test_check_ping_endpoint_failure(self): new_callable=lambda: property(lambda self: "https://test-endpoint.com"), ), patch( - "tetra_rp.core.resources.load_balancer_sls_resource.httpx.AsyncClient" - ) as mock_client, + "tetra_rp.core.resources.load_balancer_sls_resource.get_authenticated_httpx_client", + side_effect=lambda **kwargs: self._create_mock_client(503), + ), ): - mock_response = AsyncMock() - mock_response.status_code = 503 # Service unavailable - mock_client.return_value.__aenter__.return_value.get = AsyncMock( - return_value=mock_response - ) - result = await resource._check_ping_endpoint() assert result is False @@ -223,13 +226,12 @@ async def test_check_ping_endpoint_connection_error(self): new_callable=lambda: property(lambda self: "https://test-endpoint.com"), ), patch( - "tetra_rp.core.resources.load_balancer_sls_resource.httpx.AsyncClient" - ) as mock_client, + "tetra_rp.core.resources.load_balancer_sls_resource.get_authenticated_httpx_client", + side_effect=lambda **kwargs: self._create_mock_client( + error=ConnectionError("Connection refused") + ), + ), ): - mock_client.return_value.__aenter__.return_value.get = AsyncMock( - side_effect=ConnectionError("Connection refused") - ) - result = await resource._check_ping_endpoint() assert result is False