feat: add llm cache diagnostics
This commit is contained in:
1507
src/services/llm_cache_stats.py
Normal file
1507
src/services/llm_cache_stats.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,8 @@
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
|
||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||
@@ -26,6 +28,7 @@ from src.llm_models.payload_content.message import Message, MessageBuilder, Role
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMOrchestrator
|
||||
from src.services.embedding_service import EmbeddingServiceClient
|
||||
from src.services.llm_cache_stats import record_llm_cache_usage
|
||||
from src.services.service_task_resolver import (
|
||||
get_available_models as _get_available_models,
|
||||
resolve_task_name as _resolve_task_name,
|
||||
@@ -46,7 +49,7 @@ class LLMServiceClient:
|
||||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||
"""
|
||||
|
||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
||||
def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
|
||||
"""初始化 LLM 服务门面。
|
||||
|
||||
Args:
|
||||
@@ -55,6 +58,7 @@ class LLMServiceClient:
|
||||
"""
|
||||
self.task_name = _resolve_task_name(task_name)
|
||||
self.request_type = request_type
|
||||
self.session_id = str(session_id or "").strip()
|
||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||
|
||||
@staticmethod
|
||||
@@ -85,6 +89,70 @@ class LLMServiceClient:
|
||||
return LLMImageOptions()
|
||||
return options
|
||||
|
||||
@staticmethod
|
||||
def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
|
||||
parts: list[dict[str, Any]] = []
|
||||
for part in message.parts:
|
||||
if hasattr(part, "text"):
|
||||
parts.append({"type": "text", "text": part.text})
|
||||
continue
|
||||
|
||||
image_base64 = getattr(part, "image_base64", "")
|
||||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||
parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"format": getattr(part, "image_format", ""),
|
||||
"size": len(image_base64),
|
||||
"sha256": image_digest,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"role": str(message.role.value if hasattr(message.role, "value") else message.role),
|
||||
"parts": parts,
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"tool_name": message.tool_name,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.call_id,
|
||||
"name": tool_call.func_name,
|
||||
"arguments": tool_call.args,
|
||||
"extra_content": tool_call.extra_content,
|
||||
}
|
||||
for tool_call in (message.tool_calls or [])
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_cache_stats_prompt_text(
|
||||
cls,
|
||||
*,
|
||||
messages: List[Message],
|
||||
tool_options: Any,
|
||||
response_format: Any,
|
||||
) -> str:
|
||||
payload = {
|
||||
"messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
|
||||
"tool_options": tool_options or [],
|
||||
"response_format": response_format,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
|
||||
"""记录当前调用的 prompt cache 统计。"""
|
||||
|
||||
record_llm_cache_usage(
|
||||
task_name=self.task_name,
|
||||
request_type=self.request_type,
|
||||
model_name=result.model_name,
|
||||
session_id=self.session_id,
|
||||
prompt_tokens=result.prompt_tokens,
|
||||
prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
|
||||
prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
|
||||
prompt_text=prompt_text,
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -100,7 +168,12 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_async(
|
||||
prompt_text = self._build_cache_stats_prompt_text(
|
||||
messages=[MessageBuilder().add_text_content(prompt).build()],
|
||||
tool_options=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
)
|
||||
result = await self._orchestrator.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
@@ -109,6 +182,8 @@ class LLMServiceClient:
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||
return result
|
||||
|
||||
async def generate_response_with_messages(
|
||||
self,
|
||||
@@ -125,8 +200,22 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_generation_options(options)
|
||||
return await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=message_factory,
|
||||
prompt_text_holder: dict[str, str] = {}
|
||||
|
||||
def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
|
||||
if len(inspect.signature(message_factory).parameters) >= 2:
|
||||
messages = message_factory(client, model_info)
|
||||
else:
|
||||
messages = message_factory(client)
|
||||
prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
|
||||
messages=messages,
|
||||
tool_options=active_options.tool_options,
|
||||
response_format=active_options.response_format,
|
||||
)
|
||||
return messages
|
||||
|
||||
result = await self._orchestrator.generate_response_with_message_async(
|
||||
message_factory=cache_stats_message_factory,
|
||||
temperature=active_options.temperature,
|
||||
max_tokens=active_options.max_tokens,
|
||||
tools=active_options.tool_options,
|
||||
@@ -134,6 +223,8 @@ class LLMServiceClient:
|
||||
raise_when_empty=active_options.raise_when_empty,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
|
||||
return result
|
||||
|
||||
async def generate_response_for_image(
|
||||
self,
|
||||
@@ -154,7 +245,30 @@ class LLMServiceClient:
|
||||
LLMResponseResult: 统一文本生成结果。
|
||||
"""
|
||||
active_options = self._normalize_image_options(options)
|
||||
return await self._orchestrator.generate_response_for_image(
|
||||
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||
prompt_text = json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"type": "text", "text": prompt},
|
||||
{
|
||||
"type": "image",
|
||||
"format": image_format,
|
||||
"size": len(image_base64),
|
||||
"sha256": image_digest,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
"tool_options": [],
|
||||
"response_format": None,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
sort_keys=True,
|
||||
)
|
||||
result = await self._orchestrator.generate_response_for_image(
|
||||
prompt=prompt,
|
||||
image_base64=image_base64,
|
||||
image_format=image_format,
|
||||
@@ -162,6 +276,8 @@ class LLMServiceClient:
|
||||
max_tokens=active_options.max_tokens,
|
||||
interrupt_flag=active_options.interrupt_flag,
|
||||
)
|
||||
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||
return result
|
||||
|
||||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||
"""执行音频转写请求。
|
||||
|
||||
Reference in New Issue
Block a user