Ruff Fix & format

This commit is contained in:
墨梓柒
2025-11-29 14:38:42 +08:00
parent d7932595e8
commit 3935ce817e
31 changed files with 678 additions and 684 deletions

View File

@@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_"
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
@@ -36,17 +37,17 @@ class ChatHistoryMessage(BaseModel):
class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200):
self.max_messages = max_messages
def _message_to_dict(self, msg: Messages) -> Dict[str, Any]:
"""将数据库消息转换为前端格式"""
# 判断是否是机器人消息
# WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息
user_id = msg.user_id or ""
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
@@ -56,7 +57,7 @@ class ChatHistoryManager:
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录"""
try:
@@ -67,25 +68,21 @@ class ChatHistoryManager:
.order_by(Messages.time.desc())
.limit(limit)
)
# 转换为列表并反转(使最旧的消息在前)
result = [self._message_to_dict(msg) for msg in messages]
result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self) -> int:
"""清空 WebUI 聊天历史记录"""
try:
deleted = (
Messages.delete()
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
.execute()
)
deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
return deleted
except Exception as e:
@@ -100,31 +97,31 @@ chat_history = ChatHistoryManager()
# 存储 WebSocket 连接
class ChatConnectionManager:
"""聊天连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
@@ -135,16 +132,12 @@ chat_manager = ChatConnectionManager()
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True
content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据"""
if message_id is None:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": WEBUI_CHAT_PLATFORM,
@@ -163,7 +156,7 @@ def create_message_data(
},
"additional_config": {
"at_bot": is_at_bot,
}
},
},
"message_segment": {
"type": "seglist",
@@ -175,8 +168,8 @@ def create_message_data(
{
"type": "mention_bot",
"data": "1.0",
}
]
},
],
},
"raw_message": content,
"processed_plain_text": content,
@@ -186,10 +179,10 @@ def create_message_data(
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
):
"""获取聊天历史记录
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
"""
history = chat_history.get_history(limit)
@@ -217,76 +210,87 @@ async def websocket_chat(
user_name: Optional[str] = Query(default="WebUI用户"),
):
"""WebSocket 聊天端点
Args:
user_id: 用户唯一标识(由前端生成并持久化)
user_name: 用户显示昵称(可修改)
"""
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的
if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
await chat_manager.connect(websocket, session_id, user_id)
try:
# 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
})
await chat_manager.send_message(
session_id,
{
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
},
)
# 发送历史记录
history = chat_history.get_history(50)
if history:
await chat_manager.send_message(session_id, {
"type": "history",
"messages": history,
})
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史)
await chat_manager.send_message(session_id, {
"type": "system",
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
"timestamp": time.time(),
})
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
"timestamp": time.time(),
},
)
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
content = data.get("content", "").strip()
if not content:
continue
# 用户可以更新昵称
current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4())
timestamp = time.time()
# 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast({
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": current_user_name,
"user_id": user_id,
"is_bot": False,
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": current_user_name,
"user_id": user_id,
"is_bot": False,
},
}
})
)
# 创建麦麦消息格式
message_data = create_message_data(
content=content,
@@ -295,46 +299,59 @@ async def websocket_chat(
message_id=message_id,
is_at_bot=True,
)
try:
# 显示正在输入状态
await chat_manager.broadcast({
"type": "typing",
"is_typing": True,
})
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": True,
}
)
# 调用麦麦的消息处理
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(session_id, {
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
})
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally:
await chat_manager.broadcast({
"type": "typing",
"is_typing": False,
})
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": False,
}
)
elif data.get("type") == "ping":
await chat_manager.send_message(session_id, {
"type": "pong",
"timestamp": time.time(),
})
await chat_manager.send_message(
session_id,
{
"type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname":
# 允许用户更新昵称
if new_name := data.get("user_name", "").strip():
current_user_name = new_name
await chat_manager.send_message(session_id, {
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
})
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
},
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e:
@@ -356,7 +373,7 @@ async def get_chat_info():
def get_webui_chat_broadcaster() -> tuple:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
"""

View File

