Merge branch 'r-dev' of github.com:Mai-with-u/MaiBot into r-dev

This commit is contained in:
UnCLAS-Prommer
2026-02-18 16:00:58 +08:00
45 changed files with 7190 additions and 744 deletions

89
src/common/toml_utils.py Normal file
View File

@@ -0,0 +1,89 @@
"""
TOML文件工具函数 - 保留格式和注释
"""
import os
import tomlkit
from typing import Any
def save_toml_with_format(data: dict[str, Any], file_path: str) -> None:
"""
保存TOML数据到文件保留现有格式如果文件存在
Args:
data: 要保存的数据字典
file_path: 文件路径
"""
# 如果文件不存在,直接创建
if not os.path.exists(file_path):
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 如果文件存在,尝试读取现有文件以保留格式
try:
with open(file_path, "r", encoding="utf-8") as f:
existing_doc = tomlkit.load(f)
except Exception:
# 如果读取失败,直接覆盖
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(data, f)
return
# 递归更新,保留现有格式
_merge_toml_preserving_format(existing_doc, data)
# 保存
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(existing_doc, f)
def _merge_toml_preserving_format(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
递归合并source到target保留target中的格式和注释
Args:
target: 目标文档(保留格式)
source: 源数据(新数据)
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典且都是表格,递归合并
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是字典/表格
_merge_toml_preserving_format(target[key], value)
else:
target[key] = value
else:
# 其他情况直接替换
target[key] = value
else:
# 新键直接添加
target[key] = value
def _update_toml_doc(target: dict[str, Any], source: dict[str, Any]) -> None:
"""
更新TOML文档中的字段保留现有的格式和注释
这是一个递归函数,用于在部分更新配置时保留现有的格式和注释。
Args:
target: 目标表格(会被修改)
source: 源数据(新数据)
"""
for key, value in source.items():
if key in target:
# 如果两个都是字典,递归更新
if isinstance(value, dict) and isinstance(target[key], dict):
if hasattr(target[key], "items"): # 确实是表格
_update_toml_doc(target[key], value)
else:
target[key] = value
else:
# 直接更新值,保留注释
target[key] = value
else:
# 新键直接添加
target[key] = value

View File

@@ -5,7 +5,7 @@ import types
from dataclasses import dataclass, field
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field
from typing import Union, get_args, get_origin, Tuple, Any, List, Dict, Set, Literal
from typing import Any, Dict, List, Literal, Set, Tuple, Union, cast, get_args, get_origin
__all__ = ["ConfigBase", "Field", "AttributeData"]
@@ -44,6 +44,16 @@ class AttrDocBase:
# 从类定义节点中提取字段文档
return self._extract_field_docs(class_node, allow_extra_methods)
@classmethod
def get_class_field_docs(cls) -> dict[str, str]:
class_source = cls._get_class_source()
class_node = cls._find_class_node(class_source)
return AttrDocBase._extract_field_docs(
cast(AttrDocBase, cast(Any, cls)),
class_node,
allow_extra_methods=False,
)
@classmethod
def _get_class_source(cls) -> str:
"""获取类定义所在文件的完整源代码"""
@@ -265,7 +275,7 @@ class ConfigBase(BaseModel, AttrDocBase):
if origin_type in (int, float, str, bool, complex, bytes, Any):
continue
# 允许嵌套的ConfigBase自定义类
if inspect.isclass(origin_type) and issubclass(origin_type, ConfigBase): # type: ignore
if isinstance(origin_type, type) and issubclass(cast(type, origin_type), ConfigBase):
continue
# 只允许 list, set, dict 三类泛型
if origin_type not in (list, set, dict, List, Set, Dict, Literal):

View File

@@ -5,25 +5,73 @@ from .config_base import ConfigBase, Field
class APIProvider(ConfigBase):
"""API提供商配置类"""
name: str = ""
name: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "tag",
},
)
"""API服务商名称 (可随意命名, 在models的api-provider中需使用这个命名)"""
base_url: str = ""
base_url: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "link",
},
)
"""API服务商的BaseURL"""
api_key: str = Field(default_factory=str, repr=False)
api_key: str = Field(
default_factory=str,
repr=False,
json_schema_extra={
"x-widget": "input",
"x-icon": "key",
},
)
"""API密钥"""
client_type: str = Field(default="openai")
client_type: str = Field(
default="openai",
json_schema_extra={
"x-widget": "select",
"x-icon": "settings",
},
)
"""客户端类型 (可选: openai/google, 默认为openai)"""
max_retry: int = Field(default=2)
max_retry: int = Field(
default=2,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "repeat",
},
)
"""最大重试次数 (单个模型API调用失败, 最多重试的次数)"""
timeout: int = 10
timeout: int = Field(
default=10,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "clock",
"step": 1,
},
)
"""API调用的超时时长 (超过这个时长, 本次请求将被视为"请求超时", 单位: 秒)"""
retry_interval: int = 10
retry_interval: int = Field(
default=10,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "timer",
"step": 1,
},
)
"""重试间隔 (如果API调用失败, 重试的间隔时间, 单位: 秒)"""
def model_post_init(self, context: Any = None):
@@ -39,34 +87,93 @@ class APIProvider(ConfigBase):
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
_validate_any: bool = False
suppress_any_warning: bool = True
model_identifier: str = ""
model_identifier: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "package",
},
)
"""模型标识符 (API服务商提供的模型标识符)"""
name: str = ""
name: str = Field(
default="",
json_schema_extra={
"x-widget": "input",
"x-icon": "tag",
},
)
"""模型名称 (可随意命名, 在models中需使用这个命名)"""
api_provider: str = ""
api_provider: str = Field(
default="",
json_schema_extra={
"x-widget": "select",
"x-icon": "link",
},
)
"""API服务商名称 (对应在api_providers中配置的服务商名称)"""
price_in: float = Field(default=0.0)
price_in: float = Field(
default=0.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "dollar-sign",
"step": 0.001,
},
)
"""输入价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
price_out: float = Field(default=0.0)
price_out: float = Field(
default=0.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "dollar-sign",
"step": 0.001,
},
)
"""输出价格 (用于API调用统计, 单位:元/ M token) (可选, 若无该字段, 默认值为0)"""
temperature: float | None = Field(default=None)
temperature: float | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "thermometer",
},
)
"""模型级别温度(可选),会覆盖任务配置中的温度"""
max_tokens: int | None = Field(default=None)
max_tokens: int | None = Field(
default=None,
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
},
)
"""模型级别最大token数可选会覆盖任务配置中的max_tokens"""
force_stream_mode: bool = Field(default=False)
force_stream_mode: bool = Field(
default=False,
json_schema_extra={
"x-widget": "switch",
"x-icon": "zap",
},
)
"""强制流式输出模式 (若模型不支持非流式输出, 请设置为true启用强制流式输出, 默认值为false)"""
extra_params: dict[str, Any] = Field(default_factory=dict)
extra_params: dict[str, Any] = Field(
default_factory=dict,
json_schema_extra={
"x-widget": "custom",
"x-icon": "sliders",
},
)
"""额外参数 (用于API调用时的额外配置)"""
def model_post_init(self, context: Any = None):
@@ -82,48 +189,139 @@ class ModelInfo(ConfigBase):
class TaskConfig(ConfigBase):
"""任务配置类"""
model_list: list[str] = Field(default_factory=list)
model_list: list[str] = Field(
default_factory=list,
json_schema_extra={
"x-widget": "custom",
"x-icon": "list",
},
)
"""使用的模型列表, 每个元素对应上面的模型名称(name)"""
max_tokens: int = 1024
max_tokens: int = Field(
default=1024,
ge=1,
json_schema_extra={
"x-widget": "input",
"x-icon": "layers",
"step": 1,
},
)
"""任务最大输出token数"""
temperature: float = 0.3
temperature: float = Field(
default=0.3,
ge=0,
le=2,
json_schema_extra={
"x-widget": "slider",
"x-icon": "thermometer",
"step": 0.1,
},
)
"""模型温度"""
slow_threshold: float = 15.0
slow_threshold: float = Field(
default=15.0,
ge=0,
json_schema_extra={
"x-widget": "input",
"x-icon": "alert-circle",
"step": 0.1,
},
)
"""慢请求阈值(秒),超过此值会输出警告日志"""
selection_strategy: str = Field(default="balance")
selection_strategy: str = Field(
default="balance",
json_schema_extra={
"x-widget": "select",
"x-icon": "shuffle",
},
)
"""模型选择策略balance负载均衡或 random随机选择"""
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig = Field(default_factory=TaskConfig)
utils: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "wrench",
},
)
"""组件使用的模型, 例如表情包模块, 取名模块, 关系模块, 麦麦的情绪变化等,是麦麦必须的模型"""
replyer: TaskConfig = Field(default_factory=TaskConfig)
replyer: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "message-square",
},
)
"""首要回复模型配置, 还用于表达器和表达方式学习"""
vlm: TaskConfig = Field(default_factory=TaskConfig)
vlm: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "image",
},
)
"""视觉模型配置"""
voice: TaskConfig = Field(default_factory=TaskConfig)
voice: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "volume-2",
},
)
"""语音识别模型配置"""
tool_use: TaskConfig = Field(default_factory=TaskConfig)
tool_use: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "tools",
},
)
"""工具使用模型配置, 需要使用支持工具调用的模型"""
planner: TaskConfig = Field(default_factory=TaskConfig)
planner: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "map",
},
)
"""规划模型配置"""
embedding: TaskConfig = Field(default_factory=TaskConfig)
embedding: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "database",
},
)
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig = Field(default_factory=TaskConfig)
lpmm_entity_extract: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "filter",
},
)
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig = Field(default_factory=TaskConfig)
lpmm_rdf_build: TaskConfig = Field(
default_factory=TaskConfig,
json_schema_extra={
"x-widget": "custom",
"x-icon": "network",
},
)
"""LPMM RDF构建模型配置"""

File diff suppressed because it is too large Load Diff

301
src/webui/api/planner.py Normal file
View File

@@ -0,0 +1,301 @@
"""
规划器监控API
提供规划器日志数据的查询接口
性能优化:
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/planner", tags=["planner"])
# 规划器日志目录
PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
plan_count: int
latest_timestamp: float
latest_filename: str
class PlanLogSummary(BaseModel):
"""规划日志摘要"""
chat_id: str
timestamp: float
filename: str
action_count: int
action_types: List[str] # 动作类型列表
total_plan_ms: float
llm_duration_ms: float
reasoning_preview: str
class PlanLogDetail(BaseModel):
"""规划日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
reasoning: str
raw_output: str
actions: List[Dict]
timing: Dict
extra: Optional[Dict] = None
class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计"""
total_chats: int
total_plans: int
chats: List[ChatSummary]
class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表"""
data: List[PlanLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=PlannerOverview)
async def get_planner_overview():
"""
获取规划器总览 - 轻量级接口
只统计文件数量,不读取文件内容
"""
if not PLAN_LOG_DIR.exists():
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
chats = []
total_plans = 0
for chat_dir in PLAN_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
plan_count = len(json_files)
total_plans += plan_count
if plan_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview(
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
async def get_chat_plan_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的规划日志列表(分页)
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedChatLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
actions = data.get('actions', [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
logs.append(PlanLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
reasoning_preview=reasoning[:100] if reasoning else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview='[读取失败]'
))
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
async def get_log_detail(chat_id: str, filename: str):
"""获取规划日志详情 - 按需加载完整内容"""
log_file = PLAN_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return PlanLogDetail(**data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容旧接口 ==========
@router.get("/stats")
async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口"""
overview = await get_planner_overview()
# 获取最近10条计划的摘要
recent_plans = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
recent_plans.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
recent_plans = recent_plans[:10]
return {
"total_chats": overview.total_chats,
"total_plans": overview.total_plans,
"avg_plan_time_ms": 0,
"avg_llm_time_ms": 0,
"recent_plans": recent_plans
}
@router.get("/chats")
async def get_chat_list():
"""获取所有聊天ID列表 - 兼容旧接口"""
overview = await get_planner_overview()
return [chat.chat_id for chat in overview.chats]
@router.get("/all-logs")
async def get_all_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
"""获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size}
# 收集所有文件
all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file))
# 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
total = len(all_files)
offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size]
logs = []
for chat_id, log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
logs.append({
"chat_id": data.get('chat_id', chat_id),
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get('actions', [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
"reasoning_preview": reasoning[:100] if reasoning else ''
})
except Exception:
continue
return {"data": logs, "total": total, "page": page, "page_size": page_size}

269
src/webui/api/replier.py Normal file
View File

@@ -0,0 +1,269 @@
"""
回复器监控API
提供回复器日志数据的查询接口
性能优化:
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/replier", tags=["replier"])
# 回复器日志目录
REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
reply_count: int
latest_timestamp: float
latest_filename: str
class ReplyLogSummary(BaseModel):
"""回复日志摘要"""
chat_id: str
timestamp: float
filename: str
model: str
success: bool
llm_ms: float
overall_ms: float
output_preview: str
class ReplyLogDetail(BaseModel):
"""回复日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
output: str
processed_output: List[str]
model: str
reasoning: str
think_level: int
timing: Dict
error: Optional[str] = None
success: bool
class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计"""
total_chats: int
total_replies: int
chats: List[ReplierChatSummary]
class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表"""
data: List[ReplyLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=ReplierOverview)
async def get_replier_overview():
"""
获取回复器总览 - 轻量级接口
只统计文件数量,不读取文件内容
"""
if not REPLY_LOG_DIR.exists():
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
chats = []
total_replies = 0
for chat_dir in REPLY_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
reply_count = len(json_files)
total_replies += reply_count
if reply_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview(
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
async def get_chat_reply_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的回复日志列表(分页)
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedReplyLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
output = data.get('output', '')
logs.append(ReplyLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get('model', ''),
success=data.get('success', True),
llm_ms=data.get('timing', {}).get('llm_ms', 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0),
output_preview=output[:100] if output else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model='',
success=False,
llm_ms=0,
overall_ms=0,
output_preview='[读取失败]'
))
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
async def get_reply_log_detail(chat_id: str, filename: str):
"""获取回复日志详情 - 按需加载完整内容"""
log_file = REPLY_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return ReplyLogDetail(
type=data.get('type', 'reply'),
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', 0),
prompt=data.get('prompt', ''),
output=data.get('output', ''),
processed_output=data.get('processed_output', []),
model=data.get('model', ''),
reasoning=data.get('reasoning', ''),
think_level=data.get('think_level', 0),
timing=data.get('timing', {}),
error=data.get('error'),
success=data.get('success', True)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容接口 ==========
@router.get("/stats")
async def get_replier_stats():
"""获取回复器统计信息"""
overview = await get_replier_overview()
# 获取最近10条回复的摘要
recent_replies = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
recent_replies.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
recent_replies = recent_replies[:10]
return {
"total_chats": overview.total_chats,
"total_replies": overview.total_replies,
"recent_replies": recent_replies
}
@router.get("/chats")
async def get_replier_chat_list():
"""获取所有聊天ID列表"""
overview = await get_replier_overview()
return [chat.chat_id for chat in overview.chats]

133
src/webui/config_schema.py Normal file
View File

@@ -0,0 +1,133 @@
import inspect
from typing import Any, get_args, get_origin
from pydantic_core import PydanticUndefined
from src.config.config_base import ConfigBase
class ConfigSchemaGenerator:
@classmethod
def generate_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
return cls.generate_config_schema(config_class, include_nested=include_nested)
@classmethod
def generate_config_schema(cls, config_class: type[ConfigBase], include_nested: bool = True) -> dict[str, Any]:
fields: list[dict[str, Any]] = []
nested: dict[str, dict[str, Any]] = {}
for field_name, field_info in config_class.model_fields.items():
if field_name in {"field_docs", "_validate_any", "suppress_any_warning"}:
continue
field_schema = cls._build_field_schema(config_class, field_name, field_info.annotation, field_info)
fields.append(field_schema)
if include_nested:
nested_schema = cls._build_nested_schema(field_info.annotation)
if nested_schema is not None:
nested[field_name] = nested_schema
return {
"className": config_class.__name__,
"classDoc": (config_class.__doc__ or "").strip(),
"fields": fields,
"nested": nested,
}
@classmethod
def _build_nested_schema(cls, annotation: Any) -> dict[str, Any] | None:
origin = get_origin(annotation)
args = get_args(annotation)
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return cls.generate_config_schema(annotation)
if origin in {list, tuple} and args:
first = args[0]
if inspect.isclass(first) and issubclass(first, ConfigBase):
return cls.generate_config_schema(first)
return None
@classmethod
def _build_field_schema(
cls, config_class: type[ConfigBase], field_name: str, annotation: Any, field_info: Any
) -> dict[str, Any]:
field_docs = config_class.get_class_field_docs()
field_type = cls._map_field_type(annotation)
schema: dict[str, Any] = {
"name": field_name,
"type": field_type,
"label": field_name,
"description": field_docs.get(field_name, field_info.description or ""),
"required": field_info.is_required(),
}
if field_info.default is not PydanticUndefined:
schema["default"] = field_info.default
origin = get_origin(annotation)
args = get_args(annotation)
if origin is list and args:
schema["items"] = {"type": cls._map_field_type(args[0])}
options = cls._extract_options(annotation)
if options:
schema["options"] = options
# Task 1c: Merge json_schema_extra (x-widget, x-icon, step, etc.)
if hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra:
schema.update(field_info.json_schema_extra)
# Task 1d: Map Pydantic constraints to minValue/maxValue (frontend naming convention)
if hasattr(field_info, "metadata") and field_info.metadata:
for constraint in field_info.metadata:
if hasattr(constraint, "ge"):
schema["minValue"] = constraint.ge
if hasattr(constraint, "le"):
schema["maxValue"] = constraint.le
return schema
@staticmethod
def _extract_options(annotation: Any) -> list[str] | None:
origin = get_origin(annotation)
if origin is None:
return None
if str(origin) != "typing.Literal":
return None
args = get_args(annotation)
options = [str(item) for item in args]
return options or None
@classmethod
def _map_field_type(cls, annotation: Any) -> str:
origin = get_origin(annotation)
args = get_args(annotation)
if origin in {list, tuple}:
return "array"
if inspect.isclass(annotation) and issubclass(annotation, ConfigBase):
return "object"
if annotation is bool:
return "boolean"
if annotation is int:
return "integer"
if annotation is float:
return "number"
if annotation is str:
return "string"
if origin in {list, tuple} and args:
return "array"
if origin in {dict}:
return "object"
if origin is not None and str(origin) == "typing.Literal":
return "select"
return "string"

View File

@@ -10,7 +10,7 @@ from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.config import Config, ModelConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
@@ -77,7 +77,7 @@ async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
schema = ConfigSchemaGenerator.generate_config_schema(ModelConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
@@ -227,7 +227,7 @@ async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(req
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
ModelConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
@@ -377,7 +377,7 @@ async def update_model_config_section(
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
ModelConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider

View File

@@ -1,21 +1,27 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Annotated, Any, List, Optional
import asyncio
import hashlib
import io
import os
import threading
from fastapi import APIRouter, Cookie, File, Form, Header, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from typing import Optional, List, Annotated
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
import time
import os
import hashlib
from PIL import Image
import io
from pathlib import Path
import threading
import asyncio
from concurrent.futures import ThreadPoolExecutor
from sqlalchemy import func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.common.logger import get_logger
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
logger = get_logger("webui.emoji")
@@ -61,7 +67,7 @@ def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
def _ensure_thumbnail_cache_dir() -> Path:
"""确保缩略图缓存目录存在"""
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
_ = THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
return THUMBNAIL_CACHE_DIR
@@ -99,7 +105,7 @@ def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
try:
with Image.open(source_path) as img:
# GIF 处理:提取第一帧
if hasattr(img, "n_frames") and img.n_frames > 1:
if getattr(img, "n_frames", 1) > 1:
img.seek(0) # 确保在第一帧
# 转换为 RGB/RGBAWebP 支持透明度)
@@ -138,9 +144,9 @@ def cleanup_orphaned_thumbnails() -> tuple[int, int]:
return 0, 0
# 获取所有表情包的哈希值
valid_hashes = set()
for emoji in Emoji.select(Emoji.emoji_hash):
valid_hashes.add(emoji.emoji_hash)
with get_db_session() as session:
statement = select(Images.image_hash).where(col(Images.image_type) == ImageType.EMOJI)
valid_hashes = set(session.exec(statement).all())
cleaned = 0
kept = 0
@@ -179,7 +185,6 @@ class EmojiResponse(BaseModel):
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
@@ -188,7 +193,6 @@ class EmojiResponse(BaseModel):
emotion: Optional[str] # 直接返回字符串
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]
@@ -257,22 +261,19 @@ def verify_auth_token(
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
"""将 Emoji 模型转换为响应对象"""
def emoji_to_response(image: Images) -> EmojiResponse:
return EmojiResponse(
id=emoji.id,
full_path=emoji.full_path,
format=emoji.format,
emoji_hash=emoji.emoji_hash,
description=emoji.description,
query_count=emoji.query_count,
is_registered=emoji.is_registered,
is_banned=emoji.is_banned,
emotion=str(emoji.emotion) if emoji.emotion is not None else None,
record_time=emoji.record_time,
register_time=emoji.register_time,
usage_count=emoji.usage_count,
last_used_time=emoji.last_used_time,
id=image.id if image.id is not None else 0,
full_path=image.full_path,
emoji_hash=image.image_hash,
description=image.description,
query_count=image.query_count,
is_registered=image.is_registered,
is_banned=image.is_banned,
emotion=image.emotion,
record_time=image.record_time.timestamp() if image.record_time else 0.0,
register_time=image.register_time.timestamp() if image.register_time else None,
last_used_time=image.last_used_time.timestamp() if image.last_used_time else None,
)
@@ -283,8 +284,7 @@ async def get_emoji_list(
search: Optional[str] = Query(None, description="搜索关键词"),
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
sort_by: Optional[str] = Query("query_count", description="排序字段"),
sort_order: Optional[str] = Query("desc", description="排序方向"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
@@ -298,8 +298,7 @@ async def get_emoji_list(
search: 搜索关键词 (匹配 description, emoji_hash)
is_registered: 是否已注册筛选
is_banned: 是否被禁用筛选
format: 格式筛选
sort_by: 排序字段 (usage_count, register_time, record_time, last_used_time)
sort_by: 排序字段 (query_count, register_time, record_time, last_used_time)
sort_order: 排序方向 (asc, desc)
authorization: Authorization header
@@ -310,47 +309,58 @@ async def get_emoji_list(
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Emoji.select()
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
# 搜索过滤
if search:
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
statement = statement.where(
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
)
# 注册状态过滤
if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered)
statement = statement.where(col(Images.is_registered) == is_registered)
# 禁用状态过滤
if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
statement = statement.where(col(Images.is_banned) == is_banned)
# 排序字段映射
sort_field_map = {
"usage_count": Emoji.usage_count,
"register_time": Emoji.register_time,
"record_time": Emoji.record_time,
"last_used_time": Emoji.last_used_time,
"usage_count": col(Images.query_count),
"query_count": col(Images.query_count),
"register_time": col(Images.register_time),
"record_time": col(Images.record_time),
"last_used_time": col(Images.last_used_time),
}
# 获取排序字段,默认使用 usage_count
sort_field = sort_field_map.get(sort_by, Emoji.usage_count)
sort_key = sort_by or "query_count"
sort_field = sort_field_map.get(sort_key, col(Images.query_count))
# 应用排序
if sort_order == "asc":
query = query.order_by(sort_field.asc())
statement = statement.order_by(sort_field.asc())
else:
query = query.order_by(sort_field.desc())
# 获取总数
total = query.count()
statement = statement.order_by(sort_field.desc())
# 分页
offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
emojis = session.exec(statement).all()
count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
if search:
count_statement = count_statement.where(
(col(Images.description).contains(search)) | (col(Images.image_hash).contains(search))
)
if is_registered is not None:
count_statement = count_statement.where(col(Images.is_registered) == is_registered)
if is_banned is not None:
count_statement = count_statement.where(col(Images.is_banned) == is_banned)
total = session.exec(count_statement).one()
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
@@ -381,12 +391,17 @@ async def get_emoji_detail(
try:
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException:
raise
@@ -416,34 +431,37 @@ async def update_emoji(
try:
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# emotion 字段直接使用字符串,无需转换
# 如果注册状态从 False 变为 True记录注册时间
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = datetime.now()
# 如果注册状态从 False 变为 True记录注册时间
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
session.add(emoji)
emoji.save()
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException:
raise
@@ -469,20 +487,22 @@ async def delete_emoji(
try:
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息
emoji_hash = emoji.emoji_hash
emoji_hash = emoji.image_hash
session.delete(emoji)
# 执行删除
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException:
raise
@@ -505,27 +525,51 @@ async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authoriz
try:
verify_auth_token(maibot_session, authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
banned = Emoji.select().where(Emoji.is_banned).count()
with get_db_session() as session:
total_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
registered_statement = (
select(func.count())
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_registered) == True,
)
)
banned_statement = (
select(func.count())
.select_from(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == True,
)
)
# 按格式统计
formats = {}
for emoji in Emoji.select(Emoji.format):
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
total = session.exec(total_statement).one()
registered = session.exec(registered_statement).one()
banned = session.exec(banned_statement).one()
# 获取最常用的表情包前10
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
top_used_list = [
{
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count,
}
for emoji in top_used
]
formats: dict[str, int] = {}
format_statement = select(Images.full_path).where(col(Images.image_type) == ImageType.EMOJI)
for full_path in session.exec(format_statement).all():
suffix = Path(full_path).suffix.lower().lstrip(".")
fmt = suffix or "unknown"
formats[fmt] = formats.get(fmt, 0) + 1
top_used_statement = (
select(Images)
.where(col(Images.image_type) == ImageType.EMOJI)
.order_by(col(Images.query_count).desc())
.limit(10)
)
top_used_list = [
{
"id": emoji.id,
"emoji_hash": emoji.image_hash,
"description": emoji.description,
"usage_count": emoji.query_count,
}
for emoji in session.exec(top_used_statement).all()
]
return {
"success": True,
@@ -563,23 +607,27 @@ async def register_emoji(
try:
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
# 注册表情包(如果已封禁,自动解除封禁)
emoji.is_registered = True
emoji.is_banned = False # 注册时自动解除封禁
emoji.register_time = time.time()
emoji.save()
emoji.is_registered = True
emoji.is_banned = False
emoji.register_time = datetime.now()
session.add(emoji)
logger.info(f"表情包已注册: ID={emoji_id}")
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException:
raise
@@ -605,19 +653,23 @@ async def ban_emoji(
try:
verify_auth_token(maibot_session, authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册)
emoji.is_banned = True
emoji.is_registered = False
emoji.save()
emoji.is_banned = True
emoji.is_registered = False
session.add(emoji)
logger.info(f"表情包已禁用: ID={emoji_id}")
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException:
raise
@@ -672,61 +724,58 @@ async def get_emoji_thumbnail(
if not is_valid:
raise HTTPException(status_code=401, detail="Token 无效或已过期")
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
# 如果请求原图,直接返回原文件
if original:
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
# 尝试获取或生成缩略图
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查缓存是否存在
if cache_path.exists():
# 缓存命中,直接返回
return FileResponse(
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
if original:
mime_types = {
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
suffix = Path(emoji.full_path).suffix.lower().lstrip(".")
media_type = mime_types.get(suffix, "application/octet-stream")
return FileResponse(
path=emoji.full_path, media_type=media_type, filename=f"{emoji.image_hash}.{suffix}"
)
cache_path = _get_thumbnail_cache_path(emoji.image_hash)
if cache_path.exists():
return FileResponse(
path=str(cache_path), media_type="image/webp", filename=f"{emoji.image_hash}_thumb.webp"
)
with _generating_lock:
if emoji.image_hash not in _generating_thumbnails:
_generating_thumbnails.add(emoji.image_hash)
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.image_hash)
return JSONResponse(
status_code=202,
content={
"status": "generating",
"message": "缩略图正在生成中,请稍后重试",
"emoji_id": emoji_id,
},
headers={
"Retry-After": "1",
},
)
# 缓存未命中,触发后台生成并返回 202
with _generating_lock:
if emoji.emoji_hash not in _generating_thumbnails:
# 标记为正在生成
_generating_thumbnails.add(emoji.emoji_hash)
# 提交到线程池后台生成
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
# 返回 202 Accepted告诉前端缩略图正在生成中
return JSONResponse(
status_code=202,
content={
"status": "generating",
"message": "缩略图正在生成中,请稍后重试",
"emoji_id": emoji_id,
},
headers={
"Retry-After": "1", # 建议 1 秒后重试
},
)
except HTTPException:
raise
except Exception as e:
@@ -762,14 +811,19 @@ async def batch_delete_emojis(
for emoji_id in request.emoji_ids:
try:
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if emoji:
emoji.delete_instance()
deleted_count += 1
logger.info(f"批量删除表情包: {emoji_id}")
else:
failed_count += 1
failed_ids.append(emoji_id)
with get_db_session() as session:
statement = select(Images).where(
col(Images.id) == emoji_id,
col(Images.image_type) == ImageType.EMOJI,
)
emoji = session.exec(statement).first()
if emoji:
session.delete(emoji)
deleted_count += 1
logger.info(f"批量删除表情包: {emoji_id}")
else:
failed_count += 1
failed_ids.append(emoji_id)
except Exception as e:
logger.error(f"删除表情包 {emoji_id} 失败: {e}")
failed_count += 1
@@ -864,19 +918,23 @@ async def upload_emoji(
# 计算文件哈希
emoji_hash = hashlib.md5(file_content).hexdigest()
# 检查是否已存在相同哈希的表情包
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if existing_emoji:
raise HTTPException(
status_code=409,
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
with get_db_session() as session:
existing_statement = select(Images).where(
col(Images.image_hash) == emoji_hash,
col(Images.image_type) == ImageType.EMOJI,
)
existing_emoji = session.exec(existing_statement).first()
if existing_emoji:
raise HTTPException(
status_code=409,
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
)
# 确保目录存在
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
# 生成文件名
timestamp = int(time.time())
timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
@@ -889,37 +947,38 @@ async def upload_emoji(
# 保存文件
with open(full_path, "wb") as f:
f.write(file_content)
_ = f.write(file_content)
logger.info(f"表情包文件已保存: {full_path}")
# 处理情感标签
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
# 创建数据库记录
current_time = time.time()
emoji = Emoji.create(
full_path=full_path,
format=img_format,
emoji_hash=emoji_hash,
description=description,
emotion=emotion_str,
query_count=0,
is_registered=is_registered,
is_banned=False,
record_time=current_time,
register_time=current_time if is_registered else None,
usage_count=0,
last_used_time=None,
)
current_time = datetime.now()
with get_db_session() as session:
emoji = Images(
image_type=ImageType.EMOJI,
full_path=full_path,
image_hash=emoji_hash,
description=description,
emotion=emotion_str or None,
query_count=0,
is_registered=is_registered,
is_banned=False,
record_time=current_time,
register_time=current_time if is_registered else None,
last_used_time=None,
)
session.add(emoji)
session.flush()
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
return EmojiUploadResponse(
success=True,
message="表情包上传成功" + ("并已注册" if is_registered else ""),
data=emoji_to_response(emoji),
)
return EmojiUploadResponse(
success=True,
message="表情包上传成功" + ("并已注册" if is_registered else ""),
data=emoji_to_response(emoji),
)
except HTTPException:
raise
@@ -951,7 +1010,7 @@ async def batch_upload_emoji(
try:
verify_auth_token(maibot_session, authorization)
results = {
results: dict[str, Any] = {
"success": True,
"total": len(files),
"uploaded": 0,
@@ -1008,20 +1067,24 @@ async def batch_upload_emoji(
# 计算哈希
emoji_hash = hashlib.md5(file_content).hexdigest()
# 检查重复
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
results["failed"] += 1
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
}
with get_db_session() as session:
existing_statement = select(Images).where(
col(Images.image_hash) == emoji_hash,
col(Images.image_type) == ImageType.EMOJI,
)
continue
if session.exec(existing_statement).first():
results["failed"] += 1
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
}
)
continue
# 生成文件名并保存
timestamp = int(time.time())
timestamp = int(datetime.now().timestamp())
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
@@ -1032,36 +1095,37 @@ async def batch_upload_emoji(
counter += 1
with open(full_path, "wb") as f:
f.write(file_content)
_ = f.write(file_content)
# 处理情感标签
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
# 创建数据库记录
current_time = time.time()
emoji = Emoji.create(
full_path=full_path,
format=img_format,
emoji_hash=emoji_hash,
description="", # 批量上传暂不设置描述
emotion=emotion_str,
query_count=0,
is_registered=is_registered,
is_banned=False,
record_time=current_time,
register_time=current_time if is_registered else None,
usage_count=0,
last_used_time=None,
)
current_time = datetime.now()
with get_db_session() as session:
emoji = Images(
image_type=ImageType.EMOJI,
full_path=full_path,
image_hash=emoji_hash,
description="",
emotion=emotion_str or None,
query_count=0,
is_registered=is_registered,
is_banned=False,
record_time=current_time,
register_time=current_time if is_registered else None,
last_used_time=None,
)
session.add(emoji)
session.flush()
results["uploaded"] += 1
results["details"].append(
{
"filename": file.filename,
"success": True,
"id": emoji.id,
}
)
results["uploaded"] += 1
results["details"].append(
{
"filename": file.filename,
"success": True,
"id": emoji.id,
}
)
except Exception as e:
results["failed"] += 1
@@ -1138,8 +1202,9 @@ async def get_thumbnail_cache_stats(
total_size = sum(f.stat().st_size for f in cache_files)
total_size_mb = round(total_size / (1024 * 1024), 2)
# 统计表情包总数
emoji_count = Emoji.select().count()
with get_db_session() as session:
count_statement = select(func.count()).select_from(Images).where(col(Images.image_type) == ImageType.EMOJI)
emoji_count = session.exec(count_statement).one()
# 计算覆盖率
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
@@ -1213,12 +1278,17 @@ async def preheat_thumbnail_cache(
_ensure_thumbnail_cache_dir()
# 获取使用次数最高的表情包(未缓存的优先)
emojis = (
Emoji.select()
.where(Emoji.is_banned == False) # noqa: E712 Peewee ORM requires == for boolean comparison
.order_by(Emoji.usage_count.desc())
.limit(limit * 2) # 多查一些,因为有些可能已缓存
)
with get_db_session() as session:
statement = (
select(Images)
.where(
col(Images.image_type) == ImageType.EMOJI,
col(Images.is_banned) == False,
)
.order_by(col(Images.query_count).desc())
.limit(limit * 2)
)
emojis = session.exec(statement).all()
generated = 0
skipped = 0
@@ -1228,25 +1298,22 @@ async def preheat_thumbnail_cache(
if generated >= limit:
break
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
cache_path = _get_thumbnail_cache_path(emoji.image_hash)
# 已缓存,跳过
if cache_path.exists():
skipped += 1
continue
# 原文件不存在,跳过
if not os.path.exists(emoji.full_path):
failed += 1
continue
try:
# 使用线程池异步生成缩略图,避免阻塞事件循环
loop = asyncio.get_event_loop()
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.image_hash)
generated += 1
except Exception as e:
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
logger.warning(f"预热缩略图失败 {emoji.image_hash}: {e}")
failed += 1
return ThumbnailPreheatResponse(

View File

@@ -65,9 +65,6 @@ class ExpressionUpdateRequest(BaseModel):
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
checked: Optional[bool] = None
rejected: Optional[bool] = None
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
class ExpressionUpdateResponse(BaseModel):
@@ -388,26 +385,16 @@ async def update_expression(
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and getattr(expression, "checked", False):
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段
update_data.pop("require_unchecked", None)
# 映射 API 字段名到数据库字段
if "chat_id" in update_data:
update_data["session_id"] = update_data.pop("chat_id")
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if "checked" in update_data or "rejected" in update_data:
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = datetime.now()

View File

@@ -1,13 +1,16 @@
"""黑话(俚语)管理路由"""
import json
from typing import Optional, List, Annotated
from typing import Annotated, Any, List, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from sqlalchemy import func as fn
from sqlmodel import Session, col, delete, select
import json
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession, Jargon
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon")
@@ -43,27 +46,26 @@ def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str:
def get_display_name_for_chat_id(chat_id_str: str, session: Session) -> str:
"""
获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
尝试解析 JSON 并查询 ChatSession 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
return chat_id_str[:20]
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
stream_id = stream_ids[0]
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
return ", ".join(names) if names else chat_id_str
if not chat_session:
return stream_id[:20]
if chat_session.group_id:
return str(chat_session.group_id)
return chat_session.session_id[:20]
# ==================== 请求/响应模型 ====================
@@ -79,7 +81,6 @@ class JargonResponse(BaseModel):
chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0
is_jargon: Optional[bool] = None
is_complete: bool = False
@@ -94,7 +95,7 @@ class JargonListResponse(BaseModel):
total: int
page: int
page_size: int
data: List[JargonResponse]
data: List[dict[str, Any]]
class JargonDetailResponse(BaseModel):
@@ -111,7 +112,6 @@ class JargonCreateRequest(BaseModel):
raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel):
@@ -121,7 +121,6 @@ class JargonUpdateRequest(BaseModel):
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None
@@ -159,7 +158,7 @@ class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict
data: dict[str, Any]
class ChatInfoResponse(BaseModel):
@@ -181,27 +180,24 @@ class ChatListResponse(BaseModel):
# ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict:
def jargon_to_dict(jargon: Jargon, session: Session) -> dict[str, Any]:
"""将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
chat_id = jargon.session_id or ""
chat_name = get_display_name_for_chat_id(chat_id, session) if chat_id else None
return {
"id": jargon.id,
"content": jargon.content,
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": jargon.chat_id,
"stream_id": stream_id,
"chat_id": chat_id,
"stream_id": jargon.session_id,
"chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only,
"inference_content_only": jargon.inference_with_content_only,
}
@@ -215,49 +211,41 @@ async def get_jargon_list(
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
):
"""获取黑话列表"""
try:
# 构建查询
query = Jargon.select()
statement = select(Jargon)
count_statement = select(fn.count()).select_from(Jargon)
# 搜索过滤
if search:
query = query.where(
(Jargon.content.contains(search))
| (Jargon.meaning.contains(search))
| (Jargon.raw_content.contains(search))
search_filter = (
(col(Jargon.content).contains(search))
| (col(Jargon.meaning).contains(search))
| (col(Jargon.raw_content).contains(search))
)
statement = statement.where(search_filter)
count_statement = count_statement.where(search_filter)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
# 使用第一个 stream_id 进行模糊匹配
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
chat_filter = col(Jargon.session_id).contains(stream_ids[0])
else:
# 如果无法解析,使用精确匹配
query = query.where(Jargon.chat_id == chat_id)
chat_filter = col(Jargon.session_id) == chat_id
statement = statement.where(chat_filter)
count_statement = count_statement.where(chat_filter)
# 按是否是黑话筛选
if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon)
statement = statement.where(col(Jargon.is_jargon) == is_jargon)
count_statement = count_statement.where(col(Jargon.is_jargon) == is_jargon)
# 按是否全局筛选
if is_global is not None:
query = query.where(Jargon.is_global == is_global)
statement = statement.order_by(col(Jargon.count).desc(), col(Jargon.id).desc())
statement = statement.offset((page - 1) * page_size).limit(page_size)
# 获取总数
total = query.count()
# 分页和排序(按使用次数降序)
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
with get_db_session() as session:
total = session.exec(count_statement).one()
jargons = session.exec(statement).all()
data = [jargon_to_dict(jargon, session) for jargon in jargons]
return JargonListResponse(
success=True,
@@ -276,10 +264,9 @@ async def get_jargon_list(
async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
with get_db_session() as session:
statement = select(Jargon.session_id).distinct().where(col(Jargon.session_id).is_not(None))
chat_id_list = [chat_id for chat_id in session.exec(statement).all() if chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
@@ -290,27 +277,28 @@ async def get_chat_list():
seen_stream_ids.add(stream_ids[0])
result = []
for stream_id in seen_stream_ids:
# 尝试从 ChatStreams 表获取聊天名称
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id方便筛选匹配
chat_name=chat_stream.group_name or stream_id,
platform=chat_stream.platform,
is_group=True,
with get_db_session() as session:
for stream_id in seen_stream_ids:
chat_session = session.exec(select(ChatSession).where(col(ChatSession.session_id) == stream_id)).first()
if chat_session:
chat_name = str(chat_session.group_id) if chat_session.group_id else stream_id[:20]
result.append(
ChatInfoResponse(
chat_id=stream_id,
chat_name=chat_name,
platform=chat_session.platform,
is_group=bool(chat_session.group_id),
)
)
)
else:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
platform=None,
is_group=False,
else:
result.append(
ChatInfoResponse(
chat_id=stream_id,
chat_name=stream_id[:20],
platform=None,
is_group=False,
)
)
)
return ChatListResponse(success=True, data=result)
@@ -323,35 +311,35 @@ async def get_chat_list():
async def get_jargon_stats():
"""获取黑话统计数据"""
try:
# 总数量
total = Jargon.select().count()
with get_db_session() as session:
total = session.exec(select(fn.count()).select_from(Jargon)).one()
# 已确认是黑话的数量
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
confirmed_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == True)
).one()
confirmed_not_jargon = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon) == False)
).one()
pending = session.exec(select(fn.count()).select_from(Jargon).where(col(Jargon.is_jargon).is_(None))).one()
# 已确认不是黑话的数量
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
complete_count = session.exec(
select(fn.count()).select_from(Jargon).where(col(Jargon.is_complete) == True)
).one()
# 未判定的数量
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
chat_count = session.exec(
select(fn.count()).select_from(
select(col(Jargon.session_id)).distinct().where(col(Jargon.session_id).is_not(None)).subquery()
)
).one()
# 全局黑话数量
global_count = Jargon.select().where(Jargon.is_global).count()
# 已完成推断的数量
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
top_chats = session.exec(
select(col(Jargon.session_id), fn.count().label("count"))
.where(col(Jargon.session_id).is_not(None))
.group_by(col(Jargon.session_id))
.order_by(fn.count().desc())
.limit(5)
).all()
top_chats_dict = {session_id: count for session_id, count in top_chats if session_id}
return JargonStatsResponse(
success=True,
@@ -360,7 +348,6 @@ async def get_jargon_stats():
"confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon,
"pending": pending,
"global_count": global_count,
"complete_count": complete_count,
"chat_count": chat_count,
"top_chats": top_chats_dict,
@@ -376,11 +363,13 @@ async def get_jargon_stats():
async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
return JargonDetailResponse(success=True, data=data)
except HTTPException:
raise
@@ -393,30 +382,31 @@ async def get_jargon_detail(jargon_id: int):
async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
with get_db_session() as session:
existing = session.exec(
select(Jargon).where(
(col(Jargon.content) == request.content) & (col(Jargon.session_id) == request.chat_id)
)
).first()
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话
jargon = Jargon.create(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning,
chat_id=request.chat_id,
is_global=request.is_global,
count=0,
is_jargon=None,
is_complete=False,
)
jargon = Jargon(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning or "",
session_id=request.chat_id,
count=0,
is_jargon=None,
is_complete=False,
)
session.add(jargon)
session.flush()
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonCreateResponse(
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
return JargonCreateResponse(success=True, message="创建成功", data=data)
except HTTPException:
raise
@@ -429,25 +419,27 @@ async def create_jargon(request: JargonCreateRequest):
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
jargon.save()
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if field == "is_global":
continue
if field == "chat_id":
jargon.session_id = value
continue
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
session.add(jargon)
logger.info(f"更新黑话成功: id={jargon_id}")
logger.info(f"更新黑话成功: id={jargon_id}")
data = JargonResponse(**jargon_to_dict(jargon, session))
return JargonUpdateResponse(
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
return JargonUpdateResponse(success=True, message="更新成功", data=data)
except HTTPException:
raise
@@ -460,20 +452,17 @@ async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
async def delete_jargon(jargon_id: int):
"""删除黑话"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
with get_db_session() as session:
jargon = session.exec(select(Jargon).where(col(Jargon.id) == jargon_id)).first()
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content
jargon.delete_instance()
content = jargon.content
session.delete(jargon)
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse(
success=True,
message="删除成功",
deleted_count=1,
)
return JargonDeleteResponse(success=True, message="删除成功", deleted_count=1)
except HTTPException:
raise
@@ -489,9 +478,11 @@ async def batch_delete_jargons(request: BatchDeleteRequest):
if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
with get_db_session() as session:
result = session.exec(delete(Jargon).where(col(Jargon.id).in_(request.ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse(
success=True,
@@ -516,14 +507,16 @@ async def batch_set_jargon_status(
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
with get_db_session() as session:
jargons = session.exec(select(Jargon).where(col(Jargon.id).in_(ids))).all()
for jargon in jargons:
jargon.is_jargon = is_jargon
session.add(jargon)
updated_count = len(jargons)
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse(
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
return JargonUpdateResponse(success=True, message=f"成功更新 {updated_count} 条黑话状态")
except HTTPException:
raise

View File

@@ -7,7 +7,7 @@ from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.services.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.core import get_token_manager
from src.webui.routers.websocket.plugin_progress import update_progress

View File

@@ -7,7 +7,6 @@ class EmojiResponse(BaseModel):
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
@@ -16,7 +15,6 @@ class EmojiResponse(BaseModel):
emotion: Optional[str]
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]

View File

@@ -0,0 +1,662 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
_update_progress = callback
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
GITHUB = "github" # GitHub 官方源(兜底)
CUSTOM = "custom" # 自定义镜像源
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
"id": "gh-proxy",
"name": "gh-proxy 镜像",
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None,
},
{
"id": "hk-gh-proxy",
"name": "gh-proxy 香港节点",
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None,
},
{
"id": "cdn-gh-proxy",
"name": "gh-proxy CDN 节点",
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
"name": "gh-proxy EdgeOne 节点",
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None,
},
{
"id": "meyzh-github",
"name": "Meyzh GitHub 镜像",
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None,
},
{
"id": "github",
"name": "GitHub 官方源(兜底)",
"raw_prefix": "https://raw.githubusercontent.com",
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
self._init_default_mirrors()
else:
self.mirrors = data["git_mirrors"]
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
else:
logger.info("配置文件不存在,创建默认配置")
self._init_default_mirrors()
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
name: str,
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
"raw_prefix": raw_prefix,
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
name: Optional[str] = None,
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置,如果不存在则返回 None
"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
for i, mirror in enumerate(self.mirrors):
if mirror.get("id") == mirror_id:
self.mirrors.pop(i)
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
return [m["id"] for m in enabled]
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间(秒)
config: 镜像源配置管理器(可选,默认创建新实例)
"""
self.max_retries = max_retries
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
- version: str - Git 版本号(如果已安装)
- path: str - Git 可执行文件路径(如果已安装)
- error: str - 错误信息(如果未安装或检测失败)
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
branch: 分支名称
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
- data: str - 文件内容(成功时)
- error: str - 错误信息(失败时)
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
if _update_progress:
try:
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
await _update_progress(
stage="loading",
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
stage="loading",
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
logger.debug(f"尝试 #{attempt + 1}: {url}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except httpx.TimeoutException as e:
last_error = f"请求超时: {e}"
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
target_path: 目标路径
branch: 分支名称(可选)
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度(浅克隆)
Returns:
Dict 包含:
- success: bool - 是否成功
- path: str - 克隆路径(成功时)
- error: str - 错误信息(失败时)
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone(clone_cmd=cmd):
return subprocess.run(
clone_cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
"success": True,
"path": str(target_path),
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例
_git_mirror_service: Optional[GitMirrorService] = None
def get_git_mirror_service() -> GitMirrorService:
"""获取 Git 镜像源服务实例(单例)"""
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service