Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev
This commit is contained in:
89
src/common/toml_utils.py
Normal file
89
src/common/toml_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
TOML文件工具函数 - 保留格式和注释
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from typing import Any
|
||||
|
||||
|
||||
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
|
||||
"""
|
||||
保存TOML数据到文件,保留现有格式(如果文件存在)
|
||||
|
||||
Args:
|
||||
data: 要保存的数据字典
|
||||
file_path: 文件路径
|
||||
"""
|
||||
# 如果文件不存在,直接创建
|
||||
if not os.path.exists(file_path):
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 如果文件存在,尝试读取现有文件以保留格式
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
existing_doc = tomlkit.load(f)
|
||||
except Exception:
|
||||
# 如果读取失败,直接覆盖
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(data, f)
|
||||
return
|
||||
|
||||
# 递归更新,保留现有格式
|
||||
_merge_toml_preserving_format(existing_doc, data)
|
||||
|
||||
# 保存
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(existing_doc, f)
|
||||
|
||||
|
||||
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
递归合并source到target,保留target中的格式和注释
|
||||
|
||||
Args:
|
||||
target: 目标文档(保留格式)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典且都是表格,递归合并
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是字典/表格
|
||||
_merge_toml_preserving_format(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 其他情况直接替换
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
|
||||
|
||||
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""
|
||||
更新TOML文档中的字段,保留现有的格式和注释
|
||||
|
||||
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
|
||||
|
||||
Args:
|
||||
target: 目标表格(会被修改)
|
||||
source: 源数据(新数据)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
if key in target:
|
||||
# 如果两个都是字典,递归更新
|
||||
if isinstance(value, dict) and isinstance(target[key], dict):
|
||||
if hasattr(target[key], "items"): # 确实是表格
|
||||
_update_toml_doc(target[key], value)
|
||||
else:
|
||||
target[key] = value
|
||||
else:
|
||||
# 直接更新值,保留注释
|
||||
target[key] = value
|
||||
else:
|
||||
# 新键直接添加
|
||||
target[key] = value
|
||||
@@ -5,7 +5,7 @@ import types
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal
|
||||
from typing import Any, Dict, List, Literal, Set, Tuple, Union, cast, get_args, get_origin
|
||||
|
||||
__all__ = ["ConfigBase", "Field", "AttributeData"]
|
||||
|
||||
@@ -44,6 +44,16 @@ class AttrDocBase:
|
||||
# 从类定义节点中提取字段文档
|
||||
return self._extract_field_docs(class_node, allow_extra_methods)
|
||||
|
||||
@classmethod
|
||||
def get_class_field_docs(cls) -> dict[str, str]:
|
||||
class_source = cls._get_class_source()
|
||||
class_node = cls._find_class_node(class_source)
|
||||
return AttrDocBase._extract_field_docs(
|
||||
cast(AttrDocBase, cast(Any, cls)),
|
||||
class_node,
|
||||
allow_extra_methods=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_class_source(cls) -> str:
|
||||
"""获取类定义所在文件的完整源代码"""
|
||||
@@ -265,7 +275,7 @@ class ConfigBase(BaseModel, AttrDocBase):
|
||||
if origin_type in (int, float, str, bool, complex, bytes, Any):
|
||||
continue
|
||||
# 允许嵌套的ConfigBase自定义类
|
||||
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
|
||||
if isinstance(origin_type, type) and issubclass(cast(type, origin_type), ConfigBase):
|
||||
continue
|
||||
# 只允许 list, set, dict 三类泛型
|
||||
if origin_type not in (list, set, dict, List, Set, Dict, Literal):
|
||||
|
||||
@@ -5,25 +5,73 @@ from .config_base import ConfigBase, Field
|
||||
class APIProvider(ConfigBase):
|
||||
"""API提供商配置类"""
|
||||
|
||||
name: str = ""
|
||||
name: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "tag",
|
||||
},
|
||||
)
|
||||
"""API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
|
||||
|
||||
base_url: str = ""
|
||||
base_url: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""API服务商的BaseURL"""
|
||||
|
||||
api_key: str = Field(default_factory=str, repr=False)
|
||||
api_key: str = Field(
|
||||
default_factory=str,
|
||||
repr=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "key",
|
||||
},
|
||||
)
|
||||
"""API密钥"""
|
||||
|
||||
client_type: str = Field(default="openai")
|
||||
client_type: str = Field(
|
||||
default="openai",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "settings",
|
||||
},
|
||||
)
|
||||
"""客户端类型 (可选: openai/google, 默认为openai)"""
|
||||
|
||||
max_retry: int = Field(default=2)
|
||||
max_retry: int = Field(
|
||||
default=2,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "repeat",
|
||||
},
|
||||
)
|
||||
"""最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
|
||||
|
||||
timeout: int = 10
|
||||
timeout: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "clock",
|
||||
"step": 1,
|
||||
},
|
||||
)
|
||||
"""API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
|
||||
|
||||
retry_interval: int = 10
|
||||
retry_interval: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "timer",
|
||||
"step": 1,
|
||||
},
|
||||
)
|
||||
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
@@ -39,34 +87,93 @@ class APIProvider(ConfigBase):
|
||||
|
||||
class ModelInfo(ConfigBase):
|
||||
"""单个模型信息配置类"""
|
||||
|
||||
_validate_any: bool = False
|
||||
suppress_any_warning: bool = True
|
||||
|
||||
model_identifier: str = ""
|
||||
model_identifier: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "package",
|
||||
},
|
||||
)
|
||||
"""模型标识符 (API服务商提供的模型标识符)"""
|
||||
|
||||
name: str = ""
|
||||
name: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "tag",
|
||||
},
|
||||
)
|
||||
"""模型名称 (可随意命名, 在models中需使用这个命名)"""
|
||||
|
||||
api_provider: str = ""
|
||||
api_provider: str = Field(
|
||||
default="",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "link",
|
||||
},
|
||||
)
|
||||
"""API服务商名称 (对应在api_providers中配置的服务商名称)"""
|
||||
|
||||
price_in: float = Field(default=0.0)
|
||||
price_in: float = Field(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "dollar-sign",
|
||||
"step": 0.001,
|
||||
},
|
||||
)
|
||||
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||||
|
||||
price_out: float = Field(default=0.0)
|
||||
price_out: float = Field(
|
||||
default=0.0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "dollar-sign",
|
||||
"step": 0.001,
|
||||
},
|
||||
)
|
||||
"""输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
|
||||
|
||||
temperature: float | None = Field(default=None)
|
||||
|
||||
temperature: float | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "thermometer",
|
||||
},
|
||||
)
|
||||
"""模型级别温度(可选),会覆盖任务配置中的温度"""
|
||||
|
||||
max_tokens: int | None = Field(default=None)
|
||||
max_tokens: int | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
},
|
||||
)
|
||||
"""模型级别最大token数(可选),会覆盖任务配置中的max_tokens"""
|
||||
|
||||
force_stream_mode: bool = Field(default=False)
|
||||
force_stream_mode: bool = Field(
|
||||
default=False,
|
||||
json_schema_extra={
|
||||
"x-widget": "switch",
|
||||
"x-icon": "zap",
|
||||
},
|
||||
)
|
||||
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
|
||||
|
||||
extra_params: dict[str, Any] = Field(default_factory=dict)
|
||||
extra_params: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "sliders",
|
||||
},
|
||||
)
|
||||
"""额外参数 (用于API调用时的额外配置)"""
|
||||
|
||||
def model_post_init(self, context: Any = None):
|
||||
@@ -82,48 +189,139 @@ class ModelInfo(ConfigBase):
|
||||
class TaskConfig(ConfigBase):
|
||||
"""任务配置类"""
|
||||
|
||||
model_list: list[str] = Field(default_factory=list)
|
||||
model_list: list[str] = Field(
|
||||
default_factory=list,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "list",
|
||||
},
|
||||
)
|
||||
"""使用的模型列表, 每个元素对应上面的模型名称(name)"""
|
||||
|
||||
max_tokens: int = 1024
|
||||
max_tokens: int = Field(
|
||||
default=1024,
|
||||
ge=1,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "layers",
|
||||
"step": 1,
|
||||
},
|
||||
)
|
||||
"""任务最大输出token数"""
|
||||
|
||||
temperature: float = 0.3
|
||||
temperature: float = Field(
|
||||
default=0.3,
|
||||
ge=0,
|
||||
le=2,
|
||||
json_schema_extra={
|
||||
"x-widget": "slider",
|
||||
"x-icon": "thermometer",
|
||||
"step": 0.1,
|
||||
},
|
||||
)
|
||||
"""模型温度"""
|
||||
|
||||
slow_threshold: float = 15.0
|
||||
|
||||
slow_threshold: float = Field(
|
||||
default=15.0,
|
||||
ge=0,
|
||||
json_schema_extra={
|
||||
"x-widget": "input",
|
||||
"x-icon": "alert-circle",
|
||||
"step": 0.1,
|
||||
},
|
||||
)
|
||||
"""慢请求阈值(秒),超过此值会输出警告日志"""
|
||||
|
||||
selection_strategy: str = Field(default="balance")
|
||||
selection_strategy: str = Field(
|
||||
default="balance",
|
||||
json_schema_extra={
|
||||
"x-widget": "select",
|
||||
"x-icon": "shuffle",
|
||||
},
|
||||
)
|
||||
"""模型选择策略:balance(负载均衡)或 random(随机选择)"""
|
||||
|
||||
|
||||
class ModelTaskConfig(ConfigBase):
|
||||
"""模型配置类"""
|
||||
|
||||
utils: TaskConfig = Field(default_factory=TaskConfig)
|
||||
utils: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "wrench",
|
||||
},
|
||||
)
|
||||
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
|
||||
|
||||
replyer: TaskConfig = Field(default_factory=TaskConfig)
|
||||
replyer: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "message-square",
|
||||
},
|
||||
)
|
||||
"""首要回复模型配置, 还用于表达器和表达方式学习"""
|
||||
|
||||
vlm: TaskConfig = Field(default_factory=TaskConfig)
|
||||
vlm: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "image",
|
||||
},
|
||||
)
|
||||
"""视觉模型配置"""
|
||||
|
||||
voice: TaskConfig = Field(default_factory=TaskConfig)
|
||||
voice: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "volume-2",
|
||||
},
|
||||
)
|
||||
"""语音识别模型配置"""
|
||||
|
||||
tool_use: TaskConfig = Field(default_factory=TaskConfig)
|
||||
tool_use: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "tools",
|
||||
},
|
||||
)
|
||||
"""工具使用模型配置, 需要使用支持工具调用的模型"""
|
||||
|
||||
planner: TaskConfig = Field(default_factory=TaskConfig)
|
||||
planner: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "map",
|
||||
},
|
||||
)
|
||||
"""规划模型配置"""
|
||||
|
||||
embedding: TaskConfig = Field(default_factory=TaskConfig)
|
||||
embedding: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "database",
|
||||
},
|
||||
)
|
||||
"""嵌入模型配置"""
|
||||
|
||||
lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig)
|
||||
lpmm_entity_extract: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "filter",
|
||||
},
|
||||
)
|
||||
"""LPMM实体提取模型配置"""
|
||||
|
||||
lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig)
|
||||
lpmm_rdf_build: TaskConfig = Field(
|
||||
default_factory=TaskConfig,
|
||||
json_schema_extra={
|
||||
"x-widget": "custom",
|
||||
"x-icon": "network",
|
||||
},
|
||||
)
|
||||
"""LPMM RDF构建模型配置"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
301
src/webui/api/planner.py
Normal file
301
src/webui/api/planner.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
规划器监控API
|
||||
提供规划器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/planner", tags=["planner"])
|
||||
|
||||
# 规划器日志目录
|
||||
PLAN_LOG_DIR = Path("logs/plan")
|
||||
|
||||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
action_count: int
|
||||
action_types: List[str] # 动作类型列表
|
||||
total_plan_ms: float
|
||||
llm_duration_ms: float
|
||||
reasoning_preview: str
|
||||
|
||||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
reasoning: str
|
||||
raw_output: str
|
||||
actions: List[Dict]
|
||||
timing: Dict
|
||||
extra: Optional[Dict] = None
|
||||
|
||||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
|
||||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=PlannerOverview)
|
||||
async def get_planner_overview():
|
||||
"""
|
||||
获取规划器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_plans = 0
|
||||
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
plan_count = len(json_files)
|
||||
total_plans += plan_count
|
||||
|
||||
if plan_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
async def get_chat_plan_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
async def get_log_detail(chat_id: str, filename: str):
|
||||
"""获取规划日志详情 - 按需加载完整内容"""
|
||||
log_file = PLAN_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
|
||||
# 获取最近10条计划的摘要
|
||||
recent_plans = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_plans.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_plans = recent_plans[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_chat_list():
|
||||
"""获取所有聊天ID列表 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
# 收集所有文件
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
269
src/webui/api/replier.py
Normal file
269
src/webui/api/replier.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
回复器监控API
|
||||
提供回复器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/replier", tags=["replier"])
|
||||
|
||||
# 回复器日志目录
|
||||
REPLY_LOG_DIR = Path("logs/reply")
|
||||
|
||||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
model: str
|
||||
success: bool
|
||||
llm_ms: float
|
||||
overall_ms: float
|
||||
output_preview: str
|
||||
|
||||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
output: str
|
||||
processed_output: List[str]
|
||||
model: str
|
||||
reasoning: str
|
||||
think_level: int
|
||||
timing: Dict
|
||||
error: Optional[str] = None
|
||||
success: bool
|
||||
|
||||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
|
||||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=ReplierOverview)
|
||||
async def get_replier_overview():
|
||||
"""
|
||||
获取回复器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not REPLY_LOG_DIR.exists():
|
||||
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_replies = 0
|
||||
|
||||
for chat_dir in REPLY_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
reply_count = len(json_files)
|
||||
total_replies += reply_count
|
||||
|
||||
if reply_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
async def get_chat_reply_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
async def get_reply_log_detail(chat_id: str, filename: str):
|
||||
"""获取回复日志详情 - 按需加载完整内容"""
|
||||
log_file = REPLY_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
overview = await get_replier_overview()
|
||||
|
||||
# 获取最近10条回复的摘要
|
||||
recent_replies = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_replies.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_replies = recent_replies[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_replier_chat_list():
|
||||
"""获取所有聊天ID列表"""
|
||||
overview = await get_replier_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
133
src/webui/config_schema.py
Normal file
133
src/webui/config_schema.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import inspect
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
|
||||
class ConfigSchemaGenerator:
|
||||
@classmethod
|
||||
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
||||
return cls.generate_config_schema(config_class, include_nested=include_nested)
|
||||
|
||||
@classmethod
|
||||
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
|
||||
fields: list[dict[str, Any]] = []
|
||||
nested: dict[str, dict[str, Any]] = {}
|
||||
|
||||
for field_name, field_info in config_class.model_fields.items():
|
||||
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
|
||||
continue
|
||||
|
||||
field_schema = cls._build_field_schema(config_class, field_name, field_info.annotation, field_info)
|
||||
fields.append(field_schema)
|
||||
|
||||
if include_nested:
|
||||
nested_schema = cls._build_nested_schema(field_info.annotation)
|
||||
if nested_schema is not None:
|
||||
nested[field_name] = nested_schema
|
||||
|
||||
return {
|
||||
"className": config_class.__name__,
|
||||
"classDoc": (config_class.__doc__ or "").strip(),
|
||||
"fields": fields,
|
||||
"nested": nested,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
|
||||
return cls.generate_config_schema(annotation)
|
||||
|
||||
if origin in {list, tuple} and args:
|
||||
first = args[0]
|
||||
if inspect.isclass(first) and issubclass(first, ConfigBase):
|
||||
return cls.generate_config_schema(first)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _build_field_schema(
|
||||
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
|
||||
) -> dict[str, Any]:
|
||||
field_docs = config_class.get_class_field_docs()
|
||||
field_type = cls._map_field_type(annotation)
|
||||
schema: dict[str, Any] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"label": field_name,
|
||||
"description": field_docs.get(field_name, field_info.description or ""),
|
||||
"required": field_info.is_required(),
|
||||
}
|
||||
|
||||
if field_info.default is not PydanticUndefined:
|
||||
schema["default"] = field_info.default
|
||||
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin is list and args:
|
||||
schema["items"] = {"type": cls._map_field_type(args[0])}
|
||||
|
||||
options = cls._extract_options(annotation)
|
||||
if options:
|
||||
schema["options"] = options
|
||||
|
||||
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
|
||||
if hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra:
|
||||
schema.update(field_info.json_schema_extra)
|
||||
|
||||
# Task 1d: Map Pydantic constraints to minValue/maxValue (frontend naming convention)
|
||||
if hasattr(field_info, "metadata") and field_info.metadata:
|
||||
for constraint in field_info.metadata:
|
||||
if hasattr(constraint, "ge"):
|
||||
schema["minValue"] = constraint.ge
|
||||
if hasattr(constraint, "le"):
|
||||
schema["maxValue"] = constraint.le
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def _extract_options(annotation: Any) -> list[str] | None:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return None
|
||||
if str(origin) != "typing.Literal":
|
||||
return None
|
||||
|
||||
args = get_args(annotation)
|
||||
options = [str(item) for item in args]
|
||||
return options or None
|
||||
|
||||
@classmethod
|
||||
def _map_field_type(cls, annotation: Any) -> str:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
if origin in {list, tuple}:
|
||||
return "array"
|
||||
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
|
||||
return "object"
|
||||
if annotation is bool:
|
||||
return "boolean"
|
||||
if annotation is int:
|
||||
return "integer"
|
||||
if annotation is float:
|
||||
return "number"
|
||||
if annotation is str:
|
||||
return "string"
|
||||
|
||||
if origin in {list, tuple} and args:
|
||||
return "array"
|
||||
|
||||
if origin in {dict}:
|
||||
return "object"
|
||||
|
||||
if origin is not None and str(origin) == "typing.Literal":
|
||||
return "select"
|
||||
|
||||
return "string"
|
||||
@@ -10,7 +10,7 @@ from typing import Any, Annotated, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
@@ -77,7 +77,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
@@ -227,7 +227,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
ModelConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
@@ -377,7 +377,7 @@ async def update_model_config_section(
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
ModelConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
|
||||
@@ -1,21 +1,27 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, List, Optional
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import threading
|
||||
|
||||
from fastapi import APIRouter, Cookie, File, Form, Header, HTTPException, Query, UploadFile
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Annotated
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
import os
|
||||
import hashlib
|
||||
from PIL import Image
|
||||
import io
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Images, ImageType
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
@@ -61,7 +67,7 @@ def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
|
||||
def _ensure_thumbnail_cache_dir() -> Path:
|
||||
"""确保缩略图缓存目录存在"""
|
||||
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return THUMBNAIL_CACHE_DIR
|
||||
|
||||
|
||||
@@ -99,7 +105,7 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
# GIF 处理:提取第一帧
|
||||
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||
if getattr(img, "n_frames", 1) > 1:
|
||||
img.seek(0) # 确保在第一帧
|
||||
|
||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||
@@ -138,9 +144,9 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
return 0, 0
|
||||
|
||||
# 获取所有表情包的哈希值
|
||||
valid_hashes = set()
|
||||
for emoji in Emoji.select(Emoji.emoji_hash):
|
||||
valid_hashes.add(emoji.emoji_hash)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
valid_hashes = set(session.exec(statement).all())
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
@@ -179,7 +185,6 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
@@ -188,7 +193,6 @@ class EmojiResponse(BaseModel):
|
||||
emotion: Optional[str] # 直接返回字符串
|
||||
record_time: float
|
||||
register_time: Optional[float]
|
||||
usage_count: int
|
||||
last_used_time: Optional[float]
|
||||
|
||||
|
||||
@@ -257,22 +261,19 @@ def verify_auth_token(
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
"""将 Emoji 模型转换为响应对象"""
|
||||
def emoji_to_response(image: Images) -> EmojiResponse:
|
||||
return EmojiResponse(
|
||||
id=emoji.id,
|
||||
full_path=emoji.full_path,
|
||||
format=emoji.format,
|
||||
emoji_hash=emoji.emoji_hash,
|
||||
description=emoji.description,
|
||||
query_count=emoji.query_count,
|
||||
is_registered=emoji.is_registered,
|
||||
is_banned=emoji.is_banned,
|
||||
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
|
||||
record_time=emoji.record_time,
|
||||
register_time=emoji.register_time,
|
||||
usage_count=emoji.usage_count,
|
||||
last_used_time=emoji.last_used_time,
|
||||
id=image.id if image.id is not None else 0,
|
||||
full_path=image.full_path,
|
||||
emoji_hash=image.image_hash,
|
||||
description=image.description,
|
||||
query_count=image.query_count,
|
||||
is_registered=image.is_registered,
|
||||
is_banned=image.is_banned,
|
||||
emotion=image.emotion,
|
||||
record_time=image.record_time.timestamp() if image.record_time else 0.0,
|
||||
register_time=image.register_time.timestamp() if image.register_time else None,
|
||||
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
|
||||
)
|
||||
|
||||
|
||||
@@ -283,8 +284,7 @@ async def get_emoji_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
|
||||
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
|
||||
sort_by: Optional[str] = Query("query_count", description="排序字段"),
|
||||
sort_order: Optional[str] = Query("desc", description="排序方向"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
@@ -298,8 +298,7 @@ async def get_emoji_list(
|
||||
search: 搜索关键词 (匹配 description, emoji_hash)
|
||||
is_registered: 是否已注册筛选
|
||||
is_banned: 是否被禁用筛选
|
||||
format: 格式筛选
|
||||
sort_by: 排序字段 (usage_count, register_time, record_time, last_used_time)
|
||||
sort_by: 排序字段 (query_count, register_time, record_time, last_used_time)
|
||||
sort_order: 排序方向 (asc, desc)
|
||||
authorization: Authorization header
|
||||
|
||||
@@ -310,47 +309,58 @@ async def get_emoji_list(
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
|
||||
statement = statement.where(
|
||||
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
|
||||
)
|
||||
|
||||
# 注册状态过滤
|
||||
if is_registered is not None:
|
||||
query = query.where(Emoji.is_registered == is_registered)
|
||||
statement = statement.where(col(Images.is_registered) == is_registered)
|
||||
|
||||
# 禁用状态过滤
|
||||
if is_banned is not None:
|
||||
query = query.where(Emoji.is_banned == is_banned)
|
||||
|
||||
# 格式过滤
|
||||
if format:
|
||||
query = query.where(Emoji.format == format)
|
||||
statement = statement.where(col(Images.is_banned) == is_banned)
|
||||
|
||||
# 排序字段映射
|
||||
sort_field_map = {
|
||||
"usage_count": Emoji.usage_count,
|
||||
"register_time": Emoji.register_time,
|
||||
"record_time": Emoji.record_time,
|
||||
"last_used_time": Emoji.last_used_time,
|
||||
"usage_count": col(Images.query_count),
|
||||
"query_count": col(Images.query_count),
|
||||
"register_time": col(Images.register_time),
|
||||
"record_time": col(Images.record_time),
|
||||
"last_used_time": col(Images.last_used_time),
|
||||
}
|
||||
|
||||
# 获取排序字段,默认使用 usage_count
|
||||
sort_field = sort_field_map.get(sort_by, Emoji.usage_count)
|
||||
sort_key = sort_by or "query_count"
|
||||
sort_field = sort_field_map.get(sort_key, col(Images.query_count))
|
||||
|
||||
# 应用排序
|
||||
if sort_order == "asc":
|
||||
query = query.order_by(sort_field.asc())
|
||||
statement = statement.order_by(sort_field.asc())
|
||||
else:
|
||||
query = query.order_by(sort_field.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
statement = statement.order_by(sort_field.desc())
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
emojis = query.offset(offset).limit(page_size)
|
||||
statement = statement.offset(offset).limit(page_size)
|
||||
|
||||
with get_db_session() as session:
|
||||
emojis = session.exec(statement).all()
|
||||
|
||||
count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
if search:
|
||||
count_statement = count_statement.where(
|
||||
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
|
||||
)
|
||||
if is_registered is not None:
|
||||
count_statement = count_statement.where(col(Images.is_registered) == is_registered)
|
||||
if is_banned is not None:
|
||||
count_statement = count_statement.where(col(Images.is_banned) == is_banned)
|
||||
total = session.exec(count_statement).one()
|
||||
|
||||
# 转换为响应对象
|
||||
data = [emoji_to_response(emoji) for emoji in emojis]
|
||||
@@ -381,12 +391,17 @@ async def get_emoji_detail(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -416,34 +431,37 @@ async def update_emoji(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# emotion 字段直接使用字符串,无需转换
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = datetime.now()
|
||||
|
||||
# 如果注册状态从 False 变为 True,记录注册时间
|
||||
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
|
||||
update_data["register_time"] = time.time()
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(emoji, field, value)
|
||||
session.add(emoji)
|
||||
|
||||
emoji.save()
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return EmojiUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
return EmojiUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -469,20 +487,22 @@ async def delete_emoji(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 记录删除信息
|
||||
emoji_hash = emoji.emoji_hash
|
||||
emoji_hash = emoji.image_hash
|
||||
session.delete(emoji)
|
||||
|
||||
# 执行删除
|
||||
emoji.delete_instance()
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -505,27 +525,51 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
banned = Emoji.select().where(Emoji.is_banned).count()
|
||||
with get_db_session() as session:
|
||||
total_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
registered_statement = (
|
||||
select(func.count())
|
||||
.select_from(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_registered) == True,
|
||||
)
|
||||
)
|
||||
banned_statement = (
|
||||
select(func.count())
|
||||
.select_from(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_banned) == True,
|
||||
)
|
||||
)
|
||||
|
||||
# 按格式统计
|
||||
formats = {}
|
||||
for emoji in Emoji.select(Emoji.format):
|
||||
fmt = emoji.format
|
||||
formats[fmt] = formats.get(fmt, 0) + 1
|
||||
total = session.exec(total_statement).one()
|
||||
registered = session.exec(registered_statement).one()
|
||||
banned = session.exec(banned_statement).one()
|
||||
|
||||
# 获取最常用的表情包(前10)
|
||||
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
|
||||
top_used_list = [
|
||||
{
|
||||
"id": emoji.id,
|
||||
"emoji_hash": emoji.emoji_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.usage_count,
|
||||
}
|
||||
for emoji in top_used
|
||||
]
|
||||
formats: dict[str, int] = {}
|
||||
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
for full_path in session.exec(format_statement).all():
|
||||
suffix = Path(full_path).suffix.lower().lstrip(".")
|
||||
fmt = suffix or "unknown"
|
||||
formats[fmt] = formats.get(fmt, 0) + 1
|
||||
|
||||
top_used_statement = (
|
||||
select(Images)
|
||||
.where(col(Images.image_type) == ImageType.EMOJI)
|
||||
.order_by(col(Images.query_count).desc())
|
||||
.limit(10)
|
||||
)
|
||||
top_used_list = [
|
||||
{
|
||||
"id": emoji.id,
|
||||
"emoji_hash": emoji.image_hash,
|
||||
"description": emoji.description,
|
||||
"usage_count": emoji.query_count,
|
||||
}
|
||||
for emoji in session.exec(top_used_statement).all()
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -563,23 +607,27 @@ async def register_emoji(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
if emoji.is_registered:
|
||||
raise HTTPException(status_code=400, detail="该表情包已经注册")
|
||||
|
||||
# 注册表情包(如果已封禁,自动解除封禁)
|
||||
emoji.is_registered = True
|
||||
emoji.is_banned = False # 注册时自动解除封禁
|
||||
emoji.register_time = time.time()
|
||||
emoji.save()
|
||||
emoji.is_registered = True
|
||||
emoji.is_banned = False
|
||||
emoji.register_time = datetime.now()
|
||||
session.add(emoji)
|
||||
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
logger.info(f"表情包已注册: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -605,19 +653,23 @@ async def ban_emoji(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 禁用表情包(同时取消注册)
|
||||
emoji.is_banned = True
|
||||
emoji.is_registered = False
|
||||
emoji.save()
|
||||
emoji.is_banned = True
|
||||
emoji.is_registered = False
|
||||
session.add(emoji)
|
||||
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
logger.info(f"表情包已禁用: ID={emoji_id}")
|
||||
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -672,61 +724,58 @@ async def get_emoji_thumbnail(
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
# 如果请求原图,直接返回原文件
|
||||
if original:
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
|
||||
# 尝试获取或生成缩略图
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
if not emoji:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
|
||||
|
||||
# 检查缓存是否存在
|
||||
if cache_path.exists():
|
||||
# 缓存命中,直接返回
|
||||
return FileResponse(
|
||||
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
if original:
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
suffix = Path(emoji.full_path).suffix.lower().lstrip(".")
|
||||
media_type = mime_types.get(suffix, "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path, media_type=media_type, filename=f"{emoji.image_hash}.{suffix}"
|
||||
)
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.image_hash)
|
||||
|
||||
if cache_path.exists():
|
||||
return FileResponse(
|
||||
path=str(cache_path), media_type="image/webp", filename=f"{emoji.image_hash}_thumb.webp"
|
||||
)
|
||||
|
||||
with _generating_lock:
|
||||
if emoji.image_hash not in _generating_thumbnails:
|
||||
_generating_thumbnails.add(emoji.image_hash)
|
||||
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.image_hash)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "generating",
|
||||
"message": "缩略图正在生成中,请稍后重试",
|
||||
"emoji_id": emoji_id,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1",
|
||||
},
|
||||
)
|
||||
|
||||
# 缓存未命中,触发后台生成并返回 202
|
||||
with _generating_lock:
|
||||
if emoji.emoji_hash not in _generating_thumbnails:
|
||||
# 标记为正在生成
|
||||
_generating_thumbnails.add(emoji.emoji_hash)
|
||||
# 提交到线程池后台生成
|
||||
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
|
||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "generating",
|
||||
"message": "缩略图正在生成中,请稍后重试",
|
||||
"emoji_id": emoji_id,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1", # 建议 1 秒后重试
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -762,14 +811,19 @@ async def batch_delete_emojis(
|
||||
|
||||
for emoji_id in request.emoji_ids:
|
||||
try:
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
if emoji:
|
||||
emoji.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除表情包: {emoji_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(Images).where(
|
||||
col(Images.id) == emoji_id,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
emoji = session.exec(statement).first()
|
||||
if emoji:
|
||||
session.delete(emoji)
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除表情包: {emoji_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(emoji_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
@@ -864,19 +918,23 @@ async def upload_emoji(
|
||||
# 计算文件哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查是否已存在相同哈希的表情包
|
||||
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if existing_emoji:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
|
||||
with get_db_session() as session:
|
||||
existing_statement = select(Images).where(
|
||||
col(Images.image_hash) == emoji_hash,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
existing_emoji = session.exec(existing_statement).first()
|
||||
if existing_emoji:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
@@ -889,37 +947,38 @@ async def upload_emoji(
|
||||
|
||||
# 保存文件
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
_ = f.write(file_content)
|
||||
|
||||
logger.info(f"表情包文件已保存: {full_path}")
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description=description,
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
current_time = datetime.now()
|
||||
with get_db_session() as session:
|
||||
emoji = Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description=description,
|
||||
emotion=emotion_str or None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
last_used_time=None,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.flush()
|
||||
|
||||
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
|
||||
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiUploadResponse(
|
||||
success=True,
|
||||
message="表情包上传成功" + ("并已注册" if is_registered else ""),
|
||||
data=emoji_to_response(emoji),
|
||||
)
|
||||
return EmojiUploadResponse(
|
||||
success=True,
|
||||
message="表情包上传成功" + ("并已注册" if is_registered else ""),
|
||||
data=emoji_to_response(emoji),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -951,7 +1010,7 @@ async def batch_upload_emoji(
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
results = {
|
||||
results: dict[str, Any] = {
|
||||
"success": True,
|
||||
"total": len(files),
|
||||
"uploaded": 0,
|
||||
@@ -1008,20 +1067,24 @@ async def batch_upload_emoji(
|
||||
# 计算哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查重复
|
||||
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
}
|
||||
with get_db_session() as session:
|
||||
existing_statement = select(Images).where(
|
||||
col(Images.image_hash) == emoji_hash,
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
)
|
||||
continue
|
||||
if session.exec(existing_statement).first():
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成文件名并保存
|
||||
timestamp = int(time.time())
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
@@ -1032,36 +1095,37 @@ async def batch_upload_emoji(
|
||||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
_ = f.write(file_content)
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description="", # 批量上传暂不设置描述
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
current_time = datetime.now()
|
||||
with get_db_session() as session:
|
||||
emoji = Images(
|
||||
image_type=ImageType.EMOJI,
|
||||
full_path=full_path,
|
||||
image_hash=emoji_hash,
|
||||
description="",
|
||||
emotion=emotion_str or None,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
last_used_time=None,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.flush()
|
||||
|
||||
results["uploaded"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
}
|
||||
)
|
||||
results["uploaded"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
@@ -1138,8 +1202,9 @@ async def get_thumbnail_cache_stats(
|
||||
total_size = sum(f.stat().st_size for f in cache_files)
|
||||
total_size_mb = round(total_size / (1024 * 1024), 2)
|
||||
|
||||
# 统计表情包总数
|
||||
emoji_count = Emoji.select().count()
|
||||
with get_db_session() as session:
|
||||
count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||
emoji_count = session.exec(count_statement).one()
|
||||
|
||||
# 计算覆盖率
|
||||
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
|
||||
@@ -1213,12 +1278,17 @@ async def preheat_thumbnail_cache(
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
# 获取使用次数最高的表情包(未缓存的优先)
|
||||
emojis = (
|
||||
Emoji.select()
|
||||
.where(Emoji.is_banned == False) # noqa: E712 Peewee ORM requires == for boolean comparison
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(limit * 2) # 多查一些,因为有些可能已缓存
|
||||
)
|
||||
with get_db_session() as session:
|
||||
statement = (
|
||||
select(Images)
|
||||
.where(
|
||||
col(Images.image_type) == ImageType.EMOJI,
|
||||
col(Images.is_banned) == False,
|
||||
)
|
||||
.order_by(col(Images.query_count).desc())
|
||||
.limit(limit * 2)
|
||||
)
|
||||
emojis = session.exec(statement).all()
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
@@ -1228,25 +1298,22 @@ async def preheat_thumbnail_cache(
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
cache_path = _get_thumbnail_cache_path(emoji.image_hash)
|
||||
|
||||
# 已缓存,跳过
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 原文件不存在,跳过
|
||||
if not os.path.exists(emoji.full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.image_hash)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
|
||||
@@ -65,9 +65,6 @@ class ExpressionUpdateRequest(BaseModel):
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
checked: Optional[bool] = None
|
||||
rejected: Optional[bool] = None
|
||||
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
|
||||
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
@@ -388,26 +385,16 @@ async def update_expression(
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 冲突检测:如果要求未检查状态,但已经被检查了
|
||||
if request.require_unchecked and getattr(expression, "checked", False):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
|
||||
)
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# 移除 require_unchecked,它不是数据库字段
|
||||
update_data.pop("require_unchecked", None)
|
||||
# 映射 API 字段名到数据库字段名
|
||||
if "chat_id" in update_data:
|
||||
update_data["session_id"] = update_data.pop("chat_id")
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 如果更新了 checked 或 rejected,标记为用户修改
|
||||
if "checked" in update_data or "rejected" in update_data:
|
||||
update_data["modified_by"] = "user"
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data["last_active_time"] = datetime.now()
|
||||
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Annotated
|
||||
from typing import Annotated, Any, List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func as fn
|
||||
from sqlmodel import Session, col, delete, select
|
||||
|
||||
import json
|
||||
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession, Jargon
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
@@ -43,27 +46,26 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
return [chat_id_str]
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
return chat_id_str[:20]
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream and chat_stream.group_name:
|
||||
names.append(chat_stream.group_name)
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
stream_id = stream_ids[0]
|
||||
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
if not chat_session:
|
||||
return stream_id[:20]
|
||||
|
||||
if chat_session.group_id:
|
||||
return str(chat_session.group_id)
|
||||
|
||||
return chat_session.session_id[:20]
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
@@ -79,7 +81,6 @@ class JargonResponse(BaseModel):
|
||||
chat_id: str
|
||||
stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配
|
||||
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
|
||||
is_global: bool = False
|
||||
count: int = 0
|
||||
is_jargon: Optional[bool] = None
|
||||
is_complete: bool = False
|
||||
@@ -94,7 +95,7 @@ class JargonListResponse(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[JargonResponse]
|
||||
data: List[dict[str, Any]]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
@@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel):
|
||||
raw_content: Optional[str] = Field(None, description="原始内容")
|
||||
meaning: Optional[str] = Field(None, description="含义")
|
||||
chat_id: str = Field(..., description="聊天ID")
|
||||
is_global: bool = Field(False, description="是否全局")
|
||||
|
||||
|
||||
class JargonUpdateRequest(BaseModel):
|
||||
@@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel):
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
is_global: Optional[bool] = None
|
||||
is_jargon: Optional[bool] = None
|
||||
|
||||
|
||||
@@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict
|
||||
data: dict[str, Any]
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
@@ -181,27 +180,24 @@ class ChatListResponse(BaseModel):
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
# 解析 chat_id 获取显示名称和 stream_id
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
chat_id = jargon.session_id or ""
|
||||
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": jargon.chat_id,
|
||||
"stream_id": stream_id,
|
||||
"chat_id": chat_id,
|
||||
"stream_id": jargon.session_id,
|
||||
"chat_name": chat_name,
|
||||
"is_global": jargon.is_global,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
"is_complete": jargon.is_complete,
|
||||
"inference_with_context": jargon.inference_with_context,
|
||||
"inference_content_only": jargon.inference_content_only,
|
||||
"inference_content_only": jargon.inference_with_content_only,
|
||||
}
|
||||
|
||||
|
||||
@@ -215,49 +211,41 @@ async def get_jargon_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = Jargon.select()
|
||||
statement = select(Jargon)
|
||||
count_statement = select(fn.count()).select_from(Jargon)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Jargon.content.contains(search))
|
||||
| (Jargon.meaning.contains(search))
|
||||
| (Jargon.raw_content.contains(search))
|
||||
search_filter = (
|
||||
(col(Jargon.content).contains(search))
|
||||
| (col(Jargon.meaning).contains(search))
|
||||
| (col(Jargon.raw_content).contains(search))
|
||||
)
|
||||
statement = statement.where(search_filter)
|
||||
count_statement = count_statement.where(search_filter)
|
||||
|
||||
# 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式)
|
||||
if chat_id:
|
||||
# 从传入的 chat_id 中解析出 stream_id
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
# 使用第一个 stream_id 进行模糊匹配
|
||||
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
|
||||
chat_filter = col(Jargon.session_id).contains(stream_ids[0])
|
||||
else:
|
||||
# 如果无法解析,使用精确匹配
|
||||
query = query.where(Jargon.chat_id == chat_id)
|
||||
chat_filter = col(Jargon.session_id) == chat_id
|
||||
statement = statement.where(chat_filter)
|
||||
count_statement = count_statement.where(chat_filter)
|
||||
|
||||
# 按是否是黑话筛选
|
||||
if is_jargon is not None:
|
||||
query = query.where(Jargon.is_jargon == is_jargon)
|
||||
statement = statement.where(col(Jargon.is_jargon) == is_jargon)
|
||||
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
|
||||
|
||||
# 按是否全局筛选
|
||||
if is_global is not None:
|
||||
query = query.where(Jargon.is_global == is_global)
|
||||
statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
|
||||
statement = statement.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页和排序(按使用次数降序)
|
||||
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
|
||||
query = query.paginate(page, page_size)
|
||||
|
||||
# 转换为响应格式
|
||||
data = [jargon_to_dict(j) for j in query]
|
||||
with get_db_session() as session:
|
||||
total = session.exec(count_statement).one()
|
||||
jargons = session.exec(statement).all()
|
||||
data = [jargon_to_dict(jargon, session) for jargon in jargons]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
@@ -276,10 +264,9 @@ async def get_jargon_list(
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
with get_db_session() as session:
|
||||
statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
|
||||
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
@@ -290,27 +277,28 @@ async def get_chat_list():
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
|
||||
result = []
|
||||
for stream_id in seen_stream_ids:
|
||||
# 尝试从 ChatStreams 表获取聊天名称
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id,方便筛选匹配
|
||||
chat_name=chat_stream.group_name or stream_id,
|
||||
platform=chat_stream.platform,
|
||||
is_group=True,
|
||||
with get_db_session() as session:
|
||||
for stream_id in seen_stream_ids:
|
||||
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
|
||||
if chat_session:
|
||||
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id,
|
||||
chat_name=chat_name,
|
||||
platform=chat_session.platform,
|
||||
is_group=bool(chat_session.group_id),
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id
|
||||
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
|
||||
platform=None,
|
||||
is_group=False,
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id,
|
||||
chat_name=stream_id[:20],
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return ChatListResponse(success=True, data=result)
|
||||
|
||||
@@ -323,35 +311,35 @@ async def get_chat_list():
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
# 总数量
|
||||
total = Jargon.select().count()
|
||||
with get_db_session() as session:
|
||||
total = session.exec(select(fn.count()).select_from(Jargon)).one()
|
||||
|
||||
# 已确认是黑话的数量
|
||||
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
|
||||
confirmed_jargon = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
|
||||
).one()
|
||||
confirmed_not_jargon = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
|
||||
).one()
|
||||
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
|
||||
|
||||
# 已确认不是黑话的数量
|
||||
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
|
||||
complete_count = session.exec(
|
||||
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
|
||||
).one()
|
||||
|
||||
# 未判定的数量
|
||||
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
|
||||
chat_count = session.exec(
|
||||
select(fn.count()).select_from(
|
||||
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
|
||||
)
|
||||
).one()
|
||||
|
||||
# 全局黑话数量
|
||||
global_count = Jargon.select().where(Jargon.is_global).count()
|
||||
|
||||
# 已完成推断的数量
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
|
||||
.group_by(Jargon.chat_id)
|
||||
.order_by(fn.COUNT(Jargon.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
|
||||
top_chats = session.exec(
|
||||
select(col(Jargon.session_id), fn.count().label("count"))
|
||||
.where(col(Jargon.session_id).is_not(None))
|
||||
.group_by(col(Jargon.session_id))
|
||||
.order_by(fn.count().desc())
|
||||
.limit(5)
|
||||
).all()
|
||||
top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
@@ -360,7 +348,6 @@ async def get_jargon_stats():
|
||||
"confirmed_jargon": confirmed_jargon,
|
||||
"confirmed_not_jargon": confirmed_not_jargon,
|
||||
"pending": pending,
|
||||
"global_count": global_count,
|
||||
"complete_count": complete_count,
|
||||
"chat_count": chat_count,
|
||||
"top_chats": top_chats_dict,
|
||||
@@ -376,11 +363,13 @@ async def get_jargon_stats():
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
data = JargonResponse(**jargon_to_dict(jargon, session))
|
||||
|
||||
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
|
||||
return JargonDetailResponse(success=True, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int):
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
with get_db_session() as session:
|
||||
existing = session.exec(
|
||||
select(Jargon).where(
|
||||
(col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
# 创建黑话
|
||||
jargon = Jargon.create(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning,
|
||||
chat_id=request.chat_id,
|
||||
is_global=request.is_global,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
jargon = Jargon(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning or "",
|
||||
session_id=request.chat_id,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
session.add(jargon)
|
||||
session.flush()
|
||||
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
data = JargonResponse(**jargon_to_dict(jargon, session))
|
||||
|
||||
return JargonCreateResponse(
|
||||
success=True,
|
||||
message="创建成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
return JargonCreateResponse(success=True, message="创建成功", data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest):
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
# 增量更新字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
jargon.save()
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if field == "is_global":
|
||||
continue
|
||||
if field == "chat_id":
|
||||
jargon.session_id = value
|
||||
continue
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
session.add(jargon)
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
data = JargonResponse(**jargon_to_dict(jargon, session))
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message="更新成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
return JargonUpdateResponse(success=True, message="更新成功", data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
with get_db_session() as session:
|
||||
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
content = jargon.content
|
||||
jargon.delete_instance()
|
||||
content = jargon.content
|
||||
session.delete(jargon)
|
||||
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message="删除成功",
|
||||
deleted_count=1,
|
||||
)
|
||||
return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
|
||||
with get_db_session() as session:
|
||||
result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids)))
|
||||
deleted_count = result.rowcount or 0
|
||||
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
@@ -516,14 +507,16 @@ async def batch_set_jargon_status(
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
with get_db_session() as session:
|
||||
jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all()
|
||||
for jargon in jargons:
|
||||
jargon.is_jargon = is_jargon
|
||||
session.add(jargon)
|
||||
updated_count = len(jargons)
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {updated_count} 条黑话状态",
|
||||
)
|
||||
return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from src.webui.core import get_token_manager
|
||||
from src.webui.routers.websocket.plugin_progress import update_progress
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ class EmojiResponse(BaseModel):
|
||||
|
||||
id: int
|
||||
full_path: str
|
||||
format: str
|
||||
emoji_hash: str
|
||||
description: str
|
||||
query_count: int
|
||||
@@ -16,7 +15,6 @@ class EmojiResponse(BaseModel):
|
||||
emotion: Optional[str]
|
||||
record_time: float
|
||||
register_time: Optional[float]
|
||||
usage_count: int
|
||||
last_used_time: Optional[float]
|
||||
|
||||
|
||||
|
||||
662
src/webui/services/git_mirror_service.py
Normal file
662
src/webui/services/git_mirror_service.py
Normal file
@@ -0,0 +1,662 @@
|
||||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.git_mirror")
|
||||
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
_update_progress = callback
|
||||
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
|
||||
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
|
||||
GITHUB = "github" # GitHub 官方源(兜底)
|
||||
CUSTOM = "custom" # 自定义镜像源
|
||||
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
"id": "gh-proxy",
|
||||
"name": "gh-proxy 镜像",
|
||||
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
"name": "gh-proxy 香港节点",
|
||||
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
"name": "gh-proxy CDN 节点",
|
||||
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
"name": "gh-proxy EdgeOne 节点",
|
||||
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
"name": "Meyzh GitHub 镜像",
|
||||
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
"name": "GitHub 官方源(兜底)",
|
||||
"raw_prefix": "https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
self._init_default_mirrors()
|
||||
else:
|
||||
self.mirrors = data["git_mirrors"]
|
||||
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
|
||||
else:
|
||||
logger.info("配置文件不存在,创建默认配置")
|
||||
self._init_default_mirrors()
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: str,
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
"raw_prefix": raw_prefix,
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: Optional[str] = None,
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
if name is not None:
|
||||
mirror["name"] = name
|
||||
if raw_prefix is not None:
|
||||
mirror["raw_prefix"] = raw_prefix
|
||||
if clone_prefix is not None:
|
||||
mirror["clone_prefix"] = clone_prefix
|
||||
if enabled is not None:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
return None
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
for i, mirror in enumerate(self.mirrors):
|
||||
if mirror.get("id") == mirror_id:
|
||||
self.mirrors.pop(i)
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
return [m["id"] for m in enabled]
|
||||
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
config: 镜像源配置管理器(可选,默认创建新实例)
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
- version: str - Git 版本号(如果已安装)
|
||||
- path: str - Git 可执行文件路径(如果已安装)
|
||||
- error: str - 错误信息(如果未安装或检测失败)
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
branch: 分支名称
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- data: str - 文件内容(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
if _update_progress:
|
||||
try:
|
||||
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
logger.debug(f"尝试 #{attempt + 1}: {url}")
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = f"请求超时: {e}"
|
||||
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
target_path: 目标路径
|
||||
branch: 分支名称(可选)
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- path: str - 克隆路径(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
async def _clone_with_url(
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(target_path),
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_git_mirror_service: Optional[GitMirrorService] = None
|
||||
|
||||
|
||||
def get_git_mirror_service() -> GitMirrorService:
|
||||
"""获取 Git 镜像源服务实例(单例)"""
|
||||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
Reference in New Issue
Block a user