@@ -5,7 +5,7 @@
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any
from typing import Any, Annotated
from src.common.logger import get_logger
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
@@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
@@ -90,7 +96,7 @@ async def get_bot_config_schema():
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model")
@@ -101,7 +107,7 @@ async def get_model_config_schema():
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 =====
@@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str):
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 =====
@@ -196,7 +202,7 @@ async def get_bot_config():
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model")
@@ -215,21 +221,21 @@ async def get_model_config():
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model")
async def update_model_config(config_data: dict[str, Any] = Body(...)):
async def update_model_config(config_data: ConfigBody):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
async def update_bot_config_section(section_name: str, section_data: SectionBody):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
@@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 =====
@@ -338,24 +344,24 @@ async def get_bot_config_raw():
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
async def update_bot_config_raw(raw_content: RawContentBody):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
async def update_model_config_section(section_name: str, section_data: SectionBody):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
@@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
@@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 =====
@@ -425,11 +431,11 @@ def _normalize_adapter_path(path: str) -> str:
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
if not path:
return path
# 如果已经是绝对路径,直接返回
if os.path.isabs(path):
return path
# 相对路径,转换为相对于项目根目录的绝对路径
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
@@ -438,17 +444,17 @@ def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
return path
try:
# 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith('..'):
if not rel_path.startswith(".."):
return rel_path
except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
pass
# 无法转换为相对路径,返回绝对路径
return path
@@ -463,6 +469,7 @@ async def get_adapter_config_path():
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
@@ -472,10 +479,11 @@ async def get_adapter_config_path():
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
import datetime
mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能)
@@ -487,11 +495,11 @@ async def get_adapter_config_path():
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
async def save_adapter_config_path(data: PathBody):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
@@ -511,10 +519,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
# 更新路径
webui_data["adapter_config_path"] = save_path
@@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config")
@@ -542,7 +550,7 @@ async def get_adapter_config(path: str):
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
@@ -562,11 +570,11 @@ async def get_adapter_config(path: str):
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config")
async def save_adapter_config(data: dict[str, str] = Body(...)):
async def save_adapter_config(data: PathBody):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
@@ -579,17 +587,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
import tomlkit
tomlkit.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在
dir_path = os.path.dirname(abs_path)
@@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

View File

@@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else:
# 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''"
@@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else:
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
@@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
return FieldType.ARRAY, None, items
# 处理基本类型
if field_type is bool or field_type == bool:
if field_type is bool:
return FieldType.BOOLEAN, None, None
elif field_type is int or field_type == int:
elif field_type is int:
return FieldType.INTEGER, None, None
elif field_type is float or field_type == float:
elif field_type is float:
return FieldType.NUMBER, None, None
elif field_type is str or field_type == str:
elif field_type is str:
return FieldType.STRING, None, None
elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None

View File

