feat:支持模型缓存和相关配置
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from rich.console import Group, RenderableType
|
||||
@@ -48,6 +50,8 @@ from .maisaka_expression_selector import maisaka_expression_selector
|
||||
|
||||
logger = get_logger("replyer")
|
||||
|
||||
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaisakaReplyContext:
|
||||
@@ -404,6 +408,35 @@ class BaseMaisakaReplyGenerator:
|
||||
return self.chat_stream.session_id
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _build_debug_request_filename(stream_id: str, model_name: str) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
raw_name = f"{timestamp}_{stream_id or 'unknown'}_{model_name or 'unknown'}.json"
|
||||
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
|
||||
|
||||
def _save_debug_reply_request_body(
|
||||
self,
|
||||
*,
|
||||
stream_id: str,
|
||||
model_name: str,
|
||||
messages: List[Message],
|
||||
) -> None:
|
||||
try:
|
||||
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
request_body = {
|
||||
"model": model_name,
|
||||
"request_type": self.request_type,
|
||||
"stream_id": stream_id,
|
||||
"created_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"messages": serialize_prompt_messages(messages),
|
||||
}
|
||||
file_path = DEBUG_REPLY_CACHE_DIR / self._build_debug_request_filename(stream_id, model_name)
|
||||
with file_path.open("w", encoding="utf-8") as file:
|
||||
json.dump(request_body, file, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Replyer 请求体已保存: {file_path.resolve()}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"保存 Replyer 请求体失败: {exc}")
|
||||
|
||||
async def _build_reply_context(
|
||||
self,
|
||||
chat_history: List[LLMContextMessage],
|
||||
@@ -590,6 +623,11 @@ class BaseMaisakaReplyGenerator:
|
||||
|
||||
result.completion.request_prompt = prompt_preview
|
||||
result.request_messages = serialize_prompt_messages(request_messages)
|
||||
self._save_debug_reply_request_body(
|
||||
stream_id=preview_chat_id,
|
||||
model_name=generation_result.model_name or "",
|
||||
messages=request_messages,
|
||||
)
|
||||
llm_ms = round((time.perf_counter() - llm_started_at) * 1000, 2)
|
||||
response_text = (generation_result.response or "").strip()
|
||||
result.success = bool(response_text)
|
||||
@@ -612,6 +650,26 @@ class BaseMaisakaReplyGenerator:
|
||||
f"llm: {llm_ms} ms",
|
||||
],
|
||||
)
|
||||
prompt_cache_hit_tokens = getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0
|
||||
prompt_cache_miss_tokens = getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0
|
||||
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
|
||||
prompt_cache_miss_tokens = max(generation_result.prompt_tokens - prompt_cache_hit_tokens, 0)
|
||||
prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens
|
||||
prompt_cache_hit_rate = (
|
||||
prompt_cache_hit_tokens / prompt_cache_total_tokens * 100
|
||||
if prompt_cache_total_tokens > 0
|
||||
else 0
|
||||
)
|
||||
result.metrics.extra["prompt_cache_hit_tokens"] = prompt_cache_hit_tokens
|
||||
result.metrics.extra["prompt_cache_miss_tokens"] = prompt_cache_miss_tokens
|
||||
result.metrics.extra["prompt_cache_hit_rate"] = round(prompt_cache_hit_rate, 2)
|
||||
logger.info(
|
||||
"Replyer KV cache usage - "
|
||||
f"hit_tokens={prompt_cache_hit_tokens}, "
|
||||
f"miss_tokens={prompt_cache_miss_tokens}, "
|
||||
f"hit_rate={prompt_cache_hit_rate:.2f}%, "
|
||||
f"prompt_tokens={generation_result.prompt_tokens}"
|
||||
)
|
||||
|
||||
if show_replyer_reasoning and result.completion.reasoning_text:
|
||||
logger.info(f"Maisaka 回复器思考内容:\n{result.completion.reasoning_text}")
|
||||
|
||||
@@ -68,6 +68,8 @@ class LLMResponseResult(BaseDataModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
prompt_cache_hit_tokens: int = 0
|
||||
prompt_cache_miss_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -125,6 +127,8 @@ class LLMServiceResult(BaseDataModel):
|
||||
"prompt_tokens": self.completion.prompt_tokens,
|
||||
"completion_tokens": self.completion.completion_tokens,
|
||||
"total_tokens": self.completion.total_tokens,
|
||||
"prompt_cache_hit_tokens": self.completion.prompt_cache_hit_tokens,
|
||||
"prompt_cache_miss_tokens": self.completion.prompt_cache_miss_tokens,
|
||||
}
|
||||
if self.completion.tool_calls is not None:
|
||||
payload["tool_calls"] = [
|
||||
|
||||
@@ -32,6 +32,7 @@ MODULE_COLORS: Dict[str, Tuple[str, Optional[str], bool]] = {
|
||||
"remote": ("#6c6c6c", None, False), # 深灰色,更不显眼
|
||||
"planner": ("#008080", None, False),
|
||||
"maisaka_reasoning_engine": ("#008080", None, False),
|
||||
"maisaka_chat_loop": ("#0087ff", None, False),
|
||||
"maisaka_runtime": ("#ff5fff", None, False),
|
||||
"relation": ("#af87af", None, False), # 柔和的紫色,不刺眼
|
||||
# 聊天相关模块
|
||||
|
||||
@@ -56,7 +56,7 @@ MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute(
|
||||
LEGACY_ENV_PATH: Path = (PROJECT_ROOT / ".env").resolve().absolute()
|
||||
MMC_VERSION: str = "1.0.0"
|
||||
CONFIG_VERSION: str = "8.9.11"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.1"
|
||||
MODEL_CONFIG_VERSION: str = "1.14.2"
|
||||
|
||||
logger = get_logger("config")
|
||||
|
||||
|
||||
@@ -269,6 +269,26 @@ class ModelInfo(ConfigBase):
|
||||
)
|
||||
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||||
|
||||
cache: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""是否启用模型输入缓存计费。开启后命中缓存的输入 token 使用 cache_price_in 计费。"""
|
||||
|
||||
cache_price_in: float = Field(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "database-zap",
|
||||
"step": 0.001,
|
||||
},
|
||||
)
|
||||
"""缓存命中输入价格 (用于API调用统计, 单位:元/ M token)。仅当 cache=true 时使用。"""
|
||||
|
||||
price_out: float = Field(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
|
||||
@@ -163,6 +163,8 @@ class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
|
||||
prompt_tokens=usage_record[0],
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
prompt_cache_hit_tokens=usage_record[3] if len(usage_record) > 3 else 0,
|
||||
prompt_cache_miss_tokens=usage_record[4] if len(usage_record) > 4 else 0,
|
||||
)
|
||||
|
||||
def _attach_usage_record(
|
||||
|
||||
@@ -35,6 +35,12 @@ class UsageRecord:
|
||||
total_tokens: int
|
||||
"""总token数"""
|
||||
|
||||
prompt_cache_hit_tokens: int = 0
|
||||
"""输入中缓存命中的 token 数"""
|
||||
|
||||
prompt_cache_miss_tokens: int = 0
|
||||
"""输入中缓存未命中的 token 数"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
@@ -61,8 +67,8 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
UsageTuple = Tuple[int, int, int]
|
||||
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
|
||||
UsageTuple = Tuple[int, ...]
|
||||
"""统一的使用量元组,顺序为 `(prompt_tokens, completion_tokens, total_tokens, prompt_cache_hit_tokens, prompt_cache_miss_tokens)`。"""
|
||||
|
||||
StreamResponseHandler = Callable[
|
||||
[Any, asyncio.Event | None],
|
||||
|
||||
@@ -5,6 +5,8 @@ import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -71,6 +73,8 @@ from ..request_snapshot import (
|
||||
|
||||
logger = get_logger("llm_models")
|
||||
|
||||
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
|
||||
|
||||
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
|
||||
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
|
||||
|
||||
@@ -120,6 +124,26 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
|
||||
"""OpenAI 非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
def _build_debug_provider_request_filename(model_name: str) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json"
|
||||
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
|
||||
|
||||
|
||||
def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None:
|
||||
if model_name != "deepseek-v4p":
|
||||
return
|
||||
|
||||
try:
|
||||
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name)
|
||||
with file_path.open("w", encoding="utf-8") as file:
|
||||
json.dump(request_payload, file, ensure_ascii=False, indent=2)
|
||||
logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}")
|
||||
|
||||
|
||||
def _build_fallback_tool_call_id(prefix: str) -> str:
|
||||
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
|
||||
|
||||
@@ -492,10 +516,20 @@ def _extract_usage_record(usage: Any) -> UsageTuple | None:
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0
|
||||
prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0
|
||||
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
|
||||
if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None:
|
||||
prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0
|
||||
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
|
||||
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
|
||||
return (
|
||||
getattr(usage, "prompt_tokens", 0) or 0,
|
||||
prompt_tokens,
|
||||
getattr(usage, "completion_tokens", 0) or 0,
|
||||
getattr(usage, "total_tokens", 0) or 0,
|
||||
prompt_cache_hit_tokens,
|
||||
prompt_cache_miss_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -1147,6 +1181,17 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
"temperature": _snapshot_openai_argument(temperature_argument),
|
||||
"tools": tools_payload,
|
||||
}
|
||||
_save_debug_provider_request_payload(
|
||||
model_info.name,
|
||||
{
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/chat/completions",
|
||||
"model_name": model_info.name,
|
||||
"model_identifier": model_info.model_identifier,
|
||||
"created_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"request_kwargs": snapshot_provider_request["request_kwargs"],
|
||||
},
|
||||
)
|
||||
|
||||
if model_info.force_stream_mode:
|
||||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||||
|
||||
@@ -168,6 +168,25 @@ class LLMUsageRecorder:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _calculate_input_cost(model_info: ModelInfo, model_usage: UsageRecord) -> float:
|
||||
"""根据模型缓存配置计算输入 token 费用。"""
|
||||
|
||||
prompt_tokens = model_usage.prompt_tokens or 0
|
||||
if not model_info.cache:
|
||||
return (prompt_tokens / 1000000) * model_info.price_in
|
||||
|
||||
cache_hit_tokens = model_usage.prompt_cache_hit_tokens or 0
|
||||
cache_miss_tokens = model_usage.prompt_cache_miss_tokens or 0
|
||||
if cache_miss_tokens == 0 and cache_hit_tokens > 0:
|
||||
cache_miss_tokens = max(prompt_tokens - cache_hit_tokens, 0)
|
||||
if cache_hit_tokens + cache_miss_tokens == 0:
|
||||
cache_miss_tokens = prompt_tokens
|
||||
|
||||
cached_cost = (cache_hit_tokens / 1000000) * model_info.cache_price_in
|
||||
uncached_cost = (cache_miss_tokens / 1000000) * model_info.price_in
|
||||
return cached_cost + uncached_cost
|
||||
|
||||
def record_usage_to_database(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -177,7 +196,7 @@ class LLMUsageRecorder:
|
||||
endpoint: str,
|
||||
time_cost: float = 0.0,
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
input_cost = self._calculate_input_cost(model_info, model_usage)
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
try:
|
||||
|
||||
@@ -211,6 +211,8 @@ class LLMOrchestrator:
|
||||
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
|
||||
completion_tokens=usage.completion_tokens if usage is not None else 0,
|
||||
total_tokens=usage.total_tokens if usage is not None else 0,
|
||||
prompt_cache_hit_tokens=usage.prompt_cache_hit_tokens if usage is not None else 0,
|
||||
prompt_cache_miss_tokens=usage.prompt_cache_miss_tokens if usage is not None else 0,
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
|
||||
@@ -248,6 +248,33 @@ class MaisakaChatLoopService:
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt_cache_usage(
|
||||
*,
|
||||
request_kind: str,
|
||||
prompt_tokens: int,
|
||||
prompt_cache_hit_tokens: int,
|
||||
prompt_cache_miss_tokens: int,
|
||||
) -> None:
|
||||
"""记录模型 KV cache 命中情况。"""
|
||||
|
||||
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
|
||||
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
|
||||
prompt_cache_total_tokens = prompt_cache_hit_tokens + prompt_cache_miss_tokens
|
||||
prompt_cache_hit_rate = (
|
||||
prompt_cache_hit_tokens / prompt_cache_total_tokens * 100
|
||||
if prompt_cache_total_tokens > 0
|
||||
else 0
|
||||
)
|
||||
logger.info(
|
||||
"Maisaka KV cache usage - "
|
||||
f"request_kind={request_kind}, "
|
||||
f"hit_tokens={prompt_cache_hit_tokens}, "
|
||||
f"miss_tokens={prompt_cache_miss_tokens}, "
|
||||
f"hit_rate={prompt_cache_hit_rate:.2f}%, "
|
||||
f"prompt_tokens={prompt_tokens}"
|
||||
)
|
||||
|
||||
def _build_personality_prompt(self) -> str:
|
||||
"""构造人格提示词。"""
|
||||
|
||||
@@ -554,6 +581,12 @@ class MaisakaChatLoopService:
|
||||
interrupt_flag=self._interrupt_flag,
|
||||
),
|
||||
)
|
||||
self._log_prompt_cache_usage(
|
||||
request_kind=request_kind,
|
||||
prompt_tokens=generation_result.prompt_tokens,
|
||||
prompt_cache_hit_tokens=getattr(generation_result, "prompt_cache_hit_tokens", 0) or 0,
|
||||
prompt_cache_miss_tokens=getattr(generation_result, "prompt_cache_miss_tokens", 0) or 0,
|
||||
)
|
||||
|
||||
final_response = generation_result.response or ""
|
||||
final_tool_calls = list(generation_result.tool_calls or [])
|
||||
|
||||
Reference in New Issue
Block a user