diff --git a/.gitignore b/.gitignore index 75fec81..475b250 100644 --- a/.gitignore +++ b/.gitignore @@ -23,5 +23,8 @@ config.yaml .claude/.prompts.md .python-version +# Playwright MCP +.playwright-mcp/ + # Log files (dual-write logging) coding-proxy.log* diff --git a/AGENTS.md b/AGENTS.md index 20c1143..30d9d7a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -52,6 +52,7 @@ 1. **Python**: 严禁使用 pip/poetry,**必须**统一使用 `uv` 进行包管理与脚本执行(如 `uv run`); 2. **JavaScript/TypeScript**: 严禁使用 npm/yarn,**必须**统一使用 `pnpm` 进行包管理与脚本执行。 - **Database Management**: 谨慎操作,数据迁移、测试等操作严禁将现有数据删除,谨慎操作数据迁移的回滚,防止数据被清理。 +- **In-depth and close to the facts**:系统且全面地进行问题的分析,深入贴近事实,如有疑问,需先发问,不要乱做决定。 ## Documentation Standards (文档规范) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b974ed..0fb0f1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,15 +4,27 @@ ## [Unreleased] -- fix(vendor-channels): 新增 zhipu 同 vendor 自清理通道,修复 GLM-5 自循环 400 + tool_results 偶发降级; -- fix(vendor-channels): 修复 `_rewrite_srvtoolu_ids` 块顺序敏感性导致 inline tool_result 漏改名,进而 enforce 阶段 dict key 与 tool_use_ids 错位、anthropic 报 `tool_use ids without tool_result blocks immediately after` 的 cascade failover 问题(改为两遍扫描:先收集 id_map,再统一改写所有 tool_result.tool_use_id 引用); -- fix(vendor-channels): `enforce_anthropic_tool_pairing` 增加全局 sanity check pass,主循环边角错位让 dangling tool_use 漏过校验时兜底合成 is_error 占位并打 `pairing_sanity_repaired` 标签,避免 anthropic 二次报错; +## [v0.4.0](https://github.com/ThreeFish-AI/coding-proxy/releases/tag/v0.4.0) — 2026-05-01 -### Bug Fixes +> [!IMPORTANT] +> +> **🚀 Session 级专属路由策略!** +> +> 给每个 Session 指定专属的 vendor,动态调节不同 vendors 间的 LLM 流量。 + +![session](assets/session-v0.4.0.png) + +### ✨ 核心亮点 + +- feat(session-policy): 新增 Session 级专属路由策略 (#219) +- feat(dashboard): 新增会话活动面板 (#222) + +### 🔧 更多特性 -- fix(vendor-channels): 新增 `anthropic → zhipu` 跨供应商转换通道,修复 Anthropic beta 功能(web search, computer use)产生的 `server_tool_use` 块导致 zhipu 400 错误的问题; -- fix(error-classifier): 增强语义拒绝检测,识别 zhipu 等供应商返回的中文错误消息(如「API 调用参数有误」code=1210),确保正确触发故障转移; -- fix(vendor-channels): `_remove_vendor_blocks` 增加空内容占位保护,防止内容块全部剥离后消息结构不合法。 +- refactor(logging): 移除已被 ModelCall 汇总行覆盖的冗余 DEBUG 日志 (#203) +- style(dashboard): 加宽图表 tooltip 令模型名称与用量值单行显示 (#211) +- fix(usage-parser): 补充 OpenAI/Gemini SSE 流式分支的 model_served 提取 (#214) +- fix(usage-parser): 兼容 SSE chunk 中 usage 字段为 null 的极端格式 (#212) ## [v0.3.0](https://github.com/ThreeFish-AI/coding-proxy/releases/tag/v0.3.0) — 2026-04-20 diff --git a/assets/session-v0.4.0.png b/assets/session-v0.4.0.png new file mode 100644 index 0000000..dd6ff36 Binary files /dev/null and b/assets/session-v0.4.0.png differ diff --git a/docs/issue.md b/docs/issue.md index 1fc6428..c8f9765 100644 --- a/docs/issue.md +++ b/docs/issue.md @@ -45,99 +45,3 @@ if "usage" in data: # 仅判断 key 存在 - 本仓库内 `parse_usage_from_chunk` 的 Gemini `usageMetadata` 分支 (line ~219) 已经使用 `isinstance(um, dict)` 防御, 不受影响, 可作为参考实现。 - 检查其他解析器 (如 routing / vendor adapter 层) 是否还有 `if "key" in data: v = data["key"]; v.get(...)` 这种模式, 必要时同步加固。 - ---- - -## zhipu 自循环 400 + tool_results 偶发降级 - -**问题描述** - -生产日志反复出现下述链路: 请求一开始命中 zhipu 主 tier, 但在含 `tool_results` 的多轮工具调用场景下偶发返回 400, 触发到 copilot 二级 tier。具体日志特征: - -``` -WARNING Tier zhipu likely format incompatibility (400 + tool_results), trying next tier without recording failure -WARNING Tier zhipu semantic rejection (400), trying next tier without recording failure -DEBUG Applied transition channel zhipu → copilot: rewritten_38_srvtoolu_ids, stripped_16_thinking_blocks, removed_3_cache_control_fields, misplaced_tool_result_relocated -``` - -zhipu → copilot 通道的 adaptations 列表暴露了上一轮 zhipu 响应中存在的非标准产物 (`srvtoolu_*` ID、自签 thinking、错位的 `tool_result`)。 - -**表因** - -zhipu 自身偶发返回 400, 但错误体非 JSON 结构, 由 `_is_likely_request_format_error()` 判定为「格式不兼容」并跳过当前 tier。 - -**根因** - -1. zhipu 是 `NativeAnthropicVendor` 薄透传供应商, **不做任何请求体预处理**。 -2. `executor._determine_source_vendor` 三条优先级路径均以 `source != target_name` 过滤掉了同 vendor 自转换。 -3. `VENDOR_TRANSITIONS` 注册表中无 `("zhipu", "zhipu")` 条目。 - -后果: GLM-5 偶发产出非标准产物 (assistant 内联 `tool_result`、`server_tool_use_delta` 流式残块) 后, Claude Code 把这些产物原样回送下一轮请求时, **没有任何清洗发生**, 直接被转发到 zhipu 自身, 命中 zhipu 端的输入校验返回 400。 - -**处理方式** - -- 在 `vendor_channels.py` 新增 `prepare_zhipu_self_cleanup` 函数, 仅修复 zhipu 自身拒绝的两类产物: - 1. 剥离 `server_tool_use_delta` 流式残块 - 2. `enforce_anthropic_tool_pairing` 把 assistant 内联 `tool_result` 重定位到紧随 user 消息 -- 显式 **保留** zhipu 原生支持的特性: `srvtoolu_*` ID、`server_tool_use` 类型、自签 thinking signature、`cache_control` (cache_read 已在生产实证)、顶层 `thinking` 参数。 -- 在 `VENDOR_TRANSITIONS` 注册 `("zhipu", "zhipu") = prepare_zhipu_self_cleanup`。 -- 在 `executor._determine_source_vendor` 三条优先级路径中, 把「`source != target`」过滤替换为「通道已注册」门控 (`get_transition_channel(...) is not None`), 让自转换通道在显式注册时启用, 未注册时退化为原行为。 - -**后续防范** - -- 新增 `NativeAnthropicVendor` 子类 (minimax / kimi / doubao / xiaomi / alibaba 等) 时, 若上游 vendor 偶发产出违反 Anthropic 规范的产物, 可按需注册同名自清理通道, executor 无需任何额外改动。 -- 同 vendor 自转换通道应**精确剪裁**: 仅修复 vendor 自身拒绝的产物, 不要套用跨 vendor 通道的全量清理 (会误伤 vendor 原生支持的特性, 如 cache_control 损失带来 cache_read miss)。 - -**同类问题影响与处理注意事项** - -- `enforce_anthropic_tool_pairing` 仅识别 `tool_use` 类型 (不含 `server_tool_use`), 因为 `server_tool_use` 由 vendor 自身执行, 不需要客户端 `tool_result`。构造测试或类似清洗逻辑时需注意此差别。 -- `_is_likely_request_format_error()` 把「400 + tool_results + 非结构化错误体」一律标记为格式不兼容并跳过 tier 不计熔断器, 这层兜底虽能维持可用性但会**掩盖** vendor 自身的间歇性问题, 让根因更难发现。处理类似 400 偶发时, 应优先看 `Applied transition channel` 日志中的 adaptations 列表, 它能精确暴露上游响应中的非标准产物。 - ---- - -## anthropic 报 messages.X tool_use 缺 tool_result (zhipu→anthropic 故障转移失败) - -**问题描述** - -zhipu 完成响应后, executor 故障转移至 anthropic 时反复失败 (HTTP 400): - -``` -DEBUG Applied transition channel zhipu → anthropic: rewritten_86_srvtoolu_ids, misplaced_tool_result_relocated, stripped_18_thinking_blocks -WARNING anthropic stream error: status=400 ... messages.3: `tool_use` ids were found without `tool_result` blocks immediately after: toolu_normalized_3. -INFO Failover: anthropic → zhipu (reason: HTTP 400) -``` - -adaptations 列表显示 `misplaced_tool_result_relocated` 但**没有** `orphaned_tool_use_repaired`, 即 enforce 单遍扫描视角下认为所有 tool_use 都已配对; 但 anthropic 仍报 messages.X 缺 tool_result, 导致请求级 cascade failover 反复回到 zhipu。 - -**表因** - -`prepare_zhipu_to_anthropic` 链路输出的请求体中, 某个 assistant 的 `tool_use` 在紧邻的 user 消息中没有匹配的 `tool_result` 块。 - -**根因** - -`_rewrite_srvtoolu_ids` 采用单遍正向扫描: 在同一次循环中一边收集 srvtoolu_* → toolu_normalized_* 的 id_map, 一边改写遇到的 `tool_result.tool_use_id`。GLM-5 流式偶发将 inline tool_result 输出在本消息 `server_tool_use` 之前 (block 顺序异常), 导致: - -1. 处理 inline tool_result 时, id_map 尚未填入对应 srvtoolu_* → 漏改名, inline 仍保留 `srvtoolu_X` -2. 处理本消息 server_tool_use 时, 填入 id_map 并把 tool_use 改名为 `toolu_normalized_X` -3. 进入 `enforce_anthropic_tool_pairing` 时: - - A 步 extracted dict key = `srvtoolu_X` (inline 保留的旧 ID) - - B 步 tool_use_ids = `[toolu_normalized_X]` (已改名) - - F 步 `uid in extracted` 检查失败 (key 错位), 但若 next user 已含其他 stale tool_result 让 existing_result_ids "巧合" 命中, F 步会跳过 synth → 不触发 orphan 标签 - - 最终 anthropic 看到 messages.X 真的缺 toolu_normalized_X 的 tool_result → 400 - -**处理方式** - -- `_rewrite_srvtoolu_ids` 改为**两遍扫描**: Pass 1 仅遍历 assistant 消息, 收集 id_map 并改写 tool_use 自身的 id 与 type; Pass 2 全量遍历所有消息 (含 user / 异常 assistant 内联), 统一改写所有 `tool_result.tool_use_id` 引用。彻底消除 block 顺序敏感性。 -- `enforce_anthropic_tool_pairing` 主循环结束后追加**全局 sanity check pass**: 重新遍历每条 assistant, 验证其 tool_use_ids 全部在 next user 的 tool_result 中存在; 发现遗漏直接合成 is_error 占位并打 `pairing_sanity_repaired` 标签。作为防御深度抵御未来其他主循环边角错位。 -- A 步对 `tool_use_id` 缺失的破损 inline tool_result 也计入 relocated_count (避免 silent drop 影响 adaptations 标签可观测性)。 - -**后续防范** - -- 任何"按出现顺序填充字典 + 同遍引用查询"的两阶段操作都应警惕**顺序耦合**问题。两遍扫描 (collect → apply) 是消除此类 bug 的标准 pattern。 -- 关键校验函数应有**主循环 + 全局 sanity check** 的双层结构, 单层校验在边角场景下容易被绕过。 -- 处理 anthropic `tool_use ids without tool_result blocks immediately after` 类 cascade failover 时, **adaptations 标签能否复现日志**是定位 root cause 的强信号: 若 enforce 视角与 anthropic 视角不一致 (有 misplaced 但无 orphan, anthropic 仍报错), 必有上游 _rewrite / id 改写阶段的隐藏漏洞。 - -**同类问题影响与处理注意事项** - -- 任何对 messages 进行 ID 重写的转换链 (如 `_rewrite_srvtoolu_ids`、`anthropic_to_openai`、`anthropic_to_gemini`) 都应使用两遍扫描或一次性收集后再批量改写, 以保证 block 顺序无关性。 -- enforce 类校验函数若依赖 dict key 与 list 元素的**等同性**, 必须先确保两者在同一参考系下 (改名前 vs 改名后); 否则错位会以 "看起来 OK 实际有漏" 的方式静默泄漏到下游。 diff --git a/pyproject.toml b/pyproject.toml index 4cc7854..24630e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "coding-proxy" -version = "0.3.1a3" +version = "0.4.0" description = "A High-Availability, Transparent, and Smart Multi-Vendor Proxy for Claude Code. Support Claude Plans, GitHub Copilot, Google Antigravity, ZAI/GLM, MiniMax, Qwen, Xiaomi, Kimi, Doubao..." readme = "README.md" requires-python = ">=3.12" diff --git a/src/coding/proxy/cli/__init__.py b/src/coding/proxy/cli/__init__.py index 21cbe91..3b479fb 100644 --- a/src/coding/proxy/cli/__init__.py +++ b/src/coding/proxy/cli/__init__.py @@ -30,6 +30,10 @@ # 注册 Auth 子应用 app.add_typer(auth_app, name="auth") +# 注册 Session 子应用 +session_app = typer.Typer(name="session", help="管理 Session-Vendor 运行时绑定") +app.add_typer(session_app, name="session") + def _build_token_store(cfg_path: Path | None = None): """按配置解析 Token Store 路径并完成加载.""" @@ -264,6 +268,97 @@ def reset( console.print("[red]代理服务未运行[/red]") +# ── Session 子命令 ─────────────────────────────────────────────── + + +@session_app.command("bind") +def session_bind( + key: str = typer.Option(..., "--key", "-k", help="Session key"), + vendor: str = typer.Option( + ..., "--vendor", "-v", help="绑定 vendor(逗号分隔多个)" + ), + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """为指定 Session 绑定 vendor 优先级.""" + import httpx as _httpx + + vendors = [v.strip() for v in vendor.split(",") if v.strip()] + try: + resp = _httpx.put( + f"http://127.0.0.1:{port}/api/session-vendor", + json={"session_key": key, "vendors": vendors}, + timeout=5, + ) + if resp.status_code == 200: + data = resp.json() + console.print( + f"[green]绑定成功:[/] session [cyan]{key[:16]}…[/cyan] → " + + " → ".join(data.get("vendors", vendors)) + ) + else: + try: + err = resp.json() + msg = err.get("error", {}).get("message", resp.text) + except Exception: + msg = resp.text + console.print(f"[red]绑定失败: {msg}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + +@session_app.command("unbind") +def session_unbind( + key: str = typer.Option(..., "--key", "-k", help="Session key"), + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """解除指定 Session 的 vendor 绑定.""" + from urllib.parse import quote + + import httpx as _httpx + + try: + resp = _httpx.delete( + f"http://127.0.0.1:{port}/api/session-vendor/{quote(key, safe='')}", + timeout=5, + ) + if resp.status_code == 200: + console.print(f"[green]已解除绑定:[/] session [cyan]{key[:16]}…[/cyan]") + elif resp.status_code == 404: + console.print(f"[yellow]未找到绑定:[/] session [cyan]{key[:16]}…[/cyan]") + else: + console.print(f"[red]解除失败: {resp.status_code} {resp.text}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + +@session_app.command("list") +def session_list( + port: int = typer.Option(3392, "--port", "-p", help="代理服务端口"), +) -> None: + """列出所有运行时 Session-Vendor 绑定.""" + import httpx as _httpx + + try: + resp = _httpx.get( + f"http://127.0.0.1:{port}/api/session-vendor", + timeout=5, + ) + if resp.status_code == 200: + data = resp.json() + bindings = data.get("bindings", []) + if not bindings: + console.print("[dim]当前无运行时绑定[/dim]") + return + for b in bindings: + key = b.get("session_key", "?") + vendors = b.get("vendors", []) + console.print(f" [cyan]{key[:24]}…[/cyan] → " + " → ".join(vendors)) + else: + console.print(f"[red]查询失败: {resp.status_code} {resp.text}[/red]") + except _httpx.ConnectError: + console.print("[red]代理服务未运行[/red]") + + def _resolve_config_path(config: str | Path | None = None) -> Path | None: """标准化配置路径输入.""" if config is None: diff --git a/src/coding/proxy/config/config.default.yaml b/src/coding/proxy/config/config.default.yaml index f511e8b..40808fd 100644 --- a/src/coding/proxy/config/config.default.yaml +++ b/src/coding/proxy/config/config.default.yaml @@ -644,3 +644,31 @@ native_api: base_url: "https://api.anthropic.com" timeout_ms: 300000 connect_timeout_ms: 15000 + +# === Session 级别路由策略(可选)=== +# +# 为特定 Session 或客户端类别定制 vendor 优先级顺序。 +# 匹配策略按定义顺序求值,首次匹配生效。 +# +# 支持的匹配条件(OR 语义,满足任一即匹配): +# session_keys: 精确匹配的 session key 列表 +# client_category: 按客户端类别匹配(⚠️ 预留字段,当前版本暂未生效,后续版本支持) +# +# tiers: 覆盖全局 tier 顺序的供应商优先级列表(未提及的 vendor 保持在末尾) +# +# 示例 1:为特定 session 绑定专属 vendor 组合 +# session_policies: +# - name: "vip-session" +# match: +# session_keys: ["my-specific-session-id"] +# tiers: ["anthropic", "copilot", "zhipu"] +# +# 示例 2:Claude Code 请求优先走 Copilot +# session_policies: +# - name: "claude-code-preferred" +# match: +# client_category: "cc" +# tiers: ["copilot", "anthropic", "zhipu"] +# +# 未配置时(默认),所有 Session 使用全局 tiers 顺序。 +session_policies: [] diff --git a/src/coding/proxy/config/schema.py b/src/coding/proxy/config/schema.py index 5441979..ee21ee7 100644 --- a/src/coding/proxy/config/schema.py +++ b/src/coding/proxy/config/schema.py @@ -44,6 +44,7 @@ # ── 子模块 re-export ──────────────────────────────────────────── from .server import DatabaseConfig, LoggingConfig, ServerConfig # noqa: F401 +from .session_policy import SessionPoliciesConfig # noqa: F401 from .vendors import ( # noqa: F401 AlibabaConfig, AnthropicConfig, @@ -152,11 +153,19 @@ class ProxyConfig(BaseModel): "三个 provider 默认 enabled=False,显式启用才暴露 /api/{provider}/* 端点。" ), ) + # Session 级别路由策略 + session_policies: SessionPoliciesConfig = Field( + default_factory=SessionPoliciesConfig, + description=( + "Session 级别的路由策略配置。" + "可为特定 Session 或客户端类别定制 vendor 优先级顺序。" + ), + ) @model_validator(mode="before") @classmethod def _migrate_legacy_fields(cls, data: Any) -> Any: - """向后兼容迁移(legacy flat 格式 → vendors 列表格式). + """向后兼容迁移(legacy flat 格式 → vendors 列表格式)+ session_policies 规范化. 迁移规则: 1. ``anthropic`` / ``zhipu`` 字段名自动映射为 ``primary`` / ``fallback`` @@ -165,6 +174,12 @@ def _migrate_legacy_fields(cls, data: Any) -> Any: if not isinstance(data, dict): return data + # session_policies 规范化:YAML 中 session_policies: [] 解析为 list, + # 需转为 dict 包装以匹配 SessionPoliciesConfig 模型 + sp = data.get("session_policies") + if isinstance(sp, list): + data["session_policies"] = {"policies": sp} + # 1. 字段别名迁移 if "anthropic" in data and "primary" not in data: data["primary"] = data.pop("anthropic") @@ -331,4 +346,6 @@ def compat_state_path(self) -> Path: "AlibabaConfig", # native api passthrough "NativeApiConfig", + # session policy + "SessionPoliciesConfig", ] diff --git a/src/coding/proxy/config/session_policy.py b/src/coding/proxy/config/session_policy.py new file mode 100644 index 0000000..cb2c512 --- /dev/null +++ b/src/coding/proxy/config/session_policy.py @@ -0,0 +1,59 @@ +"""Session Policy 配置模型 — 为特定 Session 定制路由行为.""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class SessionPolicyMatch(BaseModel): + """策略匹配条件 — 满足任一条件即匹配(OR 语义).""" + + session_keys: list[str] = Field( + default_factory=list, + description="精确匹配的 session key 列表", + ) + client_category: str | None = Field( + default=None, + description=( + "按客户端类别匹配('cc' 或 'api')。" + "⚠️ 预留字段,当前路由执行链路未传入 client_category," + "配置此条件不会生效。后续版本将支持。" + ), + ) + + +class SessionQuotaConfig(BaseModel): + """Per-session 资源配额(架构预留).""" + + token_budget: int = Field( + default=0, + description="时间窗口内的 token 预算上限", + ) + window_hours: float = Field( + default=24.0, + description="滚动时间窗口(小时)", + ) + + +class SessionPolicy(BaseModel): + """单条 Session 路由策略.""" + + name: str = Field(description="策略名称(用于日志与排障)") + match: SessionPolicyMatch = Field(description="匹配条件") + tiers: list[str] = Field( + default_factory=list, + description="覆盖全局 tier 顺序的供应商优先级列表", + ) + quota: SessionQuotaConfig | None = Field( + default=None, + description="Per-session 资源配额(预留)", + ) + + +class SessionPoliciesConfig(BaseModel): + """顶层 Session 策略配置容器.""" + + policies: list[SessionPolicy] = Field( + default_factory=list, + description="Session 路由策略列表,按定义顺序求值,首次匹配生效", + ) diff --git a/src/coding/proxy/convert/anthropic_to_openai.py b/src/coding/proxy/convert/anthropic_to_openai.py index 82baeba..81bd163 100644 --- a/src/coding/proxy/convert/anthropic_to_openai.py +++ b/src/coding/proxy/convert/anthropic_to_openai.py @@ -284,24 +284,34 @@ def _translate_assistant_message(message: dict[str, Any]) -> list[dict[str, Any] final_text_parts = text_parts if tool_uses: + tool_calls: list[dict[str, Any]] = [] + for block in tool_uses: + raw_input = block.get("input") + if not isinstance(raw_input, dict): + logger.debug( + "copilot: tool_use id=%s name=%s has non-dict input (type=%s), " + "defaulting to empty dict", + block.get("id", ""), + block.get("name", ""), + type(raw_input).__name__, + ) + raw_input = {} + tool_calls.append( + { + "id": block.get("id", ""), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": json.dumps(raw_input, ensure_ascii=False), + }, + } + ) return [ { "role": "assistant", "content": "\n\n".join(part for part in final_text_parts if part) or None, - "tool_calls": [ - { - "id": block.get("id", ""), - "type": "function", - "function": { - "name": block.get("name", ""), - "arguments": json.dumps( - block.get("input", {}), ensure_ascii=False - ), - }, - } - for block in tool_uses - ], + "tool_calls": tool_calls, } ] diff --git a/src/coding/proxy/convert/vendor_channels.py b/src/coding/proxy/convert/vendor_channels.py index 0b1d511..bec46f7 100644 --- a/src/coding/proxy/convert/vendor_channels.py +++ b/src/coding/proxy/convert/vendor_channels.py @@ -10,8 +10,6 @@ zhipu → anthropic : prepare_zhipu_to_anthropic (剥离 thinking + tool pairing) zhipu → copilot : prepare_zhipu_to_copilot (剥离 thinking + cache_control + tool pairing) copilot → zhipu : prepare_copilot_to_zhipu (剥离 thinking + cache_control + 移除 thinking 参数 + tool pairing) - zhipu → zhipu : prepare_zhipu_self_cleanup (剥离 server_tool_use_delta + tool pairing) - anthropic → zhipu : prepare_anthropic_to_zhipu (剥离 server_tool_use + thinking + cache_control + 移除 thinking 参数 + tool pairing) """ from __future__ import annotations @@ -110,10 +108,6 @@ def enforce_anthropic_tool_pairing( 此函数是一个**自包含的单遍处理**,不依赖 Phase 1 收集的 misplaced 信息。 - 最终在主循环之后执行一次幂等的全局 sanity check pass, 防御主循环的边角 - 错位 (如 inline tool_result 引用未在本消息出现的 tool_use_id, 导致 extracted - 字典 key 与 tool_use_ids 集合错位) 让 dangling tool_use 漏过校验。 - Args: messages_list: 消息列表(就地修改)。 @@ -145,13 +139,10 @@ def enforce_anthropic_tool_pairing( if tid: extracted_tool_results[tid] = block relocated_count += 1 - else: - # 缺 tool_use_id 的破损 tool_result 也视作错位剥离 - relocated_count += 1 else: retained_content.append(block) - if extracted_tool_results or len(retained_content) != len(content): + if extracted_tool_results: msg["content"] = retained_content # B. 收集所有 tool_use ID @@ -216,17 +207,10 @@ def enforce_anthropic_tool_pairing( i += 1 - # G. 最终全局 sanity check pass (抽出为独立函数便于单测验证正向兜底路径). - sanity_synthesized = _enforce_pairing_sanity_pass(messages_list) - if relocated_count: adaptations.append("misplaced_tool_result_relocated") - if synthesized_ids or sanity_synthesized: - adaptations.append("orphaned_tool_use_repaired") - - # 主循环 F 段与 sanity G 段分别打日志, 避免 main=0/sanity=N 时把 sanity - # 兜底误归因为主循环工作 (运维在线日志聚合时易混淆 cross-pass id-map drift). if synthesized_ids: + adaptations.append("orphaned_tool_use_repaired") logger.warning( "Vendor degradation adaptation: synthesized %d tool_result block(s) " "for orphaned tool_use to satisfy Anthropic pairing constraint. " @@ -234,94 +218,10 @@ def enforce_anthropic_tool_pairing( len(synthesized_ids), ", ".join(synthesized_ids), ) - if sanity_synthesized: - adaptations.append("pairing_sanity_repaired") - logger.warning( - "Pairing sanity check repaired %d dangling tool_use(s) missed by " - "main pass (likely cross-pass id-map drift). Affected tool_use_ids: %s", - len(sanity_synthesized), - ", ".join(sanity_synthesized), - ) return adaptations -def _enforce_pairing_sanity_pass(messages_list: list[Any]) -> list[str]: - """全局 sanity check pass: 防御主循环边角错位让 dangling tool_use 漏过. - - 例如: extracted dict key 与 _rewrite 后的 tool_use_ids 错位、user_msg - 中已有 stale tool_result 让 F 步误判 existing 命中等场景。 - - 扫描所有 assistant 消息, 验证每个 ``tool_use`` block ID 在紧随的 user 消息 - 中均存在对应 ``tool_result``; 漏掉的合成 ``is_error`` 占位。 - - 抽取为独立函数的目的: 主循环 F 步在当前实现下能覆盖所有 dangling tool_use, - 导致 sanity 实际兜底分支在公开 API 测试中无法被触发; 独立函数便于直接 - 构造「绕过主循环」的输入, 对兜底合成路径建立正向回归保护。 - - Args: - messages_list: 消息列表 (就地修改, 必要时插入空 user 消息). - - Returns: - sanity 兜底合成的 tool_use_id 列表 (空表示主循环已完成所有配对). - """ - sanity_synthesized: list[str] = [] - j = 0 - while j < len(messages_list): - msg_j = messages_list[j] - if not isinstance(msg_j, dict) or msg_j.get("role") != "assistant": - j += 1 - continue - content_j = msg_j.get("content") - if not isinstance(content_j, list): - j += 1 - continue - tu_ids = [ - b["id"] - for b in content_j - if isinstance(b, dict) and b.get("type") == "tool_use" and b.get("id") - ] - if not tu_ids: - j += 1 - continue - next_j = j + 1 - if ( - next_j < len(messages_list) - and isinstance(messages_list[next_j], dict) - and messages_list[next_j].get("role") == "user" - ): - next_user = messages_list[next_j] - else: - next_user = {"role": "user", "content": []} - messages_list.insert(next_j, next_user) - nu_content = next_user.get("content") - if isinstance(nu_content, str): - next_user["content"] = [{"type": "text", "text": nu_content}] - elif not isinstance(nu_content, list): - next_user["content"] = [] - nu_result_ids = { - b["tool_use_id"] - for b in next_user["content"] - if isinstance(b, dict) - and b.get("type") == "tool_result" - and b.get("tool_use_id") - } - for uid in tu_ids: - if uid in nu_result_ids: - continue - next_user["content"].append( - { - "type": "tool_result", - "tool_use_id": uid, - "content": "", - "is_error": True, - } - ) - sanity_synthesized.append(uid) - j += 1 - return sanity_synthesized - - def _strip_cache_control(body: dict[str, Any]) -> int: """从 system/messages/tools 中移除 cache_control 字段(就地). @@ -384,13 +284,7 @@ def _remove_vendor_blocks(body: dict[str, Any], block_types: set[str]) -> int: removed += 1 continue new_content.append(block) - if content != new_content: - if not new_content: - new_content = [{"type": "text", "text": "[vendor_block_removed]"}] - logger.info( - "Inserted placeholder text block after stripping " - "vendor blocks to avoid empty message content", - ) + if removed: message["content"] = new_content return removed @@ -400,12 +294,8 @@ def _rewrite_srvtoolu_ids(body: dict[str, Any]) -> tuple[int, dict[str, str]]: Anthropic API 要求 tool_use 类型与 ``toolu_*`` 格式的 ID。Zhipu 的 ``server_tool_use`` + ``srvtoolu_*`` 在上游 Anthropic 兼容端点可用,但无法 - 透传至其他供应商;同时还需重写所有 ``tool_result.tool_use_id`` 引用, - 保持配对关系。 - - 采用**两遍扫描**避免块顺序敏感性: GLM-5 偶发将 inline tool_result 输出在 - 本消息 tool_use 之前, 单遍扫描会因 id_map 尚未填入而漏改 inline tool_result - 的 tool_use_id, 导致后续 enforce 步骤无法将其与 tool_use 配对。 + 透传至其他供应商;同时还需重写紧随其后 user 消息中 ``tool_result.tool_use_id`` + 引用,保持配对关系。 Returns: (rewritten_count, id_map) — 重写次数与 {原 ID: 新 ID} 映射。 @@ -418,59 +308,45 @@ def next_id() -> str: counter += 1 return f"toolu_normalized_{counter}" - # Pass 1: 收集所有 assistant tool_use / server_tool_use 的 ID 映射 - # 不修改 tool_result, 仅建立 id_map; 同时改写 tool_use 自身的 id 与 type for message in body.get("messages", []): if not isinstance(message, dict): continue content = message.get("content") if not isinstance(content, list): continue - if message.get("role") != "assistant": - continue + role = message.get("role") for block in content: if not isinstance(block, dict): continue block_type = block.get("type") block_id = block.get("id") - if block_type not in {"tool_use", "server_tool_use"}: - continue - if isinstance(block_id, str) and _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match( - block_id - ): - new_id = next_id() - id_map[block_id] = new_id - block["id"] = new_id - block["type"] = "tool_use" - elif ( - isinstance(block_id, str) - and block_id - and not _ANTHROPIC_TOOL_USE_ID_RE.match(block_id) - and block.get("name") - ): - # 非标准 ID(非 toolu_ / srvtoolu_),且具备 name 可改写 - new_id = next_id() - id_map[block_id] = new_id - block["id"] = new_id - block["type"] = "tool_use" - elif block_type == "server_tool_use" and isinstance(block_id, str): - # 兜底: 类型是 server_tool_use 但 ID 已是标准 toolu_ 形式,仅纠正类型 - block["type"] = "tool_use" - - # Pass 2: 全量同步所有 tool_result.tool_use_id 引用 (含 user/assistant 内联) - if id_map: - for message in body.get("messages", []): - if not isinstance(message, dict): - continue - content = message.get("content") - if not isinstance(content, list): - continue - for block in content: - if not isinstance(block, dict): - continue - if block.get("type") != "tool_result": - continue + # Case A: assistant 消息里的 server_tool_use / srvtoolu_* → 改写 + if role == "assistant" and block_type in {"tool_use", "server_tool_use"}: + if isinstance(block_id, str) and _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match( + block_id + ): + new_id = next_id() + id_map[block_id] = new_id + block["id"] = new_id + block["type"] = "tool_use" + elif ( + isinstance(block_id, str) + and block_id + and not _ANTHROPIC_TOOL_USE_ID_RE.match(block_id) + and block.get("name") + ): + # 非标准 ID(非 toolu_ / srvtoolu_),且具备 name 可改写 + new_id = next_id() + id_map[block_id] = new_id + block["id"] = new_id + block["type"] = "tool_use" + elif block_type == "server_tool_use" and isinstance(block_id, str): + # 兜底: 类型是 server_tool_use 但 ID 已是标准 toolu_ 形式,仅纠正类型 + block["type"] = "tool_use" + + # Case B: user 消息里的 tool_result.tool_use_id 同步重写 + if block_type == "tool_result": tool_use_id = block.get("tool_use_id") if isinstance(tool_use_id, str) and tool_use_id in id_map: block["tool_use_id"] = id_map[tool_use_id] @@ -482,9 +358,8 @@ def infer_source_vendor_from_body(body: dict[str, Any]) -> str | None: """从请求 body 内容推断源供应商(仅在无会话上下文时作为兜底). 启发式(按置信度排序): - - 出现 ``srvtoolu_*`` 格式的 ID → zhipu - - 出现 ``server_tool_use_delta`` 类型的 content block → zhipu - - 出现 ``server_tool_use`` 块 + ``toolu_*`` ID → anthropic(beta 功能产物) + - 出现 ``srvtoolu_*`` 格式的 ``tool_use.id`` → zhipu + - 出现 ``server_tool_use`` / ``server_tool_use_delta`` 类型的 content block → zhipu 原则: 只读扫描不修改 body;无匹配返回 None(视作纯净无需跨供应商清洗)。 @@ -492,7 +367,7 @@ def infer_source_vendor_from_body(body: dict[str, Any]) -> str | None: body: Anthropic Messages 请求体。 Returns: - 推断的源供应商名称(``"zhipu"`` 或 ``"anthropic"``),无法推断返回 None。 + 推断的源供应商名称(当前仅支持 ``"zhipu"``),无法推断返回 None。 """ for message in body.get("messages", []): if not isinstance(message, dict): @@ -504,35 +379,18 @@ def infer_source_vendor_from_body(body: dict[str, Any]) -> str | None: if not isinstance(block, dict): continue block_type = block.get("type") - block_id = block.get("id") - tool_use_id = block.get("tool_use_id") - - # Zhipu: server_tool_use_delta 是 zhipu 私有流式块(无歧义) - if block_type == "server_tool_use_delta": + if block_type in _ZHIPU_SERVER_TOOL_USE_TYPES: return "zhipu" - - # srvtoolu_* ID(无论 block type)→ zhipu + block_id = block.get("id") if isinstance(block_id, str) and _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match( block_id ): return "zhipu" + tool_use_id = block.get("tool_use_id") if isinstance(tool_use_id, str) and _ANTHROPIC_SERVER_TOOL_USE_ID_RE.match( tool_use_id ): return "zhipu" - - # server_tool_use 块 + toolu_* ID → Anthropic beta 功能 - if ( - block_type == "server_tool_use" - and isinstance(block_id, str) - and _ANTHROPIC_TOOL_USE_ID_RE.match(block_id) - ): - return "anthropic" - - # server_tool_use 块 + 非 toolu_/srvtoolu_ ID → 按类型兜底归 zhipu - if block_type == "server_tool_use": - return "zhipu" - return None @@ -580,61 +438,6 @@ def prepare_copilot_to_zhipu( return prepared, adaptations -# ── anthropic → zhipu 转换通道 ──────────────────────────────────── - -# Anthropic beta 特有的 server_tool_use 块类型(web search, computer use 等). -# 这些块在 Anthropic API 中有效,但 zhipu GLM-5 的兼容端点不支持。 -# 注意: 这与 zhipu 自己的 server_tool_use(使用 srvtoolu_* ID)是不同的概念, -# 但它们共用同一个 type 名称 "server_tool_use"。 -_ANTHROPIC_BETA_BLOCK_TYPES = {"server_tool_use"} - - -def prepare_anthropic_to_zhipu( - body: dict[str, Any], -) -> tuple[dict[str, Any], list[str]]: - """anthropic → zhipu 转换: 清理 anthropic 产物以适配 GLM-5. - - Anthropic API 可能产生的非兼容产物: - - ``server_tool_use`` blocks(web search / computer use 等 beta 功能) - - ``thinking`` / ``redacted_thinking`` blocks(含 Anthropic 签发的 signature) - - ``cache_control`` 字段 - - 顶层 ``thinking`` / ``extended_thinking`` 参数 - - Returns: - (prepared_body, adaptations) — adaptations 为应用的变换描述列表。 - """ - prepared = copy.deepcopy(body) - adaptations: list[str] = [] - - # Step 1: 剥离 anthropic 的 server_tool_use blocks(web search, computer use 等) - removed_stu = _remove_vendor_blocks(prepared, _ANTHROPIC_BETA_BLOCK_TYPES) - if removed_stu: - adaptations.append(f"removed_{removed_stu}_server_tool_use_blocks") - - # Step 2: 剥离 thinking/redacted_thinking blocks - stripped = strip_thinking_blocks(prepared) - if stripped: - adaptations.append(f"stripped_{stripped}_thinking_blocks") - - # Step 3: 移除 cache_control 字段 - removed_cc = _strip_cache_control(prepared) - if removed_cc: - adaptations.append(f"removed_{removed_cc}_cache_control_fields") - - # Step 4: 移除顶层 thinking/extended_thinking 参数(GLM-5 不支持) - for param in ("thinking", "extended_thinking"): - if param in prepared: - del prepared[param] - adaptations.append(f"removed_{param}_param") - - # Step 5: 强制 tool_use/tool_result 配对 - pairing_fixes = enforce_anthropic_tool_pairing(prepared.get("messages", [])) - if pairing_fixes: - adaptations.extend(pairing_fixes) - - return prepared, adaptations - - # ── zhipu → copilot 转换通道 ───────────────────────────────────── @@ -736,54 +539,8 @@ def prepare_zhipu_to_anthropic( return prepared, adaptations -# ── zhipu → zhipu 自清理通道 ────────────────────────────────────── - - -def prepare_zhipu_self_cleanup( - body: dict[str, Any], -) -> tuple[dict[str, Any], list[str]]: - """zhipu → zhipu 自清理: 仅修复 zhipu 自身无法消化的产物. - - GLM-5 偶发地在 assistant 消息中输出 ``tool_result`` 块(违反 Anthropic 规范), - 或在流式响应中暴露 ``server_tool_use_delta`` 私有块。当 Claude Code 将这些 - 产物原样回送下一轮请求时,zhipu 的 Anthropic 兼容端点会以 400 拒绝 - (表现为 "400 + tool_results" 偶发,进而触发到 copilot 的降级)。 - - 本通道仅修复 zhipu 自身拒绝的两类产物,**保留** 所有 zhipu 原生支持的特性: - - - ✓ ``srvtoolu_*`` ID 与 ``server_tool_use`` 类型(zhipu 原生) - - ✓ thinking blocks 的 zhipu 自签 signature - - ✓ ``cache_control`` 字段(GLM Anthropic 端点支持,cache_read 已实证) - - ✓ 顶层 ``thinking`` / ``extended_thinking`` 参数 - - 清理操作(顺序、就地、幂等): - 1. 剥离 ``server_tool_use_delta`` 流式残块 - 2. 强制 tool_use/tool_result 配对(关键: 把 assistant 内联的 tool_result - 搬迁到紧随的 user 消息) - - Returns: - (prepared_body, adaptations) — adaptations 为应用的变换描述列表。 - """ - prepared = copy.deepcopy(body) - adaptations: list[str] = [] - - # Step 1: 剥离 zhipu 私有流式块类型(input 中不应出现) - removed_vendor_blocks = _remove_vendor_blocks(prepared, _ZHIPU_VENDOR_BLOCK_TYPES) - if removed_vendor_blocks: - adaptations.append(f"removed_{removed_vendor_blocks}_zhipu_vendor_blocks") - - # Step 2: 强制 tool_use/tool_result 配对 - pairing_fixes = enforce_anthropic_tool_pairing(prepared.get("messages", [])) - if pairing_fixes: - adaptations.extend(pairing_fixes) - - return prepared, adaptations - - # ── 注册所有转换通道 ────────────────────────────────────────────── VENDOR_TRANSITIONS[("zhipu", "anthropic")] = prepare_zhipu_to_anthropic VENDOR_TRANSITIONS[("zhipu", "copilot")] = prepare_zhipu_to_copilot VENDOR_TRANSITIONS[("copilot", "zhipu")] = prepare_copilot_to_zhipu -VENDOR_TRANSITIONS[("zhipu", "zhipu")] = prepare_zhipu_self_cleanup -VENDOR_TRANSITIONS[("anthropic", "zhipu")] = prepare_anthropic_to_zhipu diff --git a/src/coding/proxy/logging/db.py b/src/coding/proxy/logging/db.py index 3c52e66..ffe9b2c 100644 --- a/src/coding/proxy/logging/db.py +++ b/src/coding/proxy/logging/db.py @@ -170,7 +170,8 @@ def _local_month_udf(ts_str: str) -> str: client_category TEXT NOT NULL DEFAULT 'cc', operation TEXT NOT NULL DEFAULT '', endpoint TEXT NOT NULL DEFAULT '', - extra_usage_json TEXT NOT NULL DEFAULT '{}' + extra_usage_json TEXT NOT NULL DEFAULT '{}', + session_key TEXT NOT NULL DEFAULT '' ); CREATE TABLE IF NOT EXISTS usage_evidence ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -194,6 +195,7 @@ def _local_month_udf(ts_str: str) -> str: CREATE INDEX IF NOT EXISTS idx_usage_vendor ON usage_log(vendor); CREATE INDEX IF NOT EXISTS idx_usage_client_category ON usage_log(client_category); CREATE INDEX IF NOT EXISTS idx_usage_operation ON usage_log(operation); +CREATE INDEX IF NOT EXISTS idx_usage_session_key ON usage_log(session_key); CREATE INDEX IF NOT EXISTS idx_usage_evidence_request_id ON usage_evidence(request_id); CREATE INDEX IF NOT EXISTS idx_usage_evidence_vendor ON usage_evidence(vendor); """ @@ -247,6 +249,7 @@ async def init(self) -> None: await self._migrate_rename_backend_to_vendor() await self._migrate_add_failover_from() await self._migrate_add_native_columns() + await self._migrate_add_session_key() await self._db.executescript(_CREATE_INDEXES) # 注册时区感知的日期函数:将 UTC 时间戳转为本地时间维度 await self._db.create_function("local_date", 1, _local_date_udf) @@ -286,6 +289,18 @@ async def _migrate_add_native_columns(self) -> None: await self._db.execute(f"ALTER TABLE usage_log ADD COLUMN {name} {ddl}") logger.info("Migration: added %s column to usage_log", name) + async def _migrate_add_session_key(self) -> None: + """幂等迁移:为已有数据库添加 session_key 列.""" + if not self._db: + return + cursor = await self._db.execute("PRAGMA table_info(usage_log)") + columns = {row["name"] for row in await cursor.fetchall()} + if "session_key" not in columns: + await self._db.execute( + "ALTER TABLE usage_log ADD COLUMN session_key TEXT NOT NULL DEFAULT ''" + ) + logger.info("Migration: added session_key column to usage_log") + async def _migrate_rename_backend_to_vendor(self) -> None: """幂等迁移:重命名 backend 列为 vendor.""" if not self._db: @@ -319,6 +334,7 @@ async def log( operation: str = "", endpoint: str = "", extra_usage_json: str = "{}", + session_key: str = "", ) -> None: if not self._db: return @@ -328,8 +344,8 @@ async def log( input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, duration_ms, success, failover, failover_from, request_id, - client_category, operation, endpoint, extra_usage_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + client_category, operation, endpoint, extra_usage_json, session_key) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( vendor, model_requested, @@ -347,6 +363,7 @@ async def log( operation, endpoint, extra_usage_json, + session_key, ), ) await self._db.commit() @@ -573,6 +590,63 @@ async def query_window_total( row = await cursor.fetchone() return row["total"] if row else 0 + async def query_recent_sessions( + self, + limit: int = 20, + hours: float = 24.0, + ) -> list[dict]: + """按 session_key 聚合近期活跃会话统计.""" + if not self._db: + return [] + cutoff_iso = _hours_ago_utc_iso(hours) + cursor = await self._db.execute( + """SELECT session_key, + MIN(ts) AS first_seen_ts, + MAX(ts) AS last_active_ts, + COUNT(*) AS total_requests, + SUM(input_tokens + output_tokens) AS total_tokens, + SUM(input_tokens) AS total_input, + SUM(output_tokens) AS total_output, + GROUP_CONCAT(DISTINCT model_served) AS models, + GROUP_CONCAT(DISTINCT vendor) AS vendors, + AVG(duration_ms) AS avg_duration_ms, + SUM(CASE WHEN success THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS success_rate, + GROUP_CONCAT(DISTINCT client_category) AS client_categories + FROM usage_log + WHERE session_key != '' AND ts >= ? + GROUP BY session_key + ORDER BY last_active_ts DESC + LIMIT ?""", + (cutoff_iso, limit), + ) + rows = await cursor.fetchall() + return [dict(row) for row in rows] + + async def query_session_profile(self, session_key: str) -> dict | None: + """查询单个会话的完整聚合数据.""" + if not self._db: + return None + cursor = await self._db.execute( + """SELECT session_key, + MIN(ts) AS first_seen_ts, + MAX(ts) AS last_active_ts, + COUNT(*) AS total_requests, + SUM(input_tokens + output_tokens) AS total_tokens, + SUM(input_tokens) AS total_input, + SUM(output_tokens) AS total_output, + GROUP_CONCAT(DISTINCT model_served) AS models, + GROUP_CONCAT(DISTINCT vendor) AS vendors, + AVG(duration_ms) AS avg_duration_ms, + SUM(CASE WHEN success THEN 1 ELSE 0 END) * 100.0 / COUNT(*) AS success_rate, + GROUP_CONCAT(DISTINCT client_category) AS client_categories + FROM usage_log + WHERE session_key = ? + GROUP BY session_key""", + (session_key,), + ) + row = await cursor.fetchone() + return dict(row) if row else None + async def close(self) -> None: if self._db: await self._db.close() diff --git a/src/coding/proxy/routing/executor.py b/src/coding/proxy/routing/executor.py index 2a27665..9d33ca9 100644 --- a/src/coding/proxy/routing/executor.py +++ b/src/coding/proxy/routing/executor.py @@ -31,6 +31,7 @@ parse_rate_limit_headers, ) from .session_manager import RouteSessionManager +from .session_policy import SessionPolicyResolver from .tier import VendorTier from .usage_parser import ( build_usage_evidence_records, @@ -135,6 +136,10 @@ def _is_likely_request_format_error( # 非结构化响应体(非 JSON) if not trimmed.startswith("{") and len(trimmed) < 200: return True + # 结构化 JSON 400 但含 tool_call 格式错误码 → 格式不兼容 + # (如 Copilot 返回 {"error":{"code":"invalid_tool_call_format",...}}) + if "invalid_tool_call_format" in trimmed: + return True return False @@ -207,12 +212,14 @@ def __init__( usage_recorder: UsageRecorder, session_manager: RouteSessionManager, reauth_coordinator: Any | None = None, + session_policy_resolver: SessionPolicyResolver | None = None, ) -> None: self._router = router self._tiers = tiers self._recorder = usage_recorder self._session_mgr = session_manager self._reauth_coordinator = reauth_coordinator + self._policy_resolver = session_policy_resolver or SessionPolicyResolver() # Tier 名称 → OAuth provider 名称的映射 self._tier_provider_map: dict[str, str] = { @@ -222,6 +229,30 @@ def __init__( # ── 公开执行入口 ────────────────────────────────────── + def _resolve_effective_tiers(self, session_key: str) -> list[VendorTier]: + """根据 Session Policy 解析生效的 tier 顺序. + + 策略指定的 vendor 按其顺序排列在头部,未提及的保持在末尾。 + 无策略时返回全局默认顺序。 + """ + policy = self._policy_resolver.resolve(session_key) + if not policy or not policy.tiers: + return self._tiers + + name_to_tier = {t.name: t for t in self._tiers} + ordered: list[VendorTier] = [] + seen: set[str] = set() + for name in policy.tiers: + tier = name_to_tier.get(name) + if tier and name not in seen: + ordered.append(tier) + seen.add(name) + for tier in self._tiers: + if tier.name not in seen: + ordered.append(tier) + seen.add(tier.name) + return ordered + def _prepare_body_for_tier( self, body: dict[str, Any], @@ -264,34 +295,38 @@ def _determine_source_vendor( Priority 1: failed_tier_name(请求内故障转移,最可靠)。 Priority 2: session_record.provider_state 中有已注册转换的 vendor(跨请求)。 Priority 3: 从 body 内容推断(兜底首次请求无会话状态场景)。 - - 同 vendor 自转换(source == target)仅在 ``VENDOR_TRANSITIONS`` 显式注册 - 了对应通道时启用(如 ``("zhipu","zhipu")`` 修复 zhipu 自身不接受的产物), - 否则退化到无源行为。 """ from ..convert.vendor_channels import ( get_transition_channel, infer_source_vendor_from_body, ) - # 请求内:刚失败的 tier 就是源 - # 同 vendor 自转换仅在显式注册通道时生效 - if failed_tier_name and ( - failed_tier_name != target_name - or get_transition_channel(failed_tier_name, target_name) is not None + # 请求内:刚失败的 tier 就是源(仅当存在已注册的转换通道时) + # 修复:原逻辑仅检查 failed_tier != target 就无条件返回, + # 导致无注册通道的 failed_tier(如 copilot→anthropic)阻断降级到 + # Priority 2/3,原始 body 中的 server_tool_use 等非标准块未被清理。 + if ( + failed_tier_name + and get_transition_channel(failed_tier_name, target_name) is not None ): return failed_tier_name - # 跨请求:从会话历史找有注册转换的源(含已注册自转换) + # 跨请求:从会话历史找有注册转换的源 if session_record is not None and session_record.provider_state: for source in session_record.provider_state: - if get_transition_channel(source, target_name): + if source != target_name and get_transition_channel( + source, target_name + ): return source - # 首次请求兜底:从 body 内容推断(识别 zhipu 产物等,含已注册自转换) + # 首次请求兜底:从 body 内容推断(识别 zhipu 产物等) if body is not None: inferred = infer_source_vendor_from_body(body) - if inferred and get_transition_channel(inferred, target_name): + if ( + inferred + and inferred != target_name + and get_transition_channel(inferred, target_name) + ): return inferred return None @@ -302,7 +337,6 @@ async def execute_stream( headers: dict[str, str], ) -> AsyncIterator[tuple[bytes, str]]: """路由流式请求,按优先级尝试各层级.""" - last_idx = len(self._tiers) - 1 last_exc: Exception | None = None failed_tier_name: str | None = None request_caps = build_request_capabilities(body) @@ -312,8 +346,10 @@ async def execute_stream( canonical_request.trace_id, ) incompatible_reasons: list[str] = [] + effective_tiers = self._resolve_effective_tiers(canonical_request.session_key) + last_idx = len(effective_tiers) - 1 - for i, tier in enumerate(self._tiers): + for i, tier in enumerate(effective_tiers): is_last = i == last_idx gate = await self._try_gate_tier( @@ -396,6 +432,7 @@ async def execute_stream( model_served=model_served, request_id=info.request_id, ), + session_key=canonical_request.session_key, ) self._router._active_vendor_name = tier.name # 更新活跃供应商 return @@ -471,7 +508,6 @@ async def execute_message( headers: dict[str, str], ) -> VendorResponse: """路由非流式请求,按优先级尝试各层级.""" - last_idx = len(self._tiers) - 1 start = time.monotonic() failed_tier_name: str | None = None request_caps = build_request_capabilities(body) @@ -481,8 +517,10 @@ async def execute_message( canonical_request.trace_id, ) incompatible_reasons: list[str] = [] + effective_tiers = self._resolve_effective_tiers(canonical_request.session_key) + last_idx = len(effective_tiers) - 1 - for i, tier in enumerate(self._tiers): + for i, tier in enumerate(effective_tiers): is_last = i == last_idx gate = await self._try_gate_tier( @@ -537,6 +575,7 @@ async def execute_message( model_served=model_served, usage=resp.usage, ), + session_key=canonical_request.session_key, ) self._router._active_vendor_name = tier.name # 更新活跃供应商 return resp @@ -620,6 +659,7 @@ async def execute_message( evidence_records=self._recorder.build_nonstream_evidence_records( vendor=tier.name, model_served=model_served, usage=resp.usage ), + session_key=canonical_request.session_key, ) return resp diff --git a/src/coding/proxy/routing/router.py b/src/coding/proxy/routing/router.py index 3a65cd6..32757a8 100644 --- a/src/coding/proxy/routing/router.py +++ b/src/coding/proxy/routing/router.py @@ -18,6 +18,7 @@ from .executor import _RouteExecutor from .session_manager import RouteSessionManager +from .session_policy import SessionPolicyResolver from .tier import VendorTier # 向后兼容别名 @@ -36,6 +37,7 @@ def __init__( token_logger: TokenLogger | None = None, reauth_coordinator: Any | None = None, compat_session_store: CompatSessionStore | None = None, + session_policy_resolver: SessionPolicyResolver | None = None, ) -> None: if not tiers: raise ValueError("至少需要一个供应商层级") @@ -53,6 +55,7 @@ def __init__( usage_recorder=self._recorder, session_manager=self._session_mgr, reauth_coordinator=reauth_coordinator, + session_policy_resolver=session_policy_resolver, ) def set_pricing_table(self, table: PricingTable) -> None: diff --git a/src/coding/proxy/routing/session_policy.py b/src/coding/proxy/routing/session_policy.py new file mode 100644 index 0000000..9102e41 --- /dev/null +++ b/src/coding/proxy/routing/session_policy.py @@ -0,0 +1,116 @@ +"""Session Policy 解析引擎 — 根据 session_key + client_category 解析适用的路由策略.""" + +from __future__ import annotations + +import logging +import threading + +from ..config.session_policy import SessionPolicy, SessionPolicyMatch + +logger = logging.getLogger(__name__) + + +class SessionPolicyResolver: + """根据 session_key + client_category 解析适用的 SessionPolicy. + + 设计要点: + - 启动时构建索引,运行时 O(1) 查找 + - 精确匹配优先:session_key > client_category > 无策略 + - 无侵入性:不匹配时返回 None,路由行为与现有一致 + - 运行时可变:支持 API 动态 upsert/remove session → vendor 绑定 + """ + + def __init__(self, policies: list[SessionPolicy] | None = None) -> None: + self._policies = policies or [] + self._key_index: dict[str, SessionPolicy] = {} + self._category_index: dict[str, SessionPolicy] = {} + self._config_key_backup: dict[str, SessionPolicy] = {} + self._lock = threading.Lock() + self._build_index() + + def _build_index(self) -> None: + """构建 session_key / client_category → SessionPolicy 的查找索引. + + 按定义顺序遍历,首次出现的 key/category 获得最高优先级。 + """ + for policy in self._policies: + for key in policy.match.session_keys: + if key not in self._key_index: + self._key_index[key] = policy + if ( + policy.match.client_category + and policy.match.client_category not in self._category_index + ): + self._category_index[policy.match.client_category] = policy + + if self._key_index or self._category_index: + logger.info( + "SessionPolicyResolver initialized: %d key rules, %d category rules", + len(self._key_index), + len(self._category_index), + ) + + def resolve( + self, session_key: str, client_category: str = "cc" + ) -> SessionPolicy | None: + """返回匹配的策略,优先精确 session_key 匹配,其次 category 匹配. + + 返回的 SessionPolicy 对象应为不可变引用;调用方不应修改其内部属性, + 否则在并发 upsert/remove 场景下可能产生竞态。 + """ + with self._lock: + policy = self._key_index.get(session_key) + if policy: + return policy + return self._category_index.get(client_category) + + # ── 运行时 session → vendor 绑定 ────────────────────────────── + + def upsert(self, session_key: str, tier_names: list[str]) -> SessionPolicy: + """为指定 session key 创建或替换运行时 vendor 绑定. + + 运行时策略使用 ``runtime:`` 名称前缀,与配置文件驱动的策略区分。 + """ + policy = SessionPolicy( + name=f"runtime:{session_key}", + match=SessionPolicyMatch(session_keys=[session_key]), + tiers=tier_names, + ) + with self._lock: + existing = self._key_index.get(session_key) + if existing and not existing.name.startswith("runtime:"): + self._config_key_backup[session_key] = existing + self._key_index[session_key] = policy + logger.info( + "Session vendor binding upserted: session_key=%s → %s", + session_key, + tier_names, + ) + return policy + + def remove(self, session_key: str) -> bool: + """删除指定 session key 的运行时 vendor 绑定. + + Returns: + True 如果找到并删除了绑定,False 如果不存在。 + """ + with self._lock: + policy = self._key_index.get(session_key) + if policy is None or not policy.name.startswith("runtime:"): + return False + del self._key_index[session_key] + # 恢复被运行时绑定覆盖的配置策略 + backup = self._config_key_backup.pop(session_key, None) + if backup is not None: + self._key_index[session_key] = backup + logger.info("Session vendor binding removed: session_key=%s", session_key) + return True + + def list_runtime_bindings(self) -> list[dict[str, str | list[str]]]: + """返回所有运行时注入的绑定快照(仅 API 创建的,不含配置文件驱动的).""" + with self._lock: + return [ + {"session_key": key, "vendors": policy.tiers} + for key, policy in self._key_index.items() + if policy.name.startswith("runtime:") + ] diff --git a/src/coding/proxy/routing/usage_parser.py b/src/coding/proxy/routing/usage_parser.py index e07b187..0d6509c 100644 --- a/src/coding/proxy/routing/usage_parser.py +++ b/src/coding/proxy/routing/usage_parser.py @@ -210,6 +210,9 @@ def parse_usage_from_chunk( request_id=data.get("id"), model_served=data.get("model"), ) + model_name = data.get("model") + if model_name: + usage["model_served"] = model_name # Gemini SSE 格式: data.usageMetadata.{promptTokenCount, candidatesTokenCount, cachedContentTokenCount, thoughtsTokenCount, toolUsePromptTokenCount} # Gemini 的流式响应在最后一帧(或每一帧)携带 usageMetadata;字段命名与 @@ -243,6 +246,9 @@ def parse_usage_from_chunk( request_id=data.get("responseId") or data.get("id"), model_served=data.get("modelVersion") or data.get("model"), ) + model_name = data.get("modelVersion") or data.get("model") + if model_name: + usage["model_served"] = model_name # request_id fallback (OpenAI 格式下 id 在顶层, Gemini 顶层为 responseId) if not usage.get("request_id"): diff --git a/src/coding/proxy/routing/usage_recorder.py b/src/coding/proxy/routing/usage_recorder.py index da66978..525a6c1 100644 --- a/src/coding/proxy/routing/usage_recorder.py +++ b/src/coding/proxy/routing/usage_recorder.py @@ -97,6 +97,7 @@ async def record( operation: str = "", endpoint: str = "", extra_usage: dict[str, Any] | None = None, + session_key: str = "", ) -> None: """记录用量到 TokenLogger. @@ -141,6 +142,7 @@ async def record( operation=operation, endpoint=endpoint, extra_usage_json=extra_usage_json, + session_key=session_key, ) if not evidence_records: return diff --git a/src/coding/proxy/server/app.py b/src/coding/proxy/server/app.py index ec1f1e4..5ce8011 100644 --- a/src/coding/proxy/server/app.py +++ b/src/coding/proxy/server/app.py @@ -23,6 +23,7 @@ from ..logging.db import TokenLogger from ..native_api import NativeProxyHandler from ..routing.router import RequestRouter +from ..routing.session_policy import SessionPolicyResolver from ..routing.tier import VendorTier from ..routing.usage_recorder import UsageRecorder from ..vendors.antigravity import AntigravityVendor @@ -155,7 +156,11 @@ def create_app(config: ProxyConfig | None = None) -> FastAPI: ) router = RequestRouter( - tiers, token_logger, reauth_coordinator, compat_session_store + tiers, + token_logger, + reauth_coordinator, + compat_session_store, + session_policy_resolver=SessionPolicyResolver(config.session_policies.policies), ) app = FastAPI(title="coding-proxy", version=__version__, lifespan=lifespan) diff --git a/src/coding/proxy/server/dashboard.py b/src/coding/proxy/server/dashboard.py index d0afcd4..07bd6a3 100644 --- a/src/coding/proxy/server/dashboard.py +++ b/src/coding/proxy/server/dashboard.py @@ -159,7 +159,7 @@ def _build_favicon() -> bytes: .kpi-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); - gap: 16px; + gap: 5px; margin-bottom: 24px; } .kpi-card { @@ -310,6 +310,7 @@ def _build_favicon() -> bytes: } .vendor-name { font-weight: 600; font-size: 14px; } .vendor-badges { display: flex; gap: 5px; flex-wrap: wrap; align-items: center; } + .quota-group { display: flex; align-items: center; gap: 6px; } .status-badge { font-size: 11px; padding: 2px 7px; border-radius: 10px; @@ -319,7 +320,7 @@ def _build_favicon() -> bytes: .sb-warn { background: rgba(210,153,34,.12); color: var(--accent-yellow); border: 1px solid rgba(210,153,34,.2); } .sb-err { background: rgba(248,81,73,.12); color: var(--accent-red); border: 1px solid rgba(248,81,73,.2); } .sb-info { background: rgba(88,166,255,.12); color: var(--accent-blue); border: 1px solid rgba(88,166,255,.2); } - .quota-bar-wrap { flex: 1; margin: 0 10px; max-width: 100px; } + .quota-bar-wrap { flex: 1; min-width: 40px; max-width: 100px; } .quota-bar-bg { height: 4px; border-radius: 2px; background: rgba(255,255,255,.06); @@ -395,6 +396,98 @@ def _build_favicon() -> bytes: color: var(--text-tertiary); font-size: 14px; } .empty-icon { font-size: 32px; margin-bottom: 8px; opacity: .5; } + /* ── Sessions Panel ── */ + .sessions-card { grid-column: 1 / -1; animation-delay: .1s; } + .session-table-wrap { overflow: hidden; } + .session-table { width: 100%; border-collapse: collapse; font-size: 13px; table-layout: fixed; } + .session-table th { + position: sticky; top: 0; z-index: 1; + background: var(--bg-card); padding: 10px 12px; + text-align: left; font-weight: 600; font-size: 12px; + color: var(--text-secondary); text-transform: uppercase; letter-spacing: .5px; + border-bottom: 1px solid var(--border); + } + .session-table td { padding: 8px 12px; border-bottom: 1px solid var(--border-subtle); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } + .session-table td.cell-tags { white-space: normal; overflow: visible; text-overflow: clip; line-height: 1.8; vertical-align: middle; } + .session-table tr:hover td { background: var(--bg-card-hover); } + .session-table .session-key { font-family: 'JetBrains Mono', monospace; font-size: 12px; color: var(--accent-blue); cursor: default; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } + .session-id { display: flex; align-items: center; gap: 4px; } + .session-id-text { overflow: hidden; text-overflow: ellipsis; } + .copy-btn { background: none; border: none; color: var(--text-tertiary); cursor: pointer; padding: 2px; border-radius: 4px; font-size: 12px; line-height: 1; opacity: .5; flex-shrink: 0; } + .copy-btn:hover { opacity: 1; color: var(--accent-blue); background: rgba(88,166,255,.1); } + .copy-btn.copied { color: var(--accent-green); opacity: 1; } + .session-meta { font-size: 10px; color: var(--text-tertiary); line-height: 1.2; margin-top: 2px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } + .session-tag { + display: inline-block; font-size: 11px; padding: 2px 7px; + border-radius: 8px; margin: 1px 2px; + background: rgba(88,166,255,.08); border: 1px solid rgba(88,166,255,.15); + color: var(--text-secondary); + } + .session-tag-cc { + background: rgba(63,185,80,.08); border-color: rgba(63,185,80,.15); + } + .session-table td.cell-success { overflow: visible; text-overflow: clip; } + /* ── 展开行 ── */ + .session-table tr.row-detail { display: none; } + .session-table tr.row-detail.open { display: table-row; } + .session-table tr.row-detail td { padding: 0; } + .detail-card { + padding: 16px 24px; margin: 6px 0; + background: linear-gradient(135deg, rgba(30,37,54,.95), rgba(22,28,40,.95)); + border: 1px solid rgba(88,166,255,.15); border-radius: 12px; + font-size: 13px; + white-space: normal; overflow: hidden; + box-shadow: 0 4px 16px rgba(0,0,0,.3); + } + .detail-card .detail-item { display: flex; flex-direction: column; gap: 2px; min-width: 0; } + .detail-card .detail-label { font-size: 11px; color: var(--text-tertiary); text-transform: uppercase; letter-spacing: .3px; } + .detail-card .detail-value { color: var(--text-primary); line-height: 1.4; word-break: break-all; overflow-wrap: break-word; } + .detail-identity-row { + display: flex; gap: 16px; + padding-bottom: 10px; margin-bottom: 10px; + border-bottom: 1px solid var(--border); + } + .detail-identity-row .detail-item { flex: 3 1 0; } + .detail-identity-row .detail-item:first-child { flex: 2 1 0; } + .detail-identity-row .detail-value { font-family: 'JetBrains Mono', monospace; font-size: 12px; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; word-break: normal; } + .detail-metrics-grid { + display: grid; + grid-template-columns: repeat(8, 1fr); + gap: 10px 16px; + } + .detail-inline-pair { display: flex; gap: 16px; } + .detail-inline-pair > div { flex: 1; display: flex; flex-direction: column; gap: 2px; min-width: 0; } + .session-table tbody tr[data-row]:not(.row-detail) { cursor: pointer; } + .success-bar { width: 56px; height: 4px; border-radius: 2px; background: rgba(255,255,255,.12); display: inline-block; vertical-align: middle; margin-left: 6px; } + .success-bar-fill { height: 100%; border-radius: 2px; display: block; } + /* ── Vendor Bind 选择器 ── */ + .bind-select { + padding: 3px 6px; border-radius: 6px; + background: rgba(48,54,61,.6); border: 1px solid rgba(255,255,255,.1); + color: var(--text-secondary); font-size: 12px; + font-family: 'JetBrains Mono', monospace; + cursor: pointer; outline: none; + transition: all .2s ease; + max-width: 120px; + } + .bind-select:hover { border-color: rgba(88,166,255,.4); color: var(--text-primary); } + .bind-select:focus { border-color: rgba(88,166,255,.6); box-shadow: 0 0 0 2px rgba(88,166,255,.1); } + .bind-select option { background: var(--bg-card); color: var(--text-primary); } + /* ── 分页 ── */ + .session-pagination { + display: flex; align-items: center; justify-content: space-between; + padding: 10px 12px; border-top: 1px solid var(--border-subtle); + font-size: 12px; color: var(--text-secondary); + } + .page-btn { + padding: 4px 10px; border-radius: 6px; + background: rgba(48,54,61,.4); border: 1px solid rgba(255,255,255,.08); + color: var(--text-secondary); font-size: 12px; cursor: pointer; + transition: all .15s ease; + } + .page-btn:hover:not(:disabled) { background: var(--bg-card-hover); color: var(--text-primary); border-color: rgba(88,166,255,.3); } + .page-btn:disabled { opacity: .35; cursor: default; } + .page-info { font-family: 'JetBrains Mono', monospace; font-size: 12px; } /* ── 加载态 ── */ .loading { opacity: .4; pointer-events: none; } /* ── 图表标签截断 ── */ @@ -435,6 +528,34 @@ def _build_favicon() -> bytes: margin-top: 6px; padding-top: 6px; border-top: 1px solid var(--border-subtle); font-weight: 500; font-size: 12px; color: var(--text-secondary); } + /* ── Tabs ─────────────────────────────────────────────────── */ + .tabs { + display: flex; + gap: 2px; + padding: 0; + } + .tab-btn { + appearance: none; + background: transparent; + border: 1px solid transparent; + color: var(--text-secondary); + cursor: pointer; + font-family: inherit; + font-size: 13px; + font-weight: 500; + padding: 4px 12px; + transition: color .15s ease, background .15s ease, border-color .15s ease; + border-radius: var(--radius-sm); + } + .tab-btn:hover { color: var(--text-primary); background: var(--bg-card-hover); } + .tab-btn.active { + color: var(--text-primary); + background: rgba(88,166,255,.1); + border-color: rgba(88,166,255,.2); + } + .tab-btn:focus-visible { outline: 2px solid var(--accent-blue); outline-offset: 2px; } + .tab-pane { display: none; } + .tab-pane.active { display: block; } @@ -445,12 +566,18 @@ def _build_favicon() -> bytes: v-.-.-
+ 正在加载…
+ +
时间区间 @@ -540,6 +667,55 @@ def _build_favicon() -> bytes:
+
+ + +
+ +
+
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Session IDLast ActiveRequestsTokensModelsVendorsAvg LatencySuccessVendor BindClient
Loading...
+
+ +
+ + + +
+
+
+
+
@@ -580,7 +756,32 @@ def _build_favicon() -> bytes: return String(n); } function fmtNum(n) { return n == null ? '–' : n.toLocaleString(); } +function copyFromParent(btn) { + var text = btn.parentElement.getAttribute('data-key') || btn.parentElement.getAttribute('title') || ''; + navigator.clipboard.writeText(text).then(function() { + btn.classList.add('copied'); + btn.textContent = '✓'; + setTimeout(function() { btn.classList.remove('copied'); btn.textContent = '⧉'; }, 1500); + }); +} +function toggleRow(tr) { + var detail = tr.nextElementSibling; + if (!detail || !detail.classList.contains('row-detail')) return; + var wasOpen = detail.classList.contains('open'); + // close all open rows first + document.querySelectorAll('.session-table tr.row-detail.open').forEach(function(r) { r.classList.remove('open'); }); + if (!wasOpen) detail.classList.add('open'); +} function isValidLabel(s) { return typeof s === 'string' && s !== 'undefined' && s !== 'null' && s.trim() !== ''; } +function fmtDuration(ms) { + if (ms == null) return '–'; + var s = ms / 1000; + if (s < 1) return Math.round(ms) + 'ms'; + if (s < 60) return s.toFixed(1).replace(/\\.0$/, '') + 's'; + var m = Math.floor(s / 60); + var sec = Math.round(s % 60); + return sec > 0 ? m + 'min ' + sec + 's' : m + 'min'; +} function now() { return new Date().toLocaleTimeString('zh-CN', {hour:'2-digit',minute:'2-digit',second:'2-digit'}); } @@ -590,12 +791,15 @@ def _build_favicon() -> bytes: // _API_VENDORS 需与后端 native_api/handler.py::_VENDOR_LABEL 对齐, // 新增无 -native 后缀的 native vendor 时同步更新本集合。 const _API_VENDORS = new Set(['anthropic-native', 'openai', 'gemini']); +function isApiVendor(v) { return _API_VENDORS.has(v); } +function vendorShortName(v) { + if (!isValidLabel(v)) return v; + if (isApiVendor(v)) return v.endsWith('-native') ? v.slice(0, -'-native'.length) : v; + return v; +} function formatVendorLabel(v) { if (!isValidLabel(v)) return v; - if (_API_VENDORS.has(v)) { - const name = v.endsWith('-native') ? v.slice(0, -'-native'.length) : v; - return 'api | ' + name; - } + if (isApiVendor(v)) return 'api | ' + vendorShortName(v); return 'cc | ' + v; } @@ -886,10 +1090,11 @@ def _build_favicon() -> bytes: if (!qg || qg.usage_percent == null) return ''; const pct = Math.round(qg.usage_percent); const label = quotaWindowLabel(qg.window_hours); - return `${label} ${pct}%` + + return `
` + + `${label} ${pct}%` + `
` + `
` + - `
`; + `
`; } function updateVendorStatus(status) { @@ -1263,47 +1468,340 @@ def _build_favicon() -> bytes: if (mt) mt.textContent = label + ' Token 用量(按 Vendor / 模型)'; } -// ── 主刷新逻辑 ──────────────────────────────────────────── -let refreshing = false; -async function refresh() { - if (refreshing) return; - refreshing = true; - document.getElementById('refresh-time').textContent = '刷新中…'; +// ── Sessions Panel ────────────────────────────────────────────── +function relativeTime(tsStr) { + if (!tsStr) return '–'; + var d = new Date(tsStr.replace('Z', '+00:00')); + var diff = (Date.now() - d.getTime()) / 1000; + if (diff < 60) return 'just now'; + if (diff < 3600) return Math.floor(diff / 60) + 'm ago'; + if (diff < 86400) return Math.floor(diff / 3600) + 'h ago'; + return Math.floor(diff / 86400) + 'd ago'; +} +function escapeHtml(s) { + if (!s) return ''; + return s.replace(/&/g,'&').replace(//g,'>').replace(/"/g,'"').replace(/'/g,'''); +} +function truncateKey(key, maxLen) { + if (!key || key.length <= maxLen) return escapeHtml(key) || '–'; + return escapeHtml(key.slice(0, maxLen - 3)) + '…'; +} +function parseSessionKey(raw) { + try { var o = JSON.parse(raw); return { device_id: o.device_id||'', account_uuid: o.account_uuid||'', session_id: o.session_id||'' }; } + catch(e) { return { device_id:'', account_uuid:'', session_id: raw || '' }; } +} +function shortId(s, n) { return s ? (s.length <= n ? s : s.slice(0, n) + '…') : ''; } +function successBarHtml(pct) { + if (pct == null) return '–'; + var p = Math.round(pct); + var color = p >= 95 ? 'var(--accent-green)' : (p >= 80 ? 'var(--accent-yellow)' : 'var(--accent-red)'); + return '' + p + '%' + + ''; +} +function formatSessionTags(str, max) { + if (!str) return '–'; + var list = str.split(','); + var html = list.slice(0, max).map(function(c) { + return '' + escapeHtml(c.trim()) + ''; + }).join(''); + if (list.length > max) { + var fullList = list.map(function(c) { return c.trim(); }).join(', '); + html += '+' + (list.length - max) + ''; + } + return html; +} +function formatCategories(cats) { + if (!cats) return '–'; + return cats.split(',').map(function(c) { + var t = c.trim(); + var label = t === 'cc' ? 'Claude Code' : (t === 'api' ? 'API' : escapeHtml(t)); + return '' + label + ''; + }).join(''); +} +function formatVendorTags(vendors) { + if (!vendors) return '–'; + var list = vendors.split(','); + var max = 4; + var html = list.slice(0, max).map(function(v) { + var vt = v.trim(); + var name = vendorShortName(vt); + var fullLabel = formatVendorLabel(vt); + var cls = isApiVendor(vt) ? 'session-tag' : 'session-tag session-tag-cc'; + return '' + escapeHtml(name) + ''; + }).join(''); + if (list.length > max) { + var fullList = list.map(function(v) { return formatVendorLabel(v.trim()); }).join(', '); + html += '+' + (list.length - max) + ''; + } + return html; +} +// ── Sessions Pagination State ── +var allSessions = []; +var sessionPage = 0; +var sessionPageSize = 30; +var sessionBindMap = {}; +var sessionAvailableVendors = []; + +async function updateSessions() { try { - const days = currentDays > 0 ? currentDays : 7; - const [summary, timeline, status] = await Promise.all([ - fetchJSON('/api/dashboard/summary?days=' + days), - fetchJSON('/api/dashboard/timeline?days=' + days), + var results = await Promise.allSettled([ + fetchJSON('/api/dashboard/sessions?hours=24&limit=200'), + fetchJSON('/api/session-vendor'), fetchJSON('/api/status'), ]); + if (results[0].status === 'rejected') throw results[0].reason; + var data = results[0].value; + var bindData = results[1].status === 'fulfilled' ? results[1].value : {bindings: []}; + var statusData = results[2].status === 'fulfilled' ? results[2].value : {tiers: []}; + allSessions = data.sessions || []; + sessionBindMap = {}; + (bindData.bindings || []).forEach(function(b) { sessionBindMap[b.session_key] = b.vendors; }); + sessionAvailableVendors = (statusData.tiers || []).map(function(t) { return t.name; }); + sessionPage = 0; + renderSessionPage(); + } catch (e) { + console.error('Sessions refresh error:', e); + } +} - if (summary.version) { - document.getElementById('version-badge').textContent = 'v' + summary.version; - } +function renderSessionPage() { + var total = allSessions.length; + var totalPages = Math.max(1, Math.ceil(total / sessionPageSize)); + if (sessionPage >= totalPages) sessionPage = totalPages - 1; + var start = sessionPage * sessionPageSize; + var page = allSessions.slice(start, start + sessionPageSize); + var tbody = document.getElementById('sessions-tbody'); - updateKPI(summary); - updateVendorStatus(status); - updateChartTitles(days); + if (!total) { + tbody.innerHTML = '
📭
No session data'; + } else { + tbody.innerHTML = page.map(function(s) { + var parsed = parseSessionKey(s.session_key); + var boundVendors = sessionBindMap[s.session_key]; + var selectHtml = buildBindSelect(s.session_key, boundVendors, sessionAvailableVendors); + var modelsFull = (s.models || '').split(',').map(function(c){return c.trim();}); + var vendorsFull = (s.vendors || '').split(',').map(function(v){return formatVendorLabel(v.trim());}); + var sr = s.success_rate != null ? Math.round(s.success_rate) : null; + return '' + + '' + + '
' + + '' + escapeHtml(parsed.session_id || s.session_key) + '' + + '' + + '
' + + '
' + + 'dev:' + escapeHtml(shortId(parsed.device_id, 8)) + ' · acct:' + escapeHtml(shortId(parsed.account_uuid, 8)) + + '
' + + '' + + '' + relativeTime(s.last_active_ts) + '' + + '' + fmtNum(s.total_requests) + '' + + '' + fmtTokens(s.total_tokens) + '' + + '' + formatSessionTags(s.models, 3) + '' + + '' + formatVendorTags(s.vendors) + '' + + '' + fmtDuration(s.avg_duration_ms) + '' + + '' + successBarHtml(s.success_rate) + '' + + '' + selectHtml + '' + + '' + formatCategories(s.client_categories) + '' + + '' + + '
' + + '
' + + '
Session ID
' + escapeHtml(parsed.session_id || s.session_key) + '
' + + '
Device
' + (parsed.device_id ? escapeHtml(parsed.device_id) : '–') + '
' + + '
Account
' + (parsed.account_uuid ? escapeHtml(parsed.account_uuid) : '–') + '
' + + '
' + + '
' + + '
Last Active
' + relativeTime(s.last_active_ts) + '
' + + '
Requests
' + fmtNum(s.total_requests) + '
' + + '
Tokens
' + fmtTokens(s.total_tokens) + '
' + + '
Models
' + (modelsFull.length ? modelsFull.map(function(m){return '' + escapeHtml(m) + '';}).join(' ') : '–') + '
' + + '
Vendors
' + (vendorsFull.length ? vendorsFull.map(function(v){return '' + escapeHtml(v) + '';}).join(' ') : '–') + '
' + + '
Avg Latency
' + fmtDuration(s.avg_duration_ms) + '
' + + '
' + + '
Success Rate
' + (sr != null ? sr + '%' : '–') + '
' + + '
Client
' + escapeHtml(s.client_categories || '–') + '
' + + '
' + + '
' + + '
'; + }).join(''); + } - const rows = timeline.rows || []; - const tierOrder = (status.tiers || []).map(t => t.name); - buildTimeline(rows, tierOrder); - buildVendorDist(rows, tierOrder); - buildTokenTimeline(rows, tierOrder); - buildModelTokenTimeline(rows); + document.getElementById('page-info').textContent = total + ' sessions'; + document.getElementById('page-num').textContent = (sessionPage + 1) + ' / ' + totalPages; + document.getElementById('btn-prev').disabled = (sessionPage === 0); + document.getElementById('btn-next').disabled = (sessionPage >= totalPages - 1); +} - document.getElementById('refresh-time').textContent = '上次刷新: ' + now(); +function changePage(delta) { + var totalPages = Math.max(1, Math.ceil(allSessions.length / sessionPageSize)); + sessionPage = Math.max(0, Math.min(totalPages - 1, sessionPage + delta)); + renderSessionPage(); +} + +function buildBindSelect(sessionKey, boundVendors, availableVendors) { + var isBound = boundVendors && boundVendors.length > 0; + var multiBound = isBound && boundVendors.length > 1; + var selected = isBound ? boundVendors[0] : ''; + var html = ''; + return html; +} + +async function handleBindChange(sel) { + var sessionKey = sel.getAttribute('data-session-key'); + var vendor = sel.value; + var previousValue = sel.getAttribute('data-previous') || ''; + try { + var resp; + if (vendor) { + resp = await fetch('/api/session-vendor', { + method: 'PUT', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({session_key: sessionKey, vendors: [vendor]}), + }); + } else { + resp = await fetch('/api/session-vendor/' + encodeURIComponent(sessionKey), {method: 'DELETE'}); + } + if (!resp.ok) { + sel.value = previousValue; + console.error('Bind change rejected:', resp.status, await resp.text()); + } } catch (e) { - console.error('Dashboard refresh error:', e); - document.getElementById('refresh-time').textContent = '刷新失败 ' + now(); + sel.value = previousValue; + console.error('Bind change failed:', e); + } +} + +var sessionsTbody = document.getElementById('sessions-tbody'); +sessionsTbody.addEventListener('focus', function(e) { + if (e.target.classList.contains('bind-select')) { + e.target.setAttribute('data-previous', e.target.value); + } +}, true); +sessionsTbody.addEventListener('change', function(e) { + if (e.target.classList.contains('bind-select')) { + handleBindChange(e.target); + } +}); + +// ── 主刷新逻辑(按 Tab 分发) ────────────────────────────── +let refreshing = false; +let currentTab = 'overview'; +const tabLoaded = { overview: false, sessions: false }; +const TAB_LABELS = { overview: 'Overview', sessions: 'Sessions' }; + +async function refreshOverview() { + const days = currentDays > 0 ? currentDays : 7; + const [summary, timeline, status] = await Promise.all([ + fetchJSON('/api/dashboard/summary?days=' + days), + fetchJSON('/api/dashboard/timeline?days=' + days), + fetchJSON('/api/status'), + ]); + + if (summary.version) { + document.getElementById('version-badge').textContent = 'v' + summary.version; + } + + updateKPI(summary); + updateVendorStatus(status); + updateChartTitles(days); + + const rows = timeline.rows || []; + const tierOrder = (status.tiers || []).map(t => t.name); + buildTimeline(rows, tierOrder); + buildVendorDist(rows, tierOrder); + buildTokenTimeline(rows, tierOrder); + buildModelTokenTimeline(rows); +} + +async function refreshSessions() { + await updateSessions(); +} + +async function refresh() { + if (refreshing) return; + refreshing = true; + try { + // 循环:若 await 期间用户切到了尚未加载的另一页签,补一次刷新,避免 tabLoaded 错位。 + while (true) { + const tab = currentTab; + document.getElementById('refresh-time').textContent = '刷新中…'; + try { + if (tab === 'sessions') { + await refreshSessions(); + } else { + await refreshOverview(); + } + tabLoaded[tab] = true; + if (tab === currentTab) { + document.getElementById('refresh-time').textContent = + '上次刷新: ' + now() + '(' + TAB_LABELS[tab] + ')'; + } + } catch (e) { + console.error('Dashboard refresh error:', e); + document.getElementById('refresh-time').textContent = '刷新失败 ' + now(); + } + if (currentTab !== tab && !tabLoaded[currentTab]) continue; + break; + } } finally { refreshing = false; } } -// 页面加载 + 每 30 秒自动刷新 -refresh(); -setInterval(refresh, 600000); +// ── 页签切换(懒加载 + URL 同步) ───────────────────────── +function syncTabUrl(name) { + try { + const url = new URL(window.location.href); + if (url.searchParams.get('tab') === name) return; + url.searchParams.set('tab', name); + window.history.replaceState({}, '', url); + } catch (e) { /* no-op */ } +} + +function applyTabState(name) { + document.querySelectorAll('.tab-btn').forEach(function (b) { + const active = b.getAttribute('data-tab') === name; + b.classList.toggle('active', active); + b.setAttribute('aria-selected', active ? 'true' : 'false'); + }); + document.querySelectorAll('.tab-pane').forEach(function (p) { + p.classList.toggle('active', p.getAttribute('data-tab') === name); + }); +} + +function switchTab(name) { + if (name !== 'overview' && name !== 'sessions') name = 'overview'; + if (name === currentTab) { + syncTabUrl(name); + return; + } + currentTab = name; + applyTabState(name); + syncTabUrl(name); + refresh(); +} + +// ── 初始化 ──────────────────────────────────────────────── +(function bootstrap() { + let initial = 'overview'; + try { + const t = new URL(window.location.href).searchParams.get('tab'); + if (t === 'sessions') initial = 'sessions'; + } catch (e) { /* no-op */ } + currentTab = initial; + applyTabState(initial); + syncTabUrl(initial); + // Load version immediately regardless of active tab + fetchJSON('/api/dashboard/summary?days=7').then(function(s) { + if (s && s.version) document.getElementById('version-badge').textContent = 'v' + s.version; + }).catch(function(){}); + refresh(); // 仅加载初始页签的数据 + setInterval(refresh, 600000); // 每 10 分钟刷新当前页签 +})(); @@ -1473,3 +1971,35 @@ async def dashboard_timeline(request: Request, days: int = 7) -> Response: status_code=200, media_type="application/json", ) + + @app.get("/api/dashboard/sessions") + async def dashboard_sessions( + request: Request, hours: float = 24.0, limit: int = 20 + ) -> Response: + """返回近期活跃会话聚合数据.""" + token_logger = getattr(request.app.state, "token_logger", None) + if token_logger is None: + return Response( + content=b'{"error":"token_logger not available"}', + status_code=503, + media_type="application/json", + ) + hours = max(1.0, min(hours, 168.0)) + limit = max(1, min(limit, 200)) + try: + sessions = await token_logger.query_recent_sessions( + limit=limit, hours=hours + ) + except Exception as exc: + logger.error("dashboard_sessions query error: %s", exc, exc_info=True) + return Response( + content=b'{"error":"query failed"}', + status_code=500, + media_type="application/json", + ) + result = {"sessions": sessions, "hours": hours} + return Response( + content=json.dumps(result, ensure_ascii=False).encode(), + status_code=200, + media_type="application/json", + ) diff --git a/src/coding/proxy/server/routes.py b/src/coding/proxy/server/routes.py index 3ee7588..7f157f0 100644 --- a/src/coding/proxy/server/routes.py +++ b/src/coding/proxy/server/routes.py @@ -327,6 +327,96 @@ async def reset_circuit(request: Request) -> Response: ) +def register_session_vendor_routes(app: Any, router: Any) -> None: + """注册 Session-Vendor 运行时绑定路由.""" + + @app.put("/api/session-vendor") + async def bind_session_vendor(request: Request) -> Response: + """为指定 session key 绑定 vendor 优先级列表.""" + try: + body = await request.json() + except Exception: + return json_error_response( + 400, error_type="invalid_request_error", message="body must be JSON" + ) + if not isinstance(body, dict): + return json_error_response( + 400, + error_type="invalid_request_error", + message="body must be a JSON object", + ) + session_key = body.get("session_key", "").strip() + vendors = body.get("vendors", []) + if not session_key: + return json_error_response( + 400, + error_type="invalid_request_error", + message="session_key is required", + ) + if not isinstance(vendors, list) or not vendors: + return json_error_response( + 400, + error_type="invalid_request_error", + message="vendors must be a non-empty list", + ) + vendors = [str(v).strip() for v in vendors] + available = set(router.get_vendor_names()) + unknown = [v for v in vendors if v not in available] + if unknown: + return json_error_response( + 400, + error_type="invalid_request_error", + message=( + f"unknown vendor(s): {', '.join(unknown)}; " + f"available: {', '.join(sorted(available))}" + ), + ) + resolver = router._executor._policy_resolver # noqa: SLF001 + resolver.upsert(session_key, vendors) + return Response( + content=json.dumps( + {"status": "ok", "session_key": session_key, "vendors": vendors}, + ensure_ascii=False, + ).encode(), + status_code=200, + media_type="application/json", + ) + + @app.delete("/api/session-vendor/{session_key}") + async def unbind_session_vendor(session_key: str) -> Response: + """解除指定 session key 的 vendor 绑定.""" + resolver = router._executor._policy_resolver # noqa: SLF001 + removed = resolver.remove(session_key) + if not removed: + return Response( + content=json.dumps( + {"status": "not_found", "session_key": session_key}, + ensure_ascii=False, + ).encode(), + status_code=404, + media_type="application/json", + ) + return Response( + content=json.dumps( + {"status": "removed", "session_key": session_key}, + ensure_ascii=False, + ).encode(), + status_code=200, + media_type="application/json", + ) + + @app.get("/api/session-vendor") + async def list_session_vendors() -> Response: + """列出所有运行时 session-vendor 绑定.""" + resolver = router._executor._policy_resolver # noqa: SLF001 + bindings = resolver.list_runtime_bindings() + return Response( + content=json.dumps({"bindings": bindings}, ensure_ascii=False).encode(), + status_code=200, + media_type="application/json", + ) + + def register_reauth_routes(app: Any, reauth_coordinator: Any) -> None: """注册重认证路由.""" @@ -368,6 +458,7 @@ def register_all_routes( register_status_route(app, router) register_copilot_routes(app, router) register_admin_routes(app, router) + register_session_vendor_routes(app, router) if reauth_coordinator: register_reauth_routes(app, reauth_coordinator) diff --git a/tests/test_copilot_convert_request.py b/tests/test_copilot_convert_request.py index 0a4dbce..c4718cc 100644 --- a/tests/test_copilot_convert_request.py +++ b/tests/test_copilot_convert_request.py @@ -472,3 +472,132 @@ def test_image_block_converted_to_image_url(): image_part = [p for p in user_msg["content"] if p.get("type") == "image_url"] assert len(image_part) == 1 assert "data:image/png;base64,abc123" in image_part[0]["image_url"]["url"] + + +# === Defensive tool_use.input serialization === + + +def test_tool_use_input_none_defaults_to_empty_dict(): + """input=None 应被降级为 {} 而非序列化为 'null'.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_001", + "name": "read_file", + "input": None, + } + ], + } + ], + } + result = convert_request(body) + assistant_msgs = [m for m in result["messages"] if m["role"] == "assistant"] + assert len(assistant_msgs) == 1 + assert "tool_calls" in assistant_msgs[0] + tc = assistant_msgs[0]["tool_calls"][0] + assert tc["function"]["arguments"] == "{}" + + +def test_tool_use_input_string_defaults_to_empty_dict(): + """input='some string' 应被降级为 {} 而非序列化为 '"some string"'.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_002", + "name": "run_cmd", + "input": "not a dict", + } + ], + } + ], + } + result = convert_request(body) + assistant_msgs = [m for m in result["messages"] if m["role"] == "assistant"] + tc = assistant_msgs[0]["tool_calls"][0] + assert tc["function"]["arguments"] == "{}" + + +def test_tool_use_input_missing_defaults_to_empty_dict(): + """input key 不存在时,block.get('input') 返回 None,应降级为 {}.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_003", + "name": "search", + } + ], + } + ], + } + result = convert_request(body) + assistant_msgs = [m for m in result["messages"] if m["role"] == "assistant"] + tc = assistant_msgs[0]["tool_calls"][0] + assert tc["function"]["arguments"] == "{}" + + +def test_tool_use_input_int_defaults_to_empty_dict(): + """input=42 应被降级为 {} 而非序列化为 '42'.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_004", + "name": "calc", + "input": 42, + } + ], + } + ], + } + result = convert_request(body) + assistant_msgs = [m for m in result["messages"] if m["role"] == "assistant"] + tc = assistant_msgs[0]["tool_calls"][0] + assert tc["function"]["arguments"] == "{}" + + +def test_tool_use_valid_dict_input_preserved(): + """正常 dict input 应保持原样.""" + body = { + "model": "claude-sonnet-4-20250514", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_005", + "name": "read_file", + "input": {"path": "/tmp/test.txt", "offset": 10}, + } + ], + } + ], + } + result = convert_request(body) + assistant_msgs = [m for m in result["messages"] if m["role"] == "assistant"] + tc = assistant_msgs[0]["tool_calls"][0] + import json + + assert json.loads(tc["function"]["arguments"]) == { + "path": "/tmp/test.txt", + "offset": 10, + } diff --git a/tests/test_model_vendor.py b/tests/test_model_vendor.py index 14895e2..bdc1a87 100644 --- a/tests/test_model_vendor.py +++ b/tests/test_model_vendor.py @@ -115,13 +115,19 @@ def test_all_member_values(self): assert CapabilityLossReason.METADATA.value == "metadata" def test_member_count(self): - """枚举成员数量固定为 5.""" + """枚举成员数量.""" assert len(CapabilityLossReason) == 5 def test_can_iterate(self): """可遍历所有成员.""" names = {m.name for m in CapabilityLossReason} - assert names == {"TOOLS", "THINKING", "IMAGES", "VENDOR_TOOLS", "METADATA"} + assert names == { + "TOOLS", + "THINKING", + "IMAGES", + "VENDOR_TOOLS", + "METADATA", + } def test_lookup_by_value(self): """可通过 value 反查成员.""" diff --git a/tests/test_parse_usage.py b/tests/test_parse_usage.py index d3502b7..251d165 100644 --- a/tests/test_parse_usage.py +++ b/tests/test_parse_usage.py @@ -117,6 +117,34 @@ def test_openai_zhipu_final_chunk(): assert usage["input_tokens"] == 200 assert usage["output_tokens"] == 80 assert usage["request_id"] == "chatcmpl-1" + assert usage["model_served"] == "glm-5.1" + + +def test_openai_final_chunk_with_model(): + """OpenAI 最终 chunk 有 model 字段时应提取到 model_served.""" + usage: dict = {} + parse_usage_from_chunk( + _sse( + '{"id":"chatcmpl-2","model":"gpt-4o-2024-08-06",' + '"usage":{"prompt_tokens":50,"completion_tokens":20}}' + ), + usage, + ) + assert usage["input_tokens"] == 50 + assert usage["output_tokens"] == 20 + assert usage["model_served"] == "gpt-4o-2024-08-06" + + +def test_openai_final_chunk_without_model(): + """OpenAI 最终 chunk 无 model 字段时不应设置 model_served.""" + usage: dict = {} + parse_usage_from_chunk( + _sse('{"id":"chatcmpl-3","usage":{"prompt_tokens":30,"completion_tokens":10}}'), + usage, + ) + assert usage["input_tokens"] == 30 + assert usage["output_tokens"] == 10 + assert "model_served" not in usage def test_openai_final_chunk_with_cache_tokens(): diff --git a/tests/test_parse_usage_gemini.py b/tests/test_parse_usage_gemini.py index cec2b79..01e7334 100644 --- a/tests/test_parse_usage_gemini.py +++ b/tests/test_parse_usage_gemini.py @@ -35,6 +35,7 @@ def test_gemini_usage_metadata_basic_fields(): assert usage["output_tokens"] == 42 assert usage.get("cache_read_tokens", 0) == 0 assert usage["request_id"] == "resp_abc" + assert usage["model_served"] == "gemini-2.0-flash" def test_gemini_usage_metadata_with_cached_content(): @@ -194,3 +195,18 @@ def test_gemini_partial_fields_ok(): ) assert usage["input_tokens"] == 77 assert "output_tokens" not in usage + + +def test_gemini_model_fallback_to_data_model(): + """当 modelVersion 不存在时,应回退到 data.model.""" + usage: dict = {} + parse_usage_from_chunk( + _sse( + '{"usageMetadata":{"promptTokenCount":80,"candidatesTokenCount":20},' + '"model":"gemini-1.5-flash"}' + ), + usage, + ) + assert usage["input_tokens"] == 80 + assert usage["output_tokens"] == 20 + assert usage["model_served"] == "gemini-1.5-flash" diff --git a/tests/test_router_executor.py b/tests/test_router_executor.py index bc0201f..1e40ea6 100644 --- a/tests/test_router_executor.py +++ b/tests/test_router_executor.py @@ -11,6 +11,7 @@ from unittest.mock import AsyncMock, MagicMock +import httpx import pytest from coding.proxy.compat.canonical import ( @@ -347,7 +348,6 @@ async def test_raises_no_compatible_vendor(self): @pytest.mark.asyncio async def test_last_tier_propagates_http_error(self): """最后一层的 HTTP 错误直接抛出.""" - import httpx vendor = _mock_vendor() vendor.send_message.side_effect = httpx.ConnectError("unreachable") @@ -369,7 +369,6 @@ async def test_last_tier_propagates_token_error(self): @pytest.mark.asyncio async def test_non_last_tier_continues_on_connect_error(self): """非最后一层连接失败时继续尝试下一层.""" - import httpx bad = _mock_vendor("bad") bad.send_message.side_effect = httpx.ConnectError("down") @@ -460,7 +459,6 @@ async def _raise_token(*a, **kw): @pytest.mark.asyncio async def test_stream_http_error_raises_on_last_tier(self): """最后一层流式 HTTP 错误直接抛出.""" - import httpx vendor = _mock_vendor() @@ -1229,6 +1227,31 @@ def test_returns_false_for_empty_body(self): is False ) + def test_returns_true_for_invalid_tool_call_format(self): + """400 + 结构化 JSON 含 invalid_tool_call_format + tool_result → 格式不兼容.""" + json_body = '{"error":{"message":"Invalid JSON format in tool call arguments","code":"invalid_tool_call_format"}}' + assert ( + _is_likely_request_format_error( + status_code=400, + error_body_text=json_body, + body=self._body_with_tool_results(), + ) + is True + ) + + def test_returns_false_for_invalid_tool_call_format_without_tool_results(self): + """invalid_tool_call_format 但无 tool_result → 不应匹配.""" + json_body = '{"error":{"message":"Invalid JSON format in tool call arguments","code":"invalid_tool_call_format"}}' + body = {"model": "test", "messages": [{"role": "user", "content": "hi"}]} + assert ( + _is_likely_request_format_error( + status_code=400, + error_body_text=json_body, + body=body, + ) + is False + ) + # ── TokenAcquireError 永久性凭证错误测试 ──────────────────── @@ -1610,13 +1633,13 @@ def test_returns_none_for_no_source(self): ) def test_returns_session_vendor_with_registered_transition_anthropic_to_zhipu(self): - """anthropic → zhipu 已注册转换,应返回 anthropic 作为源 vendor.""" + """anthropic → zhipu 未注册转换,应回退到无源行为.""" session_record = MagicMock() session_record.provider_state = {"anthropic": {}} assert ( _RouteExecutor._determine_source_vendor("zhipu", None, session_record) - == "anthropic" + is None ) def test_returns_none_when_session_is_none(self): @@ -1760,8 +1783,8 @@ def test_priority3_skips_when_no_registered_transition(self): ) def test_priority1_overrides_priority3(self): - """Priority 1 (failed_tier) 优先于 Priority 3 (body inference).""" - # body 内有 zhipu 产物,但 failed_tier 显式指定 copilot + """Priority 1 (failed_tier) 优先于 Priority 3 (body inference) — 仅当通道已注册.""" + # body 内有 zhipu 产物,failed_tier=zhipu, target=copilot → zhipu→copilot 已注册 body = { "messages": [ { @@ -1777,10 +1800,26 @@ def test_priority1_overrides_priority3(self): }, ], } - # failed_tier=copilot → 应返回 copilot,不看 body + # failed_tier=zhipu, target=copilot → (zhipu,copilot) 已注册 → 返回 zhipu assert ( - _RouteExecutor._determine_source_vendor("zhipu", "copilot", None, body) - == "copilot" + _RouteExecutor._determine_source_vendor("copilot", "zhipu", None, body) + == "zhipu" + ) + + def test_priority1_falls_through_when_no_registered_transition(self): + """Priority 1: failed_tier 无已注册转换通道 → 降级到 Priority 2. + + 核心修复场景: copilot 失败后降级到 anthropic 时, (copilot,anthropic) + 未注册 → 通过 session history 找到 zhipu → 应用 zhipu→anthropic 转换. + """ + session_record = MagicMock() + session_record.provider_state = {"zhipu": {}} + # copilot → anthropic 未注册, 但 session 中有 zhipu → zhipu→anthropic 已注册 + assert ( + _RouteExecutor._determine_source_vendor( + "anthropic", "copilot", session_record + ) + == "zhipu" ) def test_priority2_overrides_priority3(self): @@ -1818,79 +1857,6 @@ def test_body_parameter_is_optional(self): ) -# ── _determine_source_vendor 自转换通道测试 ───────────────────────── - - -class TestDetermineSourceVendorSelfTransition: - """验证已注册的同 vendor 自转换 (如 zhipu → zhipu) 在三条优先级中均能命中. - - 自转换通道用于修复 vendor 自身无法消化的产物 (如 zhipu 不接受输入中的 - server_tool_use_delta 与 assistant 内联 tool_result). - """ - - def test_priority1_self_transition_when_registered(self): - """Priority 1: failed_tier == target 且通道已注册 → 返回 target 作为源.""" - # zhipu 自转换通道已在 vendor_channels 注册 - assert ( - _RouteExecutor._determine_source_vendor("zhipu", "zhipu", None) == "zhipu" - ) - - def test_priority1_self_transition_blocked_when_unregistered(self): - """Priority 1: failed_tier == target 但通道未注册 → 返回 None. - - anthropic 未注册自转换通道, 保持原有「同 vendor 无源」行为. - """ - assert ( - _RouteExecutor._determine_source_vendor("anthropic", "anthropic", None) - is None - ) - - def test_priority2_self_transition_via_session(self): - """Priority 2: 会话历史中只有目标 vendor, 但其自转换通道已注册 → 命中.""" - session_record = MagicMock() - session_record.provider_state = {"zhipu": {}} - assert ( - _RouteExecutor._determine_source_vendor("zhipu", None, session_record) - == "zhipu" - ) - - def test_priority2_session_unregistered_self_returns_none(self): - """Priority 2: 会话只有未注册自转换的 vendor → None.""" - session_record = MagicMock() - session_record.provider_state = {"anthropic": {}} - assert ( - _RouteExecutor._determine_source_vendor("anthropic", None, session_record) - is None - ) - - def test_priority3_self_transition_when_registered(self): - """Priority 3: 首次请求 body 含 zhipu 产物且目标也是 zhipu → 命中自清理. - - 这是修复 「zhipu 400 + tool_results 偶发」 的核心兜底场景: - Claude Code 把上一轮 zhipu 响应原样回送, 命中 zhipu 主 tier 时 - 可识别并应用自清理通道。 - """ - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "srvtoolu_x", - "name": "bash", - "input": {}, - }, - ], - }, - ], - } - assert ( - _RouteExecutor._determine_source_vendor("zhipu", None, None, body) - == "zhipu" - ) - - # ── _prepare_body_for_tier 转换通道应用测试 ──────────────────────── @@ -1972,23 +1938,6 @@ def test_returns_body_when_no_source_vendor(self): assert result is body assert len(result["messages"][0]["content"]) == 2 - def test_applies_anthropic_to_zhipu_transition(self): - """anthropic → zhipu 已注册转换,应清理 thinking blocks.""" - tier = MagicMock() - tier.name = "zhipu" - - exec_inst = _executor([]) - body = self._body_with_thinking() - result = exec_inst._prepare_body_for_tier(body, tier, source_vendor="anthropic") - - # thinking blocks 应被剥离 - assert result is not body - assert all( - b.get("type") not in ("thinking", "redacted_thinking") - for b in result["messages"][0]["content"] - ) - assert len(result["messages"][0]["content"]) >= 1 - def test_returns_body_for_unknown_tier(self): """未知 tier(无注册转换)→ 原样返回.""" tier = MagicMock() @@ -1999,105 +1948,3 @@ def test_returns_body_for_unknown_tier(self): result = exec_inst._prepare_body_for_tier(body, tier, source_vendor="zhipu") assert result is body - - -# ── _prepare_body_for_tier 自转换通道测试 ─────────────────────────── - - -class TestPrepareBodyForTierSelfTransition: - """验证 zhipu → zhipu 自转换通道在 _prepare_body_for_tier 中的应用行为.""" - - def test_applies_zhipu_self_cleanup(self): - """source=zhipu, target=zhipu → 剥离 server_tool_use_delta + tool pairing.""" - tier = MagicMock() - tier.name = "zhipu" - - body = { - "model": "claude-opus-4-6", - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "server_tool_use_delta", "partial_json": "{}"}, - { - "type": "tool_use", - "id": "srvtoolu_a", - "name": "bash", - "input": {}, - }, - { - "type": "tool_result", - "tool_use_id": "srvtoolu_a", - "content": "ok", - }, - ], - }, - ], - } - exec_inst = _executor([]) - result = exec_inst._prepare_body_for_tier(body, tier, source_vendor="zhipu") - - # 深拷贝(不修改原始 body) - assert result is not body - assert len(body["messages"][0]["content"]) == 3 - - # delta 块被剥离, tool_result 被搬迁出 assistant - assistant_content = result["messages"][0]["content"] - assert all( - b.get("type") not in ("server_tool_use_delta", "tool_result") - for b in assistant_content - ) - # tool_result 已搬到下一个 user 消息 - assert result["messages"][1]["role"] == "user" - assert any( - b.get("type") == "tool_result" and b.get("tool_use_id") == "srvtoolu_a" - for b in result["messages"][1]["content"] - ) - - def test_self_cleanup_preserves_srvtoolu_ids(self): - """回归保护: 自清理通道不得改写 zhipu 原生 srvtoolu_* ID.""" - tier = MagicMock() - tier.name = "zhipu" - - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "srvtoolu_keep_me", - "name": "bash", - "input": {}, - }, - { - "type": "thinking", - "thinking": "...", - "signature": "zhipu_sig", - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "srvtoolu_keep_me", - "content": "ok", - }, - ], - }, - ], - } - exec_inst = _executor([]) - result = exec_inst._prepare_body_for_tier(body, tier, source_vendor="zhipu") - - # ID 与 server_tool_use 类型必须保留 - first_block = result["messages"][0]["content"][0] - assert first_block["id"] == "srvtoolu_keep_me" - assert first_block["type"] == "server_tool_use" - # thinking signature 也必须保留 - thinking_block = next( - b for b in result["messages"][0]["content"] if b.get("type") == "thinking" - ) - assert thinking_block["signature"] == "zhipu_sig" diff --git a/tests/test_session_aware.py b/tests/test_session_aware.py new file mode 100644 index 0000000..0c08449 --- /dev/null +++ b/tests/test_session_aware.py @@ -0,0 +1,592 @@ +"""Session-Aware Architecture 测试. + +覆盖矩阵: +1. 新装库 schema 包含 session_key 列与索引; +2. 旧库增量迁移幂等 (重复 init 不抛错); +3. log() 写入 session_key 可回读; +4. query_recent_sessions() 聚合/排序/过滤正确性; +5. query_session_profile() 单会话查询; +6. SessionPolicyResolver 精确匹配/通配匹配/无匹配; +7. _resolve_effective_tiers 策略 tier 重排逻辑. +""" + +from __future__ import annotations + +import aiosqlite +import pytest +import pytest_asyncio + +from coding.proxy.config.session_policy import ( + SessionPoliciesConfig, + SessionPolicy, + SessionPolicyMatch, +) +from coding.proxy.logging.db import TokenLogger +from coding.proxy.routing.session_policy import SessionPolicyResolver + +# ── Fixture ────────────────────────────────────────────────── + + +@pytest_asyncio.fixture +async def logger(tmp_path): + tl = TokenLogger(tmp_path / "test.db") + await tl.init() + yield tl + await tl.close() + + +# ── 1. 新装库 schema ───────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_fresh_schema_contains_session_key(logger): + cursor = await logger._db.execute("PRAGMA table_info(usage_log)") + rows = await cursor.fetchall() + columns = {row["name"] for row in rows} + assert "session_key" in columns + + +@pytest.mark.asyncio +async def test_fresh_schema_contains_session_key_index(logger): + cursor = await logger._db.execute( + "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='usage_log'" + ) + rows = await cursor.fetchall() + names = {row["name"] for row in rows} + assert "idx_usage_session_key" in names + + +# ── 2. 旧库迁移幂等 ───────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_migration_from_legacy_db_adds_session_key(tmp_path): + db_path = tmp_path / "legacy.db" + async with aiosqlite.connect(str(db_path)) as db: + await db.executescript( + """ + CREATE TABLE usage_log ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + ts TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), + vendor TEXT NOT NULL, + model_requested TEXT NOT NULL, + model_served TEXT NOT NULL, + input_tokens INTEGER DEFAULT 0, + output_tokens INTEGER DEFAULT 0, + cache_creation_tokens INTEGER DEFAULT 0, + cache_read_tokens INTEGER DEFAULT 0, + duration_ms INTEGER DEFAULT 0, + success BOOLEAN NOT NULL DEFAULT 1, + failover BOOLEAN NOT NULL DEFAULT 0, + failover_from TEXT DEFAULT NULL, + request_id TEXT DEFAULT '', + client_category TEXT NOT NULL DEFAULT 'cc', + operation TEXT NOT NULL DEFAULT '', + endpoint TEXT NOT NULL DEFAULT '', + extra_usage_json TEXT NOT NULL DEFAULT '{}' + ); + """ + ) + await db.execute( + "INSERT INTO usage_log (vendor, model_requested, model_served) " + "VALUES ('anthropic','claude-sonnet','claude-sonnet')" + ) + await db.commit() + + tl = TokenLogger(db_path) + await tl.init() + try: + cursor = await tl._db.execute("PRAGMA table_info(usage_log)") + rows = await cursor.fetchall() + columns = {row["name"] for row in rows} + assert "session_key" in columns + + cursor = await tl._db.execute( + "SELECT session_key FROM usage_log WHERE vendor='anthropic'" + ) + row = await cursor.fetchone() + assert row["session_key"] == "" + finally: + await tl.close() + + tl2 = TokenLogger(db_path) + await tl2.init() + await tl2.close() + + +# ── 3. log() 写入 session_key ─────────────────────────────── + + +@pytest.mark.asyncio +async def test_log_persists_session_key(logger): + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=100, + output_tokens=50, + session_key="test-session-123", + ) + cursor = await logger._db.execute( + "SELECT session_key FROM usage_log WHERE vendor='anthropic'" + ) + row = await cursor.fetchone() + assert row["session_key"] == "test-session-123" + + +@pytest.mark.asyncio +async def test_log_default_session_key_empty(logger): + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=50, + output_tokens=10, + ) + cursor = await logger._db.execute("SELECT session_key FROM usage_log") + row = await cursor.fetchone() + assert row["session_key"] == "" + + +# ── 4. query_recent_sessions ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_query_recent_sessions_basic(logger): + for i in range(3): + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=100 * (i + 1), + output_tokens=50 * (i + 1), + session_key="session-alpha", + duration_ms=100 + i * 50, + ) + await logger.log( + vendor="copilot", + model_requested="claude-sonnet", + model_served="gpt-4o", + input_tokens=200, + output_tokens=80, + session_key="session-beta", + duration_ms=150, + ) + await logger.log( + vendor="zhipu", + model_requested="claude-sonnet", + model_served="glm-5v-turbo", + input_tokens=50, + output_tokens=20, + session_key="", # 空 key,应被排除 + ) + + sessions = await logger.query_recent_sessions(limit=10, hours=1) + assert len(sessions) == 2 + + alpha = next(s for s in sessions if s["session_key"] == "session-alpha") + assert alpha["total_requests"] == 3 + assert alpha["total_tokens"] == (100 + 200 + 300) + (50 + 100 + 150) + assert alpha["total_input"] == 100 + 200 + 300 + assert alpha["total_output"] == 50 + 100 + 150 + assert "claude-sonnet" in alpha["models"] + assert "anthropic" in alpha["vendors"] + assert alpha["success_rate"] == 100.0 + assert "cc" in alpha["client_categories"] + + +@pytest.mark.asyncio +async def test_query_recent_sessions_excludes_empty_key(logger): + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=100, + session_key="", + ) + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=100, + session_key="valid-session", + ) + sessions = await logger.query_recent_sessions(limit=10, hours=1) + assert len(sessions) == 1 + assert sessions[0]["session_key"] == "valid-session" + + +@pytest.mark.asyncio +async def test_query_recent_sessions_limit(logger): + for i in range(5): + await logger.log( + vendor="anthropic", + model_requested="claude-sonnet", + model_served="claude-sonnet", + input_tokens=100, + session_key=f"session-{i}", + ) + sessions = await logger.query_recent_sessions(limit=3, hours=1) + assert len(sessions) == 3 + + +@pytest.mark.asyncio +async def test_query_recent_sessions_success_rate(logger): + await logger.log( + vendor="anthropic", + model_requested="m", + model_served="m", + session_key="s1", + success=True, + ) + await logger.log( + vendor="anthropic", + model_requested="m", + model_served="m", + session_key="s1", + success=False, + ) + await logger.log( + vendor="anthropic", + model_requested="m", + model_served="m", + session_key="s1", + success=True, + ) + sessions = await logger.query_recent_sessions(limit=10, hours=1) + assert len(sessions) == 1 + assert abs(sessions[0]["success_rate"] - (2 / 3 * 100)) < 0.01 + + +# ── 5. query_session_profile ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_query_session_profile_found(logger): + await logger.log( + vendor="anthropic", + model_requested="m", + model_served="m", + input_tokens=100, + output_tokens=50, + session_key="profile-test", + ) + profile = await logger.query_session_profile("profile-test") + assert profile is not None + assert profile["session_key"] == "profile-test" + assert profile["total_requests"] == 1 + + +@pytest.mark.asyncio +async def test_query_session_profile_not_found(logger): + profile = await logger.query_session_profile("nonexistent") + assert profile is None + + +# ── 6. SessionPolicyResolver ──────────────────────────────── + + +def _make_policy(name, keys=None, category=None, tiers=None): + return SessionPolicy( + name=name, + match=SessionPolicyMatch(session_keys=keys or [], client_category=category), + tiers=tiers or [], + ) + + +def test_resolve_by_session_key(): + p1 = _make_policy("vip", keys=["key-1", "key-2"], tiers=["anthropic"]) + p2 = _make_policy("cc-default", category="cc", tiers=["copilot"]) + resolver = SessionPolicyResolver([p1, p2]) + + assert resolver.resolve("key-1") is p1 + assert resolver.resolve("key-2") is p1 + assert resolver.resolve("unknown-key", "cc") is p2 + assert resolver.resolve("unknown-key", "api") is None + + +def test_resolve_key_priority_over_category(): + p1 = _make_policy("cc-default", category="cc", tiers=["copilot"]) + p2 = _make_policy("vip", keys=["vip-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p1, p2]) + + result = resolver.resolve("vip-key", "cc") + assert result is p2 # 精确 key 匹配优先 + + +def test_resolve_no_match(): + resolver = SessionPolicyResolver([]) + assert resolver.resolve("any-key") is None + + +def test_resolve_first_match_wins(): + p1 = _make_policy("first", keys=["dup-key"], tiers=["anthropic"]) + p2 = _make_policy("second", keys=["dup-key"], tiers=["zhipu"]) + resolver = SessionPolicyResolver([p1, p2]) + + assert resolver.resolve("dup-key") is p1 + + +def test_empty_resolver(): + resolver = SessionPolicyResolver() + assert resolver.resolve("any") is None + + +# ── 7. _resolve_effective_tiers (via executor) ────────────── + + +def test_resolve_effective_tiers_with_policy(): + from coding.proxy.routing.executor import _RouteExecutor + from coding.proxy.routing.tier import VendorTier + from coding.proxy.vendors.base import BaseVendor + + class FakeVendor(BaseVendor): + def __init__(self, name): + self._name = name + + def get_name(self): + return self._name + + async def _prepare_request(self, body, headers): + return body, headers + + async def send_message_stream(self, body, headers): + yield b"", "" + + async def send_message(self, body, headers): + return None + + def supports_request(self, caps): + return True, [] + + def map_model(self, model): + return model + + tiers = [ + VendorTier(vendor=FakeVendor("zhipu")), + VendorTier(vendor=FakeVendor("anthropic")), + VendorTier(vendor=FakeVendor("copilot")), + ] + + policy = _make_policy("vip", keys=["vip-key"], tiers=["anthropic", "copilot"]) + resolver = SessionPolicyResolver([policy]) + + executor = _RouteExecutor( + router=None, + tiers=tiers, + usage_recorder=None, + session_manager=None, + session_policy_resolver=resolver, + ) + + effective = executor._resolve_effective_tiers("vip-key") + names = [t.name for t in effective] + assert names == ["anthropic", "copilot", "zhipu"] + + +def test_resolve_effective_tiers_no_policy(): + from coding.proxy.routing.executor import _RouteExecutor + from coding.proxy.routing.tier import VendorTier + from coding.proxy.vendors.base import BaseVendor + + class FakeVendor(BaseVendor): + def __init__(self, name): + self._name = name + + def get_name(self): + return self._name + + async def _prepare_request(self, body, headers): + return body, headers + + async def send_message_stream(self, body, headers): + yield b"", "" + + async def send_message(self, body, headers): + return None + + def supports_request(self, caps): + return True, [] + + def map_model(self, model): + return model + + tiers = [ + VendorTier(vendor=FakeVendor("zhipu")), + VendorTier(vendor=FakeVendor("anthropic")), + ] + + executor = _RouteExecutor( + router=None, + tiers=tiers, + usage_recorder=None, + session_manager=None, + ) + + effective = executor._resolve_effective_tiers("unknown-key") + assert effective is tiers # 返回同一列表引用 + + +# ── 8. SessionPoliciesConfig 集成 ─────────────────────────── + + +def test_config_default_empty(): + config = SessionPoliciesConfig() + assert config.policies == [] + + +def test_config_parse(): + config = SessionPoliciesConfig( + policies=[ + { + "name": "vip", + "match": {"session_keys": ["key-1"]}, + "tiers": ["anthropic", "copilot"], + } + ] + ) + assert len(config.policies) == 1 + assert config.policies[0].name == "vip" + assert config.policies[0].match.session_keys == ["key-1"] + assert config.policies[0].tiers == ["anthropic", "copilot"] + + +# ── 9. SessionPolicyResolver 运行时可变性 ──────────────────────── + + +def test_runtime_upsert_and_resolve(): + resolver = SessionPolicyResolver() + assert resolver.resolve("my-session") is None + + resolver.upsert("my-session", ["anthropic", "copilot"]) + policy = resolver.resolve("my-session") + assert policy is not None + assert policy.tiers == ["anthropic", "copilot"] + assert policy.name.startswith("runtime:") + + +def test_runtime_upsert_overwrites(): + resolver = SessionPolicyResolver() + resolver.upsert("my-session", ["anthropic"]) + resolver.upsert("my-session", ["copilot", "zhipu"]) + policy = resolver.resolve("my-session") + assert policy.tiers == ["copilot", "zhipu"] + + +def test_runtime_remove(): + resolver = SessionPolicyResolver() + resolver.upsert("my-session", ["anthropic"]) + assert resolver.remove("my-session") is True + assert resolver.resolve("my-session") is None + assert resolver.remove("my-session") is False + + +def test_runtime_remove_does_not_affect_config_policy(): + p = _make_policy("config-policy", keys=["config-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + # Cannot remove config-driven policy via runtime API + assert resolver.remove("config-key") is False + assert resolver.resolve("config-key") is p + + +def test_runtime_upsert_overrides_config_policy(): + p = _make_policy("config-policy", keys=["shared-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("shared-key", ["copilot"]) + # Runtime binding takes precedence (replaces in key_index) + policy = resolver.resolve("shared-key") + assert policy.tiers == ["copilot"] + assert policy.name.startswith("runtime:") + + +def test_runtime_remove_restores_config_policy(): + p = _make_policy("config-policy", keys=["shared-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("shared-key", ["copilot"]) + assert resolver.resolve("shared-key").tiers == ["copilot"] + # 移除运行时绑定后应恢复原配置策略 + assert resolver.remove("shared-key") is True + restored = resolver.resolve("shared-key") + assert restored is p + assert restored.tiers == ["anthropic"] + + +def test_list_runtime_bindings(): + resolver = SessionPolicyResolver() + p = _make_policy("config-policy", keys=["config-key"], tiers=["anthropic"]) + resolver = SessionPolicyResolver([p]) + resolver.upsert("runtime-1", ["copilot"]) + resolver.upsert("runtime-2", ["zhipu", "anthropic"]) + + bindings = resolver.list_runtime_bindings() + assert len(bindings) == 2 + keys = {b["session_key"] for b in bindings} + assert keys == {"runtime-1", "runtime-2"} + # Config-driven policy should not appear + assert "config-key" not in keys + + +def test_runtime_upsert_integrates_with_executor(): + from coding.proxy.routing.executor import _RouteExecutor + from coding.proxy.routing.tier import VendorTier + from coding.proxy.vendors.base import BaseVendor + + class FakeVendor(BaseVendor): + def __init__(self, name): + self._name = name + + def get_name(self): + return self._name + + async def _prepare_request(self, body, headers): + return body, headers + + async def send_message_stream(self, body, headers): + yield b"", "" + + async def send_message(self, body, headers): + return None + + def supports_request(self, caps): + return True, [] + + def map_model(self, model): + return model + + tiers = [ + VendorTier(vendor=FakeVendor("zhipu")), + VendorTier(vendor=FakeVendor("anthropic")), + VendorTier(vendor=FakeVendor("copilot")), + ] + + resolver = SessionPolicyResolver() + executor = _RouteExecutor( + router=None, + tiers=tiers, + usage_recorder=None, + session_manager=None, + session_policy_resolver=resolver, + ) + + # Before binding: default order + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "zhipu", + "anthropic", + "copilot", + ] + + # After binding: anthropic first, copilot second, zhipu last + resolver.upsert("test", ["anthropic", "copilot"]) + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "anthropic", + "copilot", + "zhipu", + ] + + # After unbind: back to default + resolver.remove("test") + assert [t.name for t in executor._resolve_effective_tiers("test")] == [ + "zhipu", + "anthropic", + "copilot", + ] diff --git a/tests/test_vendor_channels.py b/tests/test_vendor_channels.py index a105b42..774b85a 100644 --- a/tests/test_vendor_channels.py +++ b/tests/test_vendor_channels.py @@ -4,8 +4,6 @@ - zhipu → anthropic 转换 (prepare_zhipu_to_anthropic) - zhipu → copilot 转换 (prepare_zhipu_to_copilot) - copilot → zhipu 转换 (prepare_copilot_to_zhipu) -- zhipu → zhipu 自清理 (prepare_zhipu_self_cleanup) -- anthropic → zhipu 转换 (prepare_anthropic_to_zhipu) - 共享辅助函数 (strip_thinking_blocks, _strip_cache_control, _remove_vendor_blocks, _rewrite_srvtoolu_ids, enforce_anthropic_tool_pairing, infer_source_vendor_from_body) - 转换注册表 (VENDOR_TRANSITIONS, get_transition_channel) @@ -17,16 +15,13 @@ from coding.proxy.convert.vendor_channels import ( VENDOR_TRANSITIONS, - _enforce_pairing_sanity_pass, _remove_vendor_blocks, _rewrite_srvtoolu_ids, _strip_cache_control, enforce_anthropic_tool_pairing, get_transition_channel, infer_source_vendor_from_body, - prepare_anthropic_to_zhipu, prepare_copilot_to_zhipu, - prepare_zhipu_self_cleanup, prepare_zhipu_to_anthropic, prepare_zhipu_to_copilot, strip_thinking_blocks, @@ -698,302 +693,6 @@ def test_strips_redacted_thinking(self): ] -# ── zhipu → zhipu 自清理通道测试 ───────────────────────────────── - - -class TestZhipuSelfCleanupChannel: - """prepare_zhipu_self_cleanup 单元测试. - - 自清理通道的核心契约: **仅** 修复 zhipu 自身拒绝的产物 - (server_tool_use_delta, 错位 tool_result), 保留所有 zhipu 原生支持 - 的特性 (srvtoolu_* ID, thinking signature, cache_control, 顶层 thinking). - """ - - def test_strips_server_tool_use_delta(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "text", "text": "thinking..."}, - {"type": "server_tool_use_delta", "partial_json": "{}"}, - ], - }, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - content = prepared["messages"][0]["content"] - assert all(b.get("type") != "server_tool_use_delta" for b in content) - assert any("zhipu_vendor_blocks" in a for a in adaptations) - - def test_relocates_misplaced_tool_result(self): - """assistant 内联 tool_result 应被搬迁到下一个 user 消息.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "srvtoolu_a", - "name": "bash", - "input": {}, - }, - { - "type": "tool_result", - "tool_use_id": "srvtoolu_a", - "content": "ok", - }, - ], - }, - {"role": "user", "content": []}, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - - # assistant 消息中应不再包含 tool_result - assistant_content = prepared["messages"][0]["content"] - assert all(b.get("type") != "tool_result" for b in assistant_content) - # tool_result 已搬到下一个 user 消息 - user_content = prepared["messages"][1]["content"] - assert any( - b.get("type") == "tool_result" and b.get("tool_use_id") == "srvtoolu_a" - for b in user_content - ) - assert "misplaced_tool_result_relocated" in adaptations - - def test_preserves_srvtoolu_ids(self): - """zhipu 原生 srvtoolu_* ID 与 server_tool_use 类型必须保留.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "srvtoolu_xyz", - "name": "bash", - "input": {}, - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "srvtoolu_xyz", - "content": "ok", - }, - ], - }, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - - block = prepared["messages"][0]["content"][0] - assert block["id"] == "srvtoolu_xyz" - assert block["type"] == "server_tool_use" - # 无任何 srvtoolu 改写或 server_tool_use 类型纠正 - assert not any("srvtoolu_ids" in a for a in adaptations) - - def test_preserves_thinking_blocks(self): - """zhipu 自签 thinking signature 必须保留.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "thinking", - "thinking": "let me think", - "signature": "zhipu_sig_abc", - }, - {"type": "text", "text": "answer"}, - ], - }, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - content = prepared["messages"][0]["content"] - assert any(b.get("type") == "thinking" for b in content) - assert not any("thinking_blocks" in a for a in adaptations) - - def test_preserves_cache_control(self): - """cache_control 字段必须保留 (GLM 原生支持, 已实证 cache_read).""" - body = { - "system": [ - { - "type": "text", - "text": "system prompt", - "cache_control": {"type": "ephemeral"}, - }, - ], - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": "hi", - "cache_control": {"type": "ephemeral"}, - }, - ], - }, - ], - "tools": [ - { - "name": "bash", - "description": "", - "input_schema": {"type": "object"}, - "cache_control": {"type": "ephemeral"}, - }, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - assert prepared["system"][0].get("cache_control") == {"type": "ephemeral"} - assert prepared["messages"][0]["content"][0].get("cache_control") == { - "type": "ephemeral" - } - assert prepared["tools"][0].get("cache_control") == {"type": "ephemeral"} - assert not any("cache_control" in a for a in adaptations) - - def test_preserves_thinking_param(self): - """顶层 thinking / extended_thinking 参数必须保留.""" - body = { - "messages": [], - "thinking": {"type": "enabled", "budget_tokens": 5000}, - "extended_thinking": {"foo": "bar"}, - } - prepared, _ = prepare_zhipu_self_cleanup(body) - assert prepared["thinking"] == { - "type": "enabled", - "budget_tokens": 5000, - } - assert prepared["extended_thinking"] == {"foo": "bar"} - - def test_idempotency(self): - """二次调用幂等: 已清洗的 body 不再产生新 adaptations.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "srvtoolu_a", - "name": "bash", - "input": {}, - }, - { - "type": "tool_result", - "tool_use_id": "srvtoolu_a", - "content": "ok", - }, - {"type": "server_tool_use_delta", "partial_json": "{}"}, - ], - }, - ], - } - first_pass, first_adapt = prepare_zhipu_self_cleanup(body) - assert first_adapt # 首次调用应产生变换 - _, second_adapt = prepare_zhipu_self_cleanup(first_pass) - assert second_adapt == [] - - def test_noop_when_clean(self): - """纯净 body (无 zhipu 产物) 应不产生任何 adaptations.""" - body = { - "messages": [ - {"role": "user", "content": "hello"}, - { - "role": "assistant", - "content": [{"type": "text", "text": "hi"}], - }, - ], - } - original = copy.deepcopy(body) - prepared, adaptations = prepare_zhipu_self_cleanup(body) - assert adaptations == [] - assert prepared == original - - def test_does_not_mutate_input(self): - """通道返回深拷贝, 输入 body 必须保持原状.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "server_tool_use_delta", "partial_json": "{}"}, - ], - }, - ], - } - original = copy.deepcopy(body) - prepare_zhipu_self_cleanup(body) - assert body == original - - def test_combined_artifacts(self): - """端到端: server_tool_use_delta 被剥, server_tool_use 保留, 错位 tool_result 搬迁. - - 典型场景: Claude Code 的客户端工具 (Bash/Read 等) 以 ``tool_use`` 形式 - emit, 其错位的 ``tool_result`` 应被重定位; zhipu 原生 ``server_tool_use`` - 块不需要客户端 tool_result, 仅需保留原状. - """ - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "server_tool_use_delta", "partial_json": "{}"}, - { - "type": "server_tool_use", - "id": "srvtoolu_native", - "name": "web_search", - "input": {}, - }, - { - "type": "tool_use", - "id": "toolu_bash_001", - "name": "bash", - "input": {"command": "ls"}, - }, - { - "type": "tool_result", - "tool_use_id": "toolu_bash_001", - "content": "ok", - }, - ], - }, - ], - } - prepared, adaptations = prepare_zhipu_self_cleanup(body) - - assistant_content = prepared["messages"][0]["content"] - # delta 被剥离 - assert all(b.get("type") != "server_tool_use_delta" for b in assistant_content) - # 错位 tool_result 被搬出 assistant - assert all(b.get("type") != "tool_result" for b in assistant_content) - # server_tool_use 与其 srvtoolu_* ID 完整保留 - srv_block = next( - b for b in assistant_content if b.get("type") == "server_tool_use" - ) - assert srv_block["id"] == "srvtoolu_native" - # tool_use ID 同样保留 - tool_use_block = next( - b for b in assistant_content if b.get("type") == "tool_use" - ) - assert tool_use_block["id"] == "toolu_bash_001" - # 后续 user 消息已被插入并包含 tool_result - assert prepared["messages"][1]["role"] == "user" - assert any( - b.get("type") == "tool_result" and b.get("tool_use_id") == "toolu_bash_001" - for b in prepared["messages"][1]["content"] - ) - # 关键 adaptation 标签均出现 - assert any("zhipu_vendor_blocks" in a for a in adaptations) - assert "misplaced_tool_result_relocated" in adaptations - - # ── 转换注册表测试 ──────────────────────────────────────────── @@ -1004,9 +703,7 @@ def test_all_transitions_registered(self): assert ("zhipu", "anthropic") in VENDOR_TRANSITIONS assert ("zhipu", "copilot") in VENDOR_TRANSITIONS assert ("copilot", "zhipu") in VENDOR_TRANSITIONS - assert ("zhipu", "zhipu") in VENDOR_TRANSITIONS - assert ("anthropic", "zhipu") in VENDOR_TRANSITIONS - assert len(VENDOR_TRANSITIONS) == 5 + assert len(VENDOR_TRANSITIONS) == 3 def test_get_transition_channel_returns_function(self): assert ( @@ -1014,18 +711,12 @@ def test_get_transition_channel_returns_function(self): ) assert get_transition_channel("zhipu", "copilot") is prepare_zhipu_to_copilot assert get_transition_channel("copilot", "zhipu") is prepare_copilot_to_zhipu - assert get_transition_channel("zhipu", "zhipu") is prepare_zhipu_self_cleanup - assert ( - get_transition_channel("anthropic", "zhipu") is prepare_anthropic_to_zhipu - ) def test_get_transition_channel_returns_none_for_unregistered(self): + assert get_transition_channel("anthropic", "zhipu") is None assert get_transition_channel("copilot", "anthropic") is None assert get_transition_channel("unknown", "target") is None assert get_transition_channel("antigravity", "copilot") is None - # 未注册的同 vendor 自转换仍返回 None - assert get_transition_channel("anthropic", "anthropic") is None - assert get_transition_channel("copilot", "copilot") is None def test_transition_functions_share_signature(self): body = {"messages": []} @@ -1056,8 +747,7 @@ def test_copilot_to_zhipu_removes_thinking_param_zhipu_to_copilot_preserves(self assert "thinking" in zhipu_to_copilot_result assert "removed_thinking_param" not in zhipu_to_copilot_adapt - def test_cross_vendor_transitions_strip_thinking_blocks(self): - """跨 vendor 通道一律剥离 thinking blocks(自清理通道刻意保留,故排除).""" + def test_all_transitions_strip_thinking_blocks(self): body = { "messages": [ { @@ -1070,9 +760,6 @@ def test_cross_vendor_transitions_strip_thinking_blocks(self): ], } for key, fn in VENDOR_TRANSITIONS.items(): - if key[0] == key[1]: - # 自转换通道(如 zhipu→zhipu)保留 thinking signature,跳过 - continue result, adaptations = fn(body) assert result["messages"][0]["content"] == [ {"type": "text", "text": "hi"} @@ -1140,48 +827,6 @@ def test_handles_string_content(self): removed = _remove_vendor_blocks(body, {"whatever"}) assert removed == 0 - def test_inserts_placeholder_when_all_blocks_stripped(self): - """assistant 消息仅含 vendor 块时插入占位 text block.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "toolu_1", - "name": "ws", - "input": {}, - }, - ], - }, - ], - } - removed = _remove_vendor_blocks(body, {"server_tool_use"}) - assert removed == 1 - assert body["messages"][0]["content"] == [ - {"type": "text", "text": "[vendor_block_removed]"}, - ] - - def test_does_not_mutate_unrelated_messages(self): - """仅含 vendor 块的消息被修改,其他消息不受影响.""" - body = { - "messages": [ - {"role": "user", "content": [{"type": "text", "text": "hi"}]}, - { - "role": "assistant", - "content": [ - {"type": "server_tool_use_delta", "partial_json": "{}"}, - ], - }, - ], - } - _remove_vendor_blocks(body, {"server_tool_use_delta"}) - assert body["messages"][0]["content"] == [{"type": "text", "text": "hi"}] - assert body["messages"][1]["content"] == [ - {"type": "text", "text": "[vendor_block_removed]"}, - ] - # ── _rewrite_srvtoolu_ids 单元测试 ───────────────────────────────── @@ -1363,32 +1008,22 @@ def test_skips_non_matching_user_tool_result(self): assert count == 0 assert body["messages"][0]["content"][0]["tool_use_id"] == "toolu_other" - def test_rewrites_inline_tool_result_before_tool_use(self): - """块顺序鲁棒性回归保护: inline tool_result 在 tool_use 之前时仍正确改名. - GLM-5 偶发将 inline tool_result 输出在本消息 tool_use 之前 (流式断片). - 若 _rewrite 用单遍扫描, 处理 inline tool_result 时 id_map 尚未填入对应 - srvtoolu_* → 漏改名 → enforce 阶段 extracted dict key 与 tool_use_ids - 错位 → dangling tool_use 漏报 → anthropic 报 'tool_use ids without - tool_result blocks immediately after'. +# ── infer_source_vendor_from_body 单元测试 ───────────────────────── + + +class TestInferSourceVendorFromBody: + """infer_source_vendor_from_body 内容感知启发式推断.""" - 修复后采用两遍扫描: 先全量收集 id_map (仅处理 tool_use), 再统一改写 - 所有 tool_result.tool_use_id 引用。 - """ + def test_detects_zhipu_by_srvtoolu_id(self): body = { "messages": [ { "role": "assistant", "content": [ - # inline tool_result 在 server_tool_use 之前! - { - "type": "tool_result", - "tool_use_id": "srvtoolu_X", - "content": "inline-X", - }, { - "type": "server_tool_use", - "id": "srvtoolu_X", + "type": "tool_use", + "id": "srvtoolu_abc", "name": "bash", "input": {}, }, @@ -1396,21 +1031,9 @@ def test_rewrites_inline_tool_result_before_tool_use(self): }, ], } - count, id_map = _rewrite_srvtoolu_ids(body) - assert count == 1 - new_id = id_map["srvtoolu_X"] - assert new_id.startswith("toolu_normalized_") - # 关键断言: inline tool_result 也被改名 (即使在 tool_use 之前) - inline_result = body["messages"][0]["content"][0] - assert inline_result["type"] == "tool_result" - assert inline_result["tool_use_id"] == new_id - # tool_use 也被改名 - tool_use_block = body["messages"][0]["content"][1] - assert tool_use_block["type"] == "tool_use" - assert tool_use_block["id"] == new_id - - def test_rewrites_tool_result_in_assistant_role(self): - """assistant role 内的 tool_result 也应被改名 (Pass 2 全量扫描所有消息).""" + assert infer_source_vendor_from_body(body) == "zhipu" + + def test_detects_zhipu_by_server_tool_use_type(self): body = { "messages": [ { @@ -1418,82 +1041,25 @@ def test_rewrites_tool_result_in_assistant_role(self): "content": [ { "type": "server_tool_use", - "id": "srvtoolu_M", + "id": "toolu_any", "name": "bash", "input": {}, }, ], }, + ], + } + assert infer_source_vendor_from_body(body) == "zhipu" + + def test_detects_zhipu_by_server_tool_use_delta(self): + body = { + "messages": [ { - "role": "assistant", # 异常: 连续 assistant + "role": "assistant", "content": [ { - "type": "tool_result", - "tool_use_id": "srvtoolu_M", - "content": "M-result", - }, - ], - }, - ], - } - count, id_map = _rewrite_srvtoolu_ids(body) - new_id = id_map["srvtoolu_M"] - # 后续 assistant 内的 tool_result 也被改名 - assert body["messages"][1]["content"][0]["tool_use_id"] == new_id - - -# ── infer_source_vendor_from_body 单元测试 ───────────────────────── - - -class TestInferSourceVendorFromBody: - """infer_source_vendor_from_body 内容感知启发式推断.""" - - def test_detects_zhipu_by_srvtoolu_id(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "srvtoolu_abc", - "name": "bash", - "input": {}, - }, - ], - }, - ], - } - assert infer_source_vendor_from_body(body) == "zhipu" - - def test_detects_zhipu_by_server_tool_use_with_non_standard_id(self): - """server_tool_use + 非 toolu_/srvtoolu_ ID → 兜底归 zhipu.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "custom_non_standard", - "name": "bash", - "input": {}, - }, - ], - }, - ], - } - assert infer_source_vendor_from_body(body) == "zhipu" - - def test_detects_zhipu_by_server_tool_use_delta(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use_delta", - "partial_json": "{}", + "type": "server_tool_use_delta", + "partial_json": "{}", }, ], }, @@ -1580,44 +1146,6 @@ def test_handles_string_content(self): } assert infer_source_vendor_from_body(body) is None - def test_detects_anthropic_by_server_tool_use_with_toolu_id(self): - """server_tool_use + toolu_* ID(Anthropic beta 功能产物)→ anthropic.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "toolu_web_search_1", - "name": "web_search", - "input": {"query": "test"}, - }, - ], - }, - ], - } - assert infer_source_vendor_from_body(body) == "anthropic" - - def test_zhipu_srvtoolu_takes_priority_over_anthropic_detection(self): - """srvtoolu_* ID 优先识别为 zhipu(即使 block type 为 server_tool_use).""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "srvtoolu_x", - "name": "bash", - "input": {}, - }, - ], - }, - ], - } - assert infer_source_vendor_from_body(body) == "zhipu" - # ── enforce_anthropic_tool_pairing 单元测试(从 test_request_normalizer.py 迁入) ─ @@ -2053,212 +1581,6 @@ def test_next_message_is_assistant_inserts_user(self): assert messages[1]["content"][0]["type"] == "tool_result" assert messages[2]["role"] == "assistant" - def test_sanity_check_does_not_false_fire_on_correctly_paired_messages(self): - """正常配对消息走完主循环后, sanity G 段不应误触发. - - 主循环 F 步已正确合成/搬迁所有 tool_result 时, sanity 视角下 next_user - 的 nu_result_ids 已覆盖全部 tool_use_ids, 走 ``if uid in nu_result_ids: - continue`` 分支, 不会重复合成占位、也不应打 ``pairing_sanity_repaired`` - 标签 → 验证 sanity 的幂等性 / 不重复合成保证。 - """ - messages = [ - {"role": "user", "content": "task"}, - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_a", "name": "bash", "input": {}}, - ], - }, - { - "role": "user", - "content": [ - {"type": "tool_result", "tool_use_id": "toolu_a", "content": "ok"}, - ], - }, - ] - _, fixes = _enforce_pairing(messages) - # 一切正常, sanity 不应介入 - assert "pairing_sanity_repaired" not in fixes - assert "orphaned_tool_use_repaired" not in fixes - assert "misplaced_tool_result_relocated" not in fixes - - -class TestEnforcePairingSanityPass: - """_enforce_pairing_sanity_pass 正向兜底路径单元测试. - - 主循环 F 步在当前实现下能覆盖所有 dangling tool_use, 因此 sanity 在公开 - ``enforce_anthropic_tool_pairing`` API 调用中不会被实际触发. 抽出为独立 - helper 后可绕过主循环, 直接对兜底合成路径建立正向回归保护, 防止 G 段 - 被未来重构「优化掉」时静默失效。 - """ - - def test_synthesizes_is_error_for_dangling_tool_use(self): - """next_user 缺对应 tool_result 时, sanity 直接合成 is_error 占位.""" - messages = [ - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "toolu_dangling", - "name": "bash", - "input": {}, - }, - ], - }, - {"role": "user", "content": []}, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == ["toolu_dangling"] - # next_user 已被注入 is_error 占位 - user_content = messages[1]["content"] - assert len(user_content) == 1 - placeholder = user_content[0] - assert placeholder["type"] == "tool_result" - assert placeholder["tool_use_id"] == "toolu_dangling" - assert placeholder["is_error"] is True - assert placeholder["content"] == "" - - def test_inserts_user_message_when_next_is_not_user(self): - """assistant 后无 user 消息时, sanity 应当插入空 user 再合成占位.""" - messages = [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_x", "name": "bash", "input": {}}, - ], - }, - # 没有后续消息 → sanity 应插入空 user 并合成占位 - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == ["toolu_x"] - assert len(messages) == 2 - assert messages[1]["role"] == "user" - results = messages[1]["content"] - assert len(results) == 1 - assert results[0]["tool_use_id"] == "toolu_x" - assert results[0]["is_error"] is True - - def test_inserts_user_message_when_next_is_assistant(self): - """assistant 后紧跟另一个 assistant (非 user) 时, sanity 应插入空 user.""" - messages = [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_a", "name": "bash", "input": {}}, - ], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "stray"}], - }, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == ["toolu_a"] - assert messages[1]["role"] == "user" # 新插入的空 user - assert messages[2]["role"] == "assistant" # 原 stray 后移 - assert messages[1]["content"][0]["tool_use_id"] == "toolu_a" - - def test_skips_when_tool_result_already_present(self): - """next_user 已含对应 tool_result 时不应重复合成.""" - messages = [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_a", "name": "bash", "input": {}}, - ], - }, - { - "role": "user", - "content": [ - {"type": "tool_result", "tool_use_id": "toolu_a", "content": "ok"}, - ], - }, - ] - original_user_content = list(messages[1]["content"]) - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == [] - assert messages[1]["content"] == original_user_content # 未被改动 - - def test_handles_string_content_in_next_user(self): - """next_user.content 是字符串时, sanity 先转为 text 块再合成占位.""" - messages = [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_a", "name": "bash", "input": {}}, - ], - }, - {"role": "user", "content": "free text"}, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == ["toolu_a"] - user_content = messages[1]["content"] - assert isinstance(user_content, list) - # 原字符串保留为 text 块, 占位追加在末尾 - assert user_content[0] == {"type": "text", "text": "free text"} - assert user_content[-1]["type"] == "tool_result" - assert user_content[-1]["tool_use_id"] == "toolu_a" - - def test_partial_repair_only_synthesizes_missing_uids(self): - """next_user 已含部分 tool_result 时, sanity 仅为缺失的 uid 合成占位.""" - messages = [ - { - "role": "assistant", - "content": [ - {"type": "tool_use", "id": "toolu_a", "name": "bash", "input": {}}, - {"type": "tool_use", "id": "toolu_b", "name": "bash", "input": {}}, - ], - }, - { - "role": "user", - "content": [ - {"type": "tool_result", "tool_use_id": "toolu_a", "content": "ok"}, - # toolu_b 缺失 - ], - }, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == ["toolu_b"] - results = messages[1]["content"] - assert len(results) == 2 - # 原 toolu_a 不变 - assert results[0]["tool_use_id"] == "toolu_a" - assert results[0].get("is_error") is not True - # 新合成 toolu_b is_error 占位 - assert results[1]["tool_use_id"] == "toolu_b" - assert results[1]["is_error"] is True - - def test_skips_assistant_without_tool_use(self): - """assistant 不含 tool_use 时 sanity 应当短路, 不插入空 user.""" - messages = [ - { - "role": "assistant", - "content": [{"type": "text", "text": "just talking"}], - }, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - - assert sanity_synthesized == [] - # 不应插入空 user - assert len(messages) == 1 - - def test_skips_non_assistant_messages(self): - """非 assistant 消息 (user/system) 不参与 sanity 检查.""" - messages = [ - {"role": "user", "content": "hi"}, - {"role": "system", "content": "rules"}, - ] - sanity_synthesized = _enforce_pairing_sanity_pass(messages) - assert sanity_synthesized == [] - assert len(messages) == 2 # 不被改动 - # ── 通道层端到端集成(zhipu 产物全量清洗) ─────────────────────────── @@ -2365,100 +1687,6 @@ def test_full_zhipu_artifacts_combined(self): assert relocated[0]["tool_use_id"] == new_id assert any("misplaced_tool_result_relocated" in a for a in adaptations) - def test_inline_tool_result_before_tool_use_pairs_correctly(self): - """日志现象回归保护: GLM-5 输出 [inline tool_result, tool_use] 块顺序时, - 修复前 _rewrite 单遍扫描漏改 inline.tool_use_id, enforce 阶段 dict key - 与 tool_use_ids 错位, 导致最终 anthropic 报 'tool_use ids without - tool_result blocks immediately after'. - - 修复后两遍扫描确保 inline tool_result 与 tool_use 同步改名, enforce 能 - 正确将 inline 搬迁到 next user, 不需合成 is_error 占位 (无 orphan 标签). - """ - body = { - "messages": [ - {"role": "user", "content": "task"}, - # 上一轮完成 - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "srvtoolu_A", - "name": "bash", - "input": {}, - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "srvtoolu_A", - "content": "A-ok", - }, - ], - }, - # 当前轮: inline tool_result 在 server_tool_use 之前 (流式断片) - { - "role": "assistant", - "content": [ - { - "type": "thinking", - "thinking": "...", - "signature": "zhipu_sig", - }, - # inline tool_result 在 tool_use 之前! - { - "type": "tool_result", - "tool_use_id": "srvtoolu_B", - "content": "B-inline", - }, - { - "type": "server_tool_use", - "id": "srvtoolu_B", - "name": "bash", - "input": {}, - }, - ], - }, - # 客户端没回 B 的 tool_result (因为已被 inline) - {"role": "user", "content": []}, - ], - } - prepared, adaptations = prepare_zhipu_to_anthropic(body) - - # 关键断言: 仅有 misplaced_tool_result_relocated, 无 orphaned_tool_use_repaired - # (因 inline 真实内容被正确搬迁, 无需合成 is_error 占位) - assert "misplaced_tool_result_relocated" in adaptations - assert "orphaned_tool_use_repaired" not in adaptations - assert "pairing_sanity_repaired" not in adaptations - - # 验证 messages[3] 的 tool_use 在 messages[4] 有匹配 tool_result - m3 = prepared["messages"][3] - m4 = prepared["messages"][4] - m3_tool_uses = [ - b["id"] - for b in m3["content"] - if isinstance(b, dict) and b.get("type") == "tool_use" - ] - m4_results = { - b.get("tool_use_id") - for b in m4["content"] - if isinstance(b, dict) and b.get("type") == "tool_result" - } - assert len(m3_tool_uses) == 1 - assert m3_tool_uses[0] in m4_results - - # 搬迁的 tool_result 应保留原始内容 ("B-inline"), 而非合成的空 is_error - relocated = next( - b - for b in m4["content"] - if isinstance(b, dict) and b.get("type") == "tool_result" - ) - assert relocated["content"] == "B-inline" - assert relocated.get("is_error") is not True - class TestZhipuToCopilotChannelFullCleanup: """验证 prepare_zhipu_to_copilot 对 zhipu 产物的完整清洗.""" @@ -2501,253 +1729,3 @@ def test_rewrites_srvtoolu_and_strips_vendor_delta(self): assert prepared["messages"][1]["content"][0]["tool_use_id"] == new_id assert any("zhipu_vendor_blocks" in a for a in adaptations) assert any("srvtoolu_ids" in a for a in adaptations) - - -# ── anthropic → zhipu 转换通道测试 ────────────────────────────── - - -class TestAnthropicToZhipuChannel: - """prepare_anthropic_to_zhipu 转换通道单元测试.""" - - def test_strips_server_tool_use_blocks(self): - """Anthropic 的 server_tool_use(web search, computer use)应被剥离.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "text", "text": "Let me search..."}, - { - "type": "server_tool_use", - "id": "toolu_web_search_123", - "name": "web_search", - "input": {"query": "python async"}, - }, - ], - }, - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "toolu_web_search_123", - "content": "search results", - }, - ], - }, - ], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert any("server_tool_use" in a for a in adaptations) - assistant_content = prepared["messages"][0]["content"] - assert all(b.get("type") != "server_tool_use" for b in assistant_content) - assert assistant_content == [{"type": "text", "text": "Let me search..."}] - - def test_strips_thinking_blocks(self): - """Anthropic 签发的 thinking blocks 应被剥离(zhipu 可能无法验证 signature).""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "thought", "signature": "sig"}, - {"type": "text", "text": "response"}, - ], - }, - ], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert any("thinking_blocks" in a for a in adaptations) - assert prepared["messages"][0]["content"] == [ - {"type": "text", "text": "response"}, - ] - - def test_removes_cache_control(self): - body = { - "system": [ - {"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}, - ], - "messages": [], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert any("cache_control" in a for a in adaptations) - assert "cache_control" not in prepared["system"][0] - - def test_removes_thinking_params(self): - body = { - "messages": [], - "thinking": {"type": "enabled", "budget_tokens": 10000}, - "extended_thinking": {"type": "enabled"}, - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert "thinking" not in prepared - assert "extended_thinking" not in prepared - assert "removed_thinking_param" in adaptations - assert "removed_extended_thinking_param" in adaptations - - def test_enforces_tool_pairing(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "tool_use", - "id": "toolu_1", - "name": "bash", - "input": {}, - }, - ], - }, - {"role": "user", "content": "next"}, - ], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert "orphaned_tool_use_repaired" in adaptations - user_results = [ - b - for b in prepared["messages"][1]["content"] - if isinstance(b, dict) and b.get("type") == "tool_result" - ] - assert len(user_results) == 1 - assert user_results[0]["tool_use_id"] == "toolu_1" - - def test_preserves_original_body(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "toolu_1", - "name": "web_search", - "input": {}, - }, - {"type": "text", "text": "hi"}, - ], - }, - ], - } - original = copy.deepcopy(body) - prepare_anthropic_to_zhipu(body) - assert body == original - - def test_noop_when_clean(self): - body = { - "messages": [ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": [{"type": "text", "text": "hi"}]}, - ], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert adaptations == [] - assert prepared == body - - def test_idempotency(self): - body = { - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "t"}, - { - "type": "server_tool_use", - "id": "toolu_1", - "name": "web_search", - "input": {}, - }, - ], - }, - ], - "thinking": {"type": "enabled"}, - } - prepared1, adaptations1 = prepare_anthropic_to_zhipu(body) - prepared2, adaptations2 = prepare_anthropic_to_zhipu(prepared1) - assert prepared2 == prepared1 - assert adaptations2 == [] - - def test_strips_multiple_server_tool_use_blocks(self): - """多个 server_tool_use 块(web search + computer use)全部剥离.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "toolu_ws_1", - "name": "web_search", - "input": {"query": "test"}, - }, - { - "type": "server_tool_use", - "id": "toolu_cu_1", - "name": "computer", - "input": {"action": "click"}, - }, - ], - }, - ], - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert not any( - b.get("type") == "server_tool_use" - for b in prepared["messages"][0]["content"] - ) - assert "removed_2_server_tool_use" in adaptations[0] - - def test_inserts_placeholder_when_all_blocks_stripped(self): - """assistant 消息仅含 server_tool_use 时插入占位 text block.""" - body = { - "messages": [ - { - "role": "assistant", - "content": [ - { - "type": "server_tool_use", - "id": "toolu_1", - "name": "web_search", - "input": {}, - }, - ], - }, - ], - } - prepared, _ = prepare_anthropic_to_zhipu(body) - assert prepared["messages"][0]["content"] == [ - {"type": "text", "text": "[vendor_block_removed]"}, - ] - - def test_combined_server_tool_use_and_thinking(self): - """server_tool_use + thinking + cache_control 的组合清洗.""" - body = { - "system": [ - {"type": "text", "text": "sys", "cache_control": {"type": "ephemeral"}}, - ], - "messages": [ - { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "t", "signature": "s"}, - { - "type": "server_tool_use", - "id": "toolu_cu_1", - "name": "computer", - "input": {}, - }, - {"type": "text", "text": "done"}, - ], - }, - ], - "thinking": {"type": "enabled"}, - } - prepared, adaptations = prepare_anthropic_to_zhipu(body) - assert all( - b.get("type") not in ("thinking", "redacted_thinking", "server_tool_use") - for b in prepared["messages"][0]["content"] - ) - assert "cache_control" not in prepared["system"][0] - assert "thinking" not in prepared - assert any("server_tool_use" in a for a in adaptations) - assert any("thinking_blocks" in a for a in adaptations) diff --git a/tests/test_vendors.py b/tests/test_vendors.py index 578b0b0..f771e9b 100644 --- a/tests/test_vendors.py +++ b/tests/test_vendors.py @@ -340,7 +340,6 @@ def test_zhipu_never_triggers_failover(): def test_zhipu_supports_tools_and_thinking(): """ZhipuVendor 应声明全部能力为 NATIVE(原生 Anthropic 兼容端点).""" - from coding.proxy.compat.canonical import CompatibilityStatus from coding.proxy.vendors.base import RequestCapabilities mapper = ModelMapper([]) @@ -361,7 +360,14 @@ def test_zhipu_supports_tools_and_thinking(): ) assert supported is True assert reasons == [] - # 兼容性画像应全部为 NATIVE + + +# 兼容性画像应全部为 NATIVE +def test_zhipu_compatibility_profile_native(): + from coding.proxy.compat.canonical import CompatibilityStatus + + mapper = ModelMapper([]) + zhipu_vendor = ZhipuVendor(ZhipuConfig(), mapper) profile = zhipu_vendor.get_compatibility_profile() assert profile.thinking is CompatibilityStatus.NATIVE assert profile.tool_calling is CompatibilityStatus.NATIVE diff --git a/uv.lock b/uv.lock index 7258f33..79995a3 100644 --- a/uv.lock +++ b/uv.lock @@ -74,7 +74,7 @@ wheels = [ [[package]] name = "coding-proxy" -version = "0.3.1a3" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "aiosqlite" },