feat: add llm cache diagnostics
This commit is contained in:
1507
src/services/llm_cache_stats.py
Normal file
1507
src/services/llm_cache_stats.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,8 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
from src.common.data_models.embedding_service_data_models import EmbeddingResult
|
||||||
@@ -26,6 +28,7 @@ from src.llm_models.payload_content.message import Message, MessageBuilder, Role
|
|||||||
from src.llm_models.payload_content.tool_option import ToolCall
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
from src.llm_models.utils_model import LLMOrchestrator
|
from src.llm_models.utils_model import LLMOrchestrator
|
||||||
from src.services.embedding_service import EmbeddingServiceClient
|
from src.services.embedding_service import EmbeddingServiceClient
|
||||||
|
from src.services.llm_cache_stats import record_llm_cache_usage
|
||||||
from src.services.service_task_resolver import (
|
from src.services.service_task_resolver import (
|
||||||
get_available_models as _get_available_models,
|
get_available_models as _get_available_models,
|
||||||
resolve_task_name as _resolve_task_name,
|
resolve_task_name as _resolve_task_name,
|
||||||
@@ -46,7 +49,7 @@ class LLMServiceClient:
|
|||||||
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
- `embed_text`(兼容入口,推荐改用 `EmbeddingServiceClient`)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, task_name: str, request_type: str = "") -> None:
|
def __init__(self, task_name: str, request_type: str = "", session_id: str = "") -> None:
|
||||||
"""初始化 LLM 服务门面。
|
"""初始化 LLM 服务门面。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -55,6 +58,7 @@ class LLMServiceClient:
|
|||||||
"""
|
"""
|
||||||
self.task_name = _resolve_task_name(task_name)
|
self.task_name = _resolve_task_name(task_name)
|
||||||
self.request_type = request_type
|
self.request_type = request_type
|
||||||
|
self.session_id = str(session_id or "").strip()
|
||||||
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
self._orchestrator = LLMOrchestrator(task_name=self.task_name, request_type=request_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -85,6 +89,70 @@ class LLMServiceClient:
|
|||||||
return LLMImageOptions()
|
return LLMImageOptions()
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_message_for_cache_stats(message: Message) -> Dict[str, Any]:
|
||||||
|
parts: list[dict[str, Any]] = []
|
||||||
|
for part in message.parts:
|
||||||
|
if hasattr(part, "text"):
|
||||||
|
parts.append({"type": "text", "text": part.text})
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_base64 = getattr(part, "image_base64", "")
|
||||||
|
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||||
|
parts.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"format": getattr(part, "image_format", ""),
|
||||||
|
"size": len(image_base64),
|
||||||
|
"sha256": image_digest,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": str(message.role.value if hasattr(message.role, "value") else message.role),
|
||||||
|
"parts": parts,
|
||||||
|
"tool_call_id": message.tool_call_id,
|
||||||
|
"tool_name": message.tool_name,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tool_call.call_id,
|
||||||
|
"name": tool_call.func_name,
|
||||||
|
"arguments": tool_call.args,
|
||||||
|
"extra_content": tool_call.extra_content,
|
||||||
|
}
|
||||||
|
for tool_call in (message.tool_calls or [])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _build_cache_stats_prompt_text(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
messages: List[Message],
|
||||||
|
tool_options: Any,
|
||||||
|
response_format: Any,
|
||||||
|
) -> str:
|
||||||
|
payload = {
|
||||||
|
"messages": [cls._serialize_message_for_cache_stats(message) for message in messages],
|
||||||
|
"tool_options": tool_options or [],
|
||||||
|
"response_format": response_format,
|
||||||
|
}
|
||||||
|
return json.dumps(payload, ensure_ascii=False, sort_keys=True, default=str)
|
||||||
|
|
||||||
|
def _record_cache_stats(self, result: LLMResponseResult, prompt_text: str | None = None) -> None:
|
||||||
|
"""记录当前调用的 prompt cache 统计。"""
|
||||||
|
|
||||||
|
record_llm_cache_usage(
|
||||||
|
task_name=self.task_name,
|
||||||
|
request_type=self.request_type,
|
||||||
|
model_name=result.model_name,
|
||||||
|
session_id=self.session_id,
|
||||||
|
prompt_tokens=result.prompt_tokens,
|
||||||
|
prompt_cache_hit_tokens=result.prompt_cache_hit_tokens,
|
||||||
|
prompt_cache_miss_tokens=result.prompt_cache_miss_tokens,
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
)
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -100,7 +168,12 @@ class LLMServiceClient:
|
|||||||
LLMResponseResult: 统一文本生成结果。
|
LLMResponseResult: 统一文本生成结果。
|
||||||
"""
|
"""
|
||||||
active_options = self._normalize_generation_options(options)
|
active_options = self._normalize_generation_options(options)
|
||||||
return await self._orchestrator.generate_response_async(
|
prompt_text = self._build_cache_stats_prompt_text(
|
||||||
|
messages=[MessageBuilder().add_text_content(prompt).build()],
|
||||||
|
tool_options=active_options.tool_options,
|
||||||
|
response_format=active_options.response_format,
|
||||||
|
)
|
||||||
|
result = await self._orchestrator.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=active_options.temperature,
|
temperature=active_options.temperature,
|
||||||
max_tokens=active_options.max_tokens,
|
max_tokens=active_options.max_tokens,
|
||||||
@@ -109,6 +182,8 @@ class LLMServiceClient:
|
|||||||
raise_when_empty=active_options.raise_when_empty,
|
raise_when_empty=active_options.raise_when_empty,
|
||||||
interrupt_flag=active_options.interrupt_flag,
|
interrupt_flag=active_options.interrupt_flag,
|
||||||
)
|
)
|
||||||
|
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||||
|
return result
|
||||||
|
|
||||||
async def generate_response_with_messages(
|
async def generate_response_with_messages(
|
||||||
self,
|
self,
|
||||||
@@ -125,8 +200,22 @@ class LLMServiceClient:
|
|||||||
LLMResponseResult: 统一文本生成结果。
|
LLMResponseResult: 统一文本生成结果。
|
||||||
"""
|
"""
|
||||||
active_options = self._normalize_generation_options(options)
|
active_options = self._normalize_generation_options(options)
|
||||||
return await self._orchestrator.generate_response_with_message_async(
|
prompt_text_holder: dict[str, str] = {}
|
||||||
message_factory=message_factory,
|
|
||||||
|
def cache_stats_message_factory(client: BaseClient, model_info: Any = None) -> List[Message]:
|
||||||
|
if len(inspect.signature(message_factory).parameters) >= 2:
|
||||||
|
messages = message_factory(client, model_info)
|
||||||
|
else:
|
||||||
|
messages = message_factory(client)
|
||||||
|
prompt_text_holder["prompt_text"] = self._build_cache_stats_prompt_text(
|
||||||
|
messages=messages,
|
||||||
|
tool_options=active_options.tool_options,
|
||||||
|
response_format=active_options.response_format,
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
result = await self._orchestrator.generate_response_with_message_async(
|
||||||
|
message_factory=cache_stats_message_factory,
|
||||||
temperature=active_options.temperature,
|
temperature=active_options.temperature,
|
||||||
max_tokens=active_options.max_tokens,
|
max_tokens=active_options.max_tokens,
|
||||||
tools=active_options.tool_options,
|
tools=active_options.tool_options,
|
||||||
@@ -134,6 +223,8 @@ class LLMServiceClient:
|
|||||||
raise_when_empty=active_options.raise_when_empty,
|
raise_when_empty=active_options.raise_when_empty,
|
||||||
interrupt_flag=active_options.interrupt_flag,
|
interrupt_flag=active_options.interrupt_flag,
|
||||||
)
|
)
|
||||||
|
self._record_cache_stats(result, prompt_text=prompt_text_holder.get("prompt_text"))
|
||||||
|
return result
|
||||||
|
|
||||||
async def generate_response_for_image(
|
async def generate_response_for_image(
|
||||||
self,
|
self,
|
||||||
@@ -154,7 +245,30 @@ class LLMServiceClient:
|
|||||||
LLMResponseResult: 统一文本生成结果。
|
LLMResponseResult: 统一文本生成结果。
|
||||||
"""
|
"""
|
||||||
active_options = self._normalize_image_options(options)
|
active_options = self._normalize_image_options(options)
|
||||||
return await self._orchestrator.generate_response_for_image(
|
image_digest = hashlib.sha256(image_base64.encode("utf-8")).hexdigest() if image_base64 else ""
|
||||||
|
prompt_text = json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [
|
||||||
|
{"type": "text", "text": prompt},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"format": image_format,
|
||||||
|
"size": len(image_base64),
|
||||||
|
"sha256": image_digest,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_options": [],
|
||||||
|
"response_format": None,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
result = await self._orchestrator.generate_response_for_image(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
image_base64=image_base64,
|
image_base64=image_base64,
|
||||||
image_format=image_format,
|
image_format=image_format,
|
||||||
@@ -162,6 +276,8 @@ class LLMServiceClient:
|
|||||||
max_tokens=active_options.max_tokens,
|
max_tokens=active_options.max_tokens,
|
||||||
interrupt_flag=active_options.interrupt_flag,
|
interrupt_flag=active_options.interrupt_flag,
|
||||||
)
|
)
|
||||||
|
self._record_cache_stats(result, prompt_text=prompt_text)
|
||||||
|
return result
|
||||||
|
|
||||||
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
async def transcribe_audio(self, voice_base64: str) -> LLMAudioTranscriptionResult:
|
||||||
"""执行音频转写请求。
|
"""执行音频转写请求。
|
||||||
|
|||||||
Reference in New Issue
Block a user