Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 135 additions & 31 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import random
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import cast

from google import genai
Expand Down Expand Up @@ -32,6 +33,21 @@ def filter(self, record):
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())


@dataclass
class ChunkView:
"""流式响应 chunk 的结构化视图对象

提供对 Gemini API 流式响应的统一访问接口,
避免在多处重复进行防御性检查。
"""
candidate: types.Candidate | None
parts: list[types.Part] | None
reasoning_text: str | None
visible_text: str | None
has_function_call: bool
finish_reason: types.FinishReason | None


@register_provider_adapter(
"googlegenai_chat_completion",
"Google Gemini Chat Completion 提供商适配器",
Expand Down Expand Up @@ -398,15 +414,85 @@ def append_or_extend(

return gemini_contents

def _extract_reasoning_content(self, candidate: types.Candidate) -> str:
"""Extract reasoning content from candidate parts"""
if not candidate.content or not candidate.content.parts:
return ""
def _split_chunk_content(
self, chunk: types.GenerateContentResponse
) -> ChunkView:
"""
从流式响应 chunk 中提取结构化视图。

添加防御性检查,安全访问可能不存在的属性。

Args:
chunk: Gemini API 返回的流式响应 chunk

Returns:
ChunkView: 包含 candidate, parts, reasoning_text, visible_text,
has_function_call, finish_reason 的结构化视图
"""
# 防御性检查:candidates 是否存在且非空
if not chunk.candidates:
return ChunkView(
candidate=None,
parts=None,
reasoning_text=None,
visible_text=None,
has_function_call=False,
finish_reason=None,
)

thought_buf: list[str] = [
(p.text or "") for p in candidate.content.parts if p.thought
]
return "".join(thought_buf).strip()
candidate = chunk.candidates[0]

# 防御性检查:使用 getattr 安全访问 content
content = getattr(candidate, "content", None)
if content is None:
return ChunkView(
candidate=candidate,
parts=None,
reasoning_text=None,
visible_text=None,
has_function_call=False,
finish_reason=getattr(candidate, "finish_reason", None),
)

# 防御性检查:使用 getattr 安全访问 parts
parts = getattr(content, "parts", None)

reasoning_text: str | None = None
visible_text: str | None = None
has_function_call = False

if parts:
text_parts: list[str] = []
reasoning_parts: list[str] = []

for part in parts:
# 安全访问 thought, text, function_call 属性
is_thought = getattr(part, "thought", False)
part_text = getattr(part, "text", None)

if getattr(part, "function_call", None):
has_function_call = True

if is_thought and part_text:
reasoning_parts.append(part_text)
elif not is_thought and part_text:
text_parts.append(part_text)

reasoning_text = "".join(reasoning_parts) if reasoning_parts else None
visible_text = "".join(text_parts) if text_parts else None
else:
# 回退:当 parts 为空但 chunk.text 存在时
chunk_text = getattr(chunk, "text", None)
visible_text = chunk_text or None

return ChunkView(
candidate=candidate,
parts=parts,
reasoning_text=reasoning_text,
visible_text=visible_text,
has_function_call=has_function_call,
finish_reason=getattr(candidate, "finish_reason", None),
)

def _extract_usage(
self, usage_metadata: types.GenerateContentResponseUsageMetadata
Expand Down Expand Up @@ -451,7 +537,12 @@ def _process_content_parts(
raise Exception("API 返回的 candidate.content.parts 为空。")

# 提取 reasoning content
reasoning = self._extract_reasoning_content(candidate)
thought_buf: list[str] = [
(p.text or "")
for p in result_parts
if getattr(p, "thought", False)
]
reasoning = "".join(thought_buf).strip()
if reasoning:
llm_response.reasoning_content = reasoning

Expand All @@ -467,6 +558,9 @@ def _process_content_parts(
):
chain.append(Comp.Plain("这是图片"))
for part in result_parts:
# 跳过思考内容(thought=True),只处理实际输出
if getattr(part, "thought", False):
continue
if part.text:
chain.append(Comp.Plain(part.text))

Expand Down Expand Up @@ -635,20 +729,25 @@ async def _query_stream(
async for chunk in result:
llm_response = LLMResponse("assistant", is_chunk=True)

if not chunk.candidates:
# 使用辅助函数进行防御性检查和内容提取
chunk_view = self._split_chunk_content(chunk)

# 如果 candidate 为空,跳过
if chunk_view.candidate is None:
logger.warning(f"收到的 chunk 中 candidates 为空: {chunk}")
continue
if not chunk.candidates[0].content:

# 如果 parts 为空(content 为空),跳过
if chunk_view.parts is None:
logger.warning(f"收到的 chunk 中 content 为空: {chunk}")
continue

if chunk.candidates[0].content.parts and any(
part.function_call for part in chunk.candidates[0].content.parts
):
# 检查是否包含函数调用
if chunk_view.has_function_call:
llm_response = LLMResponse("assistant", is_chunk=False)
llm_response.raw_completion = chunk
llm_response.result_chain = self._process_content_parts(
chunk.candidates[0],
chunk_view.candidate,
llm_response,
)
llm_response.id = chunk.response_id
Expand All @@ -657,28 +756,33 @@ async def _query_stream(
yield llm_response
return

_f = False

# 提取 reasoning content
reasoning = self._extract_reasoning_content(chunk.candidates[0])
if reasoning:
_f = True
accumulated_reasoning += reasoning
llm_response.reasoning_content = reasoning
if chunk.text:
_f = True
accumulated_text += chunk.text
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
if _f:
has_content = False

# 处理思维链内容
if chunk_view.reasoning_text:
has_content = True
accumulated_reasoning += chunk_view.reasoning_text
llm_response.reasoning_content = chunk_view.reasoning_text

# 处理实际输出内容
if chunk_view.visible_text:
has_content = True
accumulated_text += chunk_view.visible_text
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(chunk_view.visible_text)]
)

if has_content:
yield llm_response

if chunk.candidates[0].finish_reason:
# 检查是否为最终 chunk
if chunk_view.finish_reason:
# Process the final chunk for potential tool calls or other content
if chunk.candidates[0].content.parts:
if chunk_view.parts:
final_response = LLMResponse("assistant", is_chunk=False)
final_response.raw_completion = chunk
final_response.result_chain = self._process_content_parts(
chunk.candidates[0],
chunk_view.candidate,
final_response,
)
final_response.id = chunk.response_id
Expand Down