|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import contextvars |
3 | 4 | import os |
4 | 5 |
|
5 | 6 | import pytest |
6 | 7 |
|
7 | 8 | import fal |
8 | 9 | from fal import App, endpoint |
| 10 | +from fal.app import merge_contextvars, clear_contextvars, LOG_CONTEXT_PREFIX |
9 | 11 | from fal.container import ContainerImage |
10 | 12 |
|
11 | 13 |
|
@@ -114,6 +116,39 @@ class LeakCheckApp(App): |
114 | 116 | assert "app_auth" not in hk |
115 | 117 |
|
116 | 118 |
|
| 119 | +def test_merge_context_vars(): |
| 120 | + labels = {"fal_request_id": "123", "fal_endpoint": "/"} |
| 121 | + request_id_var = f"{LOG_CONTEXT_PREFIX}fal_request_id" |
| 122 | + endpoint_var = f"{LOG_CONTEXT_PREFIX}fal_endpoint" |
| 123 | + unrelated_var = "unrelated_key" |
| 124 | + contextvars.ContextVar(unrelated_var).set("value") |
| 125 | + |
| 126 | + # We have to convert to dict and lookup by name because each ContextVar |
| 127 | + # is a different object. Since merge_contextvars creates new ContextVars, |
| 128 | + # we can't just do direct lookups. |
| 129 | + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) |
| 130 | + |
| 131 | + assert vars.get(unrelated_var) == "value" |
| 132 | + |
| 133 | + assert vars.get(request_id_var) is None |
| 134 | + assert vars.get(endpoint_var) is None |
| 135 | + |
| 136 | + merge_contextvars(labels) |
| 137 | + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) |
| 138 | + |
| 139 | + assert vars.get(request_id_var) == "123" |
| 140 | + assert vars.get(endpoint_var) == "/" |
| 141 | + |
| 142 | + clear_contextvars() |
| 143 | + vars = dict((k.name, v) for k, v in contextvars.copy_context().items()) |
| 144 | + |
| 145 | + # Cleared contextvars are set to Ellipsis |
| 146 | + assert vars.get(request_id_var) is Ellipsis |
| 147 | + assert vars.get(endpoint_var) is Ellipsis |
| 148 | + # Does not clear unrelated contextvars |
| 149 | + assert vars.get("unrelated_key") == "value" |
| 150 | + |
| 151 | + |
117 | 152 | @pytest.mark.asyncio |
118 | 153 | async def test_runner_state_lifecycle_complete(): |
119 | 154 | """Test that FAL_RUNNER_STATE transitions through all phases correctly""" |
|
0 commit comments