diff --git a/src/coding/proxy/cli/__init__.py b/src/coding/proxy/cli/__init__.py index 21cbe91..3b479fb 100644 --- a/src/coding/proxy/cli/__init__.py +++ b/src/coding/proxy/cli/__init__.py @@ -30,6 +30,10 @@ # 注册 Auth 子应用 app.add_typer(auth_app, name="auth") +# 注册 Session 子应用 +session_app = typer.Typer(name="session", help="管理 Session-Vendor 运行时绑定") +app.add_typer(session_app, name="session") + def _build_token_store(cfg_path: Path | None = None): """按配置解析 Token Store 路径并完成加载.""" @@ -264,6 +268,97 @@ def reset( console.print("[red]代理服务未运行[/red]") +# ── Session 子命令 ─────────────────────────────────────────────── + + +@session_app.command("bind") +def session_bind( + key: str = typer.Option(..., "--key", "-k", help="Session key"), + vendor: str = typer.Option( + ..., "--vendor", "-v", help="绑定 vendor(逗号分隔多个)" + ), + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """为指定 Session 绑定 vendor 优先级.""" + import httpx as _httpx + + vendors = [v.strip() for v in vendor.split(",") if v.strip()] + try: + resp = _httpx.put( + f"http://127.0.0.1:{port}/api/session-vendor", + json={"session_key": key, "vendors": vendors}, + timeout=5, + ) + if resp.status_code == 200: + data = resp.json() + console.print( + f"[green]绑定成功:[/] session [cyan]{key[:16]}…[/cyan] → " + + " → ".join(data.get("vendors", vendors)) + ) + else: + try: + err = resp.json() + msg = err.get("error", {}).get("message", resp.text) + except Exception: + msg = resp.text + console.print(f"[red]绑定失败: {msg}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + +@session_app.command("unbind") +def session_unbind( + key: str = typer.Option(..., "--key", "-k", help="Session key"), + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """解除指定 Session 的 vendor 绑定.""" + from urllib.parse import quote + + import httpx as _httpx + + try: + resp = _httpx.delete( + f"http://127.0.0.1:{port}/api/session-vendor/{quote(key, safe='')}", + timeout=5, + ) + if resp.status_code == 200: + console.print(f"[green]已解除绑定:[/] session [cyan]{key[:16]}…[/cyan]") + elif resp.status_code == 404: + console.print(f"[yellow]未找到绑定:[/] session [cyan]{key[:16]}…[/cyan]") + else: + console.print(f"[red]解除失败: {resp.status_code} {resp.text}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + +@session_app.command("list") +def session_list( + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """列出所有运行时 Session-Vendor 绑定.""" + import httpx as _httpx + + try: + resp = _httpx.get( + f"http://127.0.0.1:{port}/api/session-vendor", + timeout=5, + ) + if resp.status_code == 200: + data = resp.json() + bindings = data.get("bindings", []) + if not bindings: + console.print("[dim]当前无运行时绑定[/dim]") + return + for b in bindings: + key = b.get("session_key", "?") + vendors = b.get("vendors", []) + console.print(f" [cyan]{key[:24]}…[/cyan] → " + " → ".join(vendors)) + else: + console.print(f"[red]查询失败: {resp.status_code} {resp.text}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + def _resolve_config_path(config: str | Path | None = None) -> Path | None: """标准化配置路径输入.""" if config is None: diff --git a/src/coding/proxy/routing/session_policy.py b/src/coding/proxy/routing/session_policy.py index 1ffcdaf..9102e41 100644 --- a/src/coding/proxy/routing/session_policy.py +++ b/src/coding/proxy/routing/session_policy.py @@ -3,8 +3,9 @@ from __future__ import annotations import logging +import threading -from ..config.session_policy import SessionPolicy +from ..config.session_policy import SessionPolicy, SessionPolicyMatch logger = logging.getLogger(__name__) @@ -16,12 +17,15 @@ class SessionPolicyResolver: - 启动时构建索引,运行时 O(1) 查找 - 精确匹配优先:session_key > client_category > 无策略 - 无侵入性:不匹配时返回 None,路由行为与现有一致 + - 运行时可变:支持 API 动态 upsert/remove session → vendor 绑定 """ def __init__(self, policies: list[SessionPolicy] | None = None) -> None: self._policies = policies or [] self._key_index: dict[str, SessionPolicy] = {} self._category_index: dict[str, SessionPolicy] = {} + self._config_key_backup: dict[str, SessionPolicy] = {} + self._lock = threading.Lock() self._build_index() def _build_index(self) -> None: @@ -49,8 +53,64 @@ def _build_index(self) -> None: def resolve( self, session_key: str, client_category: str = "cc" ) -> SessionPolicy | None: - """返回匹配的策略,优先精确 session_key 匹配,其次 category 匹配.""" - policy = self._key_index.get(session_key) + """返回匹配的策略,优先精确 session_key 匹配,其次 category 匹配. + + 返回的 SessionPolicy 对象应为不可变引用;调用方不应修改其内部属性, + 否则在并发 upsert/remove 场景下可能产生竞态。 + """ + with self._lock: + policy = self._key_index.get(session_key) if policy: return policy return self._category_index.get(client_category) + + # ── 运行时 session → vendor 绑定 ────────────────────────────── + + def upsert(self, session_key: str, tier_names: list[str]) -> SessionPolicy: + """为指定 session key 创建或替换运行时 vendor 绑定. + + 运行时策略使用 ``runtime:`` 名称前缀,与配置文件驱动的策略区分。 + """ + policy = SessionPolicy( + name=f"runtime:{session_key}", + match=SessionPolicyMatch(session_keys=[session_key]), + tiers=tier_names, + ) + with self._lock: + existing = self._key_index.get(session_key) + if existing and not existing.name.startswith("runtime:"): + self._config_key_backup[session_key] = existing + self._key_index[session_key] = policy + logger.info( + "Session vendor binding upserted: session_key=%s → %s", + session_key, + tier_names, + ) + return policy + + def remove(self, session_key: str) -> bool: + """删除指定 session key 的运行时 vendor 绑定. + + Returns: + True 如果找到并删除了绑定,False 如果不存在。 + """ + with self._lock: + policy = self._key_index.get(session_key) + if policy is None or not policy.name.startswith("runtime:"): + return False + del self._key_index[session_key] + # 恢复被运行时绑定覆盖的配置策略 + backup = self._config_key_backup.pop(session_key, None) + if backup is not None: + self._key_index[session_key] = backup + logger.info("Session vendor binding removed: session_key=%s", session_key) + return True + + def list_runtime_bindings(self) -> list[dict[str, str | list[str]]]: + """返回所有运行时注入的绑定快照(仅 API 创建的,不含配置文件驱动的).""" + with self._lock: + return [ + {"session_key": key, "vendors": policy.tiers} + for key, policy in self._key_index.items() + if policy.name.startswith("runtime:") + ] diff --git a/src/coding/proxy/server/dashboard.py b/src/coding/proxy/server/dashboard.py index 7581986..707494e 100644 --- a/src/coding/proxy/server/dashboard.py +++ b/src/coding/proxy/server/dashboard.py @@ -417,6 +417,19 @@ def _build_favicon() -> bytes: } .success-bar { width: 56px; height: 4px; border-radius: 2px; background: rgba(255,255,255,.06); display: inline-block; vertical-align: middle; margin-left: 6px; } .success-bar-fill { height: 100%; border-radius: 2px; } + /* ── Vendor Bind 选择器 ── */ + .bind-select { + padding: 3px 6px; border-radius: 6px; + background: rgba(48,54,61,.6); border: 1px solid rgba(255,255,255,.1); + color: var(--text-secondary); font-size: 12px; + font-family: 'JetBrains Mono', monospace; + cursor: pointer; outline: none; + transition: all .2s ease; + max-width: 120px; + } + .bind-select:hover { border-color: rgba(88,166,255,.4); color: var(--text-primary); } + .bind-select:focus { border-color: rgba(88,166,255,.6); box-shadow: 0 0 0 2px rgba(88,166,255,.1); } + .bind-select option { background: var(--bg-card); color: var(--text-primary); } /* ── 加载态 ── */ .loading { opacity: .4; pointer-events: none; } /* ── 图表标签截断 ── */ @@ -620,11 +633,12 @@ def _build_favicon() -> bytes: Vendors Avg Latency Success + Vendor Bind Client - Loading... + Loading... @@ -1403,16 +1417,31 @@ def _build_favicon() -> bytes: } async function updateSessions() { try { - var data = await fetchJSON('/api/dashboard/sessions?hours=24&limit=20'); + var results = await Promise.allSettled([ + fetchJSON('/api/dashboard/sessions?hours=24&limit=20'), + fetchJSON('/api/session-vendor'), + fetchJSON('/api/status'), + ]); + if (results[0].status === 'rejected') throw results[0].reason; + var data = results[0].value; + var bindData = results[1].status === 'fulfilled' ? results[1].value : {bindings: []}; + var statusData = results[2].status === 'fulfilled' ? results[2].value : {tiers: []}; var sessions = data.sessions || []; + var bindings = bindData.bindings || []; + var availableVendors = (statusData.tiers || []).map(function(t) { return t.name; }); var tbody = document.getElementById('sessions-tbody'); var subtitle = document.getElementById('sessions-subtitle'); if (subtitle) subtitle.textContent = 'Last ' + data.hours + 'h'; if (!sessions.length) { - tbody.innerHTML = '
📭
No session data'; + tbody.innerHTML = '
📭
No session data'; return; } + // Build binding lookup: session_key → vendors list + var bindMap = {}; + bindings.forEach(function(b) { bindMap[b.session_key] = b.vendors; }); tbody.innerHTML = sessions.map(function(s) { + var boundVendors = bindMap[s.session_key]; + var selectHtml = buildBindSelect(s.session_key, boundVendors, availableVendors); return '' + '' + truncateKey(s.session_key, 22) + '' + '' + relativeTime(s.last_active_ts) + '' + @@ -1422,6 +1451,7 @@ def _build_favicon() -> bytes: '' + formatVendorTags(s.vendors) + '' + '' + (s.avg_duration_ms ? Math.round(s.avg_duration_ms) + 'ms' : '–') + '' + '' + successBarHtml(s.success_rate) + '' + + '' + selectHtml + '' + '' + formatCategories(s.client_categories) + '' + ''; }).join(''); @@ -1430,6 +1460,57 @@ def _build_favicon() -> bytes: } } +function buildBindSelect(sessionKey, boundVendors, availableVendors) { + var isBound = boundVendors && boundVendors.length > 0; + var multiBound = isBound && boundVendors.length > 1; + var selected = isBound ? boundVendors[0] : ''; + var html = ''; + return html; +} + +async function handleBindChange(sel) { + var sessionKey = sel.getAttribute('data-session-key'); + var vendor = sel.value; + var previousValue = sel.getAttribute('data-previous') || ''; + try { + var resp; + if (vendor) { + resp = await fetch('/api/session-vendor', { + method: 'PUT', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({session_key: sessionKey, vendors: [vendor]}), + }); + } else { + resp = await fetch('/api/session-vendor/' + encodeURIComponent(sessionKey), {method: 'DELETE'}); + } + if (!resp.ok) { + sel.value = previousValue; + console.error('Bind change rejected:', resp.status, await resp.text()); + } + } catch (e) { + sel.value = previousValue; + console.error('Bind change failed:', e); + } +} + +var sessionsTbody = document.getElementById('sessions-tbody'); +sessionsTbody.addEventListener('focus', function(e) { + if (e.target.classList.contains('bind-select')) { + e.target.setAttribute('data-previous', e.target.value); + } +}, true); +sessionsTbody.addEventListener('change', function(e) { + if (e.target.classList.contains('bind-select')) { + handleBindChange(e.target); + } +}); + // ── 主刷新逻辑(按 Tab 分发) ────────────────────────────── let refreshing = false; let currentTab = 'overview'; diff --git a/src/coding/proxy/server/routes.py b/src/coding/proxy/server/routes.py index 3ee7588..7f157f0 100644 --- a/src/coding/proxy/server/routes.py +++ b/src/coding/proxy/server/routes.py @@ -327,6 +327,96 @@ async def reset_circuit(request: Request) -> Response: ) +def register_session_vendor_routes(app: Any, router: Any) -> None: + """注册 Session-Vendor 运行时绑定路由.""" + + @app.put("/api/session-vendor") + async def bind_session_vendor(request: Request) -> Response: + """为指定 session key 绑定 vendor 优先级列表.""" + try: + body = await request.json() + except Exception: + return json_error_response( + 400, error_type="invalid_request_error", message="body must be JSON" + ) + if not isinstance(body, dict): + return json_error_response( + 400, + error_type="invalid_request_error", + message="body must be a JSON object", + ) + session_key = body.get("session_key", "").strip() + vendors = body.get("vendors", []) + if not session_key: + return json_error_response( + 400, + error_type="invalid_request_error", + message="session_key is required", + ) + if not isinstance(vendors, list) or not vendors: + return json_error_response( + 400, + error_type="invalid_request_error", + message="vendors must be a non-empty list", + ) + vendors = [str(v).strip() for v in vendors] + available = set(router.get_vendor_names()) + unknown = [v for v in vendors if v not in available] + if unknown: + return json_error_response( + 400, + error_type="invalid_request_error", + message=( + f"unknown vendor(s): {', '.join(unknown)}; " + f"available: {', '.join(sorted(available))}" + ), + ) + resolver = router._executor._policy_resolver # noqa: SLF001 + resolver.upsert(session_key, vendors) + return Response( + content=json.dumps( + {"status": "ok", "session_key": session_key, "vendors": vendors}, + ensure_ascii=False, + ).encode(), + status_code=200, + media_type="application/json", + ) + + @app.delete("/api/session-vendor/{session_key}") + async def unbind_session_vendor(session_key: str) -> Response: + """解除指定 session key 的 vendor 绑定.""" + resolver = router._executor._policy_resolver # noqa: SLF001 + removed = resolver.remove(session_key) + if not removed: + return Response( + content=json.dumps( + {"status": "not_found", "session_key": session_key}, + ensure_ascii=False, + ).encode(), + status_code=404, + media_type="application/json", + ) + return Response( + content=json.dumps( + {"status": "removed", "session_key": session_key}, + ensure_ascii=False, + ).encode(), + status_code=200, + media_type="application/json", + ) + + @app.get("/api/session-vendor") + async def list_session_vendors() -> Response: + """列出所有运行时 session-vendor 绑定.""" + resolver = router._executor._policy_resolver # noqa: SLF001 + bindings = resolver.list_runtime_bindings() + return Response( + content=json.dumps({"bindings": bindings}, ensure_ascii=False).encode(), + status_code=200, + media_type="application/json", + ) + + def register_reauth_routes(app: Any, reauth_coordinator: Any) -> None: """注册重认证路由.""" @@ -368,6 +458,7 @@ def register_all_routes( register_status_route(app, router) register_copilot_routes(app, router) register_admin_routes(app, router) + register_session_vendor_routes(app, router) if reauth_coordinator: register_reauth_routes(app, reauth_coordinator) diff --git a/tests/test_session_aware.py b/tests/test_session_aware.py index bdd69e8..0c08449 100644 --- a/tests/test_session_aware.py +++ b/tests/test_session_aware.py @@ -449,3 +449,144 @@ def test_config_parse(): assert config.policies[0].name == "vip" assert config.policies[0].match.session_keys == ["key-1"] assert config.policies[0].tiers == ["anthropic", "copilot"] + + +# ── 9. SessionPolicyResolver 运行时可变性 ──────────────────────── + + +def test_runtime_upsert_and_resolve(): + resolver = SessionPolicyResolver() + assert resolver.resolve("my-session") is None + + resolver.upsert("my-session", ["anthropic", "copilot"]) + policy = resolver.resolve("my-session") + assert policy is not None + assert policy.tiers == ["anthropic", "copilot"] + assert policy.name.startswith("runtime:") + + +def test_runtime_upsert_overwrites(): + resolver = SessionPolicyResolver() + resolver.upsert("my-session", ["anthropic"]) + resolver.upsert("my-session", ["copilot", "zhipu"]) + policy = resolver.resolve("my-session") + assert policy.tiers == ["copilot", "zhipu"] + + +def test_runtime_remove(): + resolver = SessionPolicyResolver() + resolver.upsert("my-session", ["anthropic"]) + assert resolver.remove("my-session") is True + assert resolver.resolve("my-session") is None + assert resolver.remove("my-session") is False + + +def test_runtime_remove_does_not_affect_config_policy(): + p = _make_policy("config-policy", keys=["config-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + # Cannot remove config-driven policy via runtime API + assert resolver.remove("config-key") is False + assert resolver.resolve("config-key") is p + + +def test_runtime_upsert_overrides_config_policy(): + p = _make_policy("config-policy", keys=["shared-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("shared-key", ["copilot"]) + # Runtime binding takes precedence (replaces in key_index) + policy = resolver.resolve("shared-key") + assert policy.tiers == ["copilot"] + assert policy.name.startswith("runtime:") + + +def test_runtime_remove_restores_config_policy(): + p = _make_policy("config-policy", keys=["shared-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("shared-key", ["copilot"]) + assert resolver.resolve("shared-key").tiers == ["copilot"] + # 移除运行时绑定后应恢复原配置策略 + assert resolver.remove("shared-key") is True + restored = resolver.resolve("shared-key") + assert restored is p + assert restored.tiers == ["anthropic"] + + +def test_list_runtime_bindings(): + resolver = SessionPolicyResolver() + p = _make_policy("config-policy", keys=["config-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("runtime-1", ["copilot"]) + resolver.upsert("runtime-2", ["zhipu", "anthropic"]) + + bindings = resolver.list_runtime_bindings() + assert len(bindings) == 2 + keys = {b["session_key"] for b in bindings} + assert keys == {"runtime-1", "runtime-2"} + # Config-driven policy should not appear + assert "config-key" not in keys + + +def test_runtime_upsert_integrates_with_executor(): + from coding.proxy.routing.executor import _RouteExecutor + from coding.proxy.routing.tier import VendorTier + from coding.proxy.vendors.base import BaseVendor + + class FakeVendor(BaseVendor): + def __init__(self, name): + self._name = name + + def get_name(self): + return self._name + + async def _prepare_request(self, body, headers): + return body, headers + + async def send_message_stream(self, body, headers): + yield b"", "" + + async def send_message(self, body, headers): + return None + + def supports_request(self, caps): + return True, [] + + def map_model(self, model): + return model + + tiers = [ + VendorTier(vendor=FakeVendor("zhipu")), + VendorTier(vendor=FakeVendor("anthropic")), + VendorTier(vendor=FakeVendor("copilot")), + ] + + resolver = SessionPolicyResolver() + executor = _RouteExecutor( + router=None, + tiers=tiers, + usage_recorder=None, + session_manager=None, + session_policy_resolver=resolver, + ) + + # Before binding: default order + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "zhipu", + "anthropic", + "copilot", + ] + + # After binding: anthropic first, copilot second, zhipu last + resolver.upsert("test", ["anthropic", "copilot"]) + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "anthropic", + "copilot", + "zhipu", + ] + + # After unbind: back to default + resolver.remove("test") + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "zhipu", + "anthropic", + "copilot", + ]