diff --git a/docs/issue.md b/docs/issue.md index 66fd61f..6da58f6 100644 --- a/docs/issue.md +++ b/docs/issue.md @@ -132,3 +132,57 @@ AttributeError: 'ZhipuVendor' object has no attribute 'name' - 已 `grep -rn "vendor\.name\b" src/` 全仓扫描,确认 `target_vendor.name | vendor.name` 误用仅 routes.py 的这两处,已随本次修复一并消除。`/v1/messages` 主链路在 executor 中调用 `tier.name`(`Tier` 对象的合法 dataclass 属性),与 vendor 实例 `name` 无关,不受影响。 - 若未来新增 Vendor 子类,仍只需实现 `get_name()` 抽象方法;外部调用方应遵循同一契约,本档案的修复模式可作为参考。 + +--- + +## Gemini embedding 透传至 Vertex AI 上游返回 `request body doesn't contain valid prompts` + +**问题描述** + +通过本代理调用 Gemini embedding 模型时,上游返回 400: + +``` +litellm.BadRequestError: GeminiException BadRequestError - +{"error":{"message":"request body doesn't contain valid prompts"}} +POST /api/gemini/v1beta/models/gemini-embedding-001%3AbatchEmbedContents 400 +``` + +litellm 报错日志中 URL 路径是 `:batchEmbedContents`,调用端疑似格式不兼容。 + +**表因** + +litellm 按 Google AI Studio 格式构造请求: +- 路径:`POST {api_base}/v1beta/models/{model}:batchEmbedContents` +- Body:`{"requests": [{"model": "models/...", "content": {"parts": [{"text": "..."}]}}]}` + +但实际上游(如 `llms.as-in.io` 这类 Vertex AI 风格网关)只接受 Vertex AI 格式: +- 路径:`POST {api_base}/v1beta1/publishers/google/models/{model}:embedContent` +- Body:`{"content": {"parts": [{"text": "..."}]}}` + +且无 `batchEmbedContents` 端点。 + +**根因** + +1. 代理 `NativeProxyHandler.dispatch()` 是字节级透传,对 embedding 端点未做协议适配,直接把 Google AI Studio 格式的 URL/Body 转给 Vertex AI 上游,路由不匹配。 +2. litellm `_check_custom_proxy()` 在自定义 `api_base` 场景下会丢失 `v1beta/` 版本前缀,发送 `{api_base}/models/{model}:verb`,使代理原有的 `OperationClassifier` 正则(要求 `v1beta/` 前缀)失配,进而走原始透传分支再次失败。 + +**处理方式** + +1. `src/coding/proxy/native_api/operation.py`:放宽 Gemini 路径正则中的 `v1(?:beta1?)?/` 段为可选,兼容 litellm 丢失版本前缀的异常路径。 +2. `src/coding/proxy/native_api/handler.py`:在 `dispatch()` 中新增 Gemini embedding Vertex AI 适配分支: + - 仅当 `provider == "gemini"`、`operation in {"embedding", "embedding.batch"}`、且 `base_url` 非官方 `generativelanguage.googleapis.com` 时启用; + - `embedContent` → 重写路径为 `v1beta1/publishers/google/models/{model}:embedContent`,剥离 body 中的 `model` 字段; + - `batchEmbedContents` → 拆分为多次并发 `embedContent` 调用(`asyncio.gather`),聚合响应为 `{"embeddings": [...]}` 返回; + - 用量抽取累加各子请求的 `usageMetadata`。 +3. `tests/test_native_api_handler.py`:新增 3 个回归测试覆盖单次 / 批量 / 官方上游透传不变三类场景。 + +**后续防范** + +- 协议适配层只对**非官方上游**生效,官方 `generativelanguage.googleapis.com` 仍走字节级透传,避免引入不必要的转换开销与协议偏差。 +- 上游路径分支的判定优先用 base_url 域名而非依赖网关行为特征,便于后续扩展(如 Vertex Express、其他 LLM gateway)时的精确匹配。 +- 真实链路验证:使用 litellm `embedding(api_base=..., api_key=...)` 单输入 / 多输入分别调用,确认返回 3072 维向量及正确批量计数。 + +**同类问题影响与处理注意事项** + +- litellm 在 Gemini 其他端点(`generateContent` / `countTokens`)同样存在 `_check_custom_proxy` 丢失 `v1beta/` 前缀的 bug;本次仅放宽了 `operation.py` 中的路径正则(让分类器能识别此类异常路径),未对这些端点做格式转换,因为非 embedding 端点的 Google AI Studio / Vertex AI 请求体差异较小,多数上游兼容。如未来出现类似失配再做针对性适配。 +- 若上游网关同时支持 OpenAI `/v1/embeddings` 与 Vertex AI 路径,建议优先在客户端配置 OpenAI 兼容路径,减少协议转换链路。 diff --git a/src/coding/proxy/native_api/handler.py b/src/coding/proxy/native_api/handler.py index d468f3b..ab7b344 100644 --- a/src/coding/proxy/native_api/handler.py +++ b/src/coding/proxy/native_api/handler.py @@ -13,8 +13,10 @@ from __future__ import annotations +import asyncio import json import logging +import re import time from collections.abc import AsyncIterator from typing import TYPE_CHECKING @@ -194,6 +196,28 @@ async def dispatch( start_ts = time.perf_counter() client = self._get_client(provider) + # ── Gemini embedding Vertex AI 格式转换 ────────────────── + # 当上游非官方 Google AI Studio(generativelanguage.googleapis.com)时, + # litellm 发送的 Google AI Studio 格式(v1beta/models/{model}:batchEmbedContents) + # 需转换为 Vertex AI 格式(v1beta1/publishers/google/models/{model}:embedContent)。 + vertex_rewrite = ( + provider == "gemini" + and operation in ("embedding", "embedding.batch") + and cfg.base_url + and "generativelanguage.googleapis.com" not in cfg.base_url + ) + if vertex_rewrite: + return await self._dispatch_gemini_vertex_embedding( + client=client, + operation=operation, + endpoint=endpoint, + body_bytes=body_bytes, + upstream_headers=upstream_headers, + query_string=query_string, + provider=provider, + start_ts=start_ts, + ) + # 构造上游 URL(保留 query) upstream_url = endpoint if query_string: @@ -295,6 +319,313 @@ async def dispatch( media_type=content_type or None, ) + # ── Gemini embedding → Vertex AI 格式转换 ────────────────── + + # Google AI Studio 路径正则:[v1beta/]models/{model}:{verb} + # 版本段允许缺失以兼容 litellm `_check_custom_proxy` 丢失 v1beta 前缀的 bug。 + _GEMINI_EMBED_PATH_RE = re.compile( + r"^/?(?:v1(?:beta1?)?/)?models/(?P[^/:]+)(?::|%3A)(?PembedContent|batchEmbedContents)/?$" + ) + + async def _dispatch_gemini_vertex_embedding( + self, + *, + client: httpx.AsyncClient, + operation: str, + endpoint: str, + body_bytes: bytes, + upstream_headers: dict[str, str], + query_string: str, + provider: str, + start_ts: float, + ) -> StarletteResponse: + """将 Google AI Studio 格式的 embedding 请求转换为 Vertex AI 格式. + + Google AI Studio: + POST v1beta/models/{model}:batchEmbedContents + Body: {"requests": [{"model": "models/{model}", "content": {...}}]} + + Vertex AI: + POST v1beta1/publishers/google/models/{model}:embedContent + Body: {"content": {...}} + """ + from fastapi.responses import Response as FastAPIResponse + + match = self._GEMINI_EMBED_PATH_RE.match(endpoint) + if not match: + return FastAPIResponse( + content=json.dumps( + { + "error": { + "message": f"unrecognized gemini embedding path: {endpoint}" + } + } + ).encode(), + status_code=400, + media_type="application/json", + ) + + model_name = match.group("model") + verb = match.group("verb") + + # 解析原始请求体 + try: + body = json.loads(body_bytes) if body_bytes else {} + except (json.JSONDecodeError, UnicodeDecodeError): + return FastAPIResponse( + content=json.dumps( + {"error": {"message": "invalid JSON body for embedding request"}} + ).encode(), + status_code=400, + media_type="application/json", + ) + + if verb == "batchEmbedContents": + return await self._vertex_batch_embed( + client=client, + model_name=model_name, + body=body, + upstream_headers=upstream_headers, + query_string=query_string, + provider=provider, + operation=operation, + endpoint=endpoint, + start_ts=start_ts, + ) + + # 单次 embedContent:直接转换 + content = body.get("content", body) + return await self._vertex_single_embed( + client=client, + model_name=model_name, + content=content, + upstream_headers=upstream_headers, + query_string=query_string, + provider=provider, + operation=operation, + endpoint=endpoint, + start_ts=start_ts, + ) + + async def _vertex_single_embed( + self, + *, + client: httpx.AsyncClient, + model_name: str, + content: dict, + upstream_headers: dict[str, str], + query_string: str, + provider: str, + operation: str, + endpoint: str, + start_ts: float, + ) -> StarletteResponse: + """发送单次 Vertex AI embedContent 请求.""" + from fastapi.responses import Response as FastAPIResponse + + vertex_path = f"/v1beta1/publishers/google/models/{model_name}:embedContent" + vertex_url = vertex_path + if query_string: + vertex_url = f"{vertex_path}?{query_string}" + + vertex_body = json.dumps({"content": content}).encode() + + req = client.build_request( + method="POST", + url=vertex_url, + content=vertex_body, + headers=upstream_headers, + ) + + try: + upstream_resp = await client.send(req, stream=True) + except ( + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + httpx.RemoteProtocolError, + ) as exc: + duration_ms = int((time.perf_counter() - start_ts) * 1000) + await self._record_failure( + provider=provider, + operation=operation, + endpoint=endpoint, + duration_ms=duration_ms, + reason=str(exc), + ) + return FastAPIResponse( + content=json.dumps( + { + "error": { + "message": f"upstream unreachable: {exc}", + "type": "api_error", + } + } + ).encode(), + status_code=502, + media_type="application/json", + ) + + try: + raw_body = await upstream_resp.aread() + finally: + await upstream_resp.aclose() + + duration_ms = int((time.perf_counter() - start_ts) * 1000) + status = upstream_resp.status_code + content_type = upstream_resp.headers.get("content-type", "").lower() + resp_headers = _filter_response_headers(dict(upstream_resp.headers)) + + # 用量抽取 + extraction = ExtractionResult() + if "application/json" in content_type and raw_body: + try: + parsed = json.loads(raw_body.decode("utf-8", errors="replace")) + if isinstance(parsed, dict): + extraction = extract_usage( + provider, operation, parsed, status, dict(upstream_resp.headers) + ) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + vendor_label = _VENDOR_LABEL[provider] + await self._record_usage( + provider=provider, + operation=operation, + endpoint=endpoint, + duration_ms=duration_ms, + status=status, + extraction=extraction, + evidence_records=_build_nonstream_evidence( + vendor=vendor_label, extraction=extraction + ), + ) + + return FastAPIResponse( + content=raw_body, + status_code=status, + headers=resp_headers, + media_type=content_type or None, + ) + + async def _vertex_batch_embed( + self, + *, + client: httpx.AsyncClient, + model_name: str, + body: dict, + upstream_headers: dict[str, str], + query_string: str, + provider: str, + operation: str, + endpoint: str, + start_ts: float, + ) -> StarletteResponse: + """将 batchEmbedContents 拆分为多次 embedContent 调用并聚合响应.""" + from fastapi.responses import Response as FastAPIResponse + + requests_list = body.get("requests", []) + if not requests_list: + return FastAPIResponse( + content=json.dumps( + { + "error": { + "message": "batchEmbedContents requires non-empty 'requests' field" + } + } + ).encode(), + status_code=400, + media_type="application/json", + ) + + vertex_path = f"/v1beta1/publishers/google/models/{model_name}:embedContent" + vertex_url = vertex_path + if query_string: + vertex_url = f"{vertex_path}?{query_string}" + + # 并发发送所有 embedContent 请求 + async def _single(req_body: dict) -> tuple[dict, int]: + content = req_body.get("content", req_body) + vertex_body = json.dumps({"content": content}).encode() + req = client.build_request( + method="POST", + url=vertex_url, + content=vertex_body, + headers=upstream_headers, + ) + try: + resp = await client.send(req, stream=False) + except ( + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError, + httpx.RemoteProtocolError, + ) as exc: + return {"error": {"message": f"upstream unreachable: {exc}"}}, 502 + try: + return resp.json(), resp.status_code + except Exception: + return {"error": {"message": resp.text[:200]}}, resp.status_code + + results = await asyncio.gather(*[_single(r) for r in requests_list]) + + # 检查是否有失败的请求 + embeddings = [] + for resp_json, resp_status in results: + if resp_status != 200: + # 返回第一个错误 + return FastAPIResponse( + content=json.dumps(resp_json).encode(), + status_code=resp_status, + media_type="application/json", + ) + embedding_data = resp_json.get("embedding", {}) + embeddings.append(embedding_data) + + # 聚合为 batchEmbedContents 响应格式 + batch_response = {"embeddings": embeddings} + duration_ms = int((time.perf_counter() - start_ts) * 1000) + + # 用量抽取 + extraction = ExtractionResult() + for resp_json, _ in results: + if isinstance(resp_json, dict): + ext = extract_usage(provider, operation, resp_json, 200, {}) + extraction = ExtractionResult( + input_tokens=extraction.input_tokens + ext.input_tokens, + output_tokens=extraction.output_tokens + ext.output_tokens, + cache_creation_tokens=extraction.cache_creation_tokens + + ext.cache_creation_tokens, + cache_read_tokens=extraction.cache_read_tokens + + ext.cache_read_tokens, + request_id=ext.request_id or extraction.request_id, + model_served=ext.model_served or extraction.model_served, + raw_usage=ext.raw_usage or extraction.raw_usage, + source_field_map=ext.source_field_map + or extraction.source_field_map, + evidence_kind=ext.evidence_kind or extraction.evidence_kind, + extra_usage=ext.extra_usage or extraction.extra_usage, + ) + + vendor_label = _VENDOR_LABEL[provider] + await self._record_usage( + provider=provider, + operation=operation, + endpoint=endpoint, + duration_ms=duration_ms, + status=200, + extraction=extraction, + evidence_records=_build_nonstream_evidence( + vendor=vendor_label, extraction=extraction + ), + ) + + return FastAPIResponse( + content=json.dumps(batch_response).encode(), + status_code=200, + media_type="application/json", + ) + # ── SSE 流式转发(同时累加 usage) ───────────────────────── async def _stream_and_accumulate( diff --git a/src/coding/proxy/native_api/operation.py b/src/coding/proxy/native_api/operation.py index 1ff82bd..2080b6c 100644 --- a/src/coding/proxy/native_api/operation.py +++ b/src/coding/proxy/native_api/operation.py @@ -48,30 +48,34 @@ class _Rule: ) # ── Gemini ──────────────────────────────────────────────────────── -# Gemini 的方法动词作为路径后缀(``:generateContent``),通过正则提取 +# Gemini 的方法动词作为路径后缀(``:generateContent``),通过正则提取。 +# ``v1(?:beta1?)?/`` 前缀允许缺失,以兼容 litellm `_check_custom_proxy` 在 +# 自定义 ``api_base`` 场景下丢失版本段的 bug(参考 litellm issue #17759)。 _GEMINI_RULES: tuple[_Rule, ...] = ( _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)streamGenerateContent/?$"), + re.compile( + r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)streamGenerateContent/?$" + ), "generate_content", ), _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)generateContent/?$"), + re.compile(r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)generateContent/?$"), "generate_content", ), _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)countTokens/?$"), + re.compile(r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)countTokens/?$"), "count_tokens", ), _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)embedContent/?$"), + re.compile(r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)embedContent/?$"), "embedding", ), _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)batchEmbedContents/?$"), + re.compile(r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)batchEmbedContents/?$"), "embedding.batch", ), _Rule( - re.compile(r"^/?v1(?:beta)?/models/[^/]+(?:%3A|:)predict/?$"), + re.compile(r"^/?(?:v1(?:beta1?)?/)?models/[^/]+(?:%3A|:)predict/?$"), "predict", ), _Rule( diff --git a/tests/test_native_api_handler.py b/tests/test_native_api_handler.py index d00a054..14be66c 100644 --- a/tests/test_native_api_handler.py +++ b/tests/test_native_api_handler.py @@ -14,6 +14,7 @@ from __future__ import annotations +import json from collections.abc import Iterator import httpx @@ -443,3 +444,120 @@ def factory(make_transport): # 上游 URL 必须含字面冒号,不含 %3A assert "%3A" not in upstream_str assert ":batchEmbedContents" in upstream_str + + +# ── Gemini embedding Vertex AI 格式转换 ───────────────────────── + + +def test_gemini_vertex_embed_content_single() -> None: + """非官方上游时,embedContent 转为 Vertex AI 格式.""" + + def route(request: httpx.Request) -> httpx.Response: + body = json.loads(request.content) + assert "content" in body + assert "model" not in body + assert "requests" not in body + assert ":embedContent" in str(request.url) + assert "v1beta1/publishers/google/models" in str(request.url) + return httpx.Response(200, json={"embedding": {"values": [0.1, 0.2]}}) + + def factory(make_transport): + cfg = NativeApiConfig( + gemini=NativeProviderConfig(enabled=True, base_url="http://llms.as-in.io"), + ) + transport = make_transport(route) + return NativeProxyHandler(cfg, transport=transport), transport + + for client, captured in _make_app(factory): + r = client.post( + "/api/gemini/v1beta/models/gemini-embedding-2-preview:embedContent", + json={ + "model": "models/gemini-embedding-2-preview", + "content": {"parts": [{"text": "hello"}]}, + }, + ) + assert r.status_code == 200 + assert "embedding" in r.json() + + +def test_gemini_vertex_batch_embed_contents() -> None: + """非官方上游时,batchEmbedContents 拆分为多次 embedContent 并聚合.""" + + call_count = 0 + + def route(request: httpx.Request) -> httpx.Response: + nonlocal call_count + call_count += 1 + body = json.loads(request.content) + assert "content" in body + assert ":embedContent" in str(request.url) + assert "v1beta1/publishers/google/models" in str(request.url) + return httpx.Response( + 200, + json={"embedding": {"values": [float(call_count), 0.5]}}, + ) + + def factory(make_transport): + cfg = NativeApiConfig( + gemini=NativeProviderConfig(enabled=True, base_url="http://llms.as-in.io"), + ) + transport = make_transport(route) + return NativeProxyHandler(cfg, transport=transport), transport + + for client, captured in _make_app(factory): + r = client.post( + "/api/gemini/v1beta/models/gemini-embedding-2-preview:batchEmbedContents", + json={ + "requests": [ + { + "model": "models/gemini-embedding-2-preview", + "content": {"parts": [{"text": "hello"}]}, + }, + { + "model": "models/gemini-embedding-2-preview", + "content": {"parts": [{"text": "world"}]}, + }, + ] + }, + ) + assert r.status_code == 200 + data = r.json() + assert "embeddings" in data + assert len(data["embeddings"]) == 2 + assert data["embeddings"][0]["values"] == [1.0, 0.5] + assert data["embeddings"][1]["values"] == [2.0, 0.5] + assert call_count == 2 + + +def test_gemini_vertex_embed_official_upstream_unchanged() -> None: + """官方上游时,batchEmbedContents 走原始透传路径,不做格式转换.""" + + def route(request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"embeddings": [{"values": [0.1, 0.2]}]}) + + def factory(make_transport): + cfg = NativeApiConfig( + gemini=NativeProviderConfig( + enabled=True, base_url="https://generativelanguage.googleapis.com" + ), + ) + transport = make_transport(route) + return NativeProxyHandler(cfg, transport=transport), transport + + for client, captured in _make_app(factory): + r = client.post( + "/api/gemini/v1beta/models/gemini-embedding-001:batchEmbedContents?key=k", + json={ + "requests": [ + { + "model": "models/gemini-embedding-001", + "content": {"parts": [{"text": "hello"}]}, + } + ] + }, + ) + assert r.status_code == 200 + # 官方上游走原始路径,URL 保持 v1beta/models/ 格式 + upstream = captured[0] + assert "v1beta/models" in str(upstream.url) + assert "v1beta1/publishers" not in str(upstream.url)