feat:为失败请求留档并提供重试分析
This commit is contained in:
131
pytests/utils_test/test_request_snapshot.py
Normal file
131
pytests/utils_test/test_request_snapshot.py
Normal file
@@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
|
||||
import json
|
||||
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.model_client.base_client import ResponseRequest
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||
from src.llm_models.request_snapshot import (
|
||||
attach_request_snapshot,
|
||||
deserialize_messages_snapshot,
|
||||
format_request_snapshot_log_info,
|
||||
save_failed_request_snapshot,
|
||||
serialize_messages_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
from src.llm_models import request_snapshot
|
||||
|
||||
|
||||
def _build_api_provider() -> APIProvider:
|
||||
return APIProvider(
|
||||
api_key="secret-token",
|
||||
base_url="https://example.com/v1",
|
||||
name="test-provider",
|
||||
)
|
||||
|
||||
|
||||
def _build_model_info() -> ModelInfo:
|
||||
return ModelInfo(
|
||||
api_provider="test-provider",
|
||||
model_identifier="demo-model",
|
||||
name="demo-model",
|
||||
)
|
||||
|
||||
|
||||
def _build_response_request() -> ResponseRequest:
|
||||
tool_call = ToolCall(
|
||||
args={"query": "MaiBot"},
|
||||
call_id="call_1",
|
||||
func_name="search_web",
|
||||
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
|
||||
)
|
||||
message_list = [
|
||||
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
||||
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
|
||||
MessageBuilder()
|
||||
.set_role(RoleType.Tool)
|
||||
.set_tool_call_id("call_1")
|
||||
.set_tool_name("search_web")
|
||||
.add_text_content('{"ok": true}')
|
||||
.build(),
|
||||
]
|
||||
return ResponseRequest(
|
||||
extra_params={"trace_id": "trace-123"},
|
||||
max_tokens=256,
|
||||
message_list=message_list,
|
||||
model_info=_build_model_info(),
|
||||
response_format=RespFormat(RespFormatType.JSON_OBJ),
|
||||
temperature=0.2,
|
||||
tool_options=[ToolOption(name="search_web", description="搜索网页")],
|
||||
)
|
||||
|
||||
|
||||
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
|
||||
request = _build_response_request()
|
||||
|
||||
snapshot_messages = serialize_messages_snapshot(request.message_list)
|
||||
restored_messages = deserialize_messages_snapshot(snapshot_messages)
|
||||
|
||||
assert len(restored_messages) == 3
|
||||
assert restored_messages[0].role == RoleType.User
|
||||
assert restored_messages[0].get_text_content() == "你好"
|
||||
assert restored_messages[0].parts[1].image_format == "png"
|
||||
assert restored_messages[1].role == RoleType.Assistant
|
||||
assert restored_messages[1].tool_calls is not None
|
||||
assert restored_messages[1].tool_calls[0].func_name == "search_web"
|
||||
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
|
||||
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
|
||||
assert restored_messages[2].role == RoleType.Tool
|
||||
assert restored_messages[2].tool_call_id == "call_1"
|
||||
assert restored_messages[2].tool_name == "search_web"
|
||||
|
||||
|
||||
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||
|
||||
request = _build_response_request()
|
||||
provider = _build_api_provider()
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=provider,
|
||||
client_type="openai",
|
||||
error=RuntimeError("boom"),
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=request.model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
|
||||
)
|
||||
|
||||
assert snapshot_path is not None
|
||||
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
assert payload["internal_request"]["request_kind"] == "response"
|
||||
assert payload["api_provider"]["name"] == "test-provider"
|
||||
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
|
||||
assert str(snapshot_path) in payload["replay"]["command"]
|
||||
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
|
||||
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||
|
||||
request = _build_response_request()
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=_build_api_provider(),
|
||||
client_type="openai",
|
||||
error=ValueError("invalid"),
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=request.model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request={"request_kwargs": {"messages": []}},
|
||||
)
|
||||
|
||||
assert snapshot_path is not None
|
||||
exc = RuntimeError("wrapped")
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
|
||||
log_info = format_request_snapshot_log_info(exc)
|
||||
assert str(snapshot_path) in log_info
|
||||
assert snapshot_path.as_uri() in log_info
|
||||
assert "uv run python scripts/replay_llm_request.py" in log_info
|
||||
146
scripts/replay_llm_request.py
Normal file
146
scripts/replay_llm_request.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# ruff: noqa: E402
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
SRC_ROOT = PROJECT_ROOT / "src"
|
||||
if str(SRC_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(SRC_ROOT))
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(1, str(PROJECT_ROOT))
|
||||
|
||||
from src.config.config import config_manager
|
||||
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
|
||||
from src.llm_models.model_client.base_client import EmbeddingRequest
|
||||
from src.llm_models.request_snapshot import (
|
||||
deserialize_messages_snapshot,
|
||||
deserialize_model_info_snapshot,
|
||||
deserialize_response_format_snapshot,
|
||||
deserialize_tool_options_snapshot,
|
||||
)
|
||||
|
||||
|
||||
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
|
||||
"""加载请求快照。"""
|
||||
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _resolve_api_provider(provider_name: str):
|
||||
"""根据名称解析当前配置中的 API Provider。"""
|
||||
model_config = config_manager.get_model_config()
|
||||
for api_provider in model_config.api_providers:
|
||||
if api_provider.name == provider_name:
|
||||
return api_provider
|
||||
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
|
||||
|
||||
|
||||
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
|
||||
"""从快照构建响应请求对象。"""
|
||||
return ResponseRequest(
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
max_tokens=snapshot.get("max_tokens"),
|
||||
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
|
||||
temperature=snapshot.get("temperature"),
|
||||
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
|
||||
)
|
||||
|
||||
|
||||
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
|
||||
"""从快照构建嵌入请求对象。"""
|
||||
return EmbeddingRequest(
|
||||
embedding_input=str(snapshot.get("embedding_input") or ""),
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
)
|
||||
|
||||
|
||||
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
|
||||
"""从快照构建音频转写请求对象。"""
|
||||
return AudioTranscriptionRequest(
|
||||
audio_base64=str(snapshot.get("audio_base64") or ""),
|
||||
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||
max_tokens=snapshot.get("max_tokens"),
|
||||
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||
)
|
||||
|
||||
|
||||
async def _replay(snapshot_path: Path) -> int:
|
||||
"""回放一条失败请求快照。"""
|
||||
config_manager.initialize()
|
||||
snapshot = _load_snapshot(snapshot_path)
|
||||
|
||||
internal_request = snapshot.get("internal_request")
|
||||
if not isinstance(internal_request, dict):
|
||||
raise ValueError("快照缺少 internal_request 字段")
|
||||
|
||||
provider_snapshot = snapshot.get("api_provider")
|
||||
if not isinstance(provider_snapshot, dict):
|
||||
raise ValueError("快照缺少 api_provider 字段")
|
||||
|
||||
provider_name = str(provider_snapshot.get("name") or "")
|
||||
if not provider_name:
|
||||
raise ValueError("快照中的 api_provider.name 不能为空")
|
||||
|
||||
api_provider = _resolve_api_provider(provider_name)
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||
|
||||
request_kind = str(internal_request.get("request_kind") or "").strip()
|
||||
if request_kind == "response":
|
||||
response = await client.get_response(_build_response_request(internal_request))
|
||||
elif request_kind == "embedding":
|
||||
response = await client.get_embedding(_build_embedding_request(internal_request))
|
||||
elif request_kind == "audio_transcription":
|
||||
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
|
||||
else:
|
||||
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
|
||||
|
||||
output_payload = {
|
||||
"content": response.content,
|
||||
"embedding_length": len(response.embedding or []),
|
||||
"has_embedding": response.embedding is not None,
|
||||
"model_name": response.usage.model_name if response.usage is not None else None,
|
||||
"provider_name": response.usage.provider_name if response.usage is not None else None,
|
||||
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
|
||||
"reasoning_content": response.reasoning_content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"args": tool_call.args,
|
||||
"call_id": tool_call.call_id,
|
||||
"func_name": tool_call.func_name,
|
||||
}
|
||||
for tool_call in (response.tool_calls or [])
|
||||
],
|
||||
"usage": {
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
if response.usage is not None
|
||||
else None,
|
||||
}
|
||||
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""脚本入口。"""
|
||||
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
|
||||
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
|
||||
args = parser.parse_args()
|
||||
|
||||
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
|
||||
if not snapshot_path.exists():
|
||||
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
|
||||
|
||||
return asyncio.run(_replay(snapshot_path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,7 +1,10 @@
|
||||
# ruff: noqa: B025
|
||||
|
||||
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import io
|
||||
import json
|
||||
|
||||
@@ -21,6 +24,8 @@ from google.genai.types import (
|
||||
EmbedContentConfig,
|
||||
EmbedContentResponse,
|
||||
FunctionDeclaration,
|
||||
FunctionCall,
|
||||
FunctionResponse,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
GoogleSearch,
|
||||
@@ -60,6 +65,14 @@ from .base_client import (
|
||||
UsageTuple,
|
||||
client_registry,
|
||||
)
|
||||
from ..request_snapshot import (
|
||||
attach_request_snapshot,
|
||||
has_request_snapshot,
|
||||
save_failed_request_snapshot,
|
||||
serialize_audio_request_snapshot,
|
||||
serialize_embedding_request_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
|
||||
logger = get_logger("Gemini客户端")
|
||||
|
||||
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
|
||||
}
|
||||
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
|
||||
|
||||
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
|
||||
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
|
||||
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
|
||||
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
|
||||
|
||||
|
||||
def _normalize_image_mime_type(image_format: str) -> str:
|
||||
"""将图片格式名称转换为标准 MIME 类型。
|
||||
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
|
||||
return {"result": content}
|
||||
|
||||
|
||||
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
|
||||
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
|
||||
if not thought_signature:
|
||||
return None
|
||||
return {
|
||||
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
|
||||
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
|
||||
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
|
||||
if not tool_call.extra_content:
|
||||
return None
|
||||
|
||||
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
|
||||
if not isinstance(provider_payload, dict):
|
||||
return None
|
||||
|
||||
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
|
||||
if isinstance(raw_thought_signature, bytes):
|
||||
return raw_thought_signature
|
||||
if not isinstance(raw_thought_signature, str):
|
||||
return None
|
||||
|
||||
normalized_signature = raw_thought_signature.strip()
|
||||
if not normalized_signature:
|
||||
return None
|
||||
|
||||
try:
|
||||
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
|
||||
except (binascii.Error, ValueError):
|
||||
return normalized_signature.encode("utf-8")
|
||||
|
||||
|
||||
def _build_gemini_function_call_part(
|
||||
tool_call: ToolCall,
|
||||
*,
|
||||
inject_fallback_signature: bool,
|
||||
) -> Part:
|
||||
"""根据内部工具调用构建 Gemini function call part。"""
|
||||
thought_signature = _extract_gemini_thought_signature(tool_call)
|
||||
if thought_signature is None and inject_fallback_signature:
|
||||
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
|
||||
|
||||
return Part(
|
||||
function_call=FunctionCall(
|
||||
id=tool_call.call_id,
|
||||
name=tool_call.func_name,
|
||||
args=tool_call.args or {},
|
||||
),
|
||||
thought_signature=thought_signature,
|
||||
)
|
||||
|
||||
|
||||
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
|
||||
"""安全获取 Gemini 响应中的候选列表。
|
||||
|
||||
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
if message.role == RoleType.Assistant:
|
||||
assistant_parts = _build_non_tool_parts(message)
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
for tool_call_index, tool_call in enumerate(message.tool_calls):
|
||||
assistant_parts.append(
|
||||
Part.from_function_call(
|
||||
name=tool_call.func_name,
|
||||
args=tool_call.args or {},
|
||||
_build_gemini_function_call_part(
|
||||
tool_call,
|
||||
inject_fallback_signature=tool_call_index == 0,
|
||||
)
|
||||
)
|
||||
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
|
||||
@@ -256,10 +330,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
||||
"且消息中未携带 tool_name"
|
||||
)
|
||||
tool_name_by_call_id[message.tool_call_id] = tool_name
|
||||
function_response_part = Part.from_function_response(
|
||||
function_response_part = Part(
|
||||
function_response=FunctionResponse(
|
||||
id=message.tool_call_id,
|
||||
name=tool_name,
|
||||
response=_normalize_function_response_payload(message),
|
||||
)
|
||||
)
|
||||
contents.append(Content(role="tool", parts=[function_response_part]))
|
||||
continue
|
||||
|
||||
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
|
||||
Raises:
|
||||
RespParseException: 当函数调用结构不合法时抛出。
|
||||
"""
|
||||
raw_function_calls = getattr(response, "function_calls", None)
|
||||
candidates = _get_candidates(response)
|
||||
if not raw_function_calls and candidates:
|
||||
raw_function_calls = []
|
||||
tool_calls: List[ToolCall] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = getattr(candidate, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
for part in parts:
|
||||
function_call = getattr(part, "function_call", None)
|
||||
if function_call is not None:
|
||||
raw_function_calls.append(function_call)
|
||||
if function_call is None:
|
||||
continue
|
||||
|
||||
call_name = getattr(function_call, "name", None)
|
||||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
|
||||
call_args = getattr(function_call, "args", None) or {}
|
||||
if not isinstance(call_name, str) or not call_name:
|
||||
raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段")
|
||||
if not isinstance(call_args, dict):
|
||||
raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典")
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=call_name,
|
||||
args=call_args,
|
||||
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
|
||||
)
|
||||
)
|
||||
|
||||
if tool_calls:
|
||||
return tool_calls
|
||||
|
||||
raw_function_calls = getattr(response, "function_calls", None)
|
||||
if not raw_function_calls:
|
||||
return []
|
||||
|
||||
tool_calls: List[ToolCall] = []
|
||||
for index, function_call in enumerate(raw_function_calls, start=1):
|
||||
call_name = getattr(function_call, "name", None)
|
||||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
|
||||
@@ -808,6 +904,15 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
model_info = request.model_info
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:generateContent",
|
||||
"method": "POST",
|
||||
"operation": "models.generate_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
contents, system_instruction = _convert_messages(request.message_list)
|
||||
model_identifier, enable_google_search = self._resolve_model_identifier(
|
||||
model_info.model_identifier,
|
||||
@@ -823,8 +928,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
extra_params=request.extra_params,
|
||||
enable_google_search=enable_google_search,
|
||||
)
|
||||
|
||||
try:
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"config": generation_config,
|
||||
"contents": contents,
|
||||
"enable_google_search": enable_google_search,
|
||||
"model": model_identifier,
|
||||
"system_instruction": system_instruction,
|
||||
}
|
||||
if model_info.force_stream_mode:
|
||||
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
|
||||
self.client.aio.models.generate_content_stream(
|
||||
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
raise
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=wrapped_error,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except EmptyResponseException as exc:
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
except Exception as exc:
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
|
||||
except EmptyResponseException:
|
||||
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
model_info = request.model_info
|
||||
embedding_input = request.embedding_input
|
||||
extra_params = request.extra_params
|
||||
embed_config = _build_embed_content_config(extra_params)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:embedContent",
|
||||
"method": "POST",
|
||||
"operation": "models.embed_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
embed_config = _build_embed_content_config(extra_params)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"config": embed_config,
|
||||
"contents": embedding_input,
|
||||
"model": model_info.model_identifier,
|
||||
}
|
||||
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||
model=model_info.model_identifier,
|
||||
contents=embedding_input,
|
||||
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
)
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
|
||||
response = APIResponse(raw_data=raw_response)
|
||||
if not raw_response.embeddings:
|
||||
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.embed_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise exc
|
||||
if raw_response.embeddings:
|
||||
response.embedding = raw_response.embeddings[0].values
|
||||
else:
|
||||
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
audio_base64 = request.audio_base64
|
||||
max_tokens = request.max_tokens
|
||||
extra_params = request.extra_params
|
||||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/models/{model}:generateContent",
|
||||
"method": "POST",
|
||||
"operation": "models.generate_content",
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
transcription_prompt = str(
|
||||
extra_params.get(
|
||||
@@ -933,6 +1157,9 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
)
|
||||
)
|
||||
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
|
||||
|
||||
try:
|
||||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||||
contents: List[ContentUnion] = [
|
||||
Content(
|
||||
role="user",
|
||||
@@ -952,8 +1179,14 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
extra_params=extra_params,
|
||||
enable_google_search=False,
|
||||
)
|
||||
|
||||
try:
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"audio_base64": audio_base64,
|
||||
"audio_mime_type": audio_mime_type,
|
||||
"config": generation_config,
|
||||
"contents": contents,
|
||||
"model": model_identifier,
|
||||
"transcription_prompt": transcription_prompt,
|
||||
}
|
||||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||
model=model_identifier,
|
||||
contents=contents,
|
||||
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
||||
response, usage_record = _default_normal_response_parser(raw_response)
|
||||
except (ClientError, ServerError) as exc:
|
||||
status_code = int(getattr(exc, "code", 500) or 500)
|
||||
raise RespNotOkException(status_code, str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="gemini",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="models.generate_content",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = (
|
||||
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||
)
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
if wrapped_error is exc:
|
||||
raise
|
||||
raise wrapped_error from exc
|
||||
|
||||
return response, usage_record
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import binascii
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
|
||||
@@ -60,6 +59,14 @@ from .base_client import (
|
||||
UsageTuple,
|
||||
client_registry,
|
||||
)
|
||||
from ..request_snapshot import (
|
||||
attach_request_snapshot,
|
||||
has_request_snapshot,
|
||||
save_failed_request_snapshot,
|
||||
serialize_audio_request_snapshot,
|
||||
serialize_embedding_request_snapshot,
|
||||
serialize_response_request_snapshot,
|
||||
)
|
||||
|
||||
logger = get_logger("llm_models")
|
||||
|
||||
@@ -533,6 +540,13 @@ def _coerce_openai_argument(value: Any) -> Any | Omit:
|
||||
return value
|
||||
|
||||
|
||||
def _snapshot_openai_argument(value: Any | Omit) -> Any | None:
|
||||
"""将 OpenAI SDK 参数转换为适合写入快照的普通值。"""
|
||||
if value is omit:
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _build_api_status_message(error: APIStatusError) -> str:
|
||||
"""构建更适合记录和展示的状态错误信息。
|
||||
|
||||
@@ -939,10 +953,21 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/chat/completions",
|
||||
"method": "POST",
|
||||
"operation": "chat.completions.create",
|
||||
"organization": self.api_provider.organization,
|
||||
"project": self.api_provider.project,
|
||||
"request_kwargs": {},
|
||||
}
|
||||
model_info = request.model_info
|
||||
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||||
tools: Iterable[ChatCompletionToolParam] | Omit = (
|
||||
_convert_tool_options(request.tool_options) if request.tool_options else omit
|
||||
|
||||
try:
|
||||
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||||
tools_payload: List[ChatCompletionToolParam] | None = (
|
||||
_convert_tool_options(request.tool_options) if request.tool_options else None
|
||||
)
|
||||
openai_response_format = _convert_response_format(request.response_format)
|
||||
request_overrides = split_openai_request_overrides(
|
||||
@@ -958,14 +983,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
|
||||
else _coerce_openai_argument(request.max_tokens)
|
||||
)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"extra_body": request_overrides.extra_body or None,
|
||||
"extra_headers": request_overrides.extra_headers or None,
|
||||
"extra_query": request_overrides.extra_query or None,
|
||||
"max_tokens": _snapshot_openai_argument(max_tokens_argument),
|
||||
"messages": messages_payload,
|
||||
"model": model_info.model_identifier,
|
||||
"response_format": _snapshot_openai_argument(openai_response_format),
|
||||
"stream": bool(model_info.force_stream_mode),
|
||||
"temperature": _snapshot_openai_argument(temperature_argument),
|
||||
"tools": tools_payload,
|
||||
}
|
||||
|
||||
try:
|
||||
if model_info.force_stream_mode:
|
||||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||||
self.client.chat.completions.create(
|
||||
model=model_info.model_identifier,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
messages=messages_payload,
|
||||
tools=tools_payload or omit,
|
||||
temperature=temperature_argument,
|
||||
max_tokens=max_tokens_argument,
|
||||
stream=True,
|
||||
@@ -984,8 +1020,8 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
|
||||
self.client.chat.completions.create(
|
||||
model=model_info.model_identifier,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
messages=messages_payload,
|
||||
tools=tools_payload or omit,
|
||||
temperature=temperature_argument,
|
||||
max_tokens=max_tokens_argument,
|
||||
stream=False,
|
||||
@@ -1000,10 +1036,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
||||
)
|
||||
return response_parser(raw_response)
|
||||
except (EmptyResponseException, RespParseException) as exc:
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
except APIConnectionError as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = NetworkConnectionError(str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except APIStatusError as exc:
|
||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except ReqAbortException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_response_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="chat.completions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
@@ -1020,9 +1106,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
model_info = request.model_info
|
||||
embedding_input = request.embedding_input
|
||||
extra_params = request.extra_params
|
||||
request_overrides = split_openai_request_overrides(extra_params)
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/embeddings",
|
||||
"method": "POST",
|
||||
"operation": "embeddings.create",
|
||||
"organization": self.api_provider.organization,
|
||||
"project": self.api_provider.project,
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
request_overrides = split_openai_request_overrides(extra_params)
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"extra_body": request_overrides.extra_body or None,
|
||||
"extra_headers": request_overrides.extra_headers or None,
|
||||
"extra_query": request_overrides.extra_query or None,
|
||||
"input": embedding_input,
|
||||
"model": model_info.model_identifier,
|
||||
}
|
||||
raw_response = await self.client.embeddings.create(
|
||||
model=model_info.model_identifier,
|
||||
input=embedding_input,
|
||||
@@ -1031,11 +1133,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
extra_body=request_overrides.extra_body or None,
|
||||
)
|
||||
except APIConnectionError as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="embeddings.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = NetworkConnectionError(str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except APIStatusError as exc:
|
||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="embeddings.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="embeddings.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
|
||||
response = APIResponse()
|
||||
if not raw_response.data:
|
||||
exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_embedding_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="embeddings.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise exc
|
||||
if raw_response.data:
|
||||
response.embedding = raw_response.data[0].embedding
|
||||
else:
|
||||
@@ -1059,10 +1210,27 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
model_info = request.model_info
|
||||
audio_base64 = request.audio_base64
|
||||
extra_params = request.extra_params
|
||||
request_overrides = split_openai_request_overrides(extra_params)
|
||||
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
||||
snapshot_provider_request = {
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/audio/transcriptions",
|
||||
"method": "POST",
|
||||
"operation": "audio.transcriptions.create",
|
||||
"organization": self.api_provider.organization,
|
||||
"project": self.api_provider.project,
|
||||
"request_kwargs": {},
|
||||
}
|
||||
|
||||
try:
|
||||
request_overrides = split_openai_request_overrides(extra_params)
|
||||
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
||||
snapshot_provider_request["request_kwargs"] = {
|
||||
"audio_base64": audio_base64,
|
||||
"extra_body": request_overrides.extra_body or None,
|
||||
"extra_headers": request_overrides.extra_headers or None,
|
||||
"extra_query": request_overrides.extra_query or None,
|
||||
"file_name": "audio.wav",
|
||||
"model": model_info.model_identifier,
|
||||
}
|
||||
raw_response = await self.client.audio.transcriptions.create(
|
||||
model=model_info.model_identifier,
|
||||
file=audio_file,
|
||||
@@ -1071,12 +1239,61 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
extra_body=request_overrides.extra_body or None,
|
||||
)
|
||||
except APIConnectionError as exc:
|
||||
raise NetworkConnectionError(str(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="audio.transcriptions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = NetworkConnectionError(str(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except APIStatusError as exc:
|
||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="audio.transcriptions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||
raise wrapped_error from exc
|
||||
except Exception as exc:
|
||||
if has_request_snapshot(exc):
|
||||
raise
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="audio.transcriptions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise
|
||||
|
||||
response = APIResponse()
|
||||
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
|
||||
if not isinstance(transcription_text, str):
|
||||
exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。")
|
||||
snapshot_path = save_failed_request_snapshot(
|
||||
api_provider=self.api_provider,
|
||||
client_type="openai",
|
||||
error=exc,
|
||||
internal_request=serialize_audio_request_snapshot(request),
|
||||
model_info=model_info,
|
||||
operation="audio.transcriptions.create",
|
||||
provider_request=snapshot_provider_request,
|
||||
)
|
||||
attach_request_snapshot(exc, snapshot_path)
|
||||
raise exc
|
||||
if isinstance(transcription_text, str):
|
||||
response.content = transcription_text
|
||||
return response, None
|
||||
|
||||
447
src/llm_models/request_snapshot.py
Normal file
447
src/llm_models/request_snapshot.py
Normal file
@@ -0,0 +1,447 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest
|
||||
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, normalize_tool_options
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
LLM_REQUEST_LOG_DIR = PROJECT_ROOT / "logs" / "llm_request"
|
||||
REPLAY_SCRIPT_RELATIVE_PATH = Path("scripts") / "replay_llm_request.py"
|
||||
REPLAY_SCRIPT_PATH = PROJECT_ROOT / REPLAY_SCRIPT_RELATIVE_PATH
|
||||
FILENAME_SAFE_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||
SNAPSHOT_VERSION = 1
|
||||
|
||||
logger = get_logger("llm_request_snapshot")
|
||||
|
||||
|
||||
def _json_friendly(value: Any) -> Any:
|
||||
"""将任意对象尽量转换为可写入 JSON 的结构。"""
|
||||
if value is None or isinstance(value, (bool, float, int, str)):
|
||||
return value
|
||||
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
|
||||
if isinstance(value, Path):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return base64.b64encode(bytes(value)).decode("ascii")
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return {str(key): _json_friendly(item) for key, item in value.items()}
|
||||
|
||||
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
||||
return [_json_friendly(item) for item in value]
|
||||
|
||||
model_dump = getattr(value, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
try:
|
||||
return _json_friendly(model_dump(mode="json", exclude_none=True))
|
||||
except TypeError:
|
||||
return _json_friendly(model_dump(exclude_none=True))
|
||||
|
||||
to_dict = getattr(value, "to_dict", None)
|
||||
if callable(to_dict):
|
||||
return _json_friendly(to_dict())
|
||||
|
||||
return str(value)
|
||||
|
||||
|
||||
def _sanitize_filename_component(value: str) -> str:
|
||||
"""将任意字符串转换为适合文件名使用的片段。"""
|
||||
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
|
||||
normalized_value = normalized_value.strip("-._")
|
||||
return normalized_value or "unknown"
|
||||
|
||||
|
||||
def _serialize_tool_call(tool_call: ToolCall) -> dict[str, Any]:
|
||||
"""序列化单个工具调用。"""
|
||||
payload = {
|
||||
"id": tool_call.call_id,
|
||||
"function": {
|
||||
"name": tool_call.func_name,
|
||||
"arguments": _json_friendly(tool_call.args or {}),
|
||||
},
|
||||
}
|
||||
if tool_call.extra_content:
|
||||
payload["extra_content"] = _json_friendly(tool_call.extra_content)
|
||||
return payload
|
||||
|
||||
|
||||
def serialize_tool_calls_snapshot(tool_calls: Sequence[ToolCall] | None) -> list[dict[str, Any]]:
|
||||
"""序列化工具调用列表。"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
return [_serialize_tool_call(tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
def deserialize_tool_calls_snapshot(raw_tool_calls: Any) -> list[ToolCall]:
|
||||
"""从快照恢复工具调用列表。"""
|
||||
if raw_tool_calls in (None, []):
|
||||
return []
|
||||
if not isinstance(raw_tool_calls, list):
|
||||
raise ValueError("快照中的 tool_calls 必须是列表")
|
||||
|
||||
normalized_tool_calls: list[ToolCall] = []
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
raise ValueError("快照中的 tool_call 项必须是字典")
|
||||
|
||||
function_info = raw_tool_call.get("function", {})
|
||||
if isinstance(function_info, dict):
|
||||
function_name = function_info.get("name")
|
||||
function_arguments = function_info.get("arguments")
|
||||
else:
|
||||
function_name = raw_tool_call.get("name")
|
||||
function_arguments = raw_tool_call.get("arguments")
|
||||
|
||||
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
||||
raise ValueError("快照中的 tool_call 缺少 id 或 function.name")
|
||||
|
||||
extra_content = raw_tool_call.get("extra_content")
|
||||
normalized_tool_calls.append(
|
||||
ToolCall(
|
||||
call_id=call_id,
|
||||
func_name=function_name,
|
||||
args=function_arguments if isinstance(function_arguments, dict) else {},
|
||||
extra_content=extra_content if isinstance(extra_content, dict) else None,
|
||||
)
|
||||
)
|
||||
return normalized_tool_calls
|
||||
|
||||
|
||||
def serialize_message_snapshot(message: Message) -> dict[str, Any]:
|
||||
"""将内部消息对象序列化为可回放的快照结构。"""
|
||||
parts_payload: list[dict[str, Any]] = []
|
||||
for part in message.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
parts_payload.append({"type": "text", "text": part.text})
|
||||
continue
|
||||
|
||||
if isinstance(part, ImageMessagePart):
|
||||
parts_payload.append(
|
||||
{
|
||||
"type": "image",
|
||||
"image_base64": part.image_base64,
|
||||
"image_format": part.image_format,
|
||||
}
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"parts": parts_payload,
|
||||
"role": message.role.value,
|
||||
}
|
||||
if message.tool_call_id:
|
||||
payload["tool_call_id"] = message.tool_call_id
|
||||
if message.tool_name:
|
||||
payload["tool_name"] = message.tool_name
|
||||
if message.tool_calls:
|
||||
payload["tool_calls"] = serialize_tool_calls_snapshot(message.tool_calls)
|
||||
return payload
|
||||
|
||||
|
||||
def deserialize_message_snapshot(raw_message: Any) -> Message:
|
||||
"""从快照恢复内部消息对象。"""
|
||||
if not isinstance(raw_message, dict):
|
||||
raise ValueError("快照中的 message 必须是字典")
|
||||
|
||||
raw_role = raw_message.get("role")
|
||||
if not isinstance(raw_role, str):
|
||||
raise ValueError("快照中的 message 缺少 role")
|
||||
|
||||
role = RoleType(raw_role)
|
||||
builder = MessageBuilder().set_role(role)
|
||||
|
||||
raw_tool_calls = raw_message.get("tool_calls")
|
||||
tool_calls = deserialize_tool_calls_snapshot(raw_tool_calls)
|
||||
if role == RoleType.Assistant and tool_calls:
|
||||
builder.set_tool_calls(tool_calls)
|
||||
|
||||
tool_call_id = raw_message.get("tool_call_id")
|
||||
if role == RoleType.Tool and isinstance(tool_call_id, str):
|
||||
builder.set_tool_call_id(tool_call_id)
|
||||
|
||||
tool_name = raw_message.get("tool_name")
|
||||
if role == RoleType.Tool and isinstance(tool_name, str) and tool_name:
|
||||
builder.set_tool_name(tool_name)
|
||||
|
||||
raw_parts = raw_message.get("parts", [])
|
||||
if not isinstance(raw_parts, list):
|
||||
raise ValueError("快照中的 message.parts 必须是列表")
|
||||
|
||||
for raw_part in raw_parts:
|
||||
if not isinstance(raw_part, dict):
|
||||
raise ValueError("快照中的 message part 必须是字典")
|
||||
|
||||
part_type = str(raw_part.get("type", "")).strip().lower()
|
||||
if part_type == "text":
|
||||
text = raw_part.get("text")
|
||||
if not isinstance(text, str):
|
||||
raise ValueError("文本 part 缺少 text 字段")
|
||||
builder.add_text_content(text)
|
||||
continue
|
||||
|
||||
if part_type == "image":
|
||||
image_format = raw_part.get("image_format")
|
||||
image_base64 = raw_part.get("image_base64")
|
||||
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||||
raise ValueError("图片 part 缺少 image_format 或 image_base64")
|
||||
builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||
continue
|
||||
|
||||
raise ValueError(f"不支持的快照消息 part 类型: {part_type}")
|
||||
|
||||
return builder.build()
|
||||
|
||||
|
||||
def serialize_messages_snapshot(messages: Sequence[Message]) -> list[dict[str, Any]]:
|
||||
"""序列化消息列表。"""
|
||||
return [serialize_message_snapshot(message) for message in messages]
|
||||
|
||||
|
||||
def deserialize_messages_snapshot(raw_messages: Any) -> list[Message]:
|
||||
"""从快照恢复消息列表。"""
|
||||
if not isinstance(raw_messages, list):
|
||||
raise ValueError("快照中的 messages 必须是列表")
|
||||
return [deserialize_message_snapshot(raw_message) for raw_message in raw_messages]
|
||||
|
||||
|
||||
def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
|
||||
"""序列化模型信息。"""
|
||||
return {
|
||||
"api_provider": model_info.api_provider,
|
||||
"extra_params": _json_friendly(dict(model_info.extra_params)),
|
||||
"force_stream_mode": model_info.force_stream_mode,
|
||||
"max_tokens": model_info.max_tokens,
|
||||
"model_identifier": model_info.model_identifier,
|
||||
"name": model_info.name,
|
||||
"temperature": model_info.temperature,
|
||||
}
|
||||
|
||||
|
||||
def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
|
||||
"""从快照恢复模型信息。"""
|
||||
if not isinstance(raw_model_info, dict):
|
||||
raise ValueError("快照中的 model_info 必须是字典")
|
||||
|
||||
return ModelInfo(
|
||||
api_provider=str(raw_model_info.get("api_provider") or ""),
|
||||
extra_params=dict(raw_model_info.get("extra_params") or {}),
|
||||
force_stream_mode=bool(raw_model_info.get("force_stream_mode", False)),
|
||||
max_tokens=raw_model_info.get("max_tokens"),
|
||||
model_identifier=str(raw_model_info.get("model_identifier") or ""),
|
||||
name=str(raw_model_info.get("name") or ""),
|
||||
temperature=raw_model_info.get("temperature"),
|
||||
)
|
||||
|
||||
|
||||
def serialize_response_format_snapshot(response_format: RespFormat | None) -> dict[str, Any] | None:
|
||||
"""序列化响应格式定义。"""
|
||||
if response_format is None:
|
||||
return None
|
||||
return response_format.to_dict()
|
||||
|
||||
|
||||
def deserialize_response_format_snapshot(raw_response_format: Any) -> RespFormat | None:
|
||||
"""从快照恢复响应格式定义。"""
|
||||
if raw_response_format is None:
|
||||
return None
|
||||
if not isinstance(raw_response_format, dict):
|
||||
raise ValueError("快照中的 response_format 必须是字典")
|
||||
|
||||
raw_format_type = raw_response_format.get("format_type")
|
||||
if not isinstance(raw_format_type, str):
|
||||
raise ValueError("快照中的 response_format 缺少 format_type")
|
||||
|
||||
format_type = RespFormatType(raw_format_type)
|
||||
raw_schema = raw_response_format.get("schema")
|
||||
schema = raw_schema if isinstance(raw_schema, dict) else None
|
||||
return RespFormat(format_type=format_type, schema=schema)
|
||||
|
||||
|
||||
def serialize_tool_options_snapshot(tool_options: Sequence[ToolOption] | None) -> list[dict[str, Any]]:
|
||||
"""序列化工具定义列表。"""
|
||||
if not tool_options:
|
||||
return []
|
||||
return [tool_option.to_openai_function_schema() for tool_option in tool_options]
|
||||
|
||||
|
||||
def deserialize_tool_options_snapshot(raw_tool_options: Any) -> list[ToolOption] | None:
|
||||
"""从快照恢复工具定义列表。"""
|
||||
if raw_tool_options in (None, []):
|
||||
return None
|
||||
if not isinstance(raw_tool_options, list):
|
||||
raise ValueError("快照中的 tool_options 必须是列表")
|
||||
return normalize_tool_options(raw_tool_options)
|
||||
|
||||
|
||||
def serialize_response_request_snapshot(request: ResponseRequest) -> dict[str, Any]:
|
||||
"""序列化文本/多模态请求。"""
|
||||
return {
|
||||
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||
"max_tokens": request.max_tokens,
|
||||
"message_list": serialize_messages_snapshot(request.message_list),
|
||||
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||
"request_kind": "response",
|
||||
"response_format": serialize_response_format_snapshot(request.response_format),
|
||||
"temperature": request.temperature,
|
||||
"tool_options": serialize_tool_options_snapshot(request.tool_options),
|
||||
}
|
||||
|
||||
|
||||
def serialize_embedding_request_snapshot(request: EmbeddingRequest) -> dict[str, Any]:
|
||||
"""序列化嵌入请求。"""
|
||||
return {
|
||||
"embedding_input": request.embedding_input,
|
||||
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||
"request_kind": "embedding",
|
||||
}
|
||||
|
||||
|
||||
def serialize_audio_request_snapshot(request: AudioTranscriptionRequest) -> dict[str, Any]:
|
||||
"""序列化音频转写请求。"""
|
||||
return {
|
||||
"audio_base64": request.audio_base64,
|
||||
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||
"max_tokens": request.max_tokens,
|
||||
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||
"request_kind": "audio_transcription",
|
||||
}
|
||||
|
||||
|
||||
def serialize_api_provider_snapshot(api_provider: APIProvider) -> dict[str, Any]:
|
||||
"""序列化 API Provider 配置,排除敏感认证信息。"""
|
||||
return {
|
||||
"auth_header_name": api_provider.auth_header_name,
|
||||
"auth_header_prefix": api_provider.auth_header_prefix,
|
||||
"auth_query_name": api_provider.auth_query_name,
|
||||
"auth_type": api_provider.auth_type,
|
||||
"base_url": api_provider.base_url,
|
||||
"client_type": api_provider.client_type,
|
||||
"default_headers": _json_friendly(dict(api_provider.default_headers)),
|
||||
"default_query": _json_friendly(dict(api_provider.default_query)),
|
||||
"model_list_endpoint": api_provider.model_list_endpoint,
|
||||
"name": api_provider.name,
|
||||
"organization": api_provider.organization,
|
||||
"project": api_provider.project,
|
||||
"retry_interval": api_provider.retry_interval,
|
||||
"timeout": api_provider.timeout,
|
||||
}
|
||||
|
||||
|
||||
def build_replay_command(snapshot_path: Path) -> str:
|
||||
"""构建回放当前快照的命令。"""
|
||||
return f'uv run python {REPLAY_SCRIPT_RELATIVE_PATH.as_posix()} "{snapshot_path.resolve()}"'
|
||||
|
||||
|
||||
def save_failed_request_snapshot(
|
||||
*,
|
||||
api_provider: APIProvider,
|
||||
client_type: str,
|
||||
error: Exception,
|
||||
internal_request: dict[str, Any],
|
||||
model_info: ModelInfo,
|
||||
operation: str,
|
||||
provider_request: dict[str, Any],
|
||||
) -> Path | None:
|
||||
"""保存失败请求快照。"""
|
||||
try:
|
||||
LLM_REQUEST_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now()
|
||||
file_name = (
|
||||
f"{timestamp.strftime('%Y%m%d_%H%M%S_%f')}"
|
||||
f"_{_sanitize_filename_component(client_type)}"
|
||||
f"_{_sanitize_filename_component(internal_request.get('request_kind', 'request'))}"
|
||||
f"_{_sanitize_filename_component(model_info.name or model_info.model_identifier)}.json"
|
||||
)
|
||||
snapshot_path = (LLM_REQUEST_LOG_DIR / file_name).resolve()
|
||||
|
||||
snapshot_payload: dict[str, Any] = {
|
||||
"api_provider": serialize_api_provider_snapshot(api_provider),
|
||||
"client_type": client_type,
|
||||
"created_at": timestamp.isoformat(timespec="seconds"),
|
||||
"error": {
|
||||
"message": str(error),
|
||||
"status_code": getattr(error, "status_code", None),
|
||||
"type": type(error).__name__,
|
||||
},
|
||||
"internal_request": internal_request,
|
||||
"model_info": serialize_model_info_snapshot(model_info),
|
||||
"operation": operation,
|
||||
"provider_request": _json_friendly(provider_request),
|
||||
"snapshot_version": SNAPSHOT_VERSION,
|
||||
}
|
||||
|
||||
snapshot_payload["replay"] = {
|
||||
"command": build_replay_command(snapshot_path),
|
||||
"file_uri": snapshot_path.as_uri(),
|
||||
"script_path": str(REPLAY_SCRIPT_PATH),
|
||||
}
|
||||
|
||||
snapshot_path.write_text(
|
||||
json.dumps(snapshot_payload, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return snapshot_path
|
||||
except Exception:
|
||||
logger.exception("淇濆瓨 LLM 澶辫触璇锋眰蹇収鏃跺彂鐢熷紓甯?")
|
||||
return None
|
||||
|
||||
|
||||
def attach_request_snapshot(exception: Exception, snapshot_path: Path | None) -> None:
|
||||
"""将请求快照信息挂载到异常对象上。"""
|
||||
if snapshot_path is None:
|
||||
return
|
||||
|
||||
exception.request_snapshot_path = str(snapshot_path.resolve())
|
||||
exception.request_snapshot_uri = snapshot_path.resolve().as_uri()
|
||||
exception.request_snapshot_replay_command = build_replay_command(snapshot_path)
|
||||
|
||||
|
||||
def has_request_snapshot(exception: Exception) -> bool:
|
||||
"""鍒ゆ柇寮傚父鏄惁宸插叧鑱斾簡璇锋眰蹇収銆?"""
|
||||
for candidate in (exception, getattr(exception, "__cause__", None)):
|
||||
if candidate is None:
|
||||
continue
|
||||
if getattr(candidate, "request_snapshot_path", ""):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def format_request_snapshot_log_info(exception: Exception) -> str:
|
||||
"""将异常上的快照信息格式化为日志片段。"""
|
||||
for candidate in (exception, getattr(exception, "__cause__", None)):
|
||||
if candidate is None:
|
||||
continue
|
||||
|
||||
snapshot_path = getattr(candidate, "request_snapshot_path", "")
|
||||
snapshot_uri = getattr(candidate, "request_snapshot_uri", "")
|
||||
replay_command = getattr(candidate, "request_snapshot_replay_command", "")
|
||||
if not any([snapshot_path, snapshot_uri, replay_command]):
|
||||
continue
|
||||
|
||||
lines: list[str] = []
|
||||
if snapshot_path:
|
||||
lines.append(f"请求快照路径: {snapshot_path}")
|
||||
if snapshot_uri:
|
||||
lines.append(f"请求快照链接: {snapshot_uri}")
|
||||
if replay_command:
|
||||
lines.append(f"重放命令: {replay_command}")
|
||||
if lines:
|
||||
return "\n " + "\n ".join(lines)
|
||||
|
||||
return ""
|
||||
@@ -37,6 +37,7 @@ from src.llm_models.model_client.base_client import (
|
||||
UsageRecord,
|
||||
client_registry,
|
||||
)
|
||||
from src.llm_models.request_snapshot import format_request_snapshot_log_info
|
||||
from src.llm_models.payload_content.message import Message, MessageBuilder
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import (
|
||||
@@ -1008,6 +1009,16 @@ class LLMOrchestrator:
|
||||
Returns:
|
||||
str: 可直接拼接到日志中的底层异常描述。
|
||||
"""
|
||||
detail_lines: List[str] = []
|
||||
if e.__cause__:
|
||||
detail_lines.append(f"底层异常类型: {type(e.__cause__).__name__}")
|
||||
detail_lines.append(f"底层异常信息: {e.__cause__}")
|
||||
|
||||
snapshot_info = format_request_snapshot_log_info(e)
|
||||
if detail_lines or snapshot_info:
|
||||
detail_text = "\n " + "\n ".join(detail_lines) if detail_lines else ""
|
||||
return f"{detail_text}{snapshot_info}"
|
||||
|
||||
if e.__cause__:
|
||||
original_error_type = type(e.__cause__).__name__
|
||||
original_error_msg = str(e.__cause__)
|
||||
|
||||
Reference in New Issue
Block a user