feat:支持模型缓存和相关配置
This commit is contained in:
@@ -163,6 +163,8 @@ class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
|
||||
prompt_tokens=usage_record[0],
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
prompt_cache_hit_tokens=usage_record[3] if len(usage_record) > 3 else 0,
|
||||
prompt_cache_miss_tokens=usage_record[4] if len(usage_record) > 4 else 0,
|
||||
)
|
||||
|
||||
def _attach_usage_record(
|
||||
|
||||
@@ -35,6 +35,12 @@ class UsageRecord:
|
||||
total_tokens: int
|
||||
"""总token数"""
|
||||
|
||||
prompt_cache_hit_tokens: int = 0
|
||||
"""输入中缓存命中的 token 数"""
|
||||
|
||||
prompt_cache_miss_tokens: int = 0
|
||||
"""输入中缓存未命中的 token 数"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIResponse:
|
||||
@@ -61,8 +67,8 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
UsageTuple = Tuple[int, int, int]
|
||||
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
|
||||
UsageTuple = Tuple[int, ...]
|
||||
"""统一的使用量元组,顺序为 `(prompt_tokens, completion_tokens, total_tokens, prompt_cache_hit_tokens, prompt_cache_miss_tokens)`。"""
|
||||
|
||||
StreamResponseHandler = Callable[
|
||||
[Any, asyncio.Event | None],
|
||||
|
||||
@@ -5,6 +5,8 @@ import io
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -71,6 +73,8 @@ from ..request_snapshot import (
|
||||
|
||||
logger = get_logger("llm_models")
|
||||
|
||||
DEBUG_REPLY_CACHE_DIR = Path("logs/debug_reply_cache")
|
||||
|
||||
SUPPORTED_OPENAI_IMAGE_FORMATS = {"jpeg", "png", "webp"}
|
||||
"""OpenAI 兼容图片输入稳定支持的格式集合。"""
|
||||
|
||||
@@ -120,6 +124,26 @@ OpenAIResponseParser = Callable[[ChatCompletion], Tuple[APIResponse, UsageTuple
|
||||
"""OpenAI 非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
def _build_debug_provider_request_filename(model_name: str) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
raw_name = f"provider_{timestamp}_{model_name or 'unknown'}.json"
|
||||
return "".join(char if char.isalnum() or char in ("-", "_", ".") else "_" for char in raw_name)
|
||||
|
||||
|
||||
def _save_debug_provider_request_payload(model_name: str, request_payload: Dict[str, Any]) -> None:
|
||||
if model_name != "deepseek-v4p":
|
||||
return
|
||||
|
||||
try:
|
||||
DEBUG_REPLY_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
file_path = DEBUG_REPLY_CACHE_DIR / _build_debug_provider_request_filename(model_name)
|
||||
with file_path.open("w", encoding="utf-8") as file:
|
||||
json.dump(request_payload, file, ensure_ascii=False, indent=2)
|
||||
logger.info(f"DeepSeek provider 请求体已保存: {file_path.resolve()}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"保存 DeepSeek provider 请求体失败: {exc}")
|
||||
|
||||
|
||||
def _build_fallback_tool_call_id(prefix: str) -> str:
|
||||
"""为缺失原始调用 ID 的工具调用生成唯一兜底标识。"""
|
||||
|
||||
@@ -492,10 +516,20 @@ def _extract_usage_record(usage: Any) -> UsageTuple | None:
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
prompt_cache_hit_tokens = getattr(usage, "prompt_cache_hit_tokens", 0) or 0
|
||||
prompt_cache_miss_tokens = getattr(usage, "prompt_cache_miss_tokens", 0) or 0
|
||||
prompt_tokens_details = getattr(usage, "prompt_tokens_details", None)
|
||||
if prompt_cache_hit_tokens == 0 and prompt_tokens_details is not None:
|
||||
prompt_cache_hit_tokens = getattr(prompt_tokens_details, "cached_tokens", 0) or 0
|
||||
if prompt_cache_miss_tokens == 0 and prompt_cache_hit_tokens > 0:
|
||||
prompt_cache_miss_tokens = max(prompt_tokens - prompt_cache_hit_tokens, 0)
|
||||
return (
|
||||
getattr(usage, "prompt_tokens", 0) or 0,
|
||||
prompt_tokens,
|
||||
getattr(usage, "completion_tokens", 0) or 0,
|
||||
getattr(usage, "total_tokens", 0) or 0,
|
||||
prompt_cache_hit_tokens,
|
||||
prompt_cache_miss_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -1147,6 +1181,17 @@ class OpenaiClient(AdapterClient[AsyncStream[ChatCompletionChunk], ChatCompletio
|
||||
"temperature": _snapshot_openai_argument(temperature_argument),
|
||||
"tools": tools_payload,
|
||||
}
|
||||
_save_debug_provider_request_payload(
|
||||
model_info.name,
|
||||
{
|
||||
"base_url": self.api_provider.base_url,
|
||||
"endpoint": "/chat/completions",
|
||||
"model_name": model_info.name,
|
||||
"model_identifier": model_info.model_identifier,
|
||||
"created_at": datetime.now().isoformat(timespec="seconds"),
|
||||
"request_kwargs": snapshot_provider_request["request_kwargs"],
|
||||
},
|
||||
)
|
||||
|
||||
if model_info.force_stream_mode:
|
||||
stream_task: asyncio.Task[AsyncStream[ChatCompletionChunk]] = asyncio.create_task(
|
||||
|
||||
@@ -168,6 +168,25 @@ class LLMUsageRecorder:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _calculate_input_cost(model_info: ModelInfo, model_usage: UsageRecord) -> float:
|
||||
"""根据模型缓存配置计算输入 token 费用。"""
|
||||
|
||||
prompt_tokens = model_usage.prompt_tokens or 0
|
||||
if not model_info.cache:
|
||||
return (prompt_tokens / 1000000) * model_info.price_in
|
||||
|
||||
cache_hit_tokens = model_usage.prompt_cache_hit_tokens or 0
|
||||
cache_miss_tokens = model_usage.prompt_cache_miss_tokens or 0
|
||||
if cache_miss_tokens == 0 and cache_hit_tokens > 0:
|
||||
cache_miss_tokens = max(prompt_tokens - cache_hit_tokens, 0)
|
||||
if cache_hit_tokens + cache_miss_tokens == 0:
|
||||
cache_miss_tokens = prompt_tokens
|
||||
|
||||
cached_cost = (cache_hit_tokens / 1000000) * model_info.cache_price_in
|
||||
uncached_cost = (cache_miss_tokens / 1000000) * model_info.price_in
|
||||
return cached_cost + uncached_cost
|
||||
|
||||
def record_usage_to_database(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -177,7 +196,7 @@ class LLMUsageRecorder:
|
||||
endpoint: str,
|
||||
time_cost: float = 0.0,
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
input_cost = self._calculate_input_cost(model_info, model_usage)
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
try:
|
||||
|
||||
@@ -211,6 +211,8 @@ class LLMOrchestrator:
|
||||
prompt_tokens=usage.prompt_tokens if usage is not None else 0,
|
||||
completion_tokens=usage.completion_tokens if usage is not None else 0,
|
||||
total_tokens=usage.total_tokens if usage is not None else 0,
|
||||
prompt_cache_hit_tokens=usage.prompt_cache_hit_tokens if usage is not None else 0,
|
||||
prompt_cache_miss_tokens=usage.prompt_cache_miss_tokens if usage is not None else 0,
|
||||
)
|
||||
|
||||
async def generate_response_for_image(
|
||||
|
||||
Reference in New Issue
Block a user