feat: add llm cache diagnostics

This commit is contained in:
SengokuCola
2026-05-01 13:00:27 +08:00
parent a37e906862
commit badd4988b6
2 changed files with 1628 additions and 5 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -6,6 +6,8 @@
from typing import Any, Dict, List, Tuple
import hashlib
import inspect
import json
from src.common.data_models.embedding_service_data_models import EmbeddingResult
@@ -26,6 +28,7 @@ from src.llm_models.payload_content.message import Message, MessageBuilder, Role
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMOrchestrator
from src.services.embedding_service import EmbeddingServiceClient
from src.services.llm_cache_stats import record_llm_cache_usage
from src.services.service_task_resolver import (
get_available_models as _get_available_models,
resolve_task_name as _resolve_task_name,
@@ -46,7 +49,7 @@ class LLMServiceClient:
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`
"""
def __init__(self, task_name: str, request_type: str = "") -> None:
def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
"""初始化 LLM 服务门面。
Args:
@@ -55,6 +58,7 @@ class LLMServiceClient:
"""
self.task_name = _resolve_task_name(task_name)
self.request_type = request_type
self.session_id = str(session_id or "").strip()
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
@staticmethod
@@ -85,6 +89,70 @@ class LLMServiceClient:
return LLMImageOptions()
return options
@staticmethod
def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
parts: list[dict[str, Any]] = []
for part in message.parts:
if hasattr(part, "text"):
parts.append({"type": "text", "text": part.text})
continue
image_base64 = getattr(part, "image_base64", "")
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
parts.append(
{
"type": "image",
"format": getattr(part, "image_format", ""),
"size": len(image_base64),
"sha256": image_digest,
}
)
return {
"role": str(message.role.value if hasattr(message.role, "value") else message.role),
"parts": parts,
"tool_call_id": message.tool_call_id,
"tool_name": message.tool_name,
"tool_calls": [
{
"id": tool_call.call_id,
"name": tool_call.func_name,
"arguments": tool_call.args,
"extra_content": tool_call.extra_content,
}
for tool_call in (message.tool_calls or [])
],
}
@classmethod
def _build_cache_stats_prompt_text(
cls,
*,
messages: List[Message],
tool_options: Any,
response_format: Any,
) -> str:
payload = {
"messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
"tool_options": tool_options or [],
"response_format": response_format,
}
return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
"""记录当前调用的 prompt cache 统计。"""
record_llm_cache_usage(
task_name=self.task_name,
request_type=self.request_type,
model_name=result.model_name,
session_id=self.session_id,
prompt_tokens=result.prompt_tokens,
prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
prompt_text=prompt_text,
)
async def generate_response(
self,
prompt: str,
@@ -100,7 +168,12 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
return await self._orchestrator.generate_response_async(
prompt_text = self._build_cache_stats_prompt_text(
messages=[MessageBuilder().add_text_content(prompt).build()],
tool_options=active_options.tool_options,
response_format=active_options.response_format,
)
result = await self._orchestrator.generate_response_async(
prompt=prompt,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
@@ -109,6 +182,8 @@ class LLMServiceClient:
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
self._record_cache_stats(result, prompt_text=prompt_text)
return result
async def generate_response_with_messages(
self,
@@ -125,8 +200,22 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_generation_options(options)
return await self._orchestrator.generate_response_with_message_async(
message_factory=message_factory,
prompt_text_holder: dict[str, str] = {}
def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
if len(inspect.signature(message_factory).parameters) >= 2:
messages = message_factory(client, model_info)
else:
messages = message_factory(client)
prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
messages=messages,
tool_options=active_options.tool_options,
response_format=active_options.response_format,
)
return messages
result = await self._orchestrator.generate_response_with_message_async(
message_factory=cache_stats_message_factory,
temperature=active_options.temperature,
max_tokens=active_options.max_tokens,
tools=active_options.tool_options,
@@ -134,6 +223,8 @@ class LLMServiceClient:
raise_when_empty=active_options.raise_when_empty,
interrupt_flag=active_options.interrupt_flag,
)
self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
return result
async def generate_response_for_image(
self,
@@ -154,7 +245,30 @@ class LLMServiceClient:
LLMResponseResult: 统一文本生成结果。
"""
active_options = self._normalize_image_options(options)
return await self._orchestrator.generate_response_for_image(
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
prompt_text = json.dumps(
{
"messages": [
{
"role": "user",
"parts": [
{"type": "text", "text": prompt},
{
"type": "image",
"format": image_format,
"size": len(image_base64),
"sha256": image_digest,
},
],
}
],
"tool_options": [],
"response_format": None,
},
ensure_ascii=False,
sort_keys=True,
)
result = await self._orchestrator.generate_response_for_image(
prompt=prompt,
image_base64=image_base64,
image_format=image_format,
@@ -162,6 +276,8 @@ class LLMServiceClient:
max_tokens=active_options.max_tokens,
interrupt_flag=active_options.interrupt_flag,
)
self._record_cache_stats(result, prompt_text=prompt_text)
return result
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
"""执行音频转写请求。