feat:为失败请求留档并提供重试分析

This commit is contained in:
SengokuCola
2026-04-07 15:16:06 +08:00
parent 3b5baf901a
commit 09bce14664
6 changed files with 1304 additions and 91 deletions

View File

@@ -0,0 +1,131 @@
from pathlib import Path
import json
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import ResponseRequest
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
from src.llm_models.request_snapshot import (
attach_request_snapshot,
deserialize_messages_snapshot,
format_request_snapshot_log_info,
save_failed_request_snapshot,
serialize_messages_snapshot,
serialize_response_request_snapshot,
)
from src.llm_models import request_snapshot
def _build_api_provider() -> APIProvider:
return APIProvider(
api_key="secret-token",
base_url="https://example.com/v1",
name="test-provider",
)
def _build_model_info() -> ModelInfo:
return ModelInfo(
api_provider="test-provider",
model_identifier="demo-model",
name="demo-model",
)
def _build_response_request() -> ResponseRequest:
tool_call = ToolCall(
args={"query": "MaiBot"},
call_id="call_1",
func_name="search_web",
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
)
message_list = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
MessageBuilder()
.set_role(RoleType.Tool)
.set_tool_call_id("call_1")
.set_tool_name("search_web")
.add_text_content('{"ok": true}')
.build(),
]
return ResponseRequest(
extra_params={"trace_id": "trace-123"},
max_tokens=256,
message_list=message_list,
model_info=_build_model_info(),
response_format=RespFormat(RespFormatType.JSON_OBJ),
temperature=0.2,
tool_options=[ToolOption(name="search_web", description="搜索网页")],
)
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
request = _build_response_request()
snapshot_messages = serialize_messages_snapshot(request.message_list)
restored_messages = deserialize_messages_snapshot(snapshot_messages)
assert len(restored_messages) == 3
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[1].role == RoleType.Assistant
assert restored_messages[1].tool_calls is not None
assert restored_messages[1].tool_calls[0].func_name == "search_web"
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
assert restored_messages[2].role == RoleType.Tool
assert restored_messages[2].tool_call_id == "call_1"
assert restored_messages[2].tool_name == "search_web"
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
provider = _build_api_provider()
snapshot_path = save_failed_request_snapshot(
api_provider=provider,
client_type="openai",
error=RuntimeError("boom"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
)
assert snapshot_path is not None
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
assert payload["internal_request"]["request_kind"] == "response"
assert payload["api_provider"]["name"] == "test-provider"
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
assert str(snapshot_path) in payload["replay"]["command"]
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
snapshot_path = save_failed_request_snapshot(
api_provider=_build_api_provider(),
client_type="openai",
error=ValueError("invalid"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"messages": []}},
)
assert snapshot_path is not None
exc = RuntimeError("wrapped")
attach_request_snapshot(exc, snapshot_path)
log_info = format_request_snapshot_log_info(exc)
assert str(snapshot_path) in log_info
assert snapshot_path.as_uri() in log_info
assert "uv run python scripts/replay_llm_request.py" in log_info

View File

