feat:为失败请求留档并提供重试分析

This commit is contained in:
SengokuCola
2026-04-07 15:16:06 +08:00
parent 3b5baf901a
commit 09bce14664
6 changed files with 1304 additions and 91 deletions

View File

@@ -1,7 +1,10 @@
# ruff: noqa: B025
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
import asyncio
import base64
import binascii
import io
import json
@@ -21,6 +24,8 @@ from google.genai.types import (
EmbedContentConfig,
EmbedContentResponse,
FunctionDeclaration,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
GoogleSearch,
@@ -60,6 +65,14 @@ from .base_client import (
UsageTuple,
client_registry,
)
from ..request_snapshot import (
attach_request_snapshot,
has_request_snapshot,
save_failed_request_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
logger = get_logger("Gemini客户端")
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
}
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
def _normalize_image_mime_type(image_format: str) -> str:
"""将图片格式名称转换为标准 MIME 类型。
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
return {"result": content}
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
if not thought_signature:
return None
return {
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
}
}
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
if not tool_call.extra_content:
return None
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
if not isinstance(provider_payload, dict):
return None
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
if isinstance(raw_thought_signature, bytes):
return raw_thought_signature
if not isinstance(raw_thought_signature, str):
return None
normalized_signature = raw_thought_signature.strip()
if not normalized_signature:
return None
try:
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
except (binascii.Error, ValueError):
return normalized_signature.encode("utf-8")
def _build_gemini_function_call_part(
tool_call: ToolCall,
*,
inject_fallback_signature: bool,
) -> Part:
"""根据内部工具调用构建 Gemini function call part。"""
thought_signature = _extract_gemini_thought_signature(tool_call)
if thought_signature is None and inject_fallback_signature:
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
return Part(
function_call=FunctionCall(
id=tool_call.call_id,
name=tool_call.func_name,
args=tool_call.args or {},
),
thought_signature=thought_signature,
)
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
"""安全获取 Gemini 响应中的候选列表。
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
if message.role == RoleType.Assistant:
assistant_parts = _build_non_tool_parts(message)
if message.tool_calls:
for tool_call in message.tool_calls:
for tool_call_index, tool_call in enumerate(message.tool_calls):
assistant_parts.append(
Part.from_function_call(
name=tool_call.func_name,
args=tool_call.args or {},
_build_gemini_function_call_part(
tool_call,
inject_fallback_signature=tool_call_index == 0,
)
)
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
@@ -256,9 +330,12 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
"且消息中未携带 tool_name"
)
tool_name_by_call_id[message.tool_call_id] = tool_name
function_response_part = Part.from_function_response(
name=tool_name,
response=_normalize_function_response_payload(message),
function_response_part = Part(
function_response=FunctionResponse(
id=message.tool_call_id,
name=tool_name,
response=_normalize_function_response_payload(message),
)
)
contents.append(Content(role="tool", parts=[function_response_part]))
continue
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
Raises:
RespParseException: 当函数调用结构不合法时抛出。
"""
raw_function_calls = getattr(response, "function_calls", None)
candidates = _get_candidates(response)
if not raw_function_calls and candidates:
raw_function_calls = []
for candidate in candidates:
content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or []
for part in parts:
function_call = getattr(part, "function_call", None)
if function_call is not None:
raw_function_calls.append(function_call)
tool_calls: List[ToolCall] = []
for candidate in candidates:
content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or []
for part in parts:
function_call = getattr(part, "function_call", None)
if function_call is None:
continue
call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
call_args = getattr(function_call, "args", None) or {}
if not isinstance(call_name, str) or not call_name:
raise RespParseException(response, "响应解析失败Gemini 工具调用缺少 name 字段")
if not isinstance(call_args, dict):
raise RespParseException(response, "响应解析失败Gemini 工具调用参数无法解析为字典")
tool_calls.append(
ToolCall(
call_id=call_id,
func_name=call_name,
args=call_args,
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
)
)
if tool_calls:
return tool_calls
raw_function_calls = getattr(response, "function_calls", None)
if not raw_function_calls:
return []
tool_calls: List[ToolCall] = []
for index, function_call in enumerate(raw_function_calls, start=1):
call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
@@ -808,23 +904,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
model_info = request.model_info
contents, system_instruction = _convert_messages(request.message_list)
model_identifier, enable_google_search = self._resolve_model_identifier(
model_info.model_identifier,
request.extra_params,
)
generation_config = self._build_generation_config(
model_identifier=model_identifier,
system_instruction=system_instruction,
tool_options=request.tool_options,
response_format=request.response_format,
max_tokens=request.max_tokens,
temperature=request.temperature,
extra_params=request.extra_params,
enable_google_search=enable_google_search,
)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
try:
contents, system_instruction = _convert_messages(request.message_list)
model_identifier, enable_google_search = self._resolve_model_identifier(
model_info.model_identifier,
request.extra_params,
)
generation_config = self._build_generation_config(
model_identifier=model_identifier,
system_instruction=system_instruction,
tool_options=request.tool_options,
response_format=request.response_format,
max_tokens=request.max_tokens,
temperature=request.temperature,
extra_params=request.extra_params,
enable_google_search=enable_google_search,
)
snapshot_provider_request["request_kwargs"] = {
"config": generation_config,
"contents": contents,
"enable_google_search": enable_google_search,
"model": model_identifier,
"system_instruction": system_instruction,
}
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
self.client.aio.models.generate_content_stream(
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
raise
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=wrapped_error,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except EmptyResponseException as exc:
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
except EmptyResponseException:
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
model_info = request.model_info
embedding_input = request.embedding_input
extra_params = request.extra_params
embed_config = _build_embed_content_config(extra_params)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:embedContent",
"method": "POST",
"operation": "models.embed_content",
"request_kwargs": {},
}
try:
embed_config = _build_embed_content_config(extra_params)
snapshot_provider_request["request_kwargs"] = {
"config": embed_config,
"contents": embedding_input,
"model": model_info.model_identifier,
}
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
)
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
response = APIResponse(raw_data=raw_response)
if not raw_response.embeddings:
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
audio_base64 = request.audio_base64
max_tokens = request.max_tokens
extra_params = request.extra_params
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
transcription_prompt = str(
extra_params.get(
@@ -933,27 +1157,36 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
)
)
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
contents: List[ContentUnion] = [
Content(
role="user",
parts=[
Part.from_text(text=transcription_prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type),
],
)
]
generation_config = self._build_generation_config(
model_identifier=model_identifier,
system_instruction=None,
tool_options=None,
response_format=None,
max_tokens=max_tokens,
temperature=None,
extra_params=extra_params,
enable_google_search=False,
)
try:
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
contents: List[ContentUnion] = [
Content(
role="user",
parts=[
Part.from_text(text=transcription_prompt),
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type),
],
)
]
generation_config = self._build_generation_config(
model_identifier=model_identifier,
system_instruction=None,
tool_options=None,
response_format=None,
max_tokens=max_tokens,
temperature=None,
extra_params=extra_params,
enable_google_search=False,
)
snapshot_provider_request["request_kwargs"] = {
"audio_base64": audio_base64,
"audio_mime_type": audio_mime_type,
"config": generation_config,
"contents": contents,
"model": model_identifier,
"transcription_prompt": transcription_prompt,
}
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_identifier,
contents=contents,
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
response, usage_record = _default_normal_response_parser(raw_response)
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
return response, usage_record