Skip to content

Commit 71794de

Browse files
committed
Fix merged context vars and add test
1 parent 94915ed commit 71794de

File tree

2 files changed

+39
-6
lines changed

2 files changed

+39
-6
lines changed

projects/fal/src/fal/app.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,10 @@ async def open_isolate_channel(address: str) -> async_grpc.Channel | None:
8383
return channel
8484

8585

86-
def merge_contextvars(logger_labels: dict[str, str]) -> dict[str, str]:
87-
ctx = contextvars.copy_context()
88-
for k in ctx:
89-
if k.name.startswith(LOG_CONTEXT_PREFIX) and ctx[k] is not Ellipsis:
90-
logger_labels.setdefault(k.name[len(LOG_CONTEXT_PREFIX) :], ctx[k])
91-
return logger_labels
86+
def merge_contextvars(logger_labels: dict[str, str]) -> None:
87+
for k, v in logger_labels.items():
88+
contextvar = contextvars.ContextVar(f"{LOG_CONTEXT_PREFIX}{k}")
89+
contextvar.set(v)
9290

9391

9492
def clear_contextvars() -> None:

projects/fal/tests/unit/test_app.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import contextvars
34
import os
45

56
import pytest
67

78
import fal
89
from fal import App, endpoint
10+
from fal.app import merge_contextvars, clear_contextvars, LOG_CONTEXT_PREFIX
911
from fal.container import ContainerImage
1012

1113

@@ -114,6 +116,39 @@ class LeakCheckApp(App):
114116
assert "app_auth" not in hk
115117

116118

119+
def test_merge_context_vars():
120+
labels = {"fal_request_id": "123", "fal_endpoint": "/"}
121+
request_id_var = f"{LOG_CONTEXT_PREFIX}fal_request_id"
122+
endpoint_var = f"{LOG_CONTEXT_PREFIX}fal_endpoint"
123+
unrelated_var = "unrelated_key"
124+
contextvars.ContextVar(unrelated_var).set("value")
125+
126+
# We have to convert to dict and lookup by name because each ContextVar
127+
# is a different object. Since merge_contextvars creates new ContextVars,
128+
# we can't just do direct lookups.
129+
vars = dict((k.name, v) for k, v in contextvars.copy_context().items())
130+
131+
assert vars.get(unrelated_var) == "value"
132+
133+
assert vars.get(request_id_var) is None
134+
assert vars.get(endpoint_var) is None
135+
136+
merge_contextvars(labels)
137+
vars = dict((k.name, v) for k, v in contextvars.copy_context().items())
138+
139+
assert vars.get(request_id_var) == "123"
140+
assert vars.get(endpoint_var) == "/"
141+
142+
clear_contextvars()
143+
vars = dict((k.name, v) for k, v in contextvars.copy_context().items())
144+
145+
# Cleared contextvars are set to Ellipsis
146+
assert vars.get(request_id_var) is Ellipsis
147+
assert vars.get(endpoint_var) is Ellipsis
148+
# Does not clear unrelated contextvars
149+
assert vars.get("unrelated_key") == "value"
150+
151+
117152
@pytest.mark.asyncio
118153
async def test_runner_state_lifecycle_complete():
119154
"""Test that FAL_RUNNER_STATE transitions through all phases correctly"""

0 commit comments

Comments
 (0)