@@ -0,0 +1,146 @@
# ruff: noqa: E402
import argparse
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SRC_ROOT = PROJECT_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(1, str(PROJECT_ROOT))
from src.config.config import config_manager
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
from src.llm_models.model_client.base_client import EmbeddingRequest
from src.llm_models.request_snapshot import (
deserialize_messages_snapshot,
deserialize_model_info_snapshot,
deserialize_response_format_snapshot,
deserialize_tool_options_snapshot,
)
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
"""加载请求快照。"""
return json.loads(snapshot_path.read_text(encoding="utf-8"))
def _resolve_api_provider(provider_name: str):
"""根据名称解析当前配置中的 API Provider。"""
model_config = config_manager.get_model_config()
for api_provider in model_config.api_providers:
if api_provider.name == provider_name:
return api_provider
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
"""从快照构建响应请求对象。"""
return ResponseRequest(
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
temperature=snapshot.get("temperature"),
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
)
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
"""从快照构建嵌入请求对象。"""
return EmbeddingRequest(
embedding_input=str(snapshot.get("embedding_input") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
"""从快照构建音频转写请求对象。"""
return AudioTranscriptionRequest(
audio_base64=str(snapshot.get("audio_base64") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
async def _replay(snapshot_path: Path) -> int:
"""回放一条失败请求快照。"""
config_manager.initialize()
snapshot = _load_snapshot(snapshot_path)
internal_request = snapshot.get("internal_request")
if not isinstance(internal_request, dict):
raise ValueError("快照缺少 internal_request 字段")
provider_snapshot = snapshot.get("api_provider")
if not isinstance(provider_snapshot, dict):
raise ValueError("快照缺少 api_provider 字段")
provider_name = str(provider_snapshot.get("name") or "")
if not provider_name:
raise ValueError("快照中的 api_provider.name 不能为空")
api_provider = _resolve_api_provider(provider_name)
client = client_registry.get_client_class_instance(api_provider, force_new=True)
request_kind = str(internal_request.get("request_kind") or "").strip()
if request_kind == "response":
response = await client.get_response(_build_response_request(internal_request))
elif request_kind == "embedding":
response = await client.get_embedding(_build_embedding_request(internal_request))
elif request_kind == "audio_transcription":
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
else:
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
output_payload = {
"content": response.content,
"embedding_length": len(response.embedding or []),
"has_embedding": response.embedding is not None,
"model_name": response.usage.model_name if response.usage is not None else None,
"provider_name": response.usage.provider_name if response.usage is not None else None,
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
"reasoning_content": response.reasoning_content,
"tool_calls": [
{
"args": tool_call.args,
"call_id": tool_call.call_id,
"func_name": tool_call.func_name,
}
for tool_call in (response.tool_calls or [])
],
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
}
if response.usage is not None
else None,
}
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
return 0
def main() -> int:
"""脚本入口。"""
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
args = parser.parse_args()
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
if not snapshot_path.exists():
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
return asyncio.run(_replay(snapshot_path))
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,7 +1,10 @@
# ruff: noqa: B025
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
import asyncio import asyncio
import base64 import base64
import binascii
import io import io
import json import json
@@ -21,6 +24,8 @@ from google.genai.types import (
EmbedContentConfig, EmbedContentConfig,
EmbedContentResponse, EmbedContentResponse,
FunctionDeclaration, FunctionDeclaration,
FunctionCall,
FunctionResponse,
GenerateContentConfig, GenerateContentConfig,
GenerateContentResponse, GenerateContentResponse,
GoogleSearch, GoogleSearch,
@@ -60,6 +65,14 @@ from .base_client import (
UsageTuple, UsageTuple,
client_registry, client_registry,
) )
from ..request_snapshot import (
attach_request_snapshot,
has_request_snapshot,
save_failed_request_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
logger = get_logger("Gemini客户端") logger = get_logger("Gemini客户端")
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
} }
"""可透传给 `EmbedContentConfig` 的额外参数字段。""" """可透传给 `EmbedContentConfig` 的额外参数字段。"""
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
def _normalize_image_mime_type(image_format: str) -> str: def _normalize_image_mime_type(image_format: str) -> str:
"""将图片格式名称转换为标准 MIME 类型。 """将图片格式名称转换为标准 MIME 类型。
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
return {"result": content} return {"result": content}
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
if not thought_signature:
return None
return {
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
}
}
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
if not tool_call.extra_content:
return None
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
if not isinstance(provider_payload, dict):
return None
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
if isinstance(raw_thought_signature, bytes):
return raw_thought_signature
if not isinstance(raw_thought_signature, str):
return None
normalized_signature = raw_thought_signature.strip()
if not normalized_signature:
return None
try:
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
except (binascii.Error, ValueError):
return normalized_signature.encode("utf-8")
def _build_gemini_function_call_part(
tool_call: ToolCall,
*,
inject_fallback_signature: bool,
) -> Part:
"""根据内部工具调用构建 Gemini function call part。"""
thought_signature = _extract_gemini_thought_signature(tool_call)
if thought_signature is None and inject_fallback_signature:
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
return Part(
function_call=FunctionCall(
id=tool_call.call_id,
name=tool_call.func_name,
args=tool_call.args or {},
),
thought_signature=thought_signature,
)
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]: def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
"""安全获取 Gemini 响应中的候选列表。 """安全获取 Gemini 响应中的候选列表。
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
if message.role == RoleType.Assistant: if message.role == RoleType.Assistant:
assistant_parts = _build_non_tool_parts(message) assistant_parts = _build_non_tool_parts(message)
if message.tool_calls: if message.tool_calls:
for tool_call in message.tool_calls: for tool_call_index, tool_call in enumerate(message.tool_calls):
assistant_parts.append( assistant_parts.append(
Part.from_function_call( _build_gemini_function_call_part(
name=tool_call.func_name, tool_call,
args=tool_call.args or {}, inject_fallback_signature=tool_call_index == 0,
) )
) )
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
@@ -256,10 +330,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
"且消息中未携带 tool_name" "且消息中未携带 tool_name"
) )
tool_name_by_call_id[message.tool_call_id] = tool_name tool_name_by_call_id[message.tool_call_id] = tool_name
function_response_part = Part.from_function_response( function_response_part = Part(
function_response=FunctionResponse(
id=message.tool_call_id,
name=tool_name, name=tool_name,
response=_normalize_function_response_payload(message), response=_normalize_function_response_payload(message),
) )
)
contents.append(Content(role="tool", parts=[function_response_part])) contents.append(Content(role="tool", parts=[function_response_part]))
continue continue
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
Raises: Raises:
RespParseException: 当函数调用结构不合法时抛出。 RespParseException: 当函数调用结构不合法时抛出。
""" """
raw_function_calls = getattr(response, "function_calls", None)
candidates = _get_candidates(response) candidates = _get_candidates(response)
if not raw_function_calls and candidates: tool_calls: List[ToolCall] = []
raw_function_calls = []
for candidate in candidates: for candidate in candidates:
content = getattr(candidate, "content", None) content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or [] parts = getattr(content, "parts", None) or []
for part in parts: for part in parts:
function_call = getattr(part, "function_call", None) function_call = getattr(part, "function_call", None)
if function_call is not None: if function_call is None:
raw_function_calls.append(function_call) continue
call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
call_args = getattr(function_call, "args", None) or {}
if not isinstance(call_name, str) or not call_name:
raise RespParseException(response, "响应解析失败Gemini 工具调用缺少 name 字段")
if not isinstance(call_args, dict):
raise RespParseException(response, "响应解析失败Gemini 工具调用参数无法解析为字典")
tool_calls.append(
ToolCall(
call_id=call_id,
func_name=call_name,
args=call_args,
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
)
)
if tool_calls:
return tool_calls
raw_function_calls = getattr(response, "function_calls", None)
if not raw_function_calls: if not raw_function_calls:
return [] return []
tool_calls: List[ToolCall] = []
for index, function_call in enumerate(raw_function_calls, start=1): for index, function_call in enumerate(raw_function_calls, start=1):
call_name = getattr(function_call, "name", None) call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}" call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
@@ -808,6 +904,15 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
""" """
model_info = request.model_info model_info = request.model_info
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
try:
contents, system_instruction = _convert_messages(request.message_list) contents, system_instruction = _convert_messages(request.message_list)
model_identifier, enable_google_search = self._resolve_model_identifier( model_identifier, enable_google_search = self._resolve_model_identifier(
model_info.model_identifier, model_info.model_identifier,
@@ -823,8 +928,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
extra_params=request.extra_params, extra_params=request.extra_params,
enable_google_search=enable_google_search, enable_google_search=enable_google_search,
) )
snapshot_provider_request["request_kwargs"] = {
try: "config": generation_config,
"contents": contents,
"enable_google_search": enable_google_search,
"model": model_identifier,
"system_instruction": system_instruction,
}
if model_info.force_stream_mode: if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task( stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
self.client.aio.models.generate_content_stream( self.client.aio.models.generate_content_stream(
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
raise raise
except (ClientError, ServerError) as exc: except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500) status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=wrapped_error,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except EmptyResponseException as exc:
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc: except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
except EmptyResponseException: except EmptyResponseException:
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
model_info = request.model_info model_info = request.model_info
embedding_input = request.embedding_input embedding_input = request.embedding_input
extra_params = request.extra_params extra_params = request.extra_params
embed_config = _build_embed_content_config(extra_params) snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:embedContent",
"method": "POST",
"operation": "models.embed_content",
"request_kwargs": {},
}
try: try:
embed_config = _build_embed_content_config(extra_params)
snapshot_provider_request["request_kwargs"] = {
"config": embed_config,
"contents": embedding_input,
"model": model_info.model_identifier,
}
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier, model=model_info.model_identifier,
contents=embedding_input, contents=embedding_input,
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
) )
except (ClientError, ServerError) as exc: except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500) status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc: except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
response = APIResponse(raw_data=raw_response) response = APIResponse(raw_data=raw_response)
if not raw_response.embeddings:
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if raw_response.embeddings: if raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values response.embedding = raw_response.embeddings[0].values
else: else:
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
audio_base64 = request.audio_base64 audio_base64 = request.audio_base64
max_tokens = request.max_tokens max_tokens = request.max_tokens
extra_params = request.extra_params extra_params = request.extra_params
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params) snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
transcription_prompt = str( transcription_prompt = str(
extra_params.get( extra_params.get(
@@ -933,6 +1157,9 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
) )
) )
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav")) audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
try:
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
contents: List[ContentUnion] = [ contents: List[ContentUnion] = [
Content( Content(
role="user", role="user",
@@ -952,8 +1179,14 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
extra_params=extra_params, extra_params=extra_params,
enable_google_search=False, enable_google_search=False,
) )
snapshot_provider_request["request_kwargs"] = {
try: "audio_base64": audio_base64,
"audio_mime_type": audio_mime_type,
"config": generation_config,
"contents": contents,
"model": model_identifier,
"transcription_prompt": transcription_prompt,
}
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content( raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_identifier, model=model_identifier,
contents=contents, contents=contents,
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
response, usage_record = _default_normal_response_parser(raw_response) response, usage_record = _default_normal_response_parser(raw_response)
except (ClientError, ServerError) as exc: except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500) status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc: except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
return response, usage_record return response, usage_record

View File

@@ -4,7 +4,6 @@ import binascii
import io import io
import json import json
import re import re
from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
@@ -60,6 +59,14 @@ from .base_client import (
UsageTuple, UsageTuple,
client_registry, client_registry,
) )
from ..request_snapshot import (
attach_request_snapshot,
has_request_snapshot,
save_failed_request_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
logger = get_logger("llm_models") logger = get_logger("llm_models")
@@ -533,6 +540,13 @@ def _coerce_openai_argument(value: Any) -> Any | Omit:
return value return value
def _snapshot_openai_argument(value: Any | Omit) -> Any | None:
"""将 OpenAI SDK 参数转换为适合写入快照的普通值。"""
if value is omit:
return None
return value
def _build_api_status_message(error: APIStatusError) -> str: def _build_api_status_message(error: APIStatusError) -> str:
"""构建更适合记录和展示的状态错误信息。 """构建更适合记录和展示的状态错误信息。
@@ -939,10 +953,21 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
Returns: Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。 Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
""" """
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/chat/completions",
"method": "POST",
"operation": "chat.completions.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
model_info = request.model_info model_info = request.model_info
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
tools: Iterable[ChatCompletionToolParam] | Omit = ( try:
_convert_tool_options(request.tool_options) if request.tool_options else omit messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
tools_payload: List[ChatCompletionToolParam] | None = (
_convert_tool_options(request.tool_options) if request.tool_options else None
) )
openai_response_format = _convert_response_format(request.response_format) openai_response_format = _convert_response_format(request.response_format)
request_overrides = split_openai_request_overrides( request_overrides = split_openai_request_overrides(
@@ -958,14 +983,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
else _coerce_openai_argument(request.max_tokens) else _coerce_openai_argument(request.max_tokens)
) )
snapshot_provider_request["request_kwargs"] = {
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"max_tokens": _snapshot_openai_argument(max_tokens_argument),
"messages": messages_payload,
"model": model_info.model_identifier,
"response_format": _snapshot_openai_argument(openai_response_format),
"stream": bool(model_info.force_stream_mode),
"temperature": _snapshot_openai_argument(temperature_argument),
"tools": tools_payload,
}
try:
if model_info.force_stream_mode: if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task( stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
self.client.chat.completions.create( self.client.chat.completions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
messages=messages, messages=messages_payload,
tools=tools, tools=tools_payload or omit,
temperature=temperature_argument, temperature=temperature_argument,
max_tokens=max_tokens_argument, max_tokens=max_tokens_argument,
stream=True, stream=True,
@@ -984,8 +1020,8 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task( completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
self.client.chat.completions.create( self.client.chat.completions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
messages=messages, messages=messages_payload,
tools=tools, tools=tools_payload or omit,
temperature=temperature_argument, temperature=temperature_argument,
max_tokens=max_tokens_argument, max_tokens=max_tokens_argument,
stream=False, stream=False,
@@ -1000,10 +1036,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
await await_task_with_interrupt(completion_task, request.interrupt_flag), await await_task_with_interrupt(completion_task, request.interrupt_flag),
) )
return response_parser(raw_response) return response_parser(raw_response)
except (EmptyResponseException, RespParseException) as exc:
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
except APIConnectionError as exc: except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc: except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except ReqAbortException:
raise
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
async def _execute_embedding_request( async def _execute_embedding_request(
self, self,
@@ -1020,9 +1106,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
model_info = request.model_info model_info = request.model_info
embedding_input = request.embedding_input embedding_input = request.embedding_input
extra_params = request.extra_params extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params) snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/embeddings",
"method": "POST",
"operation": "embeddings.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
try: try:
request_overrides = split_openai_request_overrides(extra_params)
snapshot_provider_request["request_kwargs"] = {
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"input": embedding_input,
"model": model_info.model_identifier,
}
raw_response = await self.client.embeddings.create( raw_response = await self.client.embeddings.create(
model=model_info.model_identifier, model=model_info.model_identifier,
input=embedding_input, input=embedding_input,
@@ -1031,11 +1133,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
extra_body=request_overrides.extra_body or None, extra_body=request_overrides.extra_body or None,
) )
except APIConnectionError as exc: except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc: except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
response = APIResponse() response = APIResponse()
if not raw_response.data:
exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if raw_response.data: if raw_response.data:
response.embedding = raw_response.data[0].embedding response.embedding = raw_response.data[0].embedding
else: else:
@@ -1059,10 +1210,27 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
model_info = request.model_info model_info = request.model_info
audio_base64 = request.audio_base64 audio_base64 = request.audio_base64
extra_params = request.extra_params extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params) snapshot_provider_request = {
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64))) "base_url": self.api_provider.base_url,
"endpoint": "/audio/transcriptions",
"method": "POST",
"operation": "audio.transcriptions.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
try: try:
request_overrides = split_openai_request_overrides(extra_params)
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
snapshot_provider_request["request_kwargs"] = {
"audio_base64": audio_base64,
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"file_name": "audio.wav",
"model": model_info.model_identifier,
}
raw_response = await self.client.audio.transcriptions.create( raw_response = await self.client.audio.transcriptions.create(
model=model_info.model_identifier, model=model_info.model_identifier,
file=audio_file, file=audio_file,
@@ -1071,12 +1239,61 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
extra_body=request_overrides.extra_body or None, extra_body=request_overrides.extra_body or None,
) )
except APIConnectionError as exc: except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc: except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
response = APIResponse() response = APIResponse()
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None) transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
if not isinstance(transcription_text, str):
exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if isinstance(transcription_text, str): if isinstance(transcription_text, str):
response.content = transcription_text response.content = transcription_text
return response, None return response, None

View File

@@ -0,0 +1,447 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Mapping, Sequence
import base64
import json
import re
from src.common.logger import get_logger
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, normalize_tool_options
PROJECT_ROOT = Path(__file__).resolve().parents[2]
LLM_REQUEST_LOG_DIR = PROJECT_ROOT / "logs" / "llm_request"
REPLAY_SCRIPT_RELATIVE_PATH = Path("scripts") / "replay_llm_request.py"
REPLAY_SCRIPT_PATH = PROJECT_ROOT / REPLAY_SCRIPT_RELATIVE_PATH
FILENAME_SAFE_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
SNAPSHOT_VERSION = 1
logger = get_logger("llm_request_snapshot")
def _json_friendly(value: Any) -> Any:
"""将任意对象尽量转换为可写入 JSON 的结构。"""
if value is None or isinstance(value, (bool, float, int, str)):
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, Path):
return str(value)
if isinstance(value, (bytes, bytearray)):
return base64.b64encode(bytes(value)).decode("ascii")
if isinstance(value, Mapping):
return {str(key): _json_friendly(item) for key, item in value.items()}
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
return [_json_friendly(item) for item in value]
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
try:
return _json_friendly(model_dump(mode="json", exclude_none=True))
except TypeError:
return _json_friendly(model_dump(exclude_none=True))
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return _json_friendly(to_dict())
return str(value)
def _sanitize_filename_component(value: str) -> str:
"""将任意字符串转换为适合文件名使用的片段。"""
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
normalized_value = normalized_value.strip("-._")
return normalized_value or "unknown"
def _serialize_tool_call(tool_call: ToolCall) -> dict[str, Any]:
"""序列化单个工具调用。"""
payload = {
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": _json_friendly(tool_call.args or {}),
},
}
if tool_call.extra_content:
payload["extra_content"] = _json_friendly(tool_call.extra_content)
return payload
def serialize_tool_calls_snapshot(tool_calls: Sequence[ToolCall] | None) -> list[dict[str, Any]]:
"""序列化工具调用列表。"""
if not tool_calls:
return []
return [_serialize_tool_call(tool_call) for tool_call in tool_calls]
def deserialize_tool_calls_snapshot(raw_tool_calls: Any) -> list[ToolCall]:
"""从快照恢复工具调用列表。"""
if raw_tool_calls in (None, []):
return []
if not isinstance(raw_tool_calls, list):
raise ValueError("快照中的 tool_calls 必须是列表")
normalized_tool_calls: list[ToolCall] = []
for raw_tool_call in raw_tool_calls:
if not isinstance(raw_tool_call, dict):
raise ValueError("快照中的 tool_call 项必须是字典")
function_info = raw_tool_call.get("function", {})
if isinstance(function_info, dict):
function_name = function_info.get("name")
function_arguments = function_info.get("arguments")
else:
function_name = raw_tool_call.get("name")
function_arguments = raw_tool_call.get("arguments")
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
if not isinstance(call_id, str) or not isinstance(function_name, str):
raise ValueError("快照中的 tool_call 缺少 id 或 function.name")
extra_content = raw_tool_call.get("extra_content")
normalized_tool_calls.append(
ToolCall(
call_id=call_id,
func_name=function_name,
args=function_arguments if isinstance(function_arguments, dict) else {},
extra_content=extra_content if isinstance(extra_content, dict) else None,
)
)
return normalized_tool_calls
def serialize_message_snapshot(message: Message) -> dict[str, Any]:
"""将内部消息对象序列化为可回放的快照结构。"""
parts_payload: list[dict[str, Any]] = []
for part in message.parts:
if isinstance(part, TextMessagePart):
parts_payload.append({"type": "text", "text": part.text})
continue
if isinstance(part, ImageMessagePart):
parts_payload.append(
{
"type": "image",
"image_base64": part.image_base64,
"image_format": part.image_format,
}
)
payload: dict[str, Any] = {
"parts": parts_payload,
"role": message.role.value,
}
if message.tool_call_id:
payload["tool_call_id"] = message.tool_call_id
if message.tool_name:
payload["tool_name"] = message.tool_name
if message.tool_calls:
payload["tool_calls"] = serialize_tool_calls_snapshot(message.tool_calls)
return payload
def deserialize_message_snapshot(raw_message: Any) -> Message:
"""从快照恢复内部消息对象。"""
if not isinstance(raw_message, dict):
raise ValueError("快照中的 message 必须是字典")
raw_role = raw_message.get("role")
if not isinstance(raw_role, str):
raise ValueError("快照中的 message 缺少 role")
role = RoleType(raw_role)
builder = MessageBuilder().set_role(role)
raw_tool_calls = raw_message.get("tool_calls")
tool_calls = deserialize_tool_calls_snapshot(raw_tool_calls)
if role == RoleType.Assistant and tool_calls:
builder.set_tool_calls(tool_calls)
tool_call_id = raw_message.get("tool_call_id")
if role == RoleType.Tool and isinstance(tool_call_id, str):
builder.set_tool_call_id(tool_call_id)
tool_name = raw_message.get("tool_name")
if role == RoleType.Tool and isinstance(tool_name, str) and tool_name:
builder.set_tool_name(tool_name)
raw_parts = raw_message.get("parts", [])
if not isinstance(raw_parts, list):
raise ValueError("快照中的 message.parts 必须是列表")
for raw_part in raw_parts:
if not isinstance(raw_part, dict):
raise ValueError("快照中的 message part 必须是字典")
part_type = str(raw_part.get("type", "")).strip().lower()
if part_type == "text":
text = raw_part.get("text")
if not isinstance(text, str):
raise ValueError("文本 part 缺少 text 字段")
builder.add_text_content(text)
continue
if part_type == "image":
image_format = raw_part.get("image_format")
image_base64 = raw_part.get("image_base64")
if not isinstance(image_format, str) or not isinstance(image_base64, str):
raise ValueError("图片 part 缺少 image_format 或 image_base64")
builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
raise ValueError(f"不支持的快照消息 part 类型: {part_type}")
return builder.build()
def serialize_messages_snapshot(messages: Sequence[Message]) -> list[dict[str, Any]]:
"""序列化消息列表。"""
return [serialize_message_snapshot(message) for message in messages]
def deserialize_messages_snapshot(raw_messages: Any) -> list[Message]:
"""从快照恢复消息列表。"""
if not isinstance(raw_messages, list):
raise ValueError("快照中的 messages 必须是列表")
return [deserialize_message_snapshot(raw_message) for raw_message in raw_messages]
def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
"""序列化模型信息。"""
return {
"api_provider": model_info.api_provider,
"extra_params": _json_friendly(dict(model_info.extra_params)),
"force_stream_mode": model_info.force_stream_mode,
"max_tokens": model_info.max_tokens,
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"temperature": model_info.temperature,
}
def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
"""从快照恢复模型信息。"""
if not isinstance(raw_model_info, dict):
raise ValueError("快照中的 model_info 必须是字典")
return ModelInfo(
api_provider=str(raw_model_info.get("api_provider") or ""),
extra_params=dict(raw_model_info.get("extra_params") or {}),
force_stream_mode=bool(raw_model_info.get("force_stream_mode", False)),
max_tokens=raw_model_info.get("max_tokens"),
model_identifier=str(raw_model_info.get("model_identifier") or ""),
name=str(raw_model_info.get("name") or ""),
temperature=raw_model_info.get("temperature"),
)
def serialize_response_format_snapshot(response_format: RespFormat | None) -> dict[str, Any] | None:
"""序列化响应格式定义。"""
if response_format is None:
return None
return response_format.to_dict()
def deserialize_response_format_snapshot(raw_response_format: Any) -> RespFormat | None:
"""从快照恢复响应格式定义。"""
if raw_response_format is None:
return None
if not isinstance(raw_response_format, dict):
raise ValueError("快照中的 response_format 必须是字典")
raw_format_type = raw_response_format.get("format_type")
if not isinstance(raw_format_type, str):
raise ValueError("快照中的 response_format 缺少 format_type")
format_type = RespFormatType(raw_format_type)
raw_schema = raw_response_format.get("schema")
schema = raw_schema if isinstance(raw_schema, dict) else None
return RespFormat(format_type=format_type, schema=schema)
def serialize_tool_options_snapshot(tool_options: Sequence[ToolOption] | None) -> list[dict[str, Any]]:
"""序列化工具定义列表。"""
if not tool_options:
return []
return [tool_option.to_openai_function_schema() for tool_option in tool_options]
def deserialize_tool_options_snapshot(raw_tool_options: Any) -> list[ToolOption] | None:
"""从快照恢复工具定义列表。"""
if raw_tool_options in (None, []):
return None
if not isinstance(raw_tool_options, list):
raise ValueError("快照中的 tool_options 必须是列表")
return normalize_tool_options(raw_tool_options)
def serialize_response_request_snapshot(request: ResponseRequest) -> dict[str, Any]:
"""序列化文本/多模态请求。"""
return {
"extra_params": _json_friendly(dict(request.extra_params)),
"max_tokens": request.max_tokens,
"message_list": serialize_messages_snapshot(request.message_list),
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "response",
"response_format": serialize_response_format_snapshot(request.response_format),
"temperature": request.temperature,
"tool_options": serialize_tool_options_snapshot(request.tool_options),
}
def serialize_embedding_request_snapshot(request: EmbeddingRequest) -> dict[str, Any]:
"""序列化嵌入请求。"""
return {
"embedding_input": request.embedding_input,
"extra_params": _json_friendly(dict(request.extra_params)),
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "embedding",
}
def serialize_audio_request_snapshot(request: AudioTranscriptionRequest) -> dict[str, Any]:
"""序列化音频转写请求。"""
return {
"audio_base64": request.audio_base64,
"extra_params": _json_friendly(dict(request.extra_params)),
"max_tokens": request.max_tokens,
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "audio_transcription",
}
def serialize_api_provider_snapshot(api_provider: APIProvider) -> dict[str, Any]:
"""序列化 API Provider 配置,排除敏感认证信息。"""
return {
"auth_header_name": api_provider.auth_header_name,
"auth_header_prefix": api_provider.auth_header_prefix,
"auth_query_name": api_provider.auth_query_name,
"auth_type": api_provider.auth_type,
"base_url": api_provider.base_url,
"client_type": api_provider.client_type,
"default_headers": _json_friendly(dict(api_provider.default_headers)),
"default_query": _json_friendly(dict(api_provider.default_query)),
"model_list_endpoint": api_provider.model_list_endpoint,
"name": api_provider.name,
"organization": api_provider.organization,
"project": api_provider.project,
"retry_interval": api_provider.retry_interval,
"timeout": api_provider.timeout,
}
def build_replay_command(snapshot_path: Path) -> str:
"""构建回放当前快照的命令。"""
return f'uv run python {REPLAY_SCRIPT_RELATIVE_PATH.as_posix()} "{snapshot_path.resolve()}"'
def save_failed_request_snapshot(
*,
api_provider: APIProvider,
client_type: str,
error: Exception,
internal_request: dict[str, Any],
model_info: ModelInfo,
operation: str,
provider_request: dict[str, Any],
) -> Path | None:
"""保存失败请求快照。"""
try:
LLM_REQUEST_LOG_DIR.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now()
file_name = (
f"{timestamp.strftime('%Y%m%d_%H%M%S_%f')}"
f"_{_sanitize_filename_component(client_type)}"
f"_{_sanitize_filename_component(internal_request.get('request_kind', 'request'))}"
f"_{_sanitize_filename_component(model_info.name or model_info.model_identifier)}.json"
)
snapshot_path = (LLM_REQUEST_LOG_DIR / file_name).resolve()
snapshot_payload: dict[str, Any] = {
"api_provider": serialize_api_provider_snapshot(api_provider),
"client_type": client_type,
"created_at": timestamp.isoformat(timespec="seconds"),
"error": {
"message": str(error),
"status_code": getattr(error, "status_code", None),
"type": type(error).__name__,
},
"internal_request": internal_request,
"model_info": serialize_model_info_snapshot(model_info),
"operation": operation,
"provider_request": _json_friendly(provider_request),
"snapshot_version": SNAPSHOT_VERSION,
}
snapshot_payload["replay"] = {
"command": build_replay_command(snapshot_path),
"file_uri": snapshot_path.as_uri(),
"script_path": str(REPLAY_SCRIPT_PATH),
}
snapshot_path.write_text(
json.dumps(snapshot_payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return snapshot_path
except Exception:
logger.exception("淇濆瓨 LLM 澶辫触璇锋眰蹇収鏃跺彂鐢熷紓甯?")
return None
def attach_request_snapshot(exception: Exception, snapshot_path: Path | None) -> None:
"""将请求快照信息挂载到异常对象上。"""
if snapshot_path is None:
return
exception.request_snapshot_path = str(snapshot_path.resolve())
exception.request_snapshot_uri = snapshot_path.resolve().as_uri()
exception.request_snapshot_replay_command = build_replay_command(snapshot_path)
def has_request_snapshot(exception: Exception) -> bool:
"""鍒ゆ柇寮傚父鏄惁宸插叧鑱斾簡璇锋眰蹇収銆?"""
for candidate in (exception, getattr(exception, "__cause__", None)):
if candidate is None:
continue
if getattr(candidate, "request_snapshot_path", ""):
return True
return False
def format_request_snapshot_log_info(exception: Exception) -> str:
"""将异常上的快照信息格式化为日志片段。"""
for candidate in (exception, getattr(exception, "__cause__", None)):
if candidate is None:
continue
snapshot_path = getattr(candidate, "request_snapshot_path", "")
snapshot_uri = getattr(candidate, "request_snapshot_uri", "")
replay_command = getattr(candidate, "request_snapshot_replay_command", "")
if not any([snapshot_path, snapshot_uri, replay_command]):
continue
lines: list[str] = []
if snapshot_path:
lines.append(f"请求快照路径: {snapshot_path}")
if snapshot_uri:
lines.append(f"请求快照链接: {snapshot_uri}")
if replay_command:
lines.append(f"重放命令: {replay_command}")
if lines:
return "\n " + "\n ".join(lines)
return ""

View File

@@ -37,6 +37,7 @@ from src.llm_models.model_client.base_client import (
UsageRecord, UsageRecord,
client_registry, client_registry,
) )
from src.llm_models.request_snapshot import format_request_snapshot_log_info
from src.llm_models.payload_content.message import Message, MessageBuilder from src.llm_models.payload_content.message import Message, MessageBuilder
from src.llm_models.payload_content.resp_format import RespFormat from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ( from src.llm_models.payload_content.tool_option import (
@@ -1008,6 +1009,16 @@ class LLMOrchestrator:
Returns: Returns:
str: 可直接拼接到日志中的底层异常描述。 str: 可直接拼接到日志中的底层异常描述。
""" """
detail_lines: List[str] = []
if e.__cause__:
detail_lines.append(f"底层异常类型: {type(e.__cause__).__name__}")
detail_lines.append(f"底层异常信息: {e.__cause__}")
snapshot_info = format_request_snapshot_log_info(e)
if detail_lines or snapshot_info:
detail_text = "\n " + "\n ".join(detail_lines) if detail_lines else ""
return f"{detail_text}{snapshot_info}"
if e.__cause__: if e.__cause__:
original_error_type = type(e.__cause__).__name__ original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__) original_error_msg = str(e.__cause__)