feat:为失败请求留档并提供重试分析
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
# ruff: noqa: B025
|
||||
|
||||
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import io
|
||||
import json
|
||||
|
||||
@@ -21,6 +24,8 @@ from google.genai.types import (
|
||||
EmbedContentConfig,
|
||||
EmbedContentResponse,
|
||||
FunctionDeclaration,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
GoogleSearch,
|
||||
@@ -60,6 +65,14 @@ from .base_client import (
|
||||
UsageTuple,
|
||||
client_registry,
|
||||
)
|
||||
from ..request_snapshot import (
|
||||
attach_request_snapshot,
|
||||
has_request_snapshot,
|
||||
save_failed_request_snapshot,
|
||||
serialize_audio_request_snapshot,
|
||||
serialize_embedding_request_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
|
||||
logger = get_logger("Gemini客户端")
|
||||
|
||||
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
|
||||
}
|
||||
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
|
||||
|
||||
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
|
||||
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
|
||||
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
|
||||
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
|
||||
|
||||
|
||||
def _normalize_image_mime_type(image_format: str) -> str:
|
||||
"""将图片格式名称转换为标准 MIME 类型。
|
||||
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
|
||||
return {"result": content}
|
||||
|
||||
|
||||
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
|
||||
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
|
||||
if not thought_signature:
|
||||
return None
|
||||
return {
|
||||
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
|
||||
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
|
||||
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
|
||||
if not tool_call.extra_content:
|
||||
return None
|
||||
|
||||
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
|
||||
if not isinstance(provider_payload, dict):
|
||||
return None
|
||||
|
||||
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
|
||||
if isinstance(raw_thought_signature, bytes):
|
||||
return raw_thought_signature
|
||||
if not isinstance(raw_thought_signature, str):
|
||||
return None
|
||||
|
||||
normalized_signature = raw_thought_signature.strip()
|
||||
if not normalized_signature:
|
||||
return None
|
||||
|
||||
try:
|
||||
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return normalized_signature.encode("utf-8")
|
||||
|
||||
|
||||
def _build_gemini_function_call_part(
|
||||
tool_call: ToolCall,
|
||||
*,
|
||||
inject_fallback_signature: bool,
|
||||
) -> Part:
|
||||
"""根据内部工具调用构建 Gemini function call part。"""
|
||||
thought_signature = _extract_gemini_thought_signature(tool_call)
|
||||
if thought_signature is None and inject_fallback_signature:
|
||||
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
|
||||
|
||||
return Part(
|
||||
function_call=FunctionCall(
|
||||
id=tool_call.call_id,
|
||||
name=tool_call.func_name,
|
||||
args=tool_call.args or {},
|
||||
),
|
||||
thought_signature=thought_signature,
|
||||
)
|
||||
|
||||
|
||||
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
|
||||
"""安全获取 Gemini 响应中的候选列表。
|
||||
|
||||
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
if message.role == RoleType.Assistant:
|
||||
assistant_parts = _build_non_tool_parts(message)
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
for tool_call_index, tool_call in enumerate(message.tool_calls):
|
||||
assistant_parts.append(
|
||||
Part.from_function_call(
|
||||
name=tool_call.func_name,
|
||||
args=tool_call.args or {},
|
||||
_build_gemini_function_call_part(
|
||||
tool_call,
|
||||
inject_fallback_signature=tool_call_index == 0,
|
||||
)
|
||||
)
|
||||
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
|
||||
@@ -256,9 +330,12 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
"且消息中未携带 tool_name"
|
||||
)
|
||||
tool_name_by_call_id[message.tool_call_id] = tool_name
|
||||
function_response_part = Part.from_function_response(
|
||||
name=tool_name,
|
||||
response=_normalize_function_response_payload(message),
|
||||
function_response_part = Part(
|
||||
function_response=FunctionResponse(
|
||||
id=message.tool_call_id,
|
||||
name=tool_name,
|
||||
response=_normalize_function_response_payload(message),
|
||||
)
|
||||
)
|
||||
contents.append(Content(role="tool", parts=[function_response_part]))
|
||||
continue
|
||||
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
|
||||
Raises:
|
||||
RespParseException: 当函数调用结构不合法时抛出。
|
||||
"""
|
||||
raw_function_calls = getattr(response, "function_calls", None)
|
||||
candidates = _get_candidates(response)
|
||||
if not raw_function_calls and candidates:
|
||||
raw_function_calls = []
|
||||
for candidate in candidates:
|
||||
content = getattr(candidate, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
for part in parts:
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call is not None:
|
||||
raw_function_calls.append(function_call)
|
||||
tool_calls: List[ToolCall] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = getattr(candidate, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
for part in parts:
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call is None:
|
||||
continue
|
||||
|
||||
call_name = getattr(function_call, "name", None)
|
||||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
|
||||
call_args = getattr(function_call, "args", None) or {}
|
||||
if not isinstance(call_name, str) or not call_name:
|
||||
raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段")
|
||||
if not isinstance(call_args, dict):
|
||||
raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=call_name,
|
||||
args=call_args,
|
||||
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
return tool_calls
|
||||
|
||||
raw_function_calls = getattr(response, "function_calls", None)
|
||||
if not raw_function_calls:
|
||||
return []
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
for index, function_call in enumerate(raw_function_calls, start=1):
|
||||
call_name = getattr(function_call, "name", None)
|
||||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
|
||||
@@ -808,23 +904,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
model_info = request.model_info
|
||||
contents, system_instruction = _convert_messages(request.message_list)
|
||||
model_identifier, enable_google_search = self._resolve_model_identifier(
|
||||
model_info.model_identifier,
|
||||
request.extra_params,
|
||||
)
|
||||
generation_config = self._build_generation_config(
|
||||
model_identifier=model_identifier,
|
||||
system_instruction=system_instruction,
|
||||
tool_options=request.tool_options,
|
||||
response_format=request.response_format,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
extra_params=request.extra_params,
|
||||
enable_google_search=enable_google_search,
|
||||
)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:generateContent",
|
||||
"method": "POST",
|
||||
"operation": "models.generate_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
contents, system_instruction = _convert_messages(request.message_list)
|
||||
model_identifier, enable_google_search = self._resolve_model_identifier(
|
||||
model_info.model_identifier,
|
||||
request.extra_params,
|
||||
)
|
||||
generation_config = self._build_generation_config(
|
||||
model_identifier=model_identifier,
|
||||
system_instruction=system_instruction,
|
||||
tool_options=request.tool_options,
|
||||
response_format=request.response_format,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
extra_params=request.extra_params,
|
||||
enable_google_search=enable_google_search,
|
||||
)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"config": generation_config,
|
||||
"contents": contents,
|
||||
"enable_google_search": enable_google_search,
|
||||
"model": model_identifier,
|
||||
"system_instruction": system_instruction,
|
||||
}
|
||||
if model_info.force_stream_mode:
|
||||
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
|
||||
self.client.aio.models.generate_content_stream(
|
||||
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
raise
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=wrapped_error,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except EmptyResponseException as exc:
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
except Exception as exc:
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
|
||||
except EmptyResponseException:
|
||||
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
model_info = request.model_info
|
||||
embedding_input = request.embedding_input
|
||||
extra_params = request.extra_params
|
||||
embed_config = _build_embed_content_config(extra_params)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:embedContent",
|
||||
"method": "POST",
|
||||
"operation": "models.embed_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
embed_config = _build_embed_content_config(extra_params)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"config": embed_config,
|
||||
"contents": embedding_input,
|
||||
"model": model_info.model_identifier,
|
||||
}
|
||||
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||
model=model_info.model_identifier,
|
||||
contents=embedding_input,
|
||||
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
)
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
|
||||
response = APIResponse(raw_data=raw_response)
|
||||
if not raw_response.embeddings:
|
||||
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise exc
|
||||
if raw_response.embeddings:
|
||||
response.embedding = raw_response.embeddings[0].values
|
||||
else:
|
||||
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
audio_base64 = request.audio_base64
|
||||
max_tokens = request.max_tokens
|
||||
extra_params = request.extra_params
|
||||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:generateContent",
|
||||
"method": "POST",
|
||||
"operation": "models.generate_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
transcription_prompt = str(
|
||||
extra_params.get(
|
||||
@@ -933,27 +1157,36 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
)
|
||||
)
|
||||
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
|
||||
contents: List[ContentUnion] = [
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=transcription_prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type),
|
||||
],
|
||||
)
|
||||
]
|
||||
generation_config = self._build_generation_config(
|
||||
model_identifier=model_identifier,
|
||||
system_instruction=None,
|
||||
tool_options=None,
|
||||
response_format=None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=None,
|
||||
extra_params=extra_params,
|
||||
enable_google_search=False,
|
||||
)
|
||||
|
||||
try:
|
||||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||||
contents: List[ContentUnion] = [
|
||||
Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part.from_text(text=transcription_prompt),
|
||||
Part.from_bytes(data=base64.b64decode(audio_base64), mime_type=audio_mime_type),
|
||||
],
|
||||
)
|
||||
]
|
||||
generation_config = self._build_generation_config(
|
||||
model_identifier=model_identifier,
|
||||
system_instruction=None,
|
||||
tool_options=None,
|
||||
response_format=None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=None,
|
||||
extra_params=extra_params,
|
||||
enable_google_search=False,
|
||||
)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"audio_base64": audio_base64,
|
||||
"audio_mime_type": audio_mime_type,
|
||||
"config": generation_config,
|
||||
"contents": contents,
|
||||
"model": model_identifier,
|
||||
"transcription_prompt": transcription_prompt,
|
||||
}
|
||||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||
model=model_identifier,
|
||||
contents=contents,
|
||||
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
response, usage_record = _default_normal_response_parser(raw_response)
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
|
||||
return response, usage_record
|
||||
|
||||
|
||||
Reference in New Issue
Block a user