feat:为失败请求留档并提供重试分析
This commit is contained in:
131
pytests/utils_test/test_request_snapshot.py
Normal file
131
pytests/utils_test/test_request_snapshot.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from src.config.model_configs import APIProvider, ModelInfo
|
||||||
|
from src.llm_models.model_client.base_client import ResponseRequest
|
||||||
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
|
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||||
|
from src.llm_models.request_snapshot import (
|
||||||
|
attach_request_snapshot,
|
||||||
|
deserialize_messages_snapshot,
|
||||||
|
format_request_snapshot_log_info,
|
||||||
|
save_failed_request_snapshot,
|
||||||
|
serialize_messages_snapshot,
|
||||||
|
serialize_response_request_snapshot,
|
||||||
|
)
|
||||||
|
from src.llm_models import request_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _build_api_provider() -> APIProvider:
|
||||||
|
return APIProvider(
|
||||||
|
api_key="secret-token",
|
||||||
|
base_url="https://example.com/v1",
|
||||||
|
name="test-provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model_info() -> ModelInfo:
|
||||||
|
return ModelInfo(
|
||||||
|
api_provider="test-provider",
|
||||||
|
model_identifier="demo-model",
|
||||||
|
name="demo-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_response_request() -> ResponseRequest:
|
||||||
|
tool_call = ToolCall(
|
||||||
|
args={"query": "MaiBot"},
|
||||||
|
call_id="call_1",
|
||||||
|
func_name="search_web",
|
||||||
|
extra_content={"google": {"thought_signature": "c2lnbmF0dXJl"}},
|
||||||
|
)
|
||||||
|
message_list = [
|
||||||
|
MessageBuilder().set_role(RoleType.User).add_text_content("你好").add_image_content("png", "ZmFrZQ==").build(),
|
||||||
|
MessageBuilder().set_role(RoleType.Assistant).set_tool_calls([tool_call]).build(),
|
||||||
|
MessageBuilder()
|
||||||
|
.set_role(RoleType.Tool)
|
||||||
|
.set_tool_call_id("call_1")
|
||||||
|
.set_tool_name("search_web")
|
||||||
|
.add_text_content('{"ok": true}')
|
||||||
|
.build(),
|
||||||
|
]
|
||||||
|
return ResponseRequest(
|
||||||
|
extra_params={"trace_id": "trace-123"},
|
||||||
|
max_tokens=256,
|
||||||
|
message_list=message_list,
|
||||||
|
model_info=_build_model_info(),
|
||||||
|
response_format=RespFormat(RespFormatType.JSON_OBJ),
|
||||||
|
temperature=0.2,
|
||||||
|
tool_options=[ToolOption(name="search_web", description="搜索网页")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_snapshot_roundtrip_preserves_tool_messages() -> None:
|
||||||
|
request = _build_response_request()
|
||||||
|
|
||||||
|
snapshot_messages = serialize_messages_snapshot(request.message_list)
|
||||||
|
restored_messages = deserialize_messages_snapshot(snapshot_messages)
|
||||||
|
|
||||||
|
assert len(restored_messages) == 3
|
||||||
|
assert restored_messages[0].role == RoleType.User
|
||||||
|
assert restored_messages[0].get_text_content() == "你好"
|
||||||
|
assert restored_messages[0].parts[1].image_format == "png"
|
||||||
|
assert restored_messages[1].role == RoleType.Assistant
|
||||||
|
assert restored_messages[1].tool_calls is not None
|
||||||
|
assert restored_messages[1].tool_calls[0].func_name == "search_web"
|
||||||
|
assert restored_messages[1].tool_calls[0].args == {"query": "MaiBot"}
|
||||||
|
assert restored_messages[1].tool_calls[0].extra_content == {"google": {"thought_signature": "c2lnbmF0dXJl"}}
|
||||||
|
assert restored_messages[2].role == RoleType.Tool
|
||||||
|
assert restored_messages[2].tool_call_id == "call_1"
|
||||||
|
assert restored_messages[2].tool_name == "search_web"
|
||||||
|
|
||||||
|
|
||||||
|
def test_failed_request_snapshot_contains_replay_entry(tmp_path: Path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||||
|
|
||||||
|
request = _build_response_request()
|
||||||
|
provider = _build_api_provider()
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=RuntimeError("boom"),
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=request.model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request={"request_kwargs": {"model": request.model_info.model_identifier}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert snapshot_path is not None
|
||||||
|
payload = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
assert payload["internal_request"]["request_kind"] == "response"
|
||||||
|
assert payload["api_provider"]["name"] == "test-provider"
|
||||||
|
assert payload["replay"]["file_uri"] == snapshot_path.as_uri()
|
||||||
|
assert str(snapshot_path) in payload["replay"]["command"]
|
||||||
|
assert "secret-token" not in snapshot_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_request_snapshot_log_info_includes_path_uri_and_command(tmp_path: Path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(request_snapshot, "LLM_REQUEST_LOG_DIR", tmp_path)
|
||||||
|
|
||||||
|
request = _build_response_request()
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=_build_api_provider(),
|
||||||
|
client_type="openai",
|
||||||
|
error=ValueError("invalid"),
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=request.model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request={"request_kwargs": {"messages": []}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert snapshot_path is not None
|
||||||
|
exc = RuntimeError("wrapped")
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
|
||||||
|
log_info = format_request_snapshot_log_info(exc)
|
||||||
|
assert str(snapshot_path) in log_info
|
||||||
|
assert snapshot_path.as_uri() in log_info
|
||||||
|
assert "uv run python scripts/replay_llm_request.py" in log_info
|
||||||
146
scripts/replay_llm_request.py
Normal file
146
scripts/replay_llm_request.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# ruff: noqa: E402
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
SRC_ROOT = PROJECT_ROOT / "src"
|
||||||
|
if str(SRC_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(SRC_ROOT))
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(1, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
from src.config.config import config_manager
|
||||||
|
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
|
||||||
|
from src.llm_models.model_client.base_client import EmbeddingRequest
|
||||||
|
from src.llm_models.request_snapshot import (
|
||||||
|
deserialize_messages_snapshot,
|
||||||
|
deserialize_model_info_snapshot,
|
||||||
|
deserialize_response_format_snapshot,
|
||||||
|
deserialize_tool_options_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
|
||||||
|
"""加载请求快照。"""
|
||||||
|
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_api_provider(provider_name: str):
|
||||||
|
"""根据名称解析当前配置中的 API Provider。"""
|
||||||
|
model_config = config_manager.get_model_config()
|
||||||
|
for api_provider in model_config.api_providers:
|
||||||
|
if api_provider.name == provider_name:
|
||||||
|
return api_provider
|
||||||
|
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
|
||||||
|
"""从快照构建响应请求对象。"""
|
||||||
|
return ResponseRequest(
|
||||||
|
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||||
|
max_tokens=snapshot.get("max_tokens"),
|
||||||
|
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
|
||||||
|
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||||
|
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
|
||||||
|
temperature=snapshot.get("temperature"),
|
||||||
|
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
|
||||||
|
"""从快照构建嵌入请求对象。"""
|
||||||
|
return EmbeddingRequest(
|
||||||
|
embedding_input=str(snapshot.get("embedding_input") or ""),
|
||||||
|
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||||
|
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
|
||||||
|
"""从快照构建音频转写请求对象。"""
|
||||||
|
return AudioTranscriptionRequest(
|
||||||
|
audio_base64=str(snapshot.get("audio_base64") or ""),
|
||||||
|
extra_params=dict(snapshot.get("extra_params") or {}),
|
||||||
|
max_tokens=snapshot.get("max_tokens"),
|
||||||
|
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _replay(snapshot_path: Path) -> int:
|
||||||
|
"""回放一条失败请求快照。"""
|
||||||
|
config_manager.initialize()
|
||||||
|
snapshot = _load_snapshot(snapshot_path)
|
||||||
|
|
||||||
|
internal_request = snapshot.get("internal_request")
|
||||||
|
if not isinstance(internal_request, dict):
|
||||||
|
raise ValueError("快照缺少 internal_request 字段")
|
||||||
|
|
||||||
|
provider_snapshot = snapshot.get("api_provider")
|
||||||
|
if not isinstance(provider_snapshot, dict):
|
||||||
|
raise ValueError("快照缺少 api_provider 字段")
|
||||||
|
|
||||||
|
provider_name = str(provider_snapshot.get("name") or "")
|
||||||
|
if not provider_name:
|
||||||
|
raise ValueError("快照中的 api_provider.name 不能为空")
|
||||||
|
|
||||||
|
api_provider = _resolve_api_provider(provider_name)
|
||||||
|
client = client_registry.get_client_class_instance(api_provider, force_new=True)
|
||||||
|
|
||||||
|
request_kind = str(internal_request.get("request_kind") or "").strip()
|
||||||
|
if request_kind == "response":
|
||||||
|
response = await client.get_response(_build_response_request(internal_request))
|
||||||
|
elif request_kind == "embedding":
|
||||||
|
response = await client.get_embedding(_build_embedding_request(internal_request))
|
||||||
|
elif request_kind == "audio_transcription":
|
||||||
|
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
|
||||||
|
|
||||||
|
output_payload = {
|
||||||
|
"content": response.content,
|
||||||
|
"embedding_length": len(response.embedding or []),
|
||||||
|
"has_embedding": response.embedding is not None,
|
||||||
|
"model_name": response.usage.model_name if response.usage is not None else None,
|
||||||
|
"provider_name": response.usage.provider_name if response.usage is not None else None,
|
||||||
|
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
|
||||||
|
"reasoning_content": response.reasoning_content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"args": tool_call.args,
|
||||||
|
"call_id": tool_call.call_id,
|
||||||
|
"func_name": tool_call.func_name,
|
||||||
|
}
|
||||||
|
for tool_call in (response.tool_calls or [])
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
|
"prompt_tokens": response.usage.prompt_tokens,
|
||||||
|
"total_tokens": response.usage.total_tokens,
|
||||||
|
}
|
||||||
|
if response.usage is not None
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""脚本入口。"""
|
||||||
|
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
|
||||||
|
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
|
||||||
|
if not snapshot_path.exists():
|
||||||
|
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
|
||||||
|
|
||||||
|
return asyncio.run(_replay(snapshot_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
|
# ruff: noqa: B025
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
|
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import binascii
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -21,6 +24,8 @@ from google.genai.types import (
|
|||||||
EmbedContentConfig,
|
EmbedContentConfig,
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
|
FunctionCall,
|
||||||
|
FunctionResponse,
|
||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
GoogleSearch,
|
GoogleSearch,
|
||||||
@@ -60,6 +65,14 @@ from .base_client import (
|
|||||||
UsageTuple,
|
UsageTuple,
|
||||||
client_registry,
|
client_registry,
|
||||||
)
|
)
|
||||||
|
from ..request_snapshot import (
|
||||||
|
attach_request_snapshot,
|
||||||
|
has_request_snapshot,
|
||||||
|
save_failed_request_snapshot,
|
||||||
|
serialize_audio_request_snapshot,
|
||||||
|
serialize_embedding_request_snapshot,
|
||||||
|
serialize_response_request_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger("Gemini客户端")
|
logger = get_logger("Gemini客户端")
|
||||||
|
|
||||||
@@ -112,6 +125,11 @@ EMBED_CONFIG_SUPPORTED_EXTRA_PARAMS = {
|
|||||||
}
|
}
|
||||||
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
|
"""可透传给 `EmbedContentConfig` 的额外参数字段。"""
|
||||||
|
|
||||||
|
GEMINI_EXTRA_CONTENT_PROVIDER_KEY = "google"
|
||||||
|
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY = "thought_signature"
|
||||||
|
GEMINI_FALLBACK_THOUGHT_SIGNATURE = b"skip_thought_signature_validator"
|
||||||
|
"""当历史 function call 没有原始 thought signature 时,使用官方允许的占位签名跳过校验。"""
|
||||||
|
|
||||||
|
|
||||||
def _normalize_image_mime_type(image_format: str) -> str:
|
def _normalize_image_mime_type(image_format: str) -> str:
|
||||||
"""将图片格式名称转换为标准 MIME 类型。
|
"""将图片格式名称转换为标准 MIME 类型。
|
||||||
@@ -177,6 +195,62 @@ def _normalize_function_response_payload(message: Message) -> Dict[str, Any]:
|
|||||||
return {"result": content}
|
return {"result": content}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gemini_tool_call_extra_content(thought_signature: bytes | None) -> Dict[str, Any] | None:
|
||||||
|
"""将 Gemini thought signature 编码为内部工具调用附加信息。"""
|
||||||
|
if not thought_signature:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
GEMINI_EXTRA_CONTENT_PROVIDER_KEY: {
|
||||||
|
GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY: base64.b64encode(thought_signature).decode("ascii")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_gemini_thought_signature(tool_call: ToolCall) -> bytes | None:
|
||||||
|
"""从内部工具调用附加信息中提取 Gemini thought signature。"""
|
||||||
|
if not tool_call.extra_content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
provider_payload = tool_call.extra_content.get(GEMINI_EXTRA_CONTENT_PROVIDER_KEY)
|
||||||
|
if not isinstance(provider_payload, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
raw_thought_signature = provider_payload.get(GEMINI_EXTRA_CONTENT_THOUGHT_SIGNATURE_KEY)
|
||||||
|
if isinstance(raw_thought_signature, bytes):
|
||||||
|
return raw_thought_signature
|
||||||
|
if not isinstance(raw_thought_signature, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_signature = raw_thought_signature.strip()
|
||||||
|
if not normalized_signature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return base64.b64decode(normalized_signature.encode("ascii"), validate=True)
|
||||||
|
except (binascii.Error, ValueError):
|
||||||
|
return normalized_signature.encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gemini_function_call_part(
|
||||||
|
tool_call: ToolCall,
|
||||||
|
*,
|
||||||
|
inject_fallback_signature: bool,
|
||||||
|
) -> Part:
|
||||||
|
"""根据内部工具调用构建 Gemini function call part。"""
|
||||||
|
thought_signature = _extract_gemini_thought_signature(tool_call)
|
||||||
|
if thought_signature is None and inject_fallback_signature:
|
||||||
|
thought_signature = GEMINI_FALLBACK_THOUGHT_SIGNATURE
|
||||||
|
|
||||||
|
return Part(
|
||||||
|
function_call=FunctionCall(
|
||||||
|
id=tool_call.call_id,
|
||||||
|
name=tool_call.func_name,
|
||||||
|
args=tool_call.args or {},
|
||||||
|
),
|
||||||
|
thought_signature=thought_signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
|
def _get_candidates(response: GenerateContentResponse) -> List[Candidate]:
|
||||||
"""安全获取 Gemini 响应中的候选列表。
|
"""安全获取 Gemini 响应中的候选列表。
|
||||||
|
|
||||||
@@ -235,11 +309,11 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
|||||||
if message.role == RoleType.Assistant:
|
if message.role == RoleType.Assistant:
|
||||||
assistant_parts = _build_non_tool_parts(message)
|
assistant_parts = _build_non_tool_parts(message)
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
for tool_call in message.tool_calls:
|
for tool_call_index, tool_call in enumerate(message.tool_calls):
|
||||||
assistant_parts.append(
|
assistant_parts.append(
|
||||||
Part.from_function_call(
|
_build_gemini_function_call_part(
|
||||||
name=tool_call.func_name,
|
tool_call,
|
||||||
args=tool_call.args or {},
|
inject_fallback_signature=tool_call_index == 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
|
tool_name_by_call_id[tool_call.call_id] = tool_call.func_name
|
||||||
@@ -256,10 +330,13 @@ def _convert_messages(messages: List[Message]) -> Tuple[ContentListUnion, str |
|
|||||||
"且消息中未携带 tool_name"
|
"且消息中未携带 tool_name"
|
||||||
)
|
)
|
||||||
tool_name_by_call_id[message.tool_call_id] = tool_name
|
tool_name_by_call_id[message.tool_call_id] = tool_name
|
||||||
function_response_part = Part.from_function_response(
|
function_response_part = Part(
|
||||||
|
function_response=FunctionResponse(
|
||||||
|
id=message.tool_call_id,
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
response=_normalize_function_response_payload(message),
|
response=_normalize_function_response_payload(message),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
contents.append(Content(role="tool", parts=[function_response_part]))
|
contents.append(Content(role="tool", parts=[function_response_part]))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -368,22 +445,41 @@ def _collect_function_calls(response: GenerateContentResponse) -> List[ToolCall]
|
|||||||
Raises:
|
Raises:
|
||||||
RespParseException: 当函数调用结构不合法时抛出。
|
RespParseException: 当函数调用结构不合法时抛出。
|
||||||
"""
|
"""
|
||||||
raw_function_calls = getattr(response, "function_calls", None)
|
|
||||||
candidates = _get_candidates(response)
|
candidates = _get_candidates(response)
|
||||||
if not raw_function_calls and candidates:
|
tool_calls: List[ToolCall] = []
|
||||||
raw_function_calls = []
|
|
||||||
for candidate in candidates:
|
for candidate in candidates:
|
||||||
content = getattr(candidate, "content", None)
|
content = getattr(candidate, "content", None)
|
||||||
parts = getattr(content, "parts", None) or []
|
parts = getattr(content, "parts", None) or []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
function_call = getattr(part, "function_call", None)
|
function_call = getattr(part, "function_call", None)
|
||||||
if function_call is not None:
|
if function_call is None:
|
||||||
raw_function_calls.append(function_call)
|
continue
|
||||||
|
|
||||||
|
call_name = getattr(function_call, "name", None)
|
||||||
|
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{len(tool_calls) + 1}"
|
||||||
|
call_args = getattr(function_call, "args", None) or {}
|
||||||
|
if not isinstance(call_name, str) or not call_name:
|
||||||
|
raise RespParseException(response, "响应解析失败,Gemini 工具调用缺少 name 字段")
|
||||||
|
if not isinstance(call_args, dict):
|
||||||
|
raise RespParseException(response, "响应解析失败,Gemini 工具调用参数无法解析为字典")
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=call_id,
|
||||||
|
func_name=call_name,
|
||||||
|
args=call_args,
|
||||||
|
extra_content=_build_gemini_tool_call_extra_content(getattr(part, "thought_signature", None)),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if tool_calls:
|
||||||
|
return tool_calls
|
||||||
|
|
||||||
|
raw_function_calls = getattr(response, "function_calls", None)
|
||||||
if not raw_function_calls:
|
if not raw_function_calls:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
tool_calls: List[ToolCall] = []
|
|
||||||
for index, function_call in enumerate(raw_function_calls, start=1):
|
for index, function_call in enumerate(raw_function_calls, start=1):
|
||||||
call_name = getattr(function_call, "name", None)
|
call_name = getattr(function_call, "name", None)
|
||||||
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
|
call_id = getattr(function_call, "id", None) or f"gemini-tool-call-{index}"
|
||||||
@@ -808,6 +904,15 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||||
"""
|
"""
|
||||||
model_info = request.model_info
|
model_info = request.model_info
|
||||||
|
snapshot_provider_request = {
|
||||||
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/models/{model}:generateContent",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "models.generate_content",
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
contents, system_instruction = _convert_messages(request.message_list)
|
contents, system_instruction = _convert_messages(request.message_list)
|
||||||
model_identifier, enable_google_search = self._resolve_model_identifier(
|
model_identifier, enable_google_search = self._resolve_model_identifier(
|
||||||
model_info.model_identifier,
|
model_info.model_identifier,
|
||||||
@@ -823,8 +928,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
extra_params=request.extra_params,
|
extra_params=request.extra_params,
|
||||||
enable_google_search=enable_google_search,
|
enable_google_search=enable_google_search,
|
||||||
)
|
)
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
try:
|
"config": generation_config,
|
||||||
|
"contents": contents,
|
||||||
|
"enable_google_search": enable_google_search,
|
||||||
|
"model": model_identifier,
|
||||||
|
"system_instruction": system_instruction,
|
||||||
|
}
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
|
stream_task: asyncio.Task[AsyncIterator[GenerateContentResponse]] = asyncio.create_task(
|
||||||
self.client.aio.models.generate_content_stream(
|
self.client.aio.models.generate_content_stream(
|
||||||
@@ -855,7 +965,62 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
raise
|
raise
|
||||||
except (ClientError, ServerError) as exc:
|
except (ClientError, ServerError) as exc:
|
||||||
status_code = int(getattr(exc, "code", 500) or 500)
|
status_code = int(getattr(exc, "code", 500) or 500)
|
||||||
raise RespNotOkException(status_code, str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
|
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||||
|
wrapped_error = RespParseException(None, f"Gemini 工具调用参数错误: {exc}")
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=wrapped_error,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
|
except EmptyResponseException as exc:
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = (
|
||||||
|
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||||
|
)
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
if wrapped_error is exc:
|
||||||
|
raise
|
||||||
|
raise wrapped_error from exc
|
||||||
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
except (UnknownFunctionCallArgumentError, UnsupportedFunctionError, FunctionInvocationError) as exc:
|
||||||
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
|
raise RespParseException(None, f"Gemini 工具调用参数错误: {exc}") from exc
|
||||||
except EmptyResponseException:
|
except EmptyResponseException:
|
||||||
@@ -878,9 +1043,21 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
model_info = request.model_info
|
model_info = request.model_info
|
||||||
embedding_input = request.embedding_input
|
embedding_input = request.embedding_input
|
||||||
extra_params = request.extra_params
|
extra_params = request.extra_params
|
||||||
embed_config = _build_embed_content_config(extra_params)
|
snapshot_provider_request = {
|
||||||
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/models/{model}:embedContent",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "models.embed_content",
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
embed_config = _build_embed_content_config(extra_params)
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
|
"config": embed_config,
|
||||||
|
"contents": embedding_input,
|
||||||
|
"model": model_info.model_identifier,
|
||||||
|
}
|
||||||
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
contents=embedding_input,
|
contents=embedding_input,
|
||||||
@@ -888,11 +1065,52 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
)
|
)
|
||||||
except (ClientError, ServerError) as exc:
|
except (ClientError, ServerError) as exc:
|
||||||
status_code = int(getattr(exc, "code", 500) or 500)
|
status_code = int(getattr(exc, "code", 500) or 500)
|
||||||
raise RespNotOkException(status_code, str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.embed_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise NetworkConnectionError(str(exc)) from exc
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.embed_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = (
|
||||||
|
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||||
|
)
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
if wrapped_error is exc:
|
||||||
|
raise
|
||||||
|
raise wrapped_error from exc
|
||||||
|
|
||||||
response = APIResponse(raw_data=raw_response)
|
response = APIResponse(raw_data=raw_response)
|
||||||
|
if not raw_response.embeddings:
|
||||||
|
exc = RespParseException(raw_response, "Gemini 嵌入响应解析失败,缺少 embeddings 字段。")
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.embed_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise exc
|
||||||
if raw_response.embeddings:
|
if raw_response.embeddings:
|
||||||
response.embedding = raw_response.embeddings[0].values
|
response.embedding = raw_response.embeddings[0].values
|
||||||
else:
|
else:
|
||||||
@@ -924,7 +1142,13 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
audio_base64 = request.audio_base64
|
audio_base64 = request.audio_base64
|
||||||
max_tokens = request.max_tokens
|
max_tokens = request.max_tokens
|
||||||
extra_params = request.extra_params
|
extra_params = request.extra_params
|
||||||
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
snapshot_provider_request = {
|
||||||
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/models/{model}:generateContent",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "models.generate_content",
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
transcription_prompt = str(
|
transcription_prompt = str(
|
||||||
extra_params.get(
|
extra_params.get(
|
||||||
@@ -933,6 +1157,9 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
|
audio_mime_type = str(extra_params.get("audio_mime_type", "audio/wav"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_identifier, _ = self._resolve_model_identifier(model_info.model_identifier, extra_params)
|
||||||
contents: List[ContentUnion] = [
|
contents: List[ContentUnion] = [
|
||||||
Content(
|
Content(
|
||||||
role="user",
|
role="user",
|
||||||
@@ -952,8 +1179,14 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
extra_params=extra_params,
|
extra_params=extra_params,
|
||||||
enable_google_search=False,
|
enable_google_search=False,
|
||||||
)
|
)
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
try:
|
"audio_base64": audio_base64,
|
||||||
|
"audio_mime_type": audio_mime_type,
|
||||||
|
"config": generation_config,
|
||||||
|
"contents": contents,
|
||||||
|
"model": model_identifier,
|
||||||
|
"transcription_prompt": transcription_prompt,
|
||||||
|
}
|
||||||
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
raw_response: GenerateContentResponse = await self.client.aio.models.generate_content(
|
||||||
model=model_identifier,
|
model=model_identifier,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
@@ -962,9 +1195,37 @@ class GeminiClient(AdapterClient[AsyncIterator[GenerateContentResponse], Generat
|
|||||||
response, usage_record = _default_normal_response_parser(raw_response)
|
response, usage_record = _default_normal_response_parser(raw_response)
|
||||||
except (ClientError, ServerError) as exc:
|
except (ClientError, ServerError) as exc:
|
||||||
status_code = int(getattr(exc, "code", 500) or 500)
|
status_code = int(getattr(exc, "code", 500) or 500)
|
||||||
raise RespNotOkException(status_code, str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(status_code, str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise NetworkConnectionError(str(exc)) from exc
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="gemini",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="models.generate_content",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = (
|
||||||
|
exc if isinstance(exc, (EmptyResponseException, RespParseException)) else NetworkConnectionError(str(exc))
|
||||||
|
)
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
if wrapped_error is exc:
|
||||||
|
raise
|
||||||
|
raise wrapped_error from exc
|
||||||
|
|
||||||
return response, usage_record
|
return response, usage_record
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import binascii
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from collections.abc import Iterable
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||||
|
|
||||||
@@ -60,6 +59,14 @@ from .base_client import (
|
|||||||
UsageTuple,
|
UsageTuple,
|
||||||
client_registry,
|
client_registry,
|
||||||
)
|
)
|
||||||
|
from ..request_snapshot import (
|
||||||
|
attach_request_snapshot,
|
||||||
|
has_request_snapshot,
|
||||||
|
save_failed_request_snapshot,
|
||||||
|
serialize_audio_request_snapshot,
|
||||||
|
serialize_embedding_request_snapshot,
|
||||||
|
serialize_response_request_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger("llm_models")
|
logger = get_logger("llm_models")
|
||||||
|
|
||||||
@@ -533,6 +540,13 @@ def _coerce_openai_argument(value: Any) -> Any | Omit:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_openai_argument(value: Any | Omit) -> Any | None:
|
||||||
|
"""将 OpenAI SDK 参数转换为适合写入快照的普通值。"""
|
||||||
|
if value is omit:
|
||||||
|
return None
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _build_api_status_message(error: APIStatusError) -> str:
|
def _build_api_status_message(error: APIStatusError) -> str:
|
||||||
"""构建更适合记录和展示的状态错误信息。
|
"""构建更适合记录和展示的状态错误信息。
|
||||||
|
|
||||||
@@ -939,10 +953,21 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||||
"""
|
"""
|
||||||
|
snapshot_provider_request = {
|
||||||
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/chat/completions",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "chat.completions.create",
|
||||||
|
"organization": self.api_provider.organization,
|
||||||
|
"project": self.api_provider.project,
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
model_info = request.model_info
|
model_info = request.model_info
|
||||||
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
|
||||||
tools: Iterable[ChatCompletionToolParam] | Omit = (
|
try:
|
||||||
_convert_tool_options(request.tool_options) if request.tool_options else omit
|
messages_payload: List[ChatCompletionMessageParam] = _convert_messages(request.message_list)
|
||||||
|
tools_payload: List[ChatCompletionToolParam] | None = (
|
||||||
|
_convert_tool_options(request.tool_options) if request.tool_options else None
|
||||||
)
|
)
|
||||||
openai_response_format = _convert_response_format(request.response_format)
|
openai_response_format = _convert_response_format(request.response_format)
|
||||||
request_overrides = split_openai_request_overrides(
|
request_overrides = split_openai_request_overrides(
|
||||||
@@ -958,14 +983,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
|
if "max_tokens" in request_overrides.extra_body or "max_completion_tokens" in request_overrides.extra_body
|
||||||
else _coerce_openai_argument(request.max_tokens)
|
else _coerce_openai_argument(request.max_tokens)
|
||||||
)
|
)
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
|
"extra_body": request_overrides.extra_body or None,
|
||||||
|
"extra_headers": request_overrides.extra_headers or None,
|
||||||
|
"extra_query": request_overrides.extra_query or None,
|
||||||
|
"max_tokens": _snapshot_openai_argument(max_tokens_argument),
|
||||||
|
"messages": messages_payload,
|
||||||
|
"model": model_info.model_identifier,
|
||||||
|
"response_format": _snapshot_openai_argument(openai_response_format),
|
||||||
|
"stream": bool(model_info.force_stream_mode),
|
||||||
|
"temperature": _snapshot_openai_argument(temperature_argument),
|
||||||
|
"tools": tools_payload,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
|
||||||
if model_info.force_stream_mode:
|
if model_info.force_stream_mode:
|
||||||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||||||
self.client.chat.completions.create(
|
self.client.chat.completions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
messages=messages,
|
messages=messages_payload,
|
||||||
tools=tools,
|
tools=tools_payload or omit,
|
||||||
temperature=temperature_argument,
|
temperature=temperature_argument,
|
||||||
max_tokens=max_tokens_argument,
|
max_tokens=max_tokens_argument,
|
||||||
stream=True,
|
stream=True,
|
||||||
@@ -984,8 +1020,8 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
|
completion_task: asyncio.Task[ChatCompletion] = asyncio.create_task(
|
||||||
self.client.chat.completions.create(
|
self.client.chat.completions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
messages=messages,
|
messages=messages_payload,
|
||||||
tools=tools,
|
tools=tools_payload or omit,
|
||||||
temperature=temperature_argument,
|
temperature=temperature_argument,
|
||||||
max_tokens=max_tokens_argument,
|
max_tokens=max_tokens_argument,
|
||||||
stream=False,
|
stream=False,
|
||||||
@@ -1000,10 +1036,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
await await_task_with_interrupt(completion_task, request.interrupt_flag),
|
||||||
)
|
)
|
||||||
return response_parser(raw_response)
|
return response_parser(raw_response)
|
||||||
|
except (EmptyResponseException, RespParseException) as exc:
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise
|
||||||
except APIConnectionError as exc:
|
except APIConnectionError as exc:
|
||||||
raise NetworkConnectionError(str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = NetworkConnectionError(str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
except APIStatusError as exc:
|
except APIStatusError as exc:
|
||||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
|
except ReqAbortException:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_response_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="chat.completions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise
|
||||||
|
|
||||||
async def _execute_embedding_request(
|
async def _execute_embedding_request(
|
||||||
self,
|
self,
|
||||||
@@ -1020,9 +1106,25 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
model_info = request.model_info
|
model_info = request.model_info
|
||||||
embedding_input = request.embedding_input
|
embedding_input = request.embedding_input
|
||||||
extra_params = request.extra_params
|
extra_params = request.extra_params
|
||||||
request_overrides = split_openai_request_overrides(extra_params)
|
snapshot_provider_request = {
|
||||||
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/embeddings",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "embeddings.create",
|
||||||
|
"organization": self.api_provider.organization,
|
||||||
|
"project": self.api_provider.project,
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
request_overrides = split_openai_request_overrides(extra_params)
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
|
"extra_body": request_overrides.extra_body or None,
|
||||||
|
"extra_headers": request_overrides.extra_headers or None,
|
||||||
|
"extra_query": request_overrides.extra_query or None,
|
||||||
|
"input": embedding_input,
|
||||||
|
"model": model_info.model_identifier,
|
||||||
|
}
|
||||||
raw_response = await self.client.embeddings.create(
|
raw_response = await self.client.embeddings.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
input=embedding_input,
|
input=embedding_input,
|
||||||
@@ -1031,11 +1133,60 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
extra_body=request_overrides.extra_body or None,
|
extra_body=request_overrides.extra_body or None,
|
||||||
)
|
)
|
||||||
except APIConnectionError as exc:
|
except APIConnectionError as exc:
|
||||||
raise NetworkConnectionError(str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="embeddings.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = NetworkConnectionError(str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
except APIStatusError as exc:
|
except APIStatusError as exc:
|
||||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="embeddings.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
|
except Exception as exc:
|
||||||
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="embeddings.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise
|
||||||
|
|
||||||
response = APIResponse()
|
response = APIResponse()
|
||||||
|
if not raw_response.data:
|
||||||
|
exc = RespParseException(raw_response, "嵌入响应解析失败,缺少 embeddings 数据。")
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_embedding_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="embeddings.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise exc
|
||||||
if raw_response.data:
|
if raw_response.data:
|
||||||
response.embedding = raw_response.data[0].embedding
|
response.embedding = raw_response.data[0].embedding
|
||||||
else:
|
else:
|
||||||
@@ -1059,10 +1210,27 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
model_info = request.model_info
|
model_info = request.model_info
|
||||||
audio_base64 = request.audio_base64
|
audio_base64 = request.audio_base64
|
||||||
extra_params = request.extra_params
|
extra_params = request.extra_params
|
||||||
request_overrides = split_openai_request_overrides(extra_params)
|
snapshot_provider_request = {
|
||||||
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
"base_url": self.api_provider.base_url,
|
||||||
|
"endpoint": "/audio/transcriptions",
|
||||||
|
"method": "POST",
|
||||||
|
"operation": "audio.transcriptions.create",
|
||||||
|
"organization": self.api_provider.organization,
|
||||||
|
"project": self.api_provider.project,
|
||||||
|
"request_kwargs": {},
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
request_overrides = split_openai_request_overrides(extra_params)
|
||||||
|
audio_file: FileTypes = ("audio.wav", io.BytesIO(base64.b64decode(audio_base64)))
|
||||||
|
snapshot_provider_request["request_kwargs"] = {
|
||||||
|
"audio_base64": audio_base64,
|
||||||
|
"extra_body": request_overrides.extra_body or None,
|
||||||
|
"extra_headers": request_overrides.extra_headers or None,
|
||||||
|
"extra_query": request_overrides.extra_query or None,
|
||||||
|
"file_name": "audio.wav",
|
||||||
|
"model": model_info.model_identifier,
|
||||||
|
}
|
||||||
raw_response = await self.client.audio.transcriptions.create(
|
raw_response = await self.client.audio.transcriptions.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
file=audio_file,
|
file=audio_file,
|
||||||
@@ -1071,12 +1239,61 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
|||||||
extra_body=request_overrides.extra_body or None,
|
extra_body=request_overrides.extra_body or None,
|
||||||
)
|
)
|
||||||
except APIConnectionError as exc:
|
except APIConnectionError as exc:
|
||||||
raise NetworkConnectionError(str(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="audio.transcriptions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = NetworkConnectionError(str(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
except APIStatusError as exc:
|
except APIStatusError as exc:
|
||||||
raise RespNotOkException(exc.status_code, _build_api_status_message(exc)) from exc
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="audio.transcriptions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
wrapped_error = RespNotOkException(exc.status_code, _build_api_status_message(exc))
|
||||||
|
attach_request_snapshot(wrapped_error, snapshot_path)
|
||||||
|
raise wrapped_error from exc
|
||||||
|
except Exception as exc:
|
||||||
|
if has_request_snapshot(exc):
|
||||||
|
raise
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="audio.transcriptions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise
|
||||||
|
|
||||||
response = APIResponse()
|
response = APIResponse()
|
||||||
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
|
transcription_text = raw_response if isinstance(raw_response, str) else getattr(raw_response, "text", None)
|
||||||
|
if not isinstance(transcription_text, str):
|
||||||
|
exc = RespParseException(raw_response, "音频转写响应解析失败,缺少文本内容。")
|
||||||
|
snapshot_path = save_failed_request_snapshot(
|
||||||
|
api_provider=self.api_provider,
|
||||||
|
client_type="openai",
|
||||||
|
error=exc,
|
||||||
|
internal_request=serialize_audio_request_snapshot(request),
|
||||||
|
model_info=model_info,
|
||||||
|
operation="audio.transcriptions.create",
|
||||||
|
provider_request=snapshot_provider_request,
|
||||||
|
)
|
||||||
|
attach_request_snapshot(exc, snapshot_path)
|
||||||
|
raise exc
|
||||||
if isinstance(transcription_text, str):
|
if isinstance(transcription_text, str):
|
||||||
response.content = transcription_text
|
response.content = transcription_text
|
||||||
return response, None
|
return response, None
|
||||||
|
|||||||
447
src/llm_models/request_snapshot.py
Normal file
447
src/llm_models/request_snapshot.py
Normal file
@@ -0,0 +1,447 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Mapping, Sequence
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.model_configs import APIProvider, ModelInfo
|
||||||
|
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, EmbeddingRequest, ResponseRequest
|
||||||
|
from src.llm_models.payload_content.message import ImageMessagePart, Message, MessageBuilder, RoleType, TextMessagePart
|
||||||
|
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption, normalize_tool_options
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
LLM_REQUEST_LOG_DIR = PROJECT_ROOT / "logs" / "llm_request"
|
||||||
|
REPLAY_SCRIPT_RELATIVE_PATH = Path("scripts") / "replay_llm_request.py"
|
||||||
|
REPLAY_SCRIPT_PATH = PROJECT_ROOT / REPLAY_SCRIPT_RELATIVE_PATH
|
||||||
|
FILENAME_SAFE_PATTERN = re.compile(r"[^A-Za-z0-9._-]+")
|
||||||
|
SNAPSHOT_VERSION = 1
|
||||||
|
|
||||||
|
logger = get_logger("llm_request_snapshot")
|
||||||
|
|
||||||
|
|
||||||
|
def _json_friendly(value: Any) -> Any:
|
||||||
|
"""将任意对象尽量转换为可写入 JSON 的结构。"""
|
||||||
|
if value is None or isinstance(value, (bool, float, int, str)):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, Enum):
|
||||||
|
return value.value
|
||||||
|
|
||||||
|
if isinstance(value, Path):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
return base64.b64encode(bytes(value)).decode("ascii")
|
||||||
|
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
return {str(key): _json_friendly(item) for key, item in value.items()}
|
||||||
|
|
||||||
|
if isinstance(value, Sequence) and not isinstance(value, (bytes, bytearray, str)):
|
||||||
|
return [_json_friendly(item) for item in value]
|
||||||
|
|
||||||
|
model_dump = getattr(value, "model_dump", None)
|
||||||
|
if callable(model_dump):
|
||||||
|
try:
|
||||||
|
return _json_friendly(model_dump(mode="json", exclude_none=True))
|
||||||
|
except TypeError:
|
||||||
|
return _json_friendly(model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
to_dict = getattr(value, "to_dict", None)
|
||||||
|
if callable(to_dict):
|
||||||
|
return _json_friendly(to_dict())
|
||||||
|
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename_component(value: str) -> str:
|
||||||
|
"""将任意字符串转换为适合文件名使用的片段。"""
|
||||||
|
normalized_value = FILENAME_SAFE_PATTERN.sub("-", value.strip())
|
||||||
|
normalized_value = normalized_value.strip("-._")
|
||||||
|
return normalized_value or "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_tool_call(tool_call: ToolCall) -> dict[str, Any]:
|
||||||
|
"""序列化单个工具调用。"""
|
||||||
|
payload = {
|
||||||
|
"id": tool_call.call_id,
|
||||||
|
"function": {
|
||||||
|
"name": tool_call.func_name,
|
||||||
|
"arguments": _json_friendly(tool_call.args or {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if tool_call.extra_content:
|
||||||
|
payload["extra_content"] = _json_friendly(tool_call.extra_content)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_tool_calls_snapshot(tool_calls: Sequence[ToolCall] | None) -> list[dict[str, Any]]:
|
||||||
|
"""序列化工具调用列表。"""
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
return [_serialize_tool_call(tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_tool_calls_snapshot(raw_tool_calls: Any) -> list[ToolCall]:
|
||||||
|
"""从快照恢复工具调用列表。"""
|
||||||
|
if raw_tool_calls in (None, []):
|
||||||
|
return []
|
||||||
|
if not isinstance(raw_tool_calls, list):
|
||||||
|
raise ValueError("快照中的 tool_calls 必须是列表")
|
||||||
|
|
||||||
|
normalized_tool_calls: list[ToolCall] = []
|
||||||
|
for raw_tool_call in raw_tool_calls:
|
||||||
|
if not isinstance(raw_tool_call, dict):
|
||||||
|
raise ValueError("快照中的 tool_call 项必须是字典")
|
||||||
|
|
||||||
|
function_info = raw_tool_call.get("function", {})
|
||||||
|
if isinstance(function_info, dict):
|
||||||
|
function_name = function_info.get("name")
|
||||||
|
function_arguments = function_info.get("arguments")
|
||||||
|
else:
|
||||||
|
function_name = raw_tool_call.get("name")
|
||||||
|
function_arguments = raw_tool_call.get("arguments")
|
||||||
|
|
||||||
|
call_id = raw_tool_call.get("id") or raw_tool_call.get("call_id")
|
||||||
|
if not isinstance(call_id, str) or not isinstance(function_name, str):
|
||||||
|
raise ValueError("快照中的 tool_call 缺少 id 或 function.name")
|
||||||
|
|
||||||
|
extra_content = raw_tool_call.get("extra_content")
|
||||||
|
normalized_tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
call_id=call_id,
|
||||||
|
func_name=function_name,
|
||||||
|
args=function_arguments if isinstance(function_arguments, dict) else {},
|
||||||
|
extra_content=extra_content if isinstance(extra_content, dict) else None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return normalized_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_message_snapshot(message: Message) -> dict[str, Any]:
|
||||||
|
"""将内部消息对象序列化为可回放的快照结构。"""
|
||||||
|
parts_payload: list[dict[str, Any]] = []
|
||||||
|
for part in message.parts:
|
||||||
|
if isinstance(part, TextMessagePart):
|
||||||
|
parts_payload.append({"type": "text", "text": part.text})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(part, ImageMessagePart):
|
||||||
|
parts_payload.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_base64": part.image_base64,
|
||||||
|
"image_format": part.image_format,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"parts": parts_payload,
|
||||||
|
"role": message.role.value,
|
||||||
|
}
|
||||||
|
if message.tool_call_id:
|
||||||
|
payload["tool_call_id"] = message.tool_call_id
|
||||||
|
if message.tool_name:
|
||||||
|
payload["tool_name"] = message.tool_name
|
||||||
|
if message.tool_calls:
|
||||||
|
payload["tool_calls"] = serialize_tool_calls_snapshot(message.tool_calls)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_message_snapshot(raw_message: Any) -> Message:
|
||||||
|
"""从快照恢复内部消息对象。"""
|
||||||
|
if not isinstance(raw_message, dict):
|
||||||
|
raise ValueError("快照中的 message 必须是字典")
|
||||||
|
|
||||||
|
raw_role = raw_message.get("role")
|
||||||
|
if not isinstance(raw_role, str):
|
||||||
|
raise ValueError("快照中的 message 缺少 role")
|
||||||
|
|
||||||
|
role = RoleType(raw_role)
|
||||||
|
builder = MessageBuilder().set_role(role)
|
||||||
|
|
||||||
|
raw_tool_calls = raw_message.get("tool_calls")
|
||||||
|
tool_calls = deserialize_tool_calls_snapshot(raw_tool_calls)
|
||||||
|
if role == RoleType.Assistant and tool_calls:
|
||||||
|
builder.set_tool_calls(tool_calls)
|
||||||
|
|
||||||
|
tool_call_id = raw_message.get("tool_call_id")
|
||||||
|
if role == RoleType.Tool and isinstance(tool_call_id, str):
|
||||||
|
builder.set_tool_call_id(tool_call_id)
|
||||||
|
|
||||||
|
tool_name = raw_message.get("tool_name")
|
||||||
|
if role == RoleType.Tool and isinstance(tool_name, str) and tool_name:
|
||||||
|
builder.set_tool_name(tool_name)
|
||||||
|
|
||||||
|
raw_parts = raw_message.get("parts", [])
|
||||||
|
if not isinstance(raw_parts, list):
|
||||||
|
raise ValueError("快照中的 message.parts 必须是列表")
|
||||||
|
|
||||||
|
for raw_part in raw_parts:
|
||||||
|
if not isinstance(raw_part, dict):
|
||||||
|
raise ValueError("快照中的 message part 必须是字典")
|
||||||
|
|
||||||
|
part_type = str(raw_part.get("type", "")).strip().lower()
|
||||||
|
if part_type == "text":
|
||||||
|
text = raw_part.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
raise ValueError("文本 part 缺少 text 字段")
|
||||||
|
builder.add_text_content(text)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if part_type == "image":
|
||||||
|
image_format = raw_part.get("image_format")
|
||||||
|
image_base64 = raw_part.get("image_base64")
|
||||||
|
if not isinstance(image_format, str) or not isinstance(image_base64, str):
|
||||||
|
raise ValueError("图片 part 缺少 image_format 或 image_base64")
|
||||||
|
builder.add_image_content(image_format=image_format, image_base64=image_base64)
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(f"不支持的快照消息 part 类型: {part_type}")
|
||||||
|
|
||||||
|
return builder.build()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_messages_snapshot(messages: Sequence[Message]) -> list[dict[str, Any]]:
|
||||||
|
"""序列化消息列表。"""
|
||||||
|
return [serialize_message_snapshot(message) for message in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_messages_snapshot(raw_messages: Any) -> list[Message]:
|
||||||
|
"""从快照恢复消息列表。"""
|
||||||
|
if not isinstance(raw_messages, list):
|
||||||
|
raise ValueError("快照中的 messages 必须是列表")
|
||||||
|
return [deserialize_message_snapshot(raw_message) for raw_message in raw_messages]
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_model_info_snapshot(model_info: ModelInfo) -> dict[str, Any]:
|
||||||
|
"""序列化模型信息。"""
|
||||||
|
return {
|
||||||
|
"api_provider": model_info.api_provider,
|
||||||
|
"extra_params": _json_friendly(dict(model_info.extra_params)),
|
||||||
|
"force_stream_mode": model_info.force_stream_mode,
|
||||||
|
"max_tokens": model_info.max_tokens,
|
||||||
|
"model_identifier": model_info.model_identifier,
|
||||||
|
"name": model_info.name,
|
||||||
|
"temperature": model_info.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_model_info_snapshot(raw_model_info: Any) -> ModelInfo:
|
||||||
|
"""从快照恢复模型信息。"""
|
||||||
|
if not isinstance(raw_model_info, dict):
|
||||||
|
raise ValueError("快照中的 model_info 必须是字典")
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
api_provider=str(raw_model_info.get("api_provider") or ""),
|
||||||
|
extra_params=dict(raw_model_info.get("extra_params") or {}),
|
||||||
|
force_stream_mode=bool(raw_model_info.get("force_stream_mode", False)),
|
||||||
|
max_tokens=raw_model_info.get("max_tokens"),
|
||||||
|
model_identifier=str(raw_model_info.get("model_identifier") or ""),
|
||||||
|
name=str(raw_model_info.get("name") or ""),
|
||||||
|
temperature=raw_model_info.get("temperature"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_response_format_snapshot(response_format: RespFormat | None) -> dict[str, Any] | None:
|
||||||
|
"""序列化响应格式定义。"""
|
||||||
|
if response_format is None:
|
||||||
|
return None
|
||||||
|
return response_format.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_response_format_snapshot(raw_response_format: Any) -> RespFormat | None:
|
||||||
|
"""从快照恢复响应格式定义。"""
|
||||||
|
if raw_response_format is None:
|
||||||
|
return None
|
||||||
|
if not isinstance(raw_response_format, dict):
|
||||||
|
raise ValueError("快照中的 response_format 必须是字典")
|
||||||
|
|
||||||
|
raw_format_type = raw_response_format.get("format_type")
|
||||||
|
if not isinstance(raw_format_type, str):
|
||||||
|
raise ValueError("快照中的 response_format 缺少 format_type")
|
||||||
|
|
||||||
|
format_type = RespFormatType(raw_format_type)
|
||||||
|
raw_schema = raw_response_format.get("schema")
|
||||||
|
schema = raw_schema if isinstance(raw_schema, dict) else None
|
||||||
|
return RespFormat(format_type=format_type, schema=schema)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_tool_options_snapshot(tool_options: Sequence[ToolOption] | None) -> list[dict[str, Any]]:
|
||||||
|
"""序列化工具定义列表。"""
|
||||||
|
if not tool_options:
|
||||||
|
return []
|
||||||
|
return [tool_option.to_openai_function_schema() for tool_option in tool_options]
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_tool_options_snapshot(raw_tool_options: Any) -> list[ToolOption] | None:
|
||||||
|
"""从快照恢复工具定义列表。"""
|
||||||
|
if raw_tool_options in (None, []):
|
||||||
|
return None
|
||||||
|
if not isinstance(raw_tool_options, list):
|
||||||
|
raise ValueError("快照中的 tool_options 必须是列表")
|
||||||
|
return normalize_tool_options(raw_tool_options)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_response_request_snapshot(request: ResponseRequest) -> dict[str, Any]:
|
||||||
|
"""序列化文本/多模态请求。"""
|
||||||
|
return {
|
||||||
|
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"message_list": serialize_messages_snapshot(request.message_list),
|
||||||
|
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||||
|
"request_kind": "response",
|
||||||
|
"response_format": serialize_response_format_snapshot(request.response_format),
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"tool_options": serialize_tool_options_snapshot(request.tool_options),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_embedding_request_snapshot(request: EmbeddingRequest) -> dict[str, Any]:
|
||||||
|
"""序列化嵌入请求。"""
|
||||||
|
return {
|
||||||
|
"embedding_input": request.embedding_input,
|
||||||
|
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||||
|
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||||
|
"request_kind": "embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_audio_request_snapshot(request: AudioTranscriptionRequest) -> dict[str, Any]:
|
||||||
|
"""序列化音频转写请求。"""
|
||||||
|
return {
|
||||||
|
"audio_base64": request.audio_base64,
|
||||||
|
"extra_params": _json_friendly(dict(request.extra_params)),
|
||||||
|
"max_tokens": request.max_tokens,
|
||||||
|
"model_info": serialize_model_info_snapshot(request.model_info),
|
||||||
|
"request_kind": "audio_transcription",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_api_provider_snapshot(api_provider: APIProvider) -> dict[str, Any]:
|
||||||
|
"""序列化 API Provider 配置,排除敏感认证信息。"""
|
||||||
|
return {
|
||||||
|
"auth_header_name": api_provider.auth_header_name,
|
||||||
|
"auth_header_prefix": api_provider.auth_header_prefix,
|
||||||
|
"auth_query_name": api_provider.auth_query_name,
|
||||||
|
"auth_type": api_provider.auth_type,
|
||||||
|
"base_url": api_provider.base_url,
|
||||||
|
"client_type": api_provider.client_type,
|
||||||
|
"default_headers": _json_friendly(dict(api_provider.default_headers)),
|
||||||
|
"default_query": _json_friendly(dict(api_provider.default_query)),
|
||||||
|
"model_list_endpoint": api_provider.model_list_endpoint,
|
||||||
|
"name": api_provider.name,
|
||||||
|
"organization": api_provider.organization,
|
||||||
|
"project": api_provider.project,
|
||||||
|
"retry_interval": api_provider.retry_interval,
|
||||||
|
"timeout": api_provider.timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_replay_command(snapshot_path: Path) -> str:
|
||||||
|
"""构建回放当前快照的命令。"""
|
||||||
|
return f'uv run python {REPLAY_SCRIPT_RELATIVE_PATH.as_posix()} "{snapshot_path.resolve()}"'
|
||||||
|
|
||||||
|
|
||||||
|
def save_failed_request_snapshot(
|
||||||
|
*,
|
||||||
|
api_provider: APIProvider,
|
||||||
|
client_type: str,
|
||||||
|
error: Exception,
|
||||||
|
internal_request: dict[str, Any],
|
||||||
|
model_info: ModelInfo,
|
||||||
|
operation: str,
|
||||||
|
provider_request: dict[str, Any],
|
||||||
|
) -> Path | None:
|
||||||
|
"""保存失败请求快照。"""
|
||||||
|
try:
|
||||||
|
LLM_REQUEST_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now()
|
||||||
|
file_name = (
|
||||||
|
f"{timestamp.strftime('%Y%m%d_%H%M%S_%f')}"
|
||||||
|
f"_{_sanitize_filename_component(client_type)}"
|
||||||
|
f"_{_sanitize_filename_component(internal_request.get('request_kind', 'request'))}"
|
||||||
|
f"_{_sanitize_filename_component(model_info.name or model_info.model_identifier)}.json"
|
||||||
|
)
|
||||||
|
snapshot_path = (LLM_REQUEST_LOG_DIR / file_name).resolve()
|
||||||
|
|
||||||
|
snapshot_payload: dict[str, Any] = {
|
||||||
|
"api_provider": serialize_api_provider_snapshot(api_provider),
|
||||||
|
"client_type": client_type,
|
||||||
|
"created_at": timestamp.isoformat(timespec="seconds"),
|
||||||
|
"error": {
|
||||||
|
"message": str(error),
|
||||||
|
"status_code": getattr(error, "status_code", None),
|
||||||
|
"type": type(error).__name__,
|
||||||
|
},
|
||||||
|
"internal_request": internal_request,
|
||||||
|
"model_info": serialize_model_info_snapshot(model_info),
|
||||||
|
"operation": operation,
|
||||||
|
"provider_request": _json_friendly(provider_request),
|
||||||
|
"snapshot_version": SNAPSHOT_VERSION,
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot_payload["replay"] = {
|
||||||
|
"command": build_replay_command(snapshot_path),
|
||||||
|
"file_uri": snapshot_path.as_uri(),
|
||||||
|
"script_path": str(REPLAY_SCRIPT_PATH),
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshot_path.write_text(
|
||||||
|
json.dumps(snapshot_payload, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return snapshot_path
|
||||||
|
except Exception:
|
||||||
|
logger.exception("淇濆瓨 LLM 澶辫触璇锋眰蹇収鏃跺彂鐢熷紓甯?")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def attach_request_snapshot(exception: Exception, snapshot_path: Path | None) -> None:
|
||||||
|
"""将请求快照信息挂载到异常对象上。"""
|
||||||
|
if snapshot_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
exception.request_snapshot_path = str(snapshot_path.resolve())
|
||||||
|
exception.request_snapshot_uri = snapshot_path.resolve().as_uri()
|
||||||
|
exception.request_snapshot_replay_command = build_replay_command(snapshot_path)
|
||||||
|
|
||||||
|
|
||||||
|
def has_request_snapshot(exception: Exception) -> bool:
|
||||||
|
"""鍒ゆ柇寮傚父鏄惁宸插叧鑱斾簡璇锋眰蹇収銆?"""
|
||||||
|
for candidate in (exception, getattr(exception, "__cause__", None)):
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
if getattr(candidate, "request_snapshot_path", ""):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def format_request_snapshot_log_info(exception: Exception) -> str:
|
||||||
|
"""将异常上的快照信息格式化为日志片段。"""
|
||||||
|
for candidate in (exception, getattr(exception, "__cause__", None)):
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
snapshot_path = getattr(candidate, "request_snapshot_path", "")
|
||||||
|
snapshot_uri = getattr(candidate, "request_snapshot_uri", "")
|
||||||
|
replay_command = getattr(candidate, "request_snapshot_replay_command", "")
|
||||||
|
if not any([snapshot_path, snapshot_uri, replay_command]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lines: list[str] = []
|
||||||
|
if snapshot_path:
|
||||||
|
lines.append(f"请求快照路径: {snapshot_path}")
|
||||||
|
if snapshot_uri:
|
||||||
|
lines.append(f"请求快照链接: {snapshot_uri}")
|
||||||
|
if replay_command:
|
||||||
|
lines.append(f"重放命令: {replay_command}")
|
||||||
|
if lines:
|
||||||
|
return "\n " + "\n ".join(lines)
|
||||||
|
|
||||||
|
return ""
|
||||||
@@ -37,6 +37,7 @@ from src.llm_models.model_client.base_client import (
|
|||||||
UsageRecord,
|
UsageRecord,
|
||||||
client_registry,
|
client_registry,
|
||||||
)
|
)
|
||||||
|
from src.llm_models.request_snapshot import format_request_snapshot_log_info
|
||||||
from src.llm_models.payload_content.message import Message, MessageBuilder
|
from src.llm_models.payload_content.message import Message, MessageBuilder
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat
|
from src.llm_models.payload_content.resp_format import RespFormat
|
||||||
from src.llm_models.payload_content.tool_option import (
|
from src.llm_models.payload_content.tool_option import (
|
||||||
@@ -1008,6 +1009,16 @@ class LLMOrchestrator:
|
|||||||
Returns:
|
Returns:
|
||||||
str: 可直接拼接到日志中的底层异常描述。
|
str: 可直接拼接到日志中的底层异常描述。
|
||||||
"""
|
"""
|
||||||
|
detail_lines: List[str] = []
|
||||||
|
if e.__cause__:
|
||||||
|
detail_lines.append(f"底层异常类型: {type(e.__cause__).__name__}")
|
||||||
|
detail_lines.append(f"底层异常信息: {e.__cause__}")
|
||||||
|
|
||||||
|
snapshot_info = format_request_snapshot_log_info(e)
|
||||||
|
if detail_lines or snapshot_info:
|
||||||
|
detail_text = "\n " + "\n ".join(detail_lines) if detail_lines else ""
|
||||||
|
return f"{detail_text}{snapshot_info}"
|
||||||
|
|
||||||
if e.__cause__:
|
if e.__cause__:
|
||||||
original_error_type = type(e.__cause__).__name__
|
original_error_type = type(e.__cause__).__name__
|
||||||
original_error_msg = str(e.__cause__)
|
original_error_msg = str(e.__cause__)
|
||||||
|
|||||||
Reference in New Issue
Block a user