feat:为失败请求留档并提供重试分析

This commit is contained in:
SengokuCola
2026-04-07 15:16:06 +08:00
parent 3b5baf901a
commit 09bce14664
6 changed files with 1304 additions and 91 deletions

View File

@@ -0,0 +1,131 @@
from pathlib import Path
import json
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import ResponseRequest
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
from src.llm_models.request_snapshot import (
attach_request_snapshot,
deserialize_messages_snapshot,
format_request_snapshot_log_info,
save_failed_request_snapshot,
serialize_messages_snapshot,
serialize_response_request_snapshot,
)
from src.llm_models import request_snapshot
def _build_api_provider() -> APIProvider:
return APIProvider(
api_key="secret-token",
base_url="https://example.com/v1",
name="test-provider",
)
def _build_model_info() -> ModelInfo:
return ModelInfo(
api_provider="test-provider",
model_identifier="demo-model",
name="demo-model",
)
def _build_response_request() -> ResponseRequest:
tool_call = ToolCall(
args={"query": "MaiBot"},
call_id="call_1",
func_name="search_web",
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
)
message_list = [
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
MessageBuilder()
.set_role(RoleType.Tool)
.set_tool_call_id("call_1")
.set_tool_name("search_web")
.add_text_content('{"ok": true}')
.build(),
]
return ResponseRequest(
extra_params={"trace_id": "trace-123"},
max_tokens=256,
message_list=message_list,
model_info=_build_model_info(),
response_format=RespFormat(RespFormatType.JSON_OBJ),
temperature=0.2,
tool_options=[ToolOption(name="search_web", description="搜索网页")],
)
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
request = _build_response_request()
snapshot_messages = serialize_messages_snapshot(request.message_list)
restored_messages = deserialize_messages_snapshot(snapshot_messages)
assert len(restored_messages) == 3
assert restored_messages[0].role == RoleType.User
assert restored_messages[0].get_text_content() == "你好"
assert restored_messages[0].parts[1].image_format == "png"
assert restored_messages[1].role == RoleType.Assistant
assert restored_messages[1].tool_calls is not None
assert restored_messages[1].tool_calls[0].func_name == "search_web"
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
assert restored_messages[2].role == RoleType.Tool
assert restored_messages[2].tool_call_id == "call_1"
assert restored_messages[2].tool_name == "search_web"
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
provider = _build_api_provider()
snapshot_path = save_failed_request_snapshot(
api_provider=provider,
client_type="openai",
error=RuntimeError("boom"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
)
assert snapshot_path is not None
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
assert payload["internal_request"]["request_kind"] == "response"
assert payload["api_provider"]["name"] == "test-provider"
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
assert str(snapshot_path) in payload["replay"]["command"]
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
request = _build_response_request()
snapshot_path = save_failed_request_snapshot(
api_provider=_build_api_provider(),
client_type="openai",
error=ValueError("invalid"),
internal_request=serialize_response_request_snapshot(request),
model_info=request.model_info,
operation="chat.completions.create",
provider_request={"request_kwargs": {"messages": []}},
)
assert snapshot_path is not None
exc = RuntimeError("wrapped")
attach_request_snapshot(exc, snapshot_path)
log_info = format_request_snapshot_log_info(exc)
assert str(snapshot_path) in log_info
assert snapshot_path.as_uri() in log_info
assert "uv run python scripts/replay_llm_request.py" in log_info

View File

@@ -0,0 +1,146 @@
# ruff: noqa: E402
import argparse
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SRC_ROOT = PROJECT_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(1, str(PROJECT_ROOT))
from src.config.config import config_manager
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
from src.llm_models.model_client.base_client import EmbeddingRequest
from src.llm_models.request_snapshot import (
deserialize_messages_snapshot,
deserialize_model_info_snapshot,
deserialize_response_format_snapshot,
deserialize_tool_options_snapshot,
)
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
"""加载请求快照。"""
return json.loads(snapshot_path.read_text(encoding="utf-8"))
def _resolve_api_provider(provider_name: str):
"""根据名称解析当前配置中的 API Provider。"""
model_config = config_manager.get_model_config()
for api_provider in model_config.api_providers:
if api_provider.name == provider_name:
return api_provider
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
"""从快照构建响应请求对象。"""
return ResponseRequest(
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
temperature=snapshot.get("temperature"),
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
)
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
"""从快照构建嵌入请求对象。"""
return EmbeddingRequest(
embedding_input=str(snapshot.get("embedding_input") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
"""从快照构建音频转写请求对象。"""
return AudioTranscriptionRequest(
audio_base64=str(snapshot.get("audio_base64") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
async def _replay(snapshot_path: Path) -> int:
"""回放一条失败请求快照。"""
config_manager.initialize()
snapshot = _load_snapshot(snapshot_path)
internal_request = snapshot.get("internal_request")
if not isinstance(internal_request, dict):
raise ValueError("快照缺少 internal_request 字段")
provider_snapshot = snapshot.get("api_provider")
if not isinstance(provider_snapshot, dict):
raise ValueError("快照缺少 api_provider 字段")
provider_name = str(provider_snapshot.get("name") or "")
if not provider_name:
raise ValueError("快照中的 api_provider.name 不能为空")
api_provider = _resolve_api_provider(provider_name)
client = client_registry.get_client_class_instance(api_provider, force_new=True)
request_kind = str(internal_request.get("request_kind") or "").strip()
if request_kind == "response":
response = await client.get_response(_build_response_request(internal_request))
elif request_kind == "embedding":
response = await client.get_embedding(_build_embedding_request(internal_request))
elif request_kind == "audio_transcription":
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
else:
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
output_payload = {
"content": response.content,
"embedding_length": len(response.embedding or []),
"has_embedding": response.embedding is not None,
"model_name": response.usage.model_name if response.usage is not None else None,
"provider_name": response.usage.provider_name if response.usage is not None else None,
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
"reasoning_content": response.reasoning_content,
"tool_calls": [
{
"args": tool_call.args,
"call_id": tool_call.call_id,
"func_name": tool_call.func_name,
}
for tool_call in (response.tool_calls or [])
],
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
}
if response.usage is not None
else None,
}
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
return 0
def main() -> int:
"""脚本入口。"""
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
args = parser.parse_args()
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
if not snapshot_path.exists():
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
return asyncio.run(_replay(snapshot_path))
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -1,7 +1,10 @@
# ruff: noqa: B025
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
import asyncio
import base64
import binascii
import io
import json
@@ -21,6 +24,8 @@ from google.genai.types import (
EmbedContentConfig,
EmbedContentResponse,
FunctionDeclaration,
FunctionCall,
FunctionResponse,
GenerateContentConfig,
GenerateContentResponse,
GoogleSearch,
@@ -60,6 +65,14 @@ from .base_client import (
UsageTuple,
client_registry,
)
from ..request_snapshot import (
attach_request_snapshot,
has_request_snapshot,
save_failed_request_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
logger = get_logger("Gemini客户端")
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
}
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
def _normalize_image_mime_type(image_format: str) -> str:
"""将图片格式名称转换为标准 MIME 类型。
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
return {"result": content}
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
if not thought_signature:
return None
return {
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
}
}
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
if not tool_call.extra_content:
return None
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
if not isinstance(provider_payload, dict):
return None
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
if isinstance(raw_thought_signature, bytes):
return raw_thought_signature
if not isinstance(raw_thought_signature, str):
return None
normalized_signature = raw_thought_signature.strip()
if not normalized_signature:
return None
try:
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
except (binascii.Error, ValueError):
return normalized_signature.encode("utf-8")
def _build_gemini_function_call_part(
tool_call: ToolCall,
*,
inject_fallback_signature: bool,
) -> Part:
"""根据内部工具调用构建 Gemini function call part。"""
thought_signature = _extract_gemini_thought_signature(tool_call)
if thought_signature is None and inject_fallback_signature:
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
return Part(
function_call=FunctionCall(
id=tool_call.call_id,
name=tool_call.func_name,
args=tool_call.args or {},
),
thought_signature=thought_signature,
)
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
"""安全获取 Gemini 响应中的候选列表。
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
if message.role == RoleType.Assistant:
assistant_parts = _build_non_tool_parts(message)
if message.tool_calls:
for tool_call in message.tool_calls:
for tool_call_index, tool_call in enumerate(message.tool_calls):
assistant_parts.append(
Part.from_function_call(
name=tool_call.func_name,
args=tool_call.args or {},
_build_gemini_function_call_part(
tool_call,
inject_fallback_signature=tool_call_index == 0,
)
)
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
@@ -256,10 +330,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
"且消息中未携带 tool_name"
)
tool_name_by_call_id[message.tool_call_id] = tool_name
function_response_part = Part.from_function_response(
function_response_part = Part(
function_response=FunctionResponse(
id=message.tool_call_id,
name=tool_name,
response=_normalize_function_response_payload(message),
)
)
contents.append(Content(role="tool", parts=[function_response_part]))
continue
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
Raises:
RespParseException: 当函数调用结构不合法时抛出。
"""
raw_function_calls = getattr(response, "function_calls", None)
candidates = _get_candidates(response)
if not raw_function_calls and candidates:
raw_function_calls = []
tool_calls: List[ToolCall] = []
for candidate in candidates:
content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or []
for part in parts:
function_call = getattr(part, "function_call", None)
if function_call is not None:
raw_function_calls.append(function_call)
if function_call is None:
continue
call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
call_args = getattr(function_call, "args", None) or {}
if not isinstance(call_name, str) or not call_name:
raise RespParseException(response, "响应解析失败Gemini 工具调用缺少 name 字段")
if not isinstance(call_args, dict):
raise RespParseException(response, "响应解析失败Gemini 工具调用参数无法解析为字典")
tool_calls.append(
ToolCall(
call_id=call_id,
func_name=call_name,
args=call_args,
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
)
)
if tool_calls:
return tool_calls
raw_function_calls = getattr(response, "function_calls", None)
if not raw_function_calls:
return []
tool_calls: List[ToolCall] = []
for index, function_call in enumerate(raw_function_calls, start=1):
call_name = getattr(function_call, "name", None)
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
@@ -808,6 +904,15 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
model_info = request.model_info
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
try:
contents, system_instruction = _convert_messages(request.message_list)
model_identifier, enable_google_search = self._resolve_model_identifier(
model_info.model_identifier,
@@ -823,8 +928,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
extra_params=request.extra_params,
enable_google_search=enable_google_search,
)
try:
snapshot_provider_request["request_kwargs"] = {
"config": generation_config,
"contents": contents,
"enable_google_search": enable_google_search,
"model": model_identifier,
"system_instruction": system_instruction,
}
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
self.client.aio.models.generate_content_stream(
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
raise
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=wrapped_error,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except EmptyResponseException as exc:
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
except EmptyResponseException:
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
model_info = request.model_info
embedding_input = request.embedding_input
extra_params = request.extra_params
embed_config = _build_embed_content_config(extra_params)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:embedContent",
"method": "POST",
"operation": "models.embed_content",
"request_kwargs": {},
}
try:
embed_config = _build_embed_content_config(extra_params)
snapshot_provider_request["request_kwargs"] = {
"config": embed_config,
"contents": embedding_input,
"model": model_info.model_identifier,
}
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
)
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
response = APIResponse(raw_data=raw_response)
if not raw_response.embeddings:
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="models.embed_content",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
audio_base64 = request.audio_base64
max_tokens = request.max_tokens
extra_params = request.extra_params
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/models/{model}:generateContent",
"method": "POST",
"operation": "models.generate_content",
"request_kwargs": {},
}
transcription_prompt = str(
extra_params.get(
@@ -933,6 +1157,9 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
)
)
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
try:
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
contents: List[ContentUnion] = [
Content(
role="user",
@@ -952,8 +1179,14 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
extra_params=extra_params,
enable_google_search=False,
)
try:
snapshot_provider_request["request_kwargs"] = {
"audio_base64": audio_base64,
"audio_mime_type": audio_mime_type,
"config": generation_config,
"contents": contents,
"model": model_identifier,
"transcription_prompt": transcription_prompt,
}
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
model=model_identifier,
contents=contents,
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
response, usage_record = _default_normal_response_parser(raw_response)
except (ClientError, ServerError) as exc:
status_code = int(getattr(exc, "code", 500) or 500)
raise RespNotOkException(status_code, str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(status_code, str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
raise NetworkConnectionError(str(exc)) from exc
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="gemini",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="models.generate_content",
provider_request=snapshot_provider_request,
)
wrapped_error = (
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
)
attach_request_snapshot(wrapped_error, snapshot_path)
if wrapped_error is exc:
raise
raise wrapped_error from exc
return response, usage_record

View File

@@ -4,7 +4,6 @@ import binascii
import io
import json
import re
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
@@ -60,6 +59,14 @@ from .base_client import (
UsageTuple,
client_registry,
)
from ..request_snapshot import (
attach_request_snapshot,
has_request_snapshot,
save_failed_request_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
logger = get_logger("llm_models")
@@ -533,6 +540,13 @@ def _coerce_openai_argument(value: Any) -> Any | Omit:
return value
def _snapshot_openai_argument(value: Any | Omit) -> Any | None:
"""将 OpenAI SDK 参数转换为适合写入快照的普通值。"""
if value is omit:
return None
return value
def _build_api_status_message(error: APIStatusError) -> str:
"""构建更适合记录和展示的状态错误信息。
@@ -939,10 +953,21 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/chat/completions",
"method": "POST",
"operation": "chat.completions.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
model_info = request.model_info
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
tools: Iterable[ChatCompletionToolParam] | Omit = (
_convert_tool_options(request.tool_options) if request.tool_options else omit
try:
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
tools_payload: List[ChatCompletionToolParam] | None = (
_convert_tool_options(request.tool_options) if request.tool_options else None
)
openai_response_format = _convert_response_format(request.response_format)
request_overrides = split_openai_request_overrides(
@@ -958,14 +983,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
else _coerce_openai_argument(request.max_tokens)
)
snapshot_provider_request["request_kwargs"] = {
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"max_tokens": _snapshot_openai_argument(max_tokens_argument),
"messages": messages_payload,
"model": model_info.model_identifier,
"response_format": _snapshot_openai_argument(openai_response_format),
"stream": bool(model_info.force_stream_mode),
"temperature": _snapshot_openai_argument(temperature_argument),
"tools": tools_payload,
}
try:
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
messages=messages_payload,
tools=tools_payload or omit,
temperature=temperature_argument,
max_tokens=max_tokens_argument,
stream=True,
@@ -984,8 +1020,8 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
messages=messages_payload,
tools=tools_payload or omit,
temperature=temperature_argument,
max_tokens=max_tokens_argument,
stream=False,
@@ -1000,10 +1036,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
await await_task_with_interrupt(completion_task, request.interrupt_flag),
)
return response_parser(raw_response)
except (EmptyResponseException, RespParseException) as exc:
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except ReqAbortException:
raise
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_response_request_snapshot(request),
model_info=model_info,
operation="chat.completions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
async def _execute_embedding_request(
self,
@@ -1020,9 +1106,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
model_info = request.model_info
embedding_input = request.embedding_input
extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params)
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/embeddings",
"method": "POST",
"operation": "embeddings.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
try:
request_overrides = split_openai_request_overrides(extra_params)
snapshot_provider_request["request_kwargs"] = {
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"input": embedding_input,
"model": model_info.model_identifier,
}
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
@@ -1031,11 +1133,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
extra_body=request_overrides.extra_body or None,
)
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
response = APIResponse()
if not raw_response.data:
exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_embedding_request_snapshot(request),
model_info=model_info,
operation="embeddings.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if raw_response.data:
response.embedding = raw_response.data[0].embedding
else:
@@ -1059,10 +1210,27 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
model_info = request.model_info
audio_base64 = request.audio_base64
extra_params = request.extra_params
request_overrides = split_openai_request_overrides(extra_params)
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
snapshot_provider_request = {
"base_url": self.api_provider.base_url,
"endpoint": "/audio/transcriptions",
"method": "POST",
"operation": "audio.transcriptions.create",
"organization": self.api_provider.organization,
"project": self.api_provider.project,
"request_kwargs": {},
}
try:
request_overrides = split_openai_request_overrides(extra_params)
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
snapshot_provider_request["request_kwargs"] = {
"audio_base64": audio_base64,
"extra_body": request_overrides.extra_body or None,
"extra_headers": request_overrides.extra_headers or None,
"extra_query": request_overrides.extra_query or None,
"file_name": "audio.wav",
"model": model_info.model_identifier,
}
raw_response = await self.client.audio.transcriptions.create(
model=model_info.model_identifier,
file=audio_file,
@@ -1071,12 +1239,61 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
extra_body=request_overrides.extra_body or None,
)
except APIConnectionError as exc:
raise NetworkConnectionError(str(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = NetworkConnectionError(str(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except APIStatusError as exc:
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
attach_request_snapshot(wrapped_error, snapshot_path)
raise wrapped_error from exc
except Exception as exc:
if has_request_snapshot(exc):
raise
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise
response = APIResponse()
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
if not isinstance(transcription_text, str):
exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。")
snapshot_path = save_failed_request_snapshot(
api_provider=self.api_provider,
client_type="openai",
error=exc,
internal_request=serialize_audio_request_snapshot(request),
model_info=model_info,
operation="audio.transcriptions.create",
provider_request=snapshot_provider_request,
)
attach_request_snapshot(exc, snapshot_path)
raise exc
if isinstance(transcription_text, str):
response.content = transcription_text
return response, None

View File

@@ -0,0 +1,447 @@
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Mapping, Sequence
import base64
import json
import re
from src.common.logger import get_logger
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, normalize_tool_options
PROJECT_ROOT = Path(__file__).resolve().parents[2]
LLM_REQUEST_LOG_DIR = PROJECT_ROOT / "logs" / "llm_request"
REPLAY_SCRIPT_RELATIVE_PATH = Path("scripts") / "replay_llm_request.py"
REPLAY_SCRIPT_PATH = PROJECT_ROOT / REPLAY_SCRIPT_RELATIVE_PATH
FILENAME_SAFE_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
SNAPSHOT_VERSION = 1
logger = get_logger("llm_request_snapshot")
def _json_friendly(value: Any) -> Any:
"""将任意对象尽量转换为可写入 JSON 的结构。"""
if value is None or isinstance(value, (bool, float, int, str)):
return value
if isinstance(value, Enum):
return value.value
if isinstance(value, Path):
return str(value)
if isinstance(value, (bytes, bytearray)):
return base64.b64encode(bytes(value)).decode("ascii")
if isinstance(value, Mapping):
return {str(key): _json_friendly(item) for key, item in value.items()}
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
return [_json_friendly(item) for item in value]
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
try:
return _json_friendly(model_dump(mode="json", exclude_none=True))
except TypeError:
return _json_friendly(model_dump(exclude_none=True))
to_dict = getattr(value, "to_dict", None)
if callable(to_dict):
return _json_friendly(to_dict())
return str(value)
def _sanitize_filename_component(value: str) -> str:
"""将任意字符串转换为适合文件名使用的片段。"""
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
normalized_value = normalized_value.strip("-._")
return normalized_value or "unknown"
def _serialize_tool_call(tool_call: ToolCall) -> dict[str, Any]:
"""序列化单个工具调用。"""
payload = {
"id": tool_call.call_id,
"function": {
"name": tool_call.func_name,
"arguments": _json_friendly(tool_call.args or {}),
},
}
if tool_call.extra_content:
payload["extra_content"] = _json_friendly(tool_call.extra_content)
return payload
def serialize_tool_calls_snapshot(tool_calls: Sequence[ToolCall] | None) -> list[dict[str, Any]]:
"""序列化工具调用列表。"""
if not tool_calls:
return []
return [_serialize_tool_call(tool_call) for tool_call in tool_calls]
def deserialize_tool_calls_snapshot(raw_tool_calls: Any) -> list[ToolCall]:
"""从快照恢复工具调用列表。"""
if raw_tool_calls in (None, []):
return []
if not isinstance(raw_tool_calls, list):
raise ValueError("快照中的 tool_calls 必须是列表")
normalized_tool_calls: list[ToolCall] = []
for raw_tool_call in raw_tool_calls:
if not isinstance(raw_tool_call, dict):
raise ValueError("快照中的 tool_call 项必须是字典")
function_info = raw_tool_call.get("function", {})
if isinstance(function_info, dict):
function_name = function_info.get("name")
function_arguments = function_info.get("arguments")
else:
function_name = raw_tool_call.get("name")
function_arguments = raw_tool_call.get("arguments")
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
if not isinstance(call_id, str) or not isinstance(function_name, str):
raise ValueError("快照中的 tool_call 缺少 id 或 function.name")
extra_content = raw_tool_call.get("extra_content")
normalized_tool_calls.append(
ToolCall(
call_id=call_id,
func_name=function_name,
args=function_arguments if isinstance(function_arguments, dict) else {},
extra_content=extra_content if isinstance(extra_content, dict) else None,
)
)
return normalized_tool_calls
def serialize_message_snapshot(message: Message) -> dict[str, Any]:
"""将内部消息对象序列化为可回放的快照结构。"""
parts_payload: list[dict[str, Any]] = []
for part in message.parts:
if isinstance(part, TextMessagePart):
parts_payload.append({"type": "text", "text": part.text})
continue
if isinstance(part, ImageMessagePart):
parts_payload.append(
{
"type": "image",
"image_base64": part.image_base64,
"image_format": part.image_format,
}
)
payload: dict[str, Any] = {
"parts": parts_payload,
"role": message.role.value,
}
if message.tool_call_id:
payload["tool_call_id"] = message.tool_call_id
if message.tool_name:
payload["tool_name"] = message.tool_name
if message.tool_calls:
payload["tool_calls"] = serialize_tool_calls_snapshot(message.tool_calls)
return payload
def deserialize_message_snapshot(raw_message: Any) -> Message:
"""从快照恢复内部消息对象。"""
if not isinstance(raw_message, dict):
raise ValueError("快照中的 message 必须是字典")
raw_role = raw_message.get("role")
if not isinstance(raw_role, str):
raise ValueError("快照中的 message 缺少 role")
role = RoleType(raw_role)
builder = MessageBuilder().set_role(role)
raw_tool_calls = raw_message.get("tool_calls")
tool_calls = deserialize_tool_calls_snapshot(raw_tool_calls)
if role == RoleType.Assistant and tool_calls:
builder.set_tool_calls(tool_calls)
tool_call_id = raw_message.get("tool_call_id")
if role == RoleType.Tool and isinstance(tool_call_id, str):
builder.set_tool_call_id(tool_call_id)
tool_name = raw_message.get("tool_name")
if role == RoleType.Tool and isinstance(tool_name, str) and tool_name:
builder.set_tool_name(tool_name)
raw_parts = raw_message.get("parts", [])
if not isinstance(raw_parts, list):
raise ValueError("快照中的 message.parts 必须是列表")
for raw_part in raw_parts:
if not isinstance(raw_part, dict):
raise ValueError("快照中的 message part 必须是字典")
part_type = str(raw_part.get("type", "")).strip().lower()
if part_type == "text":
text = raw_part.get("text")
if not isinstance(text, str):
raise ValueError("文本 part 缺少 text 字段")
builder.add_text_content(text)
continue
if part_type == "image":
image_format = raw_part.get("image_format")
image_base64 = raw_part.get("image_base64")
if not isinstance(image_format, str) or not isinstance(image_base64, str):
raise ValueError("图片 part 缺少 image_format 或 image_base64")
builder.add_image_content(image_format=image_format, image_base64=image_base64)
continue
raise ValueError(f"不支持的快照消息 part 类型: {part_type}")
return builder.build()
def serialize_messages_snapshot(messages: Sequence[Message]) -> list[dict[str, Any]]:
"""序列化消息列表。"""
return [serialize_message_snapshot(message) for message in messages]
def deserialize_messages_snapshot(raw_messages: Any) -> list[Message]:
"""从快照恢复消息列表。"""
if not isinstance(raw_messages, list):
raise ValueError("快照中的 messages 必须是列表")
return [deserialize_message_snapshot(raw_message) for raw_message in raw_messages]
def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
"""序列化模型信息。"""
return {
"api_provider": model_info.api_provider,
"extra_params": _json_friendly(dict(model_info.extra_params)),
"force_stream_mode": model_info.force_stream_mode,
"max_tokens": model_info.max_tokens,
"model_identifier": model_info.model_identifier,
"name": model_info.name,
"temperature": model_info.temperature,
}
def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
"""从快照恢复模型信息。"""
if not isinstance(raw_model_info, dict):
raise ValueError("快照中的 model_info 必须是字典")
return ModelInfo(
api_provider=str(raw_model_info.get("api_provider") or ""),
extra_params=dict(raw_model_info.get("extra_params") or {}),
force_stream_mode=bool(raw_model_info.get("force_stream_mode", False)),
max_tokens=raw_model_info.get("max_tokens"),
model_identifier=str(raw_model_info.get("model_identifier") or ""),
name=str(raw_model_info.get("name") or ""),
temperature=raw_model_info.get("temperature"),
)
def serialize_response_format_snapshot(response_format: RespFormat | None) -> dict[str, Any] | None:
"""序列化响应格式定义。"""
if response_format is None:
return None
return response_format.to_dict()
def deserialize_response_format_snapshot(raw_response_format: Any) -> RespFormat | None:
"""从快照恢复响应格式定义。"""
if raw_response_format is None:
return None
if not isinstance(raw_response_format, dict):
raise ValueError("快照中的 response_format 必须是字典")
raw_format_type = raw_response_format.get("format_type")
if not isinstance(raw_format_type, str):
raise ValueError("快照中的 response_format 缺少 format_type")
format_type = RespFormatType(raw_format_type)
raw_schema = raw_response_format.get("schema")
schema = raw_schema if isinstance(raw_schema, dict) else None
return RespFormat(format_type=format_type, schema=schema)
def serialize_tool_options_snapshot(tool_options: Sequence[ToolOption] | None) -> list[dict[str, Any]]:
"""序列化工具定义列表。"""
if not tool_options:
return []
return [tool_option.to_openai_function_schema() for tool_option in tool_options]
def deserialize_tool_options_snapshot(raw_tool_options: Any) -> list[ToolOption] | None:
"""从快照恢复工具定义列表。"""
if raw_tool_options in (None, []):
return None
if not isinstance(raw_tool_options, list):
raise ValueError("快照中的 tool_options 必须是列表")
return normalize_tool_options(raw_tool_options)
def serialize_response_request_snapshot(request: ResponseRequest) -> dict[str, Any]:
"""序列化文本/多模态请求。"""
return {
"extra_params": _json_friendly(dict(request.extra_params)),
"max_tokens": request.max_tokens,
"message_list": serialize_messages_snapshot(request.message_list),
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "response",
"response_format": serialize_response_format_snapshot(request.response_format),
"temperature": request.temperature,
"tool_options": serialize_tool_options_snapshot(request.tool_options),
}
def serialize_embedding_request_snapshot(request: EmbeddingRequest) -> dict[str, Any]:
"""序列化嵌入请求。"""
return {
"embedding_input": request.embedding_input,
"extra_params": _json_friendly(dict(request.extra_params)),
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "embedding",
}
def serialize_audio_request_snapshot(request: AudioTranscriptionRequest) -> dict[str, Any]:
"""序列化音频转写请求。"""
return {
"audio_base64": request.audio_base64,
"extra_params": _json_friendly(dict(request.extra_params)),
"max_tokens": request.max_tokens,
"model_info": serialize_model_info_snapshot(request.model_info),
"request_kind": "audio_transcription",
}
def serialize_api_provider_snapshot(api_provider: APIProvider) -> dict[str, Any]:
"""序列化 API Provider 配置,排除敏感认证信息。"""
return {
"auth_header_name": api_provider.auth_header_name,
"auth_header_prefix": api_provider.auth_header_prefix,
"auth_query_name": api_provider.auth_query_name,
"auth_type": api_provider.auth_type,
"base_url": api_provider.base_url,
"client_type": api_provider.client_type,
"default_headers": _json_friendly(dict(api_provider.default_headers)),
"default_query": _json_friendly(dict(api_provider.default_query)),
"model_list_endpoint": api_provider.model_list_endpoint,
"name": api_provider.name,
"organization": api_provider.organization,
"project": api_provider.project,
"retry_interval": api_provider.retry_interval,
"timeout": api_provider.timeout,
}
def build_replay_command(snapshot_path: Path) -> str:
"""构建回放当前快照的命令。"""
return f'uv run python {REPLAY_SCRIPT_RELATIVE_PATH.as_posix()} "{snapshot_path.resolve()}"'
def save_failed_request_snapshot(
*,
api_provider: APIProvider,
client_type: str,
error: Exception,
internal_request: dict[str, Any],
model_info: ModelInfo,
operation: str,
provider_request: dict[str, Any],
) -> Path | None:
"""保存失败请求快照。"""
try:
LLM_REQUEST_LOG_DIR.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now()
file_name = (
f"{timestamp.strftime('%Y%m%d_%H%M%S_%f')}"
f"_{_sanitize_filename_component(client_type)}"
f"_{_sanitize_filename_component(internal_request.get('request_kind', 'request'))}"
f"_{_sanitize_filename_component(model_info.name or model_info.model_identifier)}.json"
)
snapshot_path = (LLM_REQUEST_LOG_DIR / file_name).resolve()
snapshot_payload: dict[str, Any] = {
"api_provider": serialize_api_provider_snapshot(api_provider),
"client_type": client_type,
"created_at": timestamp.isoformat(timespec="seconds"),
"error": {
"message": str(error),
"status_code": getattr(error, "status_code", None),
"type": type(error).__name__,
},
"internal_request": internal_request,
"model_info": serialize_model_info_snapshot(model_info),
"operation": operation,
"provider_request": _json_friendly(provider_request),
"snapshot_version": SNAPSHOT_VERSION,
}
snapshot_payload["replay"] = {
"command": build_replay_command(snapshot_path),
"file_uri": snapshot_path.as_uri(),
"script_path": str(REPLAY_SCRIPT_PATH),
}
snapshot_path.write_text(
json.dumps(snapshot_payload, ensure_ascii=False, indent=2),
encoding="utf-8",
)
return snapshot_path
except Exception:
logger.exception("淇濆瓨 LLM 澶辫触璇锋眰蹇収鏃跺彂鐢熷紓甯?")
return None
def attach_request_snapshot(exception: Exception, snapshot_path: Path | None) -> None:
"""将请求快照信息挂载到异常对象上。"""
if snapshot_path is None:
return
exception.request_snapshot_path = str(snapshot_path.resolve())
exception.request_snapshot_uri = snapshot_path.resolve().as_uri()
exception.request_snapshot_replay_command = build_replay_command(snapshot_path)
def has_request_snapshot(exception: Exception) -> bool:
"""鍒ゆ柇寮傚父鏄惁宸插叧鑱斾簡璇锋眰蹇収銆?"""
for candidate in (exception, getattr(exception, "__cause__", None)):
if candidate is None:
continue
if getattr(candidate, "request_snapshot_path", ""):
return True
return False
def format_request_snapshot_log_info(exception: Exception) -> str:
"""将异常上的快照信息格式化为日志片段。"""
for candidate in (exception, getattr(exception, "__cause__", None)):
if candidate is None:
continue
snapshot_path = getattr(candidate, "request_snapshot_path", "")
snapshot_uri = getattr(candidate, "request_snapshot_uri", "")
replay_command = getattr(candidate, "request_snapshot_replay_command", "")
if not any([snapshot_path, snapshot_uri, replay_command]):
continue
lines: list[str] = []
if snapshot_path:
lines.append(f"请求快照路径: {snapshot_path}")
if snapshot_uri:
lines.append(f"请求快照链接: {snapshot_uri}")
if replay_command:
lines.append(f"重放命令: {replay_command}")
if lines:
return "\n " + "\n ".join(lines)
return ""

View File

@@ -37,6 +37,7 @@ from src.llm_models.model_client.base_client import (
UsageRecord,
client_registry,
)
from src.llm_models.request_snapshot import format_request_snapshot_log_info
from src.llm_models.payload_content.message import Message, MessageBuilder
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import (
@@ -1008,6 +1009,16 @@ class LLMOrchestrator:
Returns:
str: 可直接拼接到日志中的底层异常描述。
"""
detail_lines: List[str] = []
if e.__cause__:
detail_lines.append(f"底层异常类型: {type(e.__cause__).__name__}")
detail_lines.append(f"底层异常信息: {e.__cause__}")
snapshot_info = format_request_snapshot_log_info(e)
if detail_lines or snapshot_info:
detail_text = "\n " + "\n ".join(detail_lines) if detail_lines else ""
return f"{detail_text}{snapshot_info}"
if e.__cause__:
original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__)