diff --git a/src/chat/replyer/maisaka_generator_base.py b/src/chat/replyer/maisaka_generator_base.py index 8fbb8d18..868f2858 100644 --- a/src/chat/replyer/maisaka_generator_base.py +++ b/src/chat/replyer/maisaka_generator_base.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, field from datetime import datetime +from pathlib import Path from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple +import json import time from rich.console import Group, RenderableType @@ -48,6 +50,8 @@ from .maisaka_expression_selector import maisaka_expression_selector logger = get_logger("replyer") +DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache") + @dataclass class MaisakaReplyContext: @@ -404,6 +408,35 @@ class BaseMaisakaReplyGenerator: return self.chat_stream.session_id return "" + @staticmethod + def _build_debug_request_filename(stream_id: str, model_name: str) -> str: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + raw_name = f"{timestamp}_{stream_id or 'unknown'}_{model_name or 'unknown'}.json" + return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name) + + def _save_debug_reply_request_body( + self, + *, + stream_id: str, + model_name: str, + messages: List[Message], + ) -> None: + try: + DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True) + request_body = { + "model": model_name, + "request_type": self.request_type, + "stream_id": stream_id, + "created_at": datetime.now().isoformat(timespec="seconds"), + "messages": serialize_prompt_messages(messages), + } + file_path = DEBUG_REPLY_CACHE_DIR / self._build_debug_request_filename(stream_id, model_name) + with file_path.open("w", encoding="utf-8") as file: + json.dump(request_body, file, ensure_ascii=False, indent=2) + logger.info(f"Replyer 请求体已保存: {file_path.resolve()}") + except Exception as exc: + logger.warning(f"保存 Replyer 请求体失败: {exc}") + async def _build_reply_context( self, chat_history: List[LLMContextMessage], @@ -590,6 +623,11 @@ class BaseMaisakaReplyGenerator: result.completion.request_prompt = prompt_preview result.request_messages = serialize_prompt_messages(request_messages) + self._save_debug_reply_request_body( + stream_id=preview_chat_id, + model_name=generation_result.model_name or "", + messages=request_messages, + ) llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2) response_text = (generation_result.response or "").strip() result.success = bool(response_text) @@ -612,6 +650,26 @@ class BaseMaisakaReplyGenerator: f"llm: {llm_ms} ms", ], ) + prompt_cache_hit_tokens = getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0 + prompt_cache_miss_tokens = getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0 + if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0: + prompt_cache_miss_tokens = max(generation_result.prompt_tokens - prompt_cache_hit_tokens, 0) + prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens + prompt_cache_hit_rate = ( + prompt_cache_hit_tokens / prompt_cache_total_tokens * 100 + if prompt_cache_total_tokens > 0 + else 0 + ) + result.metrics.extra["prompt_cache_hit_tokens"] = prompt_cache_hit_tokens + result.metrics.extra["prompt_cache_miss_tokens"] = prompt_cache_miss_tokens + result.metrics.extra["prompt_cache_hit_rate"] = round(prompt_cache_hit_rate, 2) + logger.info( + "Replyer KV cache usage - " + f"hit_tokens={prompt_cache_hit_tokens}, " + f"miss_tokens={prompt_cache_miss_tokens}, " + f"hit_rate={prompt_cache_hit_rate:.2f}%, " + f"prompt_tokens={generation_result.prompt_tokens}" + ) if show_replyer_reasoning and result.completion.reasoning_text: logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}") diff --git a/src/common/data_models/llm_service_data_models.py b/src/common/data_models/llm_service_data_models.py index 4326d410..4023e165 100644 --- a/src/common/data_models/llm_service_data_models.py +++ b/src/common/data_models/llm_service_data_models.py @@ -68,6 +68,8 @@ class LLMResponseResult(BaseDataModel): prompt_tokens: int = 0 completion_tokens: int = 0 total_tokens: int = 0 + prompt_cache_hit_tokens: int = 0 + prompt_cache_miss_tokens: int = 0 @dataclass(slots=True) @@ -125,6 +127,8 @@ class LLMServiceResult(BaseDataModel): "prompt_tokens": self.completion.prompt_tokens, "completion_tokens": self.completion.completion_tokens, "total_tokens": self.completion.total_tokens, + "prompt_cache_hit_tokens": self.completion.prompt_cache_hit_tokens, + "prompt_cache_miss_tokens": self.completion.prompt_cache_miss_tokens, } if self.completion.tool_calls is not None: payload["tool_calls"] = [ diff --git a/src/common/logger_color_and_mapping.py b/src/common/logger_color_and_mapping.py index 41b356df..22c99d34 100644 --- a/src/common/logger_color_and_mapping.py +++ b/src/common/logger_color_and_mapping.py @@ -32,6 +32,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = { "remote": ("#6c6c6c", None, False), # 深灰色,更不显眼 "planner": ("#008080", None, False), "maisaka_reasoning_engine": ("#008080", None, False), + "maisaka_chat_loop": ("#0087ff", None, False), "maisaka_runtime": ("#ff5fff", None, False), "relation": ("#af87af", None, False), # 柔和的紫色,不刺眼 # 聊天相关模块 diff --git a/src/config/config.py b/src/config/config.py index 4ed6d7ab..be9d45c2 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -56,7 +56,7 @@ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute( LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute() MMC_VERSION: str = "1.0.0" CONFIG_VERSION: str = "8.9.11" -MODEL_CONFIG_VERSION: str = "1.14.1" +MODEL_CONFIG_VERSION: str = "1.14.2" logger = get_logger("config") diff --git a/src/config/model_configs.py b/src/config/model_configs.py index 2d436c77..b2a50dde 100644 --- a/src/config/model_configs.py +++ b/src/config/model_configs.py @@ -269,6 +269,26 @@ class ModelInfo(ConfigBase): ) """输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)""" + cache: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "database", + }, + ) + """是否启用模型输入缓存计费。开启后命中缓存的输入 token 使用 cache_price_in 计费。""" + + cache_price_in: float = Field( + default=0.0, + ge=0, + json_schema_extra={ + "x-widget": "input", + "x-icon": "database-zap", + "step": 0.001, + }, + ) + """缓存命中输入价格 (用于API调用统计, 单位:元/ M token)。仅当 cache=true 时使用。""" + price_out: float = Field( default=0.0, ge=0, diff --git a/src/llm_models/model_client/adapter_base.py b/src/llm_models/model_client/adapter_base.py index 660a286d..407a0910 100644 --- a/src/llm_models/model_client/adapter_base.py +++ b/src/llm_models/model_client/adapter_base.py @@ -163,6 +163,8 @@ class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]): prompt_tokens=usage_record[0], completion_tokens=usage_record[1], total_tokens=usage_record[2], + prompt_cache_hit_tokens=usage_record[3] if len(usage_record) > 3 else 0, + prompt_cache_miss_tokens=usage_record[4] if len(usage_record) > 4 else 0, ) def _attach_usage_record( diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index fc03ac02..6e16e759 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -35,6 +35,12 @@ class UsageRecord: total_tokens: int """总token数""" + prompt_cache_hit_tokens: int = 0 + """输入中缓存命中的 token 数""" + + prompt_cache_miss_tokens: int = 0 + """输入中缓存未命中的 token 数""" + @dataclass class APIResponse: @@ -61,8 +67,8 @@ class APIResponse: """响应原始数据""" -UsageTuple = Tuple[int, int, int] -"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。""" +UsageTuple = Tuple[int, ...] +"""统一的使用量元组,顺序为 `(prompt_tokens, completion_tokens, total_tokens, prompt_cache_hit_tokens, prompt_cache_miss_tokens)`。""" StreamResponseHandler = Callable[ [Any, asyncio.Event | None], diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 8a02e37c..490036c5 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -5,6 +5,8 @@ import io import json import re from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast from uuid import uuid4 @@ -71,6 +73,8 @@ from ..request_snapshot import ( logger = get_logger("llm_models") +DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache") + SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"} """OpenAI 兼容图片输入稳定支持的格式集合。""" @@ -120,6 +124,26 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple """OpenAI 非流式响应解析函数类型。""" +def _build_debug_provider_request_filename(model_name: str) -> str: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json" + return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name) + + +def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None: + if model_name != "deepseek-v4p": + return + + try: + DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True) + file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name) + with file_path.open("w", encoding="utf-8") as file: + json.dump(request_payload, file, ensure_ascii=False, indent=2) + logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}") + except Exception as exc: + logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}") + + def _build_fallback_tool_call_id(prefix: str) -> str: """为缺失原始调用 ID 的工具调用生成唯一兜底标识。""" @@ -492,10 +516,20 @@ def _extract_usage_record(usage: Any) -> UsageTuple | None: """ if usage is None: return None + prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0 + prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0 + prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0 + prompt_tokens_details = getattr(usage, "prompt_tokens_details", None) + if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None: + prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0 + if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0: + prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0) return ( - getattr(usage, "prompt_tokens", 0) or 0, + prompt_tokens, getattr(usage, "completion_tokens", 0) or 0, getattr(usage, "total_tokens", 0) or 0, + prompt_cache_hit_tokens, + prompt_cache_miss_tokens, ) @@ -1147,6 +1181,17 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio "temperature": _snapshot_openai_argument(temperature_argument), "tools": tools_payload, } + _save_debug_provider_request_payload( + model_info.name, + { + "base_url": self.api_provider.base_url, + "endpoint": "/chat/completions", + "model_name": model_info.name, + "model_identifier": model_info.model_identifier, + "created_at": datetime.now().isoformat(timespec="seconds"), + "request_kwargs": snapshot_provider_request["request_kwargs"], + }, + ) if model_info.force_stream_mode: stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task( diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 2d164fe8..847e21f3 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -168,6 +168,25 @@ class LLMUsageRecorder: def __init__(self): pass + @staticmethod + def _calculate_input_cost(model_info: ModelInfo, model_usage: UsageRecord) -> float: + """根据模型缓存配置计算输入 token 费用。""" + + prompt_tokens = model_usage.prompt_tokens or 0 + if not model_info.cache: + return (prompt_tokens / 1000000) * model_info.price_in + + cache_hit_tokens = model_usage.prompt_cache_hit_tokens or 0 + cache_miss_tokens = model_usage.prompt_cache_miss_tokens or 0 + if cache_miss_tokens == 0 and cache_hit_tokens > 0: + cache_miss_tokens = max(prompt_tokens - cache_hit_tokens, 0) + if cache_hit_tokens + cache_miss_tokens == 0: + cache_miss_tokens = prompt_tokens + + cached_cost = (cache_hit_tokens / 1000000) * model_info.cache_price_in + uncached_cost = (cache_miss_tokens / 1000000) * model_info.price_in + return cached_cost + uncached_cost + def record_usage_to_database( self, model_info: ModelInfo, @@ -177,7 +196,7 @@ class LLMUsageRecorder: endpoint: str, time_cost: float = 0.0, ): - input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in + input_cost = self._calculate_input_cost(model_info, model_usage) output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out total_cost = round(input_cost + output_cost, 6) try: diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 78cd5cae..e80fe796 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -211,6 +211,8 @@ class LLMOrchestrator: prompt_tokens=usage.prompt_tokens if usage is not None else 0, completion_tokens=usage.completion_tokens if usage is not None else 0, total_tokens=usage.total_tokens if usage is not None else 0, + prompt_cache_hit_tokens=usage.prompt_cache_hit_tokens if usage is not None else 0, + prompt_cache_miss_tokens=usage.prompt_cache_miss_tokens if usage is not None else 0, ) async def generate_response_for_image( diff --git a/src/maisaka/chat_loop_service.py b/src/maisaka/chat_loop_service.py index dd8ced5d..9a6801e6 100644 --- a/src/maisaka/chat_loop_service.py +++ b/src/maisaka/chat_loop_service.py @@ -248,6 +248,33 @@ class MaisakaChatLoopService: except (TypeError, ValueError): return default + @staticmethod + def _log_prompt_cache_usage( + *, + request_kind: str, + prompt_tokens: int, + prompt_cache_hit_tokens: int, + prompt_cache_miss_tokens: int, + ) -> None: + """记录模型 KV cache 命中情况。""" + + if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0: + prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0) + prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens + prompt_cache_hit_rate = ( + prompt_cache_hit_tokens / prompt_cache_total_tokens * 100 + if prompt_cache_total_tokens > 0 + else 0 + ) + logger.info( + "Maisaka KV cache usage - " + f"request_kind={request_kind}, " + f"hit_tokens={prompt_cache_hit_tokens}, " + f"miss_tokens={prompt_cache_miss_tokens}, " + f"hit_rate={prompt_cache_hit_rate:.2f}%, " + f"prompt_tokens={prompt_tokens}" + ) + def _build_personality_prompt(self) -> str: """构造人格提示词。""" @@ -554,6 +581,12 @@ class MaisakaChatLoopService: interrupt_flag=self._interrupt_flag, ), ) + self._log_prompt_cache_usage( + request_kind=request_kind, + prompt_tokens=generation_result.prompt_tokens, + prompt_cache_hit_tokens=getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0, + prompt_cache_miss_tokens=getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0, + ) final_response = generation_result.response or "" final_tool_calls = list(generation_result.tool_calls or [])