feat:支持模型缓存和相关配置

This commit is contained in:
SengokuCola
2026-04-25 13:53:30 +08:00
parent 4b1bc2aba8
commit 9759018a0c
11 changed files with 195 additions and 5 deletions

View File

@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
import json
import time
from rich.console import Group, RenderableType
@@ -48,6 +50,8 @@ from .maisaka_expression_selector import maisaka_expression_selector
logger = get_logger("replyer")
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
@dataclass
class MaisakaReplyContext:
@@ -404,6 +408,35 @@ class BaseMaisakaReplyGenerator:
return self.chat_stream.session_id
return ""
@staticmethod
def _build_debug_request_filename(stream_id: str, model_name: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
raw_name = f"{timestamp}_{stream_id or 'unknown'}_{model_name or 'unknown'}.json"
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
def _save_debug_reply_request_body(
self,
*,
stream_id: str,
model_name: str,
messages: List[Message],
) -> None:
try:
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
request_body = {
"model": model_name,
"request_type": self.request_type,
"stream_id": stream_id,
"created_at": datetime.now().isoformat(timespec="seconds"),
"messages": serialize_prompt_messages(messages),
}
file_path = DEBUG_REPLY_CACHE_DIR / self._build_debug_request_filename(stream_id, model_name)
with file_path.open("w", encoding="utf-8") as file:
json.dump(request_body, file, ensure_ascii=False, indent=2)
logger.info(f"Replyer 请求体已保存: {file_path.resolve()}")
except Exception as exc:
logger.warning(f"保存 Replyer 请求体失败: {exc}")
async def _build_reply_context(
self,
chat_history: List[LLMContextMessage],
@@ -590,6 +623,11 @@ class BaseMaisakaReplyGenerator:
result.completion.request_prompt = prompt_preview
result.request_messages = serialize_prompt_messages(request_messages)
self._save_debug_reply_request_body(
stream_id=preview_chat_id,
model_name=generation_result.model_name or "",
messages=request_messages,
)
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
response_text = (generation_result.response or "").strip()
result.success = bool(response_text)
@@ -612,6 +650,26 @@ class BaseMaisakaReplyGenerator:
f"llm: {llm_ms} ms",
],
)
prompt_cache_hit_tokens = getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0
prompt_cache_miss_tokens = getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
prompt_cache_miss_tokens = max(generation_result.prompt_tokens - prompt_cache_hit_tokens, 0)
prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens
prompt_cache_hit_rate = (
prompt_cache_hit_tokens / prompt_cache_total_tokens * 100
if prompt_cache_total_tokens > 0
else 0
)
result.metrics.extra["prompt_cache_hit_tokens"] = prompt_cache_hit_tokens
result.metrics.extra["prompt_cache_miss_tokens"] = prompt_cache_miss_tokens
result.metrics.extra["prompt_cache_hit_rate"] = round(prompt_cache_hit_rate, 2)
logger.info(
"Replyer KV cache usage - "
f"hit_tokens={prompt_cache_hit_tokens}, "
f"miss_tokens={prompt_cache_miss_tokens}, "
f"hit_rate={prompt_cache_hit_rate:.2f}%, "
f"prompt_tokens={generation_result.prompt_tokens}"
)
if show_replyer_reasoning and result.completion.reasoning_text:
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")

View File