@@ -3,20 +3,25 @@
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional, List
from typing import Optional, List, Annotated
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from .token_manager import get_token_manager
import json
import time
import os
import hashlib
import base64
from PIL import Image
import io
logger = get_logger("webui.emoji")
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
DescriptionForm = Annotated[str, Form(description="表情包描述")]
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
# 创建路由器
router = APIRouter(prefix="/emoji", tags=["Emoji"])
@@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel):
@router.post("/upload", response_model=EmojiUploadResponse)
async def upload_emoji(
file: UploadFile = File(..., description="表情包图片文件"),
description: str = Form("", description="表情包描述"),
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
is_registered: bool = Form(True, description="是否直接注册"),
file: EmojiFile,
description: DescriptionForm = "",
emotion: EmotionForm = "",
is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None),
):
"""
@@ -713,9 +718,9 @@ async def upload_emoji(
@router.post("/batch/upload")
async def batch_upload_emoji(
files: List[UploadFile] = File(..., description="多个表情包图片文件"),
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
is_registered: bool = Form(True, description="是否直接注册"),
files: EmojiFiles,
emotion: EmotionForm = "",
is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None),
):
"""
@@ -749,11 +754,13 @@ async def batch_upload_emoji(
# 验证文件类型
if file.content_type not in allowed_types:
results["failed"] += 1
results["details"].append({
"filename": file.filename,
"success": False,
"error": f"不支持的文件类型: {file.content_type}",
})
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": f"不支持的文件类型: {file.content_type}",
}
)
continue
# 读取文件内容
@@ -761,11 +768,13 @@ async def batch_upload_emoji(
if not file_content:
results["failed"] += 1
results["details"].append({
"filename": file.filename,
"success": False,
"error": "文件内容为空",
})
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": "文件内容为空",
}
)
continue
# 验证图片
@@ -774,11 +783,13 @@ async def batch_upload_emoji(
img_format = img.format.lower() if img.format else "png"
except Exception as e:
results["failed"] += 1
results["details"].append({
"filename": file.filename,
"success": False,
"error": f"无效的图片: {str(e)}",
})
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": f"无效的图片: {str(e)}",
}
)
continue
# 计算哈希
@@ -787,11 +798,13 @@ async def batch_upload_emoji(
# 检查重复
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
results["failed"] += 1
results["details"].append({
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
})
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": "已存在相同的表情包",
}
)
continue
# 生成文件名并保存
@@ -829,19 +842,23 @@ async def batch_upload_emoji(
)
results["uploaded"] += 1
results["details"].append({
"filename": file.filename,
"success": True,
"id": emoji.id,
})
results["details"].append(
{
"filename": file.filename,
"success": True,
"id": emoji.id,
}
)
except Exception as e:
results["failed"] += 1
results["details"].append({
"filename": file.filename,
"success": False,
"error": str(e),
})
results["details"].append(
{
"filename": file.filename,
"success": False,
"error": str(e),
}
)
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}"
return results
@@ -850,4 +867,4 @@ async def batch_upload_emoji(
raise
except Exception as e:
logger.exception(f"批量上传表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e

View File

@@ -602,9 +602,9 @@ class GitMirrorService:
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone():
def run_git_clone(clone_cmd=cmd):
return subprocess.run(
cmd,
clone_cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时

View File

@@ -1,4 +1,5 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query
from pydantic import BaseModel
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
class KnowledgeNode(BaseModel):
"""知识节点"""
id: str
type: str # 'entity' or 'paragraph'
content: str
@@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
class KnowledgeEdge(BaseModel):
"""知识边"""
source: str
target: str
weight: float
@@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
class KnowledgeGraph(BaseModel):
"""知识图谱"""
nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel):
"""知识库统计信息"""
total_nodes: int
total_edges: int
entity_nodes: int
@@ -45,7 +50,7 @@ def _load_kg_manager():
"""延迟加载 KGManager"""
try:
from src.chat.knowledge.kg_manager import KGManager
kg_manager = KGManager()
kg_manager.load_from_file()
return kg_manager
@@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
"""将 DiGraph 转换为 JSON 格式"""
if kg_manager is None or kg_manager.graph is None:
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
nodes = []
edges = []
# 转换节点
node_list = graph.get_node_list()
for node_id in node_list:
try:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None
nodes.append(KnowledgeNode(
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 转换边
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
@@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
source, target = edge_tuple[0], edge_tuple[1]
# 通过 graph[source, target] 获取边的属性数据
edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
edges.append(KnowledgeEdge(
source=source,
target=target,
weight=weight,
create_time=create_time,
update_time=update_time
))
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
return KnowledgeGraph(nodes=nodes, edges=edges)
@router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
):
"""获取知识图谱(限制节点数量)
Args:
limit: 返回的最大节点数,默认 100,最大 10000
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
Returns:
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
"""
@@ -130,46 +128,43 @@ async def get_knowledge_graph(
if kg_manager is None:
logger.warning("KGManager 未初始化,返回空图谱")
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
all_node_list = graph.get_node_list()
# 按类型过滤节点
if node_type == "entity":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量
total_nodes = len(all_node_list)
if len(all_node_list) > limit:
node_list = all_node_list[:limit]
else:
node_list = all_node_list
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
# 转换节点
nodes = []
node_ids = set()
for node_id in node_list:
try:
node_data = graph[node_id]
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None
nodes.append(KnowledgeNode(
id=node_id,
type=node_type_val,
content=content,
create_time=create_time
))
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
node_ids.add(node_id)
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 只获取涉及当前节点集的边(保证图的完整性)
edges = []
edge_list = graph.get_edge_list()
@@ -179,27 +174,25 @@ async def get_knowledge_graph(
# 只包含两端都在当前节点集中的边
if source not in node_ids or target not in node_ids:
continue
edge_data = graph[source, target]
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
edges.append(KnowledgeEdge(
source=source,
target=target,
weight=weight,
create_time=create_time,
update_time=update_time
))
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
return graph_data
except Exception as e:
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
return KnowledgeGraph(nodes=[], edges=[])
@@ -208,71 +201,59 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats():
"""获取知识库统计信息
Returns:
KnowledgeStats: 统计信息
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats(
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
graph = kg_manager.graph
node_list = graph.get_node_list()
edge_list = graph.get_edge_list()
total_nodes = len(node_list)
total_edges = len(edge_list)
# 统计节点类型
entity_nodes = 0
paragraph_nodes = 0
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = node_data['type'] if 'type' in node_data else 'ent'
if node_type == 'ent':
node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == "ent":
entity_nodes += 1
elif node_type == 'pg':
elif node_type == "pg":
paragraph_nodes += 1
except Exception:
continue
# 计算平均连接数
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
return KnowledgeStats(
total_nodes=total_nodes,
total_edges=total_edges,
entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2)
avg_connections=round(avg_connections, 2),
)
except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats(
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)):
"""搜索知识节点
Args:
query: 搜索关键词
Returns:
List[KnowledgeNode]: 匹配的节点列表
"""
@@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return []
graph = kg_manager.graph
node_list = graph.get_node_list()
results = []
query_lower = query.lower()
# 在节点内容中搜索
for node_id in node_list:
try:
node_data = graph[node_id]
content = node_data['content'] if 'content' in node_data else node_id
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data['create_time'] if 'create_time' in node_data else None
results.append(KnowledgeNode(
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception:
continue
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
return results[:50] # 限制返回数量
except Exception as e:
logger.error(f"搜索节点失败: {e}", exc_info=True)
return []

View File

@@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
def _parse_openai_response(data: dict) -> list[dict]:
"""
解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append({
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
})
models.append(
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
)
return models
def _parse_gemini_response(data: dict) -> list[dict]:
"""
解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
"""
models = []
@@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
model_id = model["name"]
if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀
models.append({
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
})
models.append(
{
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
}
)
return models
@@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
) -> list[dict]:
"""
从提供商 API 获取模型列表
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else:
raise HTTPException(
status_code=502,
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
)
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
) from e
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应
if parser == "openai":
return _parse_openai_response(data)
@@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
def _get_provider_config(provider_name: str) -> Optional[dict]:
"""
从 model_config.toml 获取指定提供商的配置
Args:
provider_name: 提供商名称
Returns:
提供商配置,如果未找到则返回 None
"""
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
@@ -184,23 +187,23 @@ async def get_provider_models(
):
"""
获取指定提供商的可用模型列表
通过提供商名称查找配置,然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
@@ -209,7 +212,7 @@ async def get_provider_models(
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
@@ -236,7 +239,7 @@ async def get_models_by_url(
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
@@ -251,11 +254,11 @@ async def test_provider_connection(
):
"""
测试提供商连接状态
分两步测试:
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
2. API Key 验证(可选):如果提供了 api_key尝试获取模型列表验证 Key 是否有效
返回:
- network_ok: 网络是否连通
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
@@ -263,11 +266,11 @@ async def test_provider_connection(
- error: 错误信息(如果有)
"""
import time
base_url = _normalize_url(base_url)
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
result = {
"network_ok": False,
"api_key_valid": None,
@@ -275,7 +278,7 @@ async def test_provider_connection(
"error": None,
"http_status": None,
}
# 第一步:测试网络连通性
try:
start_time = time.time()
@@ -283,11 +286,11 @@ async def test_provider_connection(
# 尝试 GET 请求 base_url不需要 API Key
response = await client.get(base_url)
latency = (time.time() - start_time) * 1000
result["network_ok"] = True
result["latency_ms"] = round(latency, 2)
result["http_status"] = response.status_code
except httpx.ConnectError as e:
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
return result
@@ -300,7 +303,7 @@ async def test_provider_connection(
except Exception as e:
result["error"] = f"未知错误:{str(e)}"
return result
# 第二步:如果提供了 API Key验证其有效性
if api_key:
try:
@@ -313,7 +316,7 @@ async def test_provider_connection(
# 尝试获取模型列表
models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers)
if response.status_code == 200:
result["api_key_valid"] = True
elif response.status_code in (401, 403):
@@ -322,12 +325,12 @@ async def test_provider_connection(
else:
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
result["api_key_valid"] = None
except Exception as e:
# API Key 验证失败不影响网络连通性结果
logger.warning(f"API Key 验证失败: {e}")
result["api_key_valid"] = None
return result
@@ -342,10 +345,10 @@ async def test_provider_connection_by_name(
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(model_config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(model_config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 查找提供商
providers = config.get("api_providers", [])
provider = None
@@ -353,15 +356,15 @@ async def test_provider_connection_by_name(
if p.get("name") == provider_name:
provider = p
break
if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)

View File

@@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
"""
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
import re
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0]
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".")
if len(parts) < 3:
@@ -613,7 +614,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
for field in required_fields:
if field not in manifest:
raise ValueError(f"缺少必需字段: {field}")
# 将插件 ID 写入 manifest用于后续准确识别
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
manifest["id"] = request.plugin_id
@@ -705,7 +706,7 @@ async def uninstall_plugin(
plugin_path = plugins_dir / folder_name
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
if old_format_path.exists():
@@ -839,7 +840,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
plugin_path = plugins_dir / folder_name
# 旧格式:点
old_format_path = plugins_dir / request.plugin_id
# 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists():
if old_format_path.exists():
@@ -1092,21 +1093,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
# 尝试从 author.name 和 repository_url 构建标准 ID
author_name = None
repo_name = None
# 获取作者名
if "author" in manifest:
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
author_name = manifest["author"]["name"]
elif isinstance(manifest["author"], str):
author_name = manifest["author"]
# 从 repository_url 获取仓库名
if "repository_url" in manifest:
repo_url = manifest["repository_url"].rstrip("/")
if repo_url.endswith(".git"):
repo_url = repo_url[:-4]
repo_name = repo_url.split("/")[-1]
# 构建 ID
if author_name and repo_name:
# 标准格式: Author.RepoName
@@ -1122,7 +1123,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
else:
# 直接使用文件夹名
plugin_id = folder_name
# 将推断的 ID 写入 manifest方便下次识别
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
manifest["id"] = plugin_id
@@ -1167,12 +1168,10 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema(
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
获取插件配置 Schema
返回插件的完整配置 schema包含所有 section、字段定义和布局信息。
用于前端动态生成配置表单。
"""
@@ -1187,10 +1186,10 @@ async def get_plugin_config_schema(
try:
# 尝试从已加载的插件中获取
from src.plugin_system.core.plugin_manager import plugin_manager
# 查找插件实例
plugin_instance = None
# 遍历所有已加载的插件
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
@@ -1204,17 +1203,17 @@ async def get_plugin_config_schema(
if manifest_id == plugin_id:
plugin_instance = instance
break
if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'):
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
# 从插件实例获取 schema
schema = plugin_instance.get_webui_config_schema()
return {"success": True, "schema": schema}
# 如果插件未加载,尝试从文件系统读取
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None
for p in plugins_dir.iterdir():
if p.is_dir():
manifest_path = p / "_manifest.json"
@@ -1227,18 +1226,19 @@ async def get_plugin_config_schema(
break
except Exception:
continue
if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
# 读取配置文件获取当前配置
config_path = plugin_path / "config.toml"
current_config = {}
if config_path.exists():
import tomlkit
with open(config_path, "r", encoding="utf-8") as f:
current_config = tomlkit.load(f)
# 构建基础 schema无法获取完整的 ConfigField 信息)
schema = {
"plugin_id": plugin_id,
@@ -1252,7 +1252,7 @@ async def get_plugin_config_schema(
"layout": {"type": "auto", "tabs": []},
"_note": "插件未加载,仅返回当前配置结构",
}
# 从当前配置推断 schema
for section_name, section_data in current_config.items():
if isinstance(section_data, dict):
@@ -1277,7 +1277,7 @@ async def get_plugin_config_schema(
ui_type = "list"
elif isinstance(field_value, dict):
ui_type = "json"
schema["sections"][section_name]["fields"][field_name] = {
"name": field_name,
"type": field_type,
@@ -1290,7 +1290,7 @@ async def get_plugin_config_schema(
"disabled": False,
"order": 0,
}
return {"success": True, "schema": schema}
except HTTPException:
@@ -1301,12 +1301,10 @@ async def get_plugin_config_schema(
@router.get("/config/{plugin_id}")
async def get_plugin_config(
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
获取插件当前配置值
返回插件的当前配置值。
"""
# Token 验证
@@ -1321,7 +1319,7 @@ async def get_plugin_config(
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None
for p in plugins_dir.iterdir():
if p.is_dir():
manifest_path = p / "_manifest.json"
@@ -1334,19 +1332,20 @@ async def get_plugin_config(
break
except Exception:
continue
if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
# 读取配置文件
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"}
import tomlkit
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
return {"success": True, "config": dict(config)}
except HTTPException:
@@ -1358,13 +1357,11 @@ async def get_plugin_config(
@router.put("/config/{plugin_id}")
async def update_plugin_config(
plugin_id: str,
request: UpdatePluginConfigRequest,
authorization: Optional[str] = Header(None)
plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
"""
更新插件配置
保存新的配置值到插件的配置文件。
"""
# Token 验证
@@ -1379,7 +1376,7 @@ async def update_plugin_config(
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None
for p in plugins_dir.iterdir():
if p.is_dir():
manifest_path = p / "_manifest.json"
@@ -1392,23 +1389,25 @@ async def update_plugin_config(
break
except Exception:
continue
if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
# 备份旧配置
import shutil
import datetime
if config_path.exists():
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name
shutil.copy(config_path, backup_path)
logger.info(f"已备份配置文件: {backup_path}")
# 写入新配置(使用 tomlkit 保留注释)
import tomlkit
# 先读取原配置以保留注释和格式
existing_doc = tomlkit.document()
if config_path.exists():
@@ -1419,14 +1418,10 @@ async def update_plugin_config(
existing_doc[key] = value
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(existing_doc, f)
logger.info(f"已更新插件配置: {plugin_id}")
return {
"success": True,
"message": "配置已保存",
"note": "配置更改将在插件重新加载后生效"
}
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
except HTTPException:
raise
@@ -1436,12 +1431,10 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset")
async def reset_plugin_config(
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
重置插件配置为默认值
删除当前配置文件,下次加载插件时将使用默认配置。
"""
# Token 验证
@@ -1456,7 +1449,7 @@ async def reset_plugin_config(
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None
for p in plugins_dir.iterdir():
if p.is_dir():
manifest_path = p / "_manifest.json"
@@ -1469,29 +1462,26 @@ async def reset_plugin_config(
break
except Exception:
continue
if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
if not config_path.exists():
return {"success": True, "message": "配置文件不存在,无需重置"}
# 备份并删除
import shutil
import datetime
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name
shutil.move(config_path, backup_path)
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
return {
"success": True,
"message": "配置已重置,下次加载插件时将使用默认配置",
"backup": str(backup_path)
}
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
except HTTPException:
raise
@@ -1501,12 +1491,10 @@ async def reset_plugin_config(
@router.post("/config/{plugin_id}/toggle")
async def toggle_plugin(
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
"""
切换插件启用状态
切换插件配置中的 enabled 字段。
"""
# Token 验证
@@ -1521,7 +1509,7 @@ async def toggle_plugin(
# 查找插件目录
plugins_dir = Path("plugins")
plugin_path = None
for p in plugins_dir.iterdir():
if p.is_dir():
manifest_path = p / "_manifest.json"
@@ -1534,40 +1522,40 @@ async def toggle_plugin(
break
except Exception:
continue
if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml"
import tomlkit
# 读取当前配置(保留注释和格式)
config = tomlkit.document()
if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 切换 enabled 状态
if "plugin" not in config:
config["plugin"] = tomlkit.table()
current_enabled = config["plugin"].get("enabled", True)
new_enabled = not current_enabled
config["plugin"]["enabled"] = new_enabled
# 写入配置(保留注释)
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config, f)
status = "启用" if new_enabled else "禁用"
logger.info(f"{status}插件: {plugin_id}")
return {
"success": True,
"enabled": new_enabled,
"message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效"
"note": "状态更改将在下次加载插件时生效",
}
except HTTPException:

View File

@@ -43,7 +43,7 @@ async def restart_maibot():
注意:此操作会使麦麦暂时离线。
"""
import asyncio
try:
# 记录重启操作
print(f"[{datetime.now()}] WebUI 触发重启操作")
@@ -54,7 +54,7 @@ async def restart_maibot():
python = sys.executable
args = [python] + sys.argv
os.execv(python, args)
# 创建后台任务执行重启
asyncio.create_task(delayed_restart())

View File

@@ -20,10 +20,10 @@ class WebUIServer:
self.port = port
self.app = FastAPI(title="MaiBot WebUI")
self._server = None
# 显示 Access Token
self._show_access_token()
# 重要:先注册 API 路由,再设置静态文件
self._register_api_routes()
self._setup_static_files()
@@ -32,7 +32,7 @@ class WebUIServer:
"""显示 WebUI Access Token"""
try:
from src.webui.token_manager import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
@@ -69,7 +69,7 @@ class WebUIServer:
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html")
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
@@ -88,13 +88,15 @@ class WebUIServer:
# 导入所有 WebUI 路由
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
logger.info("开始导入 knowledge_routes...")
from src.webui.knowledge_routes import router as knowledge_router
logger.info("knowledge_routes 导入成功")
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
logger.info("chat_routes 导入成功")
# 注册路由