Files
mai-bot/scripts/replay_llm_request.py
2026-04-07 15:16:06 +08:00

147 lines
5.7 KiB
Python

# ruff: noqa: E402
import argparse
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
PROJECT_ROOT = Path(__file__).resolve().parent.parent
SRC_ROOT = PROJECT_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
sys.path.insert(0, str(SRC_ROOT))
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(1, str(PROJECT_ROOT))
from src.config.config import config_manager
from src.llm_models.model_client.base_client import AudioTranscriptionRequest, ResponseRequest, client_registry
from src.llm_models.model_client.base_client import EmbeddingRequest
from src.llm_models.request_snapshot import (
deserialize_messages_snapshot,
deserialize_model_info_snapshot,
deserialize_response_format_snapshot,
deserialize_tool_options_snapshot,
)
def _load_snapshot(snapshot_path: Path) -> dict[str, Any]:
"""加载请求快照。"""
return json.loads(snapshot_path.read_text(encoding="utf-8"))
def _resolve_api_provider(provider_name: str):
"""根据名称解析当前配置中的 API Provider。"""
model_config = config_manager.get_model_config()
for api_provider in model_config.api_providers:
if api_provider.name == provider_name:
return api_provider
raise ValueError(f"当前配置中不存在名为 {provider_name!r} 的 API Provider")
def _build_response_request(snapshot: dict[str, Any]) -> ResponseRequest:
"""从快照构建响应请求对象。"""
return ResponseRequest(
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
message_list=deserialize_messages_snapshot(snapshot.get("message_list") or []),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
response_format=deserialize_response_format_snapshot(snapshot.get("response_format")),
temperature=snapshot.get("temperature"),
tool_options=deserialize_tool_options_snapshot(snapshot.get("tool_options")),
)
def _build_embedding_request(snapshot: dict[str, Any]) -> EmbeddingRequest:
"""从快照构建嵌入请求对象。"""
return EmbeddingRequest(
embedding_input=str(snapshot.get("embedding_input") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
def _build_audio_request(snapshot: dict[str, Any]) -> AudioTranscriptionRequest:
"""从快照构建音频转写请求对象。"""
return AudioTranscriptionRequest(
audio_base64=str(snapshot.get("audio_base64") or ""),
extra_params=dict(snapshot.get("extra_params") or {}),
max_tokens=snapshot.get("max_tokens"),
model_info=deserialize_model_info_snapshot(snapshot.get("model_info") or {}),
)
async def _replay(snapshot_path: Path) -> int:
"""回放一条失败请求快照。"""
config_manager.initialize()
snapshot = _load_snapshot(snapshot_path)
internal_request = snapshot.get("internal_request")
if not isinstance(internal_request, dict):
raise ValueError("快照缺少 internal_request 字段")
provider_snapshot = snapshot.get("api_provider")
if not isinstance(provider_snapshot, dict):
raise ValueError("快照缺少 api_provider 字段")
provider_name = str(provider_snapshot.get("name") or "")
if not provider_name:
raise ValueError("快照中的 api_provider.name 不能为空")
api_provider = _resolve_api_provider(provider_name)
client = client_registry.get_client_class_instance(api_provider, force_new=True)
request_kind = str(internal_request.get("request_kind") or "").strip()
if request_kind == "response":
response = await client.get_response(_build_response_request(internal_request))
elif request_kind == "embedding":
response = await client.get_embedding(_build_embedding_request(internal_request))
elif request_kind == "audio_transcription":
response = await client.get_audio_transcriptions(_build_audio_request(internal_request))
else:
raise ValueError(f"不支持的 request_kind: {request_kind!r}")
output_payload = {
"content": response.content,
"embedding_length": len(response.embedding or []),
"has_embedding": response.embedding is not None,
"model_name": response.usage.model_name if response.usage is not None else None,
"provider_name": response.usage.provider_name if response.usage is not None else None,
"raw_data_type": type(response.raw_data).__name__ if response.raw_data is not None else None,
"reasoning_content": response.reasoning_content,
"tool_calls": [
{
"args": tool_call.args,
"call_id": tool_call.call_id,
"func_name": tool_call.func_name,
}
for tool_call in (response.tool_calls or [])
],
"usage": {
"completion_tokens": response.usage.completion_tokens,
"prompt_tokens": response.usage.prompt_tokens,
"total_tokens": response.usage.total_tokens,
}
if response.usage is not None
else None,
}
print(json.dumps(output_payload, ensure_ascii=False, indent=2))
return 0
def main() -> int:
"""脚本入口。"""
parser = argparse.ArgumentParser(description="回放失败的 LLM 请求快照。")
parser.add_argument("snapshot_path", help="请求快照 JSON 文件路径")
args = parser.parse_args()
snapshot_path = Path(args.snapshot_path).expanduser().resolve()
if not snapshot_path.exists():
raise FileNotFoundError(f"快照文件不存在: {snapshot_path}")
return asyncio.run(_replay(snapshot_path))
if __name__ == "__main__":
raise SystemExit(main())