@@ -68,6 +68,8 @@ class LLMResponseResult(BaseDataModel):
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
prompt_cache_hit_tokens: int = 0
prompt_cache_miss_tokens: int = 0
@dataclass(slots=True)
@@ -125,6 +127,8 @@ class LLMServiceResult(BaseDataModel):
"prompt_tokens": self.completion.prompt_tokens,
"completion_tokens": self.completion.completion_tokens,
"total_tokens": self.completion.total_tokens,
"prompt_cache_hit_tokens": self.completion.prompt_cache_hit_tokens,
"prompt_cache_miss_tokens": self.completion.prompt_cache_miss_tokens,
}
if self.completion.tool_calls is not None:
payload["tool_calls"] = [

View File

@@ -32,6 +32,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
"remote": ("#6c6c6c", None, False), # 深灰色,更不显眼
"planner": ("#008080", None, False),
"maisaka_reasoning_engine": ("#008080", None, False),
"maisaka_chat_loop": ("#0087ff", None, False),
"maisaka_runtime": ("#ff5fff", None, False),
"relation": ("#af87af", None, False), # 柔和的紫色,不刺眼
# 聊天相关模块

View File

@@ -56,7 +56,7 @@ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute(
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
MMC_VERSION: str = "1.0.0"
CONFIG_VERSION: str = "8.9.11"
MODEL_CONFIG_VERSION: str = "1.14.1"
MODEL_CONFIG_VERSION: str = "1.14.2"
logger = get_logger("config")

View File

@@ -269,6 +269,26 @@ class ModelInfo(ConfigBase):
)
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
cache: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "database",
},
)
"""是否启用模型输入缓存计费。开启后命中缓存的输入 token 使用 cache_price_in 计费。"""
cache_price_in: float = Field(
default=0.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "database-zap",
"step": 0.001,
},
)
"""缓存命中输入价格 (用于API调用统计, 单位:元/ M token)。仅当 cache=true 时使用。"""
price_out: float = Field(
default=0.0,
ge=0,

View File

@@ -163,6 +163,8 @@ class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
prompt_cache_hit_tokens=usage_record[3] if len(usage_record) > 3 else 0,
prompt_cache_miss_tokens=usage_record[4] if len(usage_record) > 4 else 0,
)
def _attach_usage_record(

View File

@@ -35,6 +35,12 @@ class UsageRecord:
total_tokens: int
"""总token数"""
prompt_cache_hit_tokens: int = 0
"""输入中缓存命中的 token 数"""
prompt_cache_miss_tokens: int = 0
"""输入中缓存未命中的 token 数"""
@dataclass
class APIResponse:
@@ -61,8 +67,8 @@ class APIResponse:
"""响应原始数据"""
UsageTuple = Tuple[int, int, int]
"""统一的使用量元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
UsageTuple = Tuple[int, ...]
"""统一的使用量元组,顺序为 `(prompt_tokens, completion_tokens, total_tokens, prompt_cache_hit_tokens, prompt_cache_miss_tokens)`。"""
StreamResponseHandler = Callable[
[Any, asyncio.Event | None],

View File

@@ -5,6 +5,8 @@ import io
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from uuid import uuid4
@@ -71,6 +73,8 @@ from ..request_snapshot import (
logger = get_logger("llm_models")
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
@@ -120,6 +124,26 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
"""OpenAI 非流式响应解析函数类型。"""
def _build_debug_provider_request_filename(model_name: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json"
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None:
if model_name != "deepseek-v4p":
return
try:
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name)
with file_path.open("w", encoding="utf-8") as file:
json.dump(request_payload, file, ensure_ascii=False, indent=2)
logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}")
except Exception as exc:
logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}")
def _build_fallback_tool_call_id(prefix: str) -> str:
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
@@ -492,10 +516,20 @@ def _extract_usage_record(usage: Any) -> UsageTuple | None:
"""
if usage is None:
return None
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0
prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None:
prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
return (
getattr(usage, "prompt_tokens", 0) or 0,
prompt_tokens,
getattr(usage, "completion_tokens", 0) or 0,
getattr(usage, "total_tokens", 0) or 0,
prompt_cache_hit_tokens,
prompt_cache_miss_tokens,
)
@@ -1147,6 +1181,17 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
"temperature": _snapshot_openai_argument(temperature_argument),
"tools": tools_payload,
}
_save_debug_provider_request_payload(
model_info.name,
{
"base_url": self.api_provider.base_url,
"endpoint": "/chat/completions",
"model_name": model_info.name,
"model_identifier": model_info.model_identifier,
"created_at": datetime.now().isoformat(timespec="seconds"),
"request_kwargs": snapshot_provider_request["request_kwargs"],
},
)
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(

View File

@@ -168,6 +168,25 @@ class LLMUsageRecorder:
def __init__(self):
pass
@staticmethod
def _calculate_input_cost(model_info: ModelInfo, model_usage: UsageRecord) -> float:
"""根据模型缓存配置计算输入 token 费用。"""
prompt_tokens = model_usage.prompt_tokens or 0
if not model_info.cache:
return (prompt_tokens / 1000000) * model_info.price_in
cache_hit_tokens = model_usage.prompt_cache_hit_tokens or 0
cache_miss_tokens = model_usage.prompt_cache_miss_tokens or 0
if cache_miss_tokens == 0 and cache_hit_tokens > 0:
cache_miss_tokens = max(prompt_tokens - cache_hit_tokens, 0)
if cache_hit_tokens + cache_miss_tokens == 0:
cache_miss_tokens = prompt_tokens
cached_cost = (cache_hit_tokens / 1000000) * model_info.cache_price_in
uncached_cost = (cache_miss_tokens / 1000000) * model_info.price_in
return cached_cost + uncached_cost
def record_usage_to_database(
self,
model_info: ModelInfo,
@@ -177,7 +196,7 @@ class LLMUsageRecorder:
endpoint: str,
time_cost: float = 0.0,
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
input_cost = self._calculate_input_cost(model_info, model_usage)
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:

View File

@@ -211,6 +211,8 @@ class LLMOrchestrator:
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
completion_tokens=usage.completion_tokens if usage is not None else 0,
total_tokens=usage.total_tokens if usage is not None else 0,
prompt_cache_hit_tokens=usage.prompt_cache_hit_tokens if usage is not None else 0,
prompt_cache_miss_tokens=usage.prompt_cache_miss_tokens if usage is not None else 0,
)
async def generate_response_for_image(

View File

@@ -248,6 +248,33 @@ class MaisakaChatLoopService:
except (TypeError, ValueError):
return default
@staticmethod
def _log_prompt_cache_usage(
*,
request_kind: str,
prompt_tokens: int,
prompt_cache_hit_tokens: int,
prompt_cache_miss_tokens: int,
) -> None:
"""记录模型 KV cache 命中情况。"""
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens
prompt_cache_hit_rate = (
prompt_cache_hit_tokens / prompt_cache_total_tokens * 100
if prompt_cache_total_tokens > 0
else 0
)
logger.info(
"Maisaka KV cache usage - "
f"request_kind={request_kind}, "
f"hit_tokens={prompt_cache_hit_tokens}, "
f"miss_tokens={prompt_cache_miss_tokens}, "
f"hit_rate={prompt_cache_hit_rate:.2f}%, "
f"prompt_tokens={prompt_tokens}"
)
def _build_personality_prompt(self) -> str:
"""构造人格提示词。"""
@@ -554,6 +581,12 @@ class MaisakaChatLoopService:
interrupt_flag=self._interrupt_flag,
),
)
self._log_prompt_cache_usage(
request_kind=request_kind,
prompt_tokens=generation_result.prompt_tokens,
prompt_cache_hit_tokens=getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0,
prompt_cache_miss_tokens=getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0,
)
final_response = generation_result.response or ""
final_tool_calls = list(generation_result.tool_calls or [])