feat:支持模型缓存和相关配置

This commit is contained in:
SengokuCola
2026-04-25 13:53:30 +08:00
parent 4b1bc2aba8
commit 9759018a0c
11 changed files with 195 additions and 5 deletions

View File

@@ -163,6 +163,8 @@ class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
prompt_cache_hit_tokens=usage_record[3] if len(usage_record) > 3 else 0,
prompt_cache_miss_tokens=usage_record[4] if len(usage_record) > 4 else 0,
)
def _attach_usage_record(

View File

@@ -35,6 +35,12 @@ class UsageRecord:
total_tokens: int
"""总token数"""
prompt_cache_hit_tokens: int = 0
"""输入中缓存命中的 token 数"""
prompt_cache_miss_tokens: int = 0
"""输入中缓存未命中的 token 数"""
@dataclass
class APIResponse:
@@ -61,8 +67,8 @@ class APIResponse:
"""响应原始数据"""
UsageTuple = Tuple[int, int, int]
"""统一的使用量元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
UsageTuple = Tuple[int, ...]
"""统一的使用量元组,顺序为 `(prompt_tokens, completion_tokens, total_tokens, prompt_cache_hit_tokens, prompt_cache_miss_tokens)`。"""
StreamResponseHandler = Callable[
[Any, asyncio.Event | None],

View File

@@ -5,6 +5,8 @@ import io
import json
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
from uuid import uuid4
@@ -71,6 +73,8 @@ from ..request_snapshot import (
logger = get_logger("llm_models")
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
@@ -120,6 +124,26 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
"""OpenAI 非流式响应解析函数类型。"""
def _build_debug_provider_request_filename(model_name: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json"
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None:
if model_name != "deepseek-v4p":
return
try:
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name)
with file_path.open("w", encoding="utf-8") as file:
json.dump(request_payload, file, ensure_ascii=False, indent=2)
logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}")
except Exception as exc:
logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}")
def _build_fallback_tool_call_id(prefix: str) -> str:
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
@@ -492,10 +516,20 @@ def _extract_usage_record(usage: Any) -> UsageTuple | None:
"""
if usage is None:
return None
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0
prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None:
prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
return (
getattr(usage, "prompt_tokens", 0) or 0,
prompt_tokens,
getattr(usage, "completion_tokens", 0) or 0,
getattr(usage, "total_tokens", 0) or 0,
prompt_cache_hit_tokens,
prompt_cache_miss_tokens,
)
@@ -1147,6 +1181,17 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
"temperature": _snapshot_openai_argument(temperature_argument),
"tools": tools_payload,
}
_save_debug_provider_request_payload(
model_info.name,
{
"base_url": self.api_provider.base_url,
"endpoint": "/chat/completions",
"model_name": model_info.name,
"model_identifier": model_info.model_identifier,
"created_at": datetime.now().isoformat(timespec="seconds"),
"request_kwargs": snapshot_provider_request["request_kwargs"],
},
)
if model_info.force_stream_mode:
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(

View File

@@ -168,6 +168,25 @@ class LLMUsageRecorder:
def __init__(self):
pass
@staticmethod
def _calculate_input_cost(model_info: ModelInfo, model_usage: UsageRecord) -> float:
"""根据模型缓存配置计算输入 token 费用。"""
prompt_tokens = model_usage.prompt_tokens or 0
if not model_info.cache:
return (prompt_tokens / 1000000) * model_info.price_in
cache_hit_tokens = model_usage.prompt_cache_hit_tokens or 0
cache_miss_tokens = model_usage.prompt_cache_miss_tokens or 0
if cache_miss_tokens == 0 and cache_hit_tokens > 0:
cache_miss_tokens = max(prompt_tokens - cache_hit_tokens, 0)
if cache_hit_tokens + cache_miss_tokens == 0:
cache_miss_tokens = prompt_tokens
cached_cost = (cache_hit_tokens / 1000000) * model_info.cache_price_in
uncached_cost = (cache_miss_tokens / 1000000) * model_info.price_in
return cached_cost + uncached_cost
def record_usage_to_database(
self,
model_info: ModelInfo,
@@ -177,7 +196,7 @@ class LLMUsageRecorder:
endpoint: str,
time_cost: float = 0.0,
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
input_cost = self._calculate_input_cost(model_info, model_usage)
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:

View File

@@ -211,6 +211,8 @@ class LLMOrchestrator:
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
completion_tokens=usage.completion_tokens if usage is not None else 0,
total_tokens=usage.total_tokens if usage is not None else 0,
prompt_cache_hit_tokens=usage.prompt_cache_hit_tokens if usage is not None else 0,
prompt_cache_miss_tokens=usage.prompt_cache_miss_tokens if usage is not None else 0,
)
async def generate_response_for_image(