Merge branch 'dev' of https://github.com/A-Dawn/MaiBot into dev
This commit is contained in:
127
src/webui/auth.py
Normal file
127
src/webui/auth.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
|
||||
# Cookie 配置
|
||||
COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
"""
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取
|
||||
samesite="lax", # 允许同站导航时发送 Cookie(兼容开发环境代理)
|
||||
secure=False, # 本地开发不强制 HTTPS,生产环境建议设为 True
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
logger.debug(f"已设置认证 Cookie: {token[:8]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
731
src/webui/chat_routes.py
Normal file
731
src/webui/chat_routes.py
Normal file
@@ -0,0 +1,731 @@
|
||||
"""本地聊天室路由 - WebUI 与麦麦直接对话
|
||||
|
||||
支持两种模式:
|
||||
1. WebUI 模式:使用 WebUI 平台独立身份聊天
|
||||
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
||||
# 虚拟身份模式的群 ID 前缀
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
# 固定的 WebUI 用户 ID 前缀
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
enabled: bool = False # 是否启用虚拟身份模式
|
||||
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
|
||||
person_id: Optional[str] = None # PersonInfo 的 person_id
|
||||
user_id: Optional[str] = None # 原始平台用户 ID
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定)
|
||||
group_name: Optional[str] = None # 虚拟群名(用户自定义)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
# 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息
|
||||
bot_qq = str(global_config.bot.qq_account)
|
||||
is_bot = user_id == bot_qq
|
||||
else:
|
||||
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.time,
|
||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录
|
||||
|
||||
Args:
|
||||
limit: 获取的消息数量
|
||||
group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
# 查询指定群的消息,按时间排序
|
||||
messages = (
|
||||
Messages.select()
|
||||
.where(Messages.chat_info_group_id == target_group_id)
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局聊天历史管理器
|
||||
chat_history = ChatHistoryManager()
|
||||
|
||||
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
user_id: 用户 ID
|
||||
user_name: 用户昵称
|
||||
message_id: 消息 ID(可选,自动生成)
|
||||
is_at_bot: 是否 @ 机器人
|
||||
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# 确定使用的平台、群信息和用户信息
|
||||
if virtual_config and virtual_config.enabled:
|
||||
# 虚拟身份模式:使用真实平台身份
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
# 标准 WebUI 模式
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
如果指定了 group_id,则获取该虚拟群的历史记录
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms():
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
from peewee import fn
|
||||
|
||||
# 查询所有不同的平台
|
||||
platforms = (
|
||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(fn.COUNT(PersonInfo.id).desc())
|
||||
)
|
||||
|
||||
result = []
|
||||
for p in platforms:
|
||||
if p.platform: # 排除空平台
|
||||
result.append({"platform": p.platform, "count": p.count})
|
||||
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
Args:
|
||||
platform: 平台名称(如 qq, discord 等)
|
||||
search: 搜索关键词(匹配昵称、用户名、user_id)
|
||||
limit: 返回数量限制
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = PersonInfo.select().where(PersonInfo.platform == platform)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.limit(limit)
|
||||
|
||||
result = []
|
||||
for person in query:
|
||||
result.append(
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.nickname or person.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清空 {deleted} 条聊天记录",
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
platform: 虚拟身份模式的平台(可选)
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
"""
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
# 当前会话的虚拟身份配置(可通过消息动态更新)
|
||||
current_virtual_config: Optional[VirtualIdentityConfig] = None
|
||||
|
||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||
if platform and person_id:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
try:
|
||||
# 构建会话信息
|
||||
session_info_data = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
# 如果有虚拟身份配置,添加到会话信息中
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = current_virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
}
|
||||
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, session_info_data)
|
||||
|
||||
# 发送历史记录(根据模式选择不同的群)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
else:
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
|
||||
else:
|
||||
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": welcome_msg,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
# 确定发送者信息(根据是否使用虚拟身份)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
sender_name = current_virtual_config.user_nickname or current_user_name
|
||||
sender_user_id = current_virtual_config.user_id or user_id
|
||||
else:
|
||||
sender_name = current_user_name
|
||||
sender_user_id = user_id
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
user_name=current_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "set_virtual_identity":
|
||||
# 设置或更新虚拟身份配置
|
||||
virtual_data = data.get("config", {})
|
||||
if virtual_data.get("enabled"):
|
||||
# 验证必要字段
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取用户信息
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
|
||||
if not person:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"找不到用户: {virtual_data.get('person_id')}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成虚拟群 ID
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
if custom_group_id:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
else:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
|
||||
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||
)
|
||||
|
||||
# 发送虚拟身份已激活的消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 加载虚拟群的历史记录
|
||||
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": virtual_history,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送系统消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"设置虚拟身份失败: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 禁用虚拟身份模式
|
||||
current_virtual_config = None
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 重新加载默认聊天室历史
|
||||
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": default_history,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info():
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
@@ -5,9 +5,10 @@
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any
|
||||
from typing import Any, Annotated
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
@@ -28,9 +29,7 @@ from src.config.official_configs import (
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
MoodConfig,
|
||||
VoiceConfig,
|
||||
JargonConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
@@ -41,43 +40,15 @@ from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
# ===== 辅助函数 =====
|
||||
|
||||
|
||||
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
|
||||
"""
|
||||
递归合并字典,保留 target 中的注释和格式
|
||||
将 source 的值更新到 target 中(仅更新已存在的键)
|
||||
|
||||
Args:
|
||||
target: 目标字典(tomlkit 对象,包含注释)
|
||||
source: 源字典(普通 dict 或 list)
|
||||
"""
|
||||
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
|
||||
if isinstance(source, list):
|
||||
return # 调用者需要直接赋值
|
||||
|
||||
# 如果都是字典,递归合并
|
||||
if isinstance(source, dict) and isinstance(target, dict):
|
||||
for key, value in source.items():
|
||||
if key == "version":
|
||||
continue # 跳过版本号
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
# 递归处理嵌套字典
|
||||
if isinstance(value, dict) and isinstance(target_value, dict):
|
||||
_update_dict_preserve_comments(target_value, value)
|
||||
else:
|
||||
# 使用 tomlkit.item 保持类型
|
||||
try:
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
target[key] = value
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@@ -90,7 +61,7 @@ async def get_bot_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
@@ -101,7 +72,7 @@ async def get_model_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
@@ -131,7 +102,6 @@ async def get_config_section_schema(section_name: str):
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- mood: MoodConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
@@ -157,9 +127,7 @@ async def get_config_section_schema(section_name: str):
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"mood": MoodConfig,
|
||||
"voice": VoiceConfig,
|
||||
"jargon": JargonConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
@@ -174,7 +142,7 @@ async def get_config_section_schema(section_name: str):
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
@@ -196,7 +164,7 @@ async def get_bot_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
@@ -215,26 +183,25 @@ async def get_model_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("麦麦主程序配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
@@ -242,23 +209,22 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("模型配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
@@ -266,14 +232,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -295,7 +261,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -304,11 +270,10 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
@@ -316,7 +281,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
@@ -338,24 +303,24 @@ async def get_bot_config_raw():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -368,11 +333,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -394,7 +359,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_dict_preserve_comments(config_data[section_name], section_data)
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
@@ -403,11 +368,22 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models
|
||||
if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config_data, f)
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
@@ -415,7 +391,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
@@ -425,11 +401,11 @@ def _normalize_adapter_path(path: str) -> str:
|
||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
|
||||
# 如果已经是绝对路径,直接返回
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
@@ -438,17 +414,17 @@ def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
try:
|
||||
# 尝试获取相对路径
|
||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||
if not rel_path.startswith('..'):
|
||||
if not rel_path.startswith(".."):
|
||||
return rel_path
|
||||
except (ValueError, TypeError):
|
||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||
pass
|
||||
|
||||
|
||||
# 无法转换为相对路径,返回绝对路径
|
||||
return path
|
||||
|
||||
@@ -463,6 +439,7 @@ async def get_adapter_config_path():
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
@@ -472,10 +449,11 @@ async def get_adapter_config_path():
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
import datetime
|
||||
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
# 返回相对路径(如果可能)
|
||||
@@ -487,11 +465,11 @@ async def get_adapter_config_path():
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -511,10 +489,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = save_path
|
||||
|
||||
@@ -530,7 +508,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
@@ -542,7 +520,7 @@ async def get_adapter_config(path: str):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
@@ -562,11 +540,11 @@ async def get_adapter_config(path: str):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config(data: PathBody):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -579,17 +557,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
import toml
|
||||
toml.loads(content)
|
||||
tomlkit.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(abs_path)
|
||||
@@ -607,5 +584,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||
|
||||
@@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
@@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
@@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool or field_type == bool:
|
||||
if field_type is bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int or field_type == int:
|
||||
elif field_type is int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float or field_type == float:
|
||||
elif field_type is float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str or field_type == str:
|
||||
elif field_type is str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
@@ -296,7 +296,6 @@ class ConfigSchemaGenerator:
|
||||
"plan_style",
|
||||
"visual_style",
|
||||
"private_plan_style",
|
||||
"emotion_style",
|
||||
"reaction",
|
||||
"filtration_prompt",
|
||||
]:
|
||||
|
||||
@@ -1,18 +1,176 @@
|
||||
"""表情包管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form, Cookie
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Annotated
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
import json
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
import os
|
||||
import hashlib
|
||||
from PIL import Image
|
||||
import io
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
# ==================== 缩略图缓存配置 ====================
|
||||
# 缩略图缓存目录
|
||||
THUMBNAIL_CACHE_DIR = Path("data/emoji_thumbnails")
|
||||
# 缩略图尺寸 (宽, 高)
|
||||
THUMBNAIL_SIZE = (200, 200)
|
||||
# 缩略图质量 (WebP 格式, 1-100)
|
||||
THUMBNAIL_QUALITY = 80
|
||||
# 缓存锁,防止并发生成同一缩略图
|
||||
_thumbnail_locks: dict[str, threading.Lock] = {}
|
||||
_locks_lock = threading.Lock()
|
||||
# 缩略图生成专用线程池(避免阻塞事件循环)
|
||||
_thumbnail_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="thumbnail")
|
||||
# 正在生成中的缩略图哈希集合(防止重复提交任务)
|
||||
_generating_thumbnails: set[str] = set()
|
||||
_generating_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_thumbnail_lock(file_hash: str) -> threading.Lock:
|
||||
"""获取指定文件哈希的锁,用于防止并发生成同一缩略图"""
|
||||
with _locks_lock:
|
||||
if file_hash not in _thumbnail_locks:
|
||||
_thumbnail_locks[file_hash] = threading.Lock()
|
||||
return _thumbnail_locks[file_hash]
|
||||
|
||||
|
||||
def _background_generate_thumbnail(source_path: str, file_hash: str) -> None:
|
||||
"""
|
||||
后台生成缩略图(在线程池中执行)
|
||||
|
||||
生成完成后自动从 generating 集合中移除
|
||||
"""
|
||||
try:
|
||||
_generate_thumbnail(source_path, file_hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"后台生成缩略图失败 {file_hash}: {e}")
|
||||
finally:
|
||||
with _generating_lock:
|
||||
_generating_thumbnails.discard(file_hash)
|
||||
|
||||
|
||||
def _ensure_thumbnail_cache_dir() -> Path:
|
||||
"""确保缩略图缓存目录存在"""
|
||||
THUMBNAIL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return THUMBNAIL_CACHE_DIR
|
||||
|
||||
|
||||
def _get_thumbnail_cache_path(file_hash: str) -> Path:
|
||||
"""获取缩略图缓存路径"""
|
||||
return THUMBNAIL_CACHE_DIR / f"{file_hash}.webp"
|
||||
|
||||
|
||||
def _generate_thumbnail(source_path: str, file_hash: str) -> Path:
|
||||
"""
|
||||
生成缩略图并保存到缓存目录
|
||||
|
||||
Args:
|
||||
source_path: 原图路径
|
||||
file_hash: 文件哈希值,用作缓存文件名
|
||||
|
||||
Returns:
|
||||
缩略图路径
|
||||
|
||||
Features:
|
||||
- GIF: 提取第一帧作为缩略图
|
||||
- 所有格式统一转为 WebP
|
||||
- 保持宽高比缩放
|
||||
"""
|
||||
_ensure_thumbnail_cache_dir()
|
||||
cache_path = _get_thumbnail_cache_path(file_hash)
|
||||
|
||||
# 使用锁防止并发生成同一缩略图
|
||||
lock = _get_thumbnail_lock(file_hash)
|
||||
with lock:
|
||||
# 双重检查,可能在等待锁时已被其他线程生成
|
||||
if cache_path.exists():
|
||||
return cache_path
|
||||
|
||||
try:
|
||||
with Image.open(source_path) as img:
|
||||
# GIF 处理:提取第一帧
|
||||
if hasattr(img, "n_frames") and img.n_frames > 1:
|
||||
img.seek(0) # 确保在第一帧
|
||||
|
||||
# 转换为 RGB/RGBA(WebP 支持透明度)
|
||||
if img.mode in ("P", "PA"):
|
||||
# 调色板模式转换为 RGBA 以保留透明度
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode == "LA":
|
||||
img = img.convert("RGBA")
|
||||
elif img.mode not in ("RGB", "RGBA"):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# 创建缩略图(保持宽高比)
|
||||
img.thumbnail(THUMBNAIL_SIZE, Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存为 WebP 格式
|
||||
img.save(cache_path, "WEBP", quality=THUMBNAIL_QUALITY, method=6)
|
||||
|
||||
logger.debug(f"生成缩略图: {file_hash} -> {cache_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"生成缩略图失败 {file_hash}: {e},将返回原图")
|
||||
# 生成失败时不创建缓存文件,下次会重试
|
||||
raise
|
||||
|
||||
return cache_path
|
||||
|
||||
|
||||
def cleanup_orphaned_thumbnails() -> tuple[int, int]:
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已不存在的缩略图)
|
||||
|
||||
Returns:
|
||||
(清理数量, 保留数量)
|
||||
"""
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
# 获取所有表情包的哈希值
|
||||
valid_hashes = set()
|
||||
for emoji in Emoji.select(Emoji.emoji_hash):
|
||||
valid_hashes.add(emoji.emoji_hash)
|
||||
|
||||
cleaned = 0
|
||||
kept = 0
|
||||
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
file_hash = cache_file.stem
|
||||
if file_hash not in valid_hashes:
|
||||
try:
|
||||
cache_file.unlink()
|
||||
cleaned += 1
|
||||
logger.debug(f"清理孤立缩略图: {cache_file.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清理缩略图失败 {cache_file.name}: {e}")
|
||||
else:
|
||||
kept += 1
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"清理孤立缩略图: 删除 {cleaned} 个,保留 {kept} 个")
|
||||
|
||||
return cleaned, kept
|
||||
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
@@ -92,18 +250,12 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[int] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
|
||||
@@ -135,6 +287,7 @@ async def get_emoji_list(
|
||||
format: Optional[str] = Query(None, description="格式筛选"),
|
||||
sort_by: Optional[str] = Query("usage_count", description="排序字段"),
|
||||
sort_order: Optional[str] = Query("desc", description="排序方向"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -155,7 +308,7 @@ async def get_emoji_list(
|
||||
表情包列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Emoji.select()
|
||||
@@ -213,7 +366,9 @@ async def get_emoji_list(
|
||||
|
||||
|
||||
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
|
||||
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_detail(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表情包详细信息
|
||||
|
||||
@@ -225,7 +380,7 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
表情包详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -242,7 +397,12 @@ async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(
|
||||
|
||||
|
||||
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
|
||||
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_emoji(
|
||||
emoji_id: int,
|
||||
request: EmojiUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表情包(只更新提供的字段)
|
||||
|
||||
@@ -255,7 +415,7 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -294,7 +454,9 @@ async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization
|
||||
|
||||
|
||||
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
|
||||
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表情包
|
||||
|
||||
@@ -306,7 +468,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -331,7 +493,7 @@ async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取表情包统计数据
|
||||
|
||||
@@ -342,7 +504,7 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Emoji.select().count()
|
||||
registered = Emoji.select().where(Emoji.is_registered).count()
|
||||
@@ -386,7 +548,9 @@ async def get_emoji_stats(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
|
||||
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def register_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
注册表情包(快捷操作)
|
||||
|
||||
@@ -398,7 +562,7 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -426,7 +590,9 @@ async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(No
|
||||
|
||||
|
||||
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
|
||||
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def ban_emoji(
|
||||
emoji_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
禁用表情包(快捷操作)
|
||||
|
||||
@@ -438,7 +604,7 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -465,28 +631,47 @@ async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_emoji_thumbnail(
|
||||
emoji_id: int,
|
||||
token: Optional[str] = Query(None, description="访问令牌"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
original: bool = Query(False, description="是否返回原图"),
|
||||
):
|
||||
"""
|
||||
获取表情包缩略图
|
||||
获取表情包缩略图(懒加载生成 + 缓存)
|
||||
|
||||
Args:
|
||||
emoji_id: 表情包ID
|
||||
token: 访问令牌(通过 query parameter)
|
||||
token: 访问令牌(通过 query parameter,用于向后兼容)
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header
|
||||
original: 是否返回原图(用于详情页查看原图)
|
||||
|
||||
Returns:
|
||||
表情包图片文件
|
||||
表情包缩略图(WebP 格式)或原图
|
||||
|
||||
Features:
|
||||
- 懒加载:首次请求时生成缩略图
|
||||
- 缓存:后续请求直接返回缓存
|
||||
- GIF 支持:提取第一帧作为缩略图
|
||||
- 格式统一:所有缩略图统一为 WebP 格式
|
||||
"""
|
||||
try:
|
||||
# 优先使用 query parameter 中的 token(用于 img 标签)
|
||||
if token:
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
else:
|
||||
# 如果没有 query token,则验证 Authorization header
|
||||
verify_auth_token(authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_valid = False
|
||||
|
||||
# 1. 优先使用 Cookie
|
||||
if maibot_session and token_manager.verify_token(maibot_session):
|
||||
is_valid = True
|
||||
# 2. 其次使用 query parameter(用于向后兼容 img 标签)
|
||||
elif token and token_manager.verify_token(token):
|
||||
is_valid = True
|
||||
# 3. 最后使用 Authorization header
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
auth_token = authorization.replace("Bearer ", "")
|
||||
if token_manager.verify_token(auth_token):
|
||||
is_valid = True
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
|
||||
|
||||
@@ -497,19 +682,51 @@ async def get_emoji_thumbnail(
|
||||
if not os.path.exists(emoji.full_path):
|
||||
raise HTTPException(status_code=404, detail="表情包文件不存在")
|
||||
|
||||
# 根据格式设置 MIME 类型
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
# 如果请求原图,直接返回原文件
|
||||
if original:
|
||||
mime_types = {
|
||||
"png": "image/png",
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"gif": "image/gif",
|
||||
"webp": "image/webp",
|
||||
"bmp": "image/bmp",
|
||||
}
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
return FileResponse(
|
||||
path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}"
|
||||
)
|
||||
|
||||
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
|
||||
# 尝试获取或生成缩略图
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
|
||||
# 检查缓存是否存在
|
||||
if cache_path.exists():
|
||||
# 缓存命中,直接返回
|
||||
return FileResponse(
|
||||
path=str(cache_path), media_type="image/webp", filename=f"{emoji.emoji_hash}_thumb.webp"
|
||||
)
|
||||
|
||||
# 缓存未命中,触发后台生成并返回 202
|
||||
with _generating_lock:
|
||||
if emoji.emoji_hash not in _generating_thumbnails:
|
||||
# 标记为正在生成
|
||||
_generating_thumbnails.add(emoji.emoji_hash)
|
||||
# 提交到线程池后台生成
|
||||
_thumbnail_executor.submit(_background_generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
|
||||
# 返回 202 Accepted,告诉前端缩略图正在生成中
|
||||
return JSONResponse(
|
||||
status_code=202,
|
||||
content={
|
||||
"status": "generating",
|
||||
"message": "缩略图正在生成中,请稍后重试",
|
||||
"emoji_id": emoji_id,
|
||||
},
|
||||
headers={
|
||||
"Retry-After": "1", # 建议 1 秒后重试
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -519,7 +736,11 @@ async def get_emoji_thumbnail(
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_emojis(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表情包
|
||||
|
||||
@@ -531,7 +752,7 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.emoji_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表情包ID")
|
||||
@@ -572,3 +793,519 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||
|
||||
|
||||
class EmojiUploadResponse(BaseModel):
|
||||
"""表情包上传响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[EmojiResponse] = None
|
||||
|
||||
|
||||
@router.post("/upload", response_model=EmojiUploadResponse)
|
||||
async def upload_emoji(
|
||||
file: EmojiFile,
|
||||
description: DescriptionForm = "",
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
上传并注册表情包
|
||||
|
||||
Args:
|
||||
file: 表情包图片文件 (支持 jpg, jpeg, png, gif, webp)
|
||||
description: 表情包描述
|
||||
emotion: 情感标签,多个用逗号分隔
|
||||
is_registered: 是否直接注册,默认为 True
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
上传结果和表情包信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 验证文件类型
|
||||
if not file.content_type:
|
||||
raise HTTPException(status_code=400, detail="无法识别文件类型")
|
||||
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
if file.content_type not in allowed_types:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的文件类型: {file.content_type},支持: {', '.join(allowed_types)}",
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
if not file_content:
|
||||
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||
|
||||
# 验证图片并获取格式
|
||||
try:
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
# 验证图片可以正常打开
|
||||
img.verify()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"无效的图片文件: {str(e)}") from e
|
||||
|
||||
# 重新打开图片(verify后需要重新打开)
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
|
||||
# 计算文件哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查是否已存在相同哈希的表情包
|
||||
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
if existing_emoji:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
|
||||
)
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
# 如果文件已存在,添加随机后缀
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
# 保存文件
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
logger.info(f"表情包文件已保存: {full_path}")
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description=description,
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
|
||||
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
|
||||
|
||||
return EmojiUploadResponse(
|
||||
success=True,
|
||||
message="表情包上传成功" + ("并已注册" if is_registered else ""),
|
||||
data=emoji_to_response(emoji),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"上传表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/upload")
|
||||
async def batch_upload_emoji(
|
||||
files: EmojiFiles,
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量上传表情包
|
||||
|
||||
Args:
|
||||
files: 多个表情包图片文件
|
||||
emotion: 共用的情感标签
|
||||
is_registered: 是否直接注册
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量上传结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
results = {
|
||||
"success": True,
|
||||
"total": len(files),
|
||||
"uploaded": 0,
|
||||
"failed": 0,
|
||||
"details": [],
|
||||
}
|
||||
|
||||
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||
|
||||
for file in files:
|
||||
try:
|
||||
# 验证文件类型
|
||||
if file.content_type not in allowed_types:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"不支持的文件类型: {file.content_type}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
if not file_content:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "文件内容为空",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 验证图片
|
||||
try:
|
||||
with Image.open(io.BytesIO(file_content)) as img:
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"无效的图片: {str(e)}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 计算哈希
|
||||
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||
|
||||
# 检查重复
|
||||
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成文件名并保存
|
||||
timestamp = int(time.time())
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
|
||||
counter = 1
|
||||
while os.path.exists(full_path):
|
||||
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||
counter += 1
|
||||
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
# 处理情感标签
|
||||
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||
|
||||
# 创建数据库记录
|
||||
current_time = time.time()
|
||||
emoji = Emoji.create(
|
||||
full_path=full_path,
|
||||
format=img_format,
|
||||
emoji_hash=emoji_hash,
|
||||
description="", # 批量上传暂不设置描述
|
||||
emotion=emotion_str,
|
||||
query_count=0,
|
||||
is_registered=is_registered,
|
||||
is_banned=False,
|
||||
record_time=current_time,
|
||||
register_time=current_time if is_registered else None,
|
||||
usage_count=0,
|
||||
last_used_time=None,
|
||||
)
|
||||
|
||||
results["uploaded"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||
return results
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量上传表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ==================== 缩略图缓存管理 API ====================
|
||||
|
||||
|
||||
class ThumbnailCacheStatsResponse(BaseModel):
|
||||
"""缩略图缓存统计响应"""
|
||||
|
||||
success: bool
|
||||
cache_dir: str
|
||||
total_count: int
|
||||
total_size_mb: float
|
||||
emoji_count: int
|
||||
coverage_percent: float
|
||||
|
||||
|
||||
class ThumbnailCleanupResponse(BaseModel):
|
||||
"""缩略图清理响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
cleaned_count: int
|
||||
kept_count: int
|
||||
|
||||
|
||||
class ThumbnailPreheatResponse(BaseModel):
|
||||
"""缩略图预热响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
generated_count: int
|
||||
skipped_count: int
|
||||
failed_count: int
|
||||
|
||||
|
||||
@router.get("/thumbnail-cache/stats", response_model=ThumbnailCacheStatsResponse)
|
||||
async def get_thumbnail_cache_stats(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取缩略图缓存统计信息
|
||||
|
||||
Returns:
|
||||
缓存目录、缓存数量、总大小、覆盖率等统计信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
# 统计缓存文件
|
||||
cache_files = list(THUMBNAIL_CACHE_DIR.glob("*.webp"))
|
||||
total_count = len(cache_files)
|
||||
total_size = sum(f.stat().st_size for f in cache_files)
|
||||
total_size_mb = round(total_size / (1024 * 1024), 2)
|
||||
|
||||
# 统计表情包总数
|
||||
emoji_count = Emoji.select().count()
|
||||
|
||||
# 计算覆盖率
|
||||
coverage_percent = round((total_count / emoji_count * 100) if emoji_count > 0 else 0, 1)
|
||||
|
||||
return ThumbnailCacheStatsResponse(
|
||||
success=True,
|
||||
cache_dir=str(THUMBNAIL_CACHE_DIR.absolute()),
|
||||
total_count=total_count,
|
||||
total_size_mb=total_size_mb,
|
||||
emoji_count=emoji_count,
|
||||
coverage_percent=coverage_percent,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取缩略图缓存统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/thumbnail-cache/cleanup", response_model=ThumbnailCleanupResponse)
|
||||
async def cleanup_thumbnail_cache(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
清理孤立的缩略图缓存(原图已删除的表情包对应的缩略图)
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
cleaned, kept = cleanup_orphaned_thumbnails()
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"清理完成:删除 {cleaned} 个孤立缓存,保留 {kept} 个有效缓存",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=kept,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"清理缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"清理失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/thumbnail-cache/preheat", response_model=ThumbnailPreheatResponse)
|
||||
async def preheat_thumbnail_cache(
|
||||
limit: int = Query(100, ge=1, le=1000, description="最多预热数量"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
预热缩略图缓存(提前生成未缓存的缩略图)
|
||||
|
||||
优先处理使用次数高的表情包
|
||||
|
||||
Args:
|
||||
limit: 最多预热数量 (1-1000)
|
||||
|
||||
Returns:
|
||||
预热结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
_ensure_thumbnail_cache_dir()
|
||||
|
||||
# 获取使用次数最高的表情包(未缓存的优先)
|
||||
emojis = (
|
||||
Emoji.select()
|
||||
.where(Emoji.is_banned == False) # noqa: E712 Peewee ORM requires == for boolean comparison
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(limit * 2) # 多查一些,因为有些可能已缓存
|
||||
)
|
||||
|
||||
generated = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for emoji in emojis:
|
||||
if generated >= limit:
|
||||
break
|
||||
|
||||
cache_path = _get_thumbnail_cache_path(emoji.emoji_hash)
|
||||
|
||||
# 已缓存,跳过
|
||||
if cache_path.exists():
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# 原文件不存在,跳过
|
||||
if not os.path.exists(emoji.full_path):
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
# 使用线程池异步生成缩略图,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(_thumbnail_executor, _generate_thumbnail, emoji.full_path, emoji.emoji_hash)
|
||||
generated += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"预热缩略图失败 {emoji.emoji_hash}: {e}")
|
||||
failed += 1
|
||||
|
||||
return ThumbnailPreheatResponse(
|
||||
success=True,
|
||||
message=f"预热完成:生成 {generated} 个,跳过 {skipped} 个已缓存,失败 {failed} 个",
|
||||
generated_count=generated,
|
||||
skipped_count=skipped,
|
||||
failed_count=failed,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"预热缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"预热失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/thumbnail-cache/clear", response_model=ThumbnailCleanupResponse)
|
||||
async def clear_all_thumbnail_cache(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
清空所有缩略图缓存(下次访问时会重新生成)
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not THUMBNAIL_CACHE_DIR.exists():
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message="缓存目录不存在,无需清理",
|
||||
cleaned_count=0,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
cleaned = 0
|
||||
for cache_file in THUMBNAIL_CACHE_DIR.glob("*.webp"):
|
||||
try:
|
||||
cache_file.unlink()
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"删除缓存文件失败 {cache_file.name}: {e}")
|
||||
|
||||
logger.info(f"已清空缩略图缓存: 删除 {cleaned} 个文件")
|
||||
|
||||
return ThumbnailCleanupResponse(
|
||||
success=True,
|
||||
message=f"已清空所有缩略图缓存:删除 {cleaned} 个文件",
|
||||
cleaned_count=cleaned,
|
||||
kept_count=0,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"清空缩略图缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"清空失败: {str(e)}") from e
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from pydantic import BaseModel
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel, NonNegativeFloat
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
@@ -21,7 +21,6 @@ class ExpressionResponse(BaseModel):
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str]
|
||||
up_content: Optional[str]
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
@@ -49,8 +48,7 @@ class ExpressionCreateRequest(BaseModel):
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
context: Optional[str] = NonNegativeFloat
|
||||
chat_id: str
|
||||
|
||||
|
||||
@@ -60,7 +58,6 @@ class ExpressionUpdateRequest(BaseModel):
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
|
||||
|
||||
@@ -87,18 +84,12 @@ class ExpressionCreateResponse(BaseModel):
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
@@ -108,7 +99,6 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
context=expression.context,
|
||||
up_content=expression.up_content,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
@@ -162,7 +152,7 @@ class ChatListResponse(BaseModel):
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
@@ -173,7 +163,7 @@ async def get_chat_list(authorization: Optional[str] = Header(None)):
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
@@ -205,6 +195,7 @@ async def get_expression_list(
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -221,7 +212,7 @@ async def get_expression_list(
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
@@ -265,7 +256,9 @@ async def get_expression_list(
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
@@ -277,7 +270,7 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -294,7 +287,11 @@ async def get_expression_detail(expression_id: int, authorization: Optional[str]
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
@@ -306,7 +303,7 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
@@ -315,7 +312,6 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
context=request.context,
|
||||
up_content=request.up_content,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
@@ -336,7 +332,10 @@ async def create_expression(request: ExpressionCreateRequest, authorization: Opt
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
@@ -350,7 +349,7 @@ async def update_expression(
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -386,7 +385,9 @@ async def update_expression(
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
@@ -398,7 +399,7 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
@@ -429,7 +430,11 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
@@ -441,7 +446,7 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
@@ -470,7 +475,9 @@ async def batch_delete_expressions(request: BatchDeleteRequest, authorization: O
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
@@ -481,7 +488,7 @@ async def get_expression_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
|
||||
@@ -602,9 +602,9 @@ class GitMirrorService:
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone():
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
|
||||
532
src/webui/jargon_routes.py
Normal file
532
src/webui/jargon_routes.py
Normal file
@@ -0,0 +1,532 @@
|
||||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
if isinstance(parsed, list):
|
||||
# 格式: [["stream_id", user_id], ...]
|
||||
stream_ids = []
|
||||
for item in parsed:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
stream_ids.append(str(item[0]))
|
||||
return stream_ids
|
||||
else:
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 不是有效的 JSON,可能是直接的 stream_id
|
||||
return [chat_id_str]
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream and chat_stream.group_name:
|
||||
names.append(chat_stream.group_name)
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
|
||||
class JargonResponse(BaseModel):
|
||||
"""黑话信息响应"""
|
||||
|
||||
id: int
|
||||
content: str
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: str
|
||||
stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配
|
||||
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
|
||||
is_global: bool = False
|
||||
count: int = 0
|
||||
is_jargon: Optional[bool] = None
|
||||
is_complete: bool = False
|
||||
inference_with_context: Optional[str] = None
|
||||
inference_content_only: Optional[str] = None
|
||||
|
||||
|
||||
class JargonListResponse(BaseModel):
|
||||
"""黑话列表响应"""
|
||||
|
||||
success: bool = True
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[JargonResponse]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
"""黑话详情响应"""
|
||||
|
||||
success: bool = True
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonCreateRequest(BaseModel):
|
||||
"""黑话创建请求"""
|
||||
|
||||
content: str = Field(..., description="黑话内容")
|
||||
raw_content: Optional[str] = Field(None, description="原始内容")
|
||||
meaning: Optional[str] = Field(None, description="含义")
|
||||
chat_id: str = Field(..., description="聊天ID")
|
||||
is_global: bool = Field(False, description="是否全局")
|
||||
|
||||
|
||||
class JargonUpdateRequest(BaseModel):
|
||||
"""黑话更新请求"""
|
||||
|
||||
content: Optional[str] = None
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
is_global: Optional[bool] = None
|
||||
is_jargon: Optional[bool] = None
|
||||
|
||||
|
||||
class JargonCreateResponse(BaseModel):
|
||||
"""黑话创建响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonUpdateResponse(BaseModel):
|
||||
"""黑话更新响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: Optional[JargonResponse] = None
|
||||
|
||||
|
||||
class JargonDeleteResponse(BaseModel):
|
||||
"""黑话删除响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
deleted_count: int = 0
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int] = Field(..., description="要删除的黑话ID列表")
|
||||
|
||||
|
||||
class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
"""聊天信息响应"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool = True
|
||||
data: List[ChatInfoResponse]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
# 解析 chat_id 获取显示名称和 stream_id
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": jargon.chat_id,
|
||||
"stream_id": stream_id,
|
||||
"chat_name": chat_name,
|
||||
"is_global": jargon.is_global,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
"is_complete": jargon.is_complete,
|
||||
"inference_with_context": jargon.inference_with_context,
|
||||
"inference_content_only": jargon.inference_content_only,
|
||||
}
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
|
||||
@router.get("/list", response_model=JargonListResponse)
|
||||
async def get_jargon_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = Jargon.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Jargon.content.contains(search))
|
||||
| (Jargon.meaning.contains(search))
|
||||
| (Jargon.raw_content.contains(search))
|
||||
)
|
||||
|
||||
# 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式)
|
||||
if chat_id:
|
||||
# 从传入的 chat_id 中解析出 stream_id
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
# 使用第一个 stream_id 进行模糊匹配
|
||||
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
|
||||
else:
|
||||
# 如果无法解析,使用精确匹配
|
||||
query = query.where(Jargon.chat_id == chat_id)
|
||||
|
||||
# 按是否是黑话筛选
|
||||
if is_jargon is not None:
|
||||
query = query.where(Jargon.is_jargon == is_jargon)
|
||||
|
||||
# 按是否全局筛选
|
||||
if is_global is not None:
|
||||
query = query.where(Jargon.is_global == is_global)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页和排序(按使用次数降序)
|
||||
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
|
||||
query = query.paginate(page, page_size)
|
||||
|
||||
# 转换为响应格式
|
||||
data = [jargon_to_dict(j) for j in query]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
|
||||
result = []
|
||||
for stream_id in seen_stream_ids:
|
||||
# 尝试从 ChatStreams 表获取聊天名称
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id,方便筛选匹配
|
||||
chat_name=chat_stream.group_name or stream_id,
|
||||
platform=chat_stream.platform,
|
||||
is_group=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id
|
||||
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatListResponse(success=True, data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
# 总数量
|
||||
total = Jargon.select().count()
|
||||
|
||||
# 已确认是黑话的数量
|
||||
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
|
||||
|
||||
# 已确认不是黑话的数量
|
||||
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
|
||||
|
||||
# 未判定的数量
|
||||
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
|
||||
|
||||
# 全局黑话数量
|
||||
global_count = Jargon.select().where(Jargon.is_global).count()
|
||||
|
||||
# 已完成推断的数量
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
|
||||
.group_by(Jargon.chat_id)
|
||||
.order_by(fn.COUNT(Jargon.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
data={
|
||||
"total": total,
|
||||
"confirmed_jargon": confirmed_jargon,
|
||||
"confirmed_not_jargon": confirmed_not_jargon,
|
||||
"pending": pending,
|
||||
"global_count": global_count,
|
||||
"complete_count": complete_count,
|
||||
"chat_count": chat_count,
|
||||
"top_chats": top_chats_dict,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
# 创建黑话
|
||||
jargon = Jargon.create(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning,
|
||||
chat_id=request.chat_id,
|
||||
is_global=request.is_global,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
|
||||
return JargonCreateResponse(
|
||||
success=True,
|
||||
message="创建成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
# 增量更新字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
jargon.save()
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message="更新成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
content = jargon.content
|
||||
jargon.delete_instance()
|
||||
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message="删除成功",
|
||||
deleted_count=1,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
|
||||
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除 {deleted_count} 条黑话",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {updated_count} 条黑话状态",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新黑话状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e
|
||||
@@ -1,4 +1,5 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
id: str
|
||||
type: str # 'entity' or 'paragraph'
|
||||
content: str
|
||||
@@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
|
||||
|
||||
class KnowledgeEdge(BaseModel):
|
||||
"""知识边"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
weight: float
|
||||
@@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""知识图谱"""
|
||||
|
||||
nodes: List[KnowledgeNode]
|
||||
edges: List[KnowledgeEdge]
|
||||
|
||||
|
||||
class KnowledgeStats(BaseModel):
|
||||
"""知识库统计信息"""
|
||||
|
||||
total_nodes: int
|
||||
total_edges: int
|
||||
entity_nodes: int
|
||||
@@ -45,7 +50,7 @@ def _load_kg_manager():
|
||||
"""延迟加载 KGManager"""
|
||||
try:
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
|
||||
|
||||
kg_manager = KGManager()
|
||||
kg_manager.load_from_file()
|
||||
return kg_manager
|
||||
@@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
"""将 DiGraph 转换为 JSON 格式"""
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
||||
# 转换节点
|
||||
node_list = graph.get_node_list()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 转换边
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
@@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 通过 graph[source, target] 获取边的属性数据
|
||||
edge_data = graph[source, target]
|
||||
|
||||
|
||||
# edge_data 支持 [] 操作符但不支持 .get()
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.get("/graph", response_model=KnowledgeGraph)
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
|
||||
Args:
|
||||
limit: 返回的最大节点数,默认 100,最大 10000
|
||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||
"""
|
||||
@@ -130,46 +128,43 @@ async def get_knowledge_graph(
|
||||
if kg_manager is None:
|
||||
logger.warning("KGManager 未初始化,返回空图谱")
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
all_node_list = graph.get_node_list()
|
||||
|
||||
|
||||
# 按类型过滤节点
|
||||
if node_type == "entity":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
|
||||
all_node_list = [
|
||||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||
]
|
||||
elif node_type == "paragraph":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
|
||||
|
||||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||
|
||||
# 限制节点数量
|
||||
total_nodes = len(all_node_list)
|
||||
if len(all_node_list) > limit:
|
||||
node_list = all_node_list[:limit]
|
||||
else:
|
||||
node_list = all_node_list
|
||||
|
||||
|
||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||
|
||||
|
||||
# 转换节点
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type_val,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
node_ids.add(node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||
edges = []
|
||||
edge_list = graph.get_edge_list()
|
||||
@@ -179,27 +174,25 @@ async def get_knowledge_graph(
|
||||
# 只包含两端都在当前节点集中的边
|
||||
if source not in node_ids or target not in node_ids:
|
||||
continue
|
||||
|
||||
|
||||
edge_data = graph[source, target]
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
return graph_data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
@@ -208,71 +201,59 @@ async def get_knowledge_graph(
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats():
|
||||
"""获取知识库统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeStats: 统计信息
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
edge_list = graph.get_edge_list()
|
||||
|
||||
|
||||
total_nodes = len(node_list)
|
||||
total_edges = len(edge_list)
|
||||
|
||||
|
||||
# 统计节点类型
|
||||
entity_nodes = 0
|
||||
paragraph_nodes = 0
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = node_data['type'] if 'type' in node_data else 'ent'
|
||||
if node_type == 'ent':
|
||||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||
if node_type == "ent":
|
||||
entity_nodes += 1
|
||||
elif node_type == 'pg':
|
||||
elif node_type == "pg":
|
||||
paragraph_nodes += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 计算平均连接数
|
||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||
|
||||
|
||||
return KnowledgeStats(
|
||||
total_nodes=total_nodes,
|
||||
total_edges=total_edges,
|
||||
entity_nodes=entity_nodes,
|
||||
paragraph_nodes=paragraph_nodes,
|
||||
avg_connections=round(avg_connections, 2)
|
||||
avg_connections=round(avg_connections, 2),
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
"""搜索知识节点
|
||||
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
|
||||
Returns:
|
||||
List[KnowledgeNode]: 匹配的节点列表
|
||||
"""
|
||||
@@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return []
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
|
||||
# 在节点内容中搜索
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
results.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||
return results[:50] # 限制返回数量
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
@@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append({
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||
"""
|
||||
models = []
|
||||
@@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
model_id = model["name"]
|
||||
if model_id.startswith("models/"):
|
||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
@@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||
if e.response.status_code == 401:
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||
elif e.response.status_code == 403:
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||
elif e.response.status_code == 404:
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
)
|
||||
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||
|
||||
# 根据解析器类型解析响应
|
||||
if parser == "openai":
|
||||
return _parse_openai_response(data)
|
||||
@@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
|
||||
Returns:
|
||||
提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
@@ -184,23 +187,23 @@ async def get_provider_models(
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
if not provider_config:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key")
|
||||
client_type = provider_config.get("client_type", "openai")
|
||||
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
@@ -209,7 +212,7 @@ async def get_provider_models(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
@@ -236,9 +239,132 @@ async def get_models_by_url(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-connection")
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
||||
分两步测试:
|
||||
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||
|
||||
返回:
|
||||
- network_ok: 网络是否连通
|
||||
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||
- latency_ms: 响应延迟(毫秒)
|
||||
- error: 错误信息(如果有)
|
||||
"""
|
||||
import time
|
||||
|
||||
base_url = _normalize_url(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
"latency_ms": None,
|
||||
"error": None,
|
||||
"http_status": None,
|
||||
}
|
||||
|
||||
# 第一步:测试网络连通性
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||
response = await client.get(base_url)
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
result["network_ok"] = True
|
||||
result["latency_ms"] = round(latency, 2)
|
||||
result["http_status"] = response.status_code
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||
return result
|
||||
except httpx.TimeoutException:
|
||||
result["error"] = "连接超时:服务器响应时间过长"
|
||||
return result
|
||||
except httpx.RequestError as e:
|
||||
result["error"] = f"请求错误:{str(e)}"
|
||||
return result
|
||||
except Exception as e:
|
||||
result["error"] = f"未知错误:{str(e)}"
|
||||
return result
|
||||
|
||||
# 第二步:如果提供了 API Key,验证其有效性
|
||||
if api_key:
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
elif response.status_code in (401, 403):
|
||||
result["api_key_valid"] = False
|
||||
result["error"] = "API Key 无效或已过期"
|
||||
else:
|
||||
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||
result["api_key_valid"] = None
|
||||
|
||||
except Exception as e:
|
||||
# API Key 验证失败不影响网络连通性结果
|
||||
logger.warning(f"API Key 验证失败: {e}")
|
||||
result["api_key_valid"] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
"""
|
||||
# 读取配置文件
|
||||
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(model_config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import json
|
||||
import time
|
||||
|
||||
@@ -91,18 +91,12 @@ class BatchDeleteResponse(BaseModel):
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(authorization: Optional[str]) -> bool:
|
||||
"""验证认证 Token"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
@@ -141,6 +135,7 @@ async def get_person_list(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -158,7 +153,7 @@ async def get_person_list(
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
@@ -205,7 +200,9 @@ async def get_person_list(
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
@@ -217,7 +214,7 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -234,7 +231,12 @@ async def get_person_detail(person_id: str, authorization: Optional[str] = Heade
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
@@ -247,7 +249,7 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -283,7 +285,9 @@ async def update_person(person_id: str, request: PersonUpdateRequest, authorizat
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
@@ -295,7 +299,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
@@ -320,7 +324,7 @@ async def delete_person(person_id: str, authorization: Optional[str] = Header(No
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
@@ -331,7 +335,7 @@ async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
@@ -353,7 +357,11 @@ async def get_person_stats(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)):
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
@@ -365,7 +373,7 @@ async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optio
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(authorization)
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Cookie
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
from typing import Optional, List, Dict, Any, get_origin
|
||||
from pathlib import Path
|
||||
import json
|
||||
from src.common.logger import get_logger
|
||||
from src.common.toml_utils import save_toml_with_format
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||
from .token_manager import get_token_manager
|
||||
from .plugin_progress_ws import update_progress
|
||||
@@ -18,6 +20,20 @@ router = APIRouter(prefix="/plugins", tags=["插件管理"])
|
||||
set_update_progress_callback(update_progress)
|
||||
|
||||
|
||||
def get_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""从 Cookie 或 Header 获取 token"""
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
return maibot_session
|
||||
# 其次从 Header 获取
|
||||
if authorization and authorization.startswith("Bearer "):
|
||||
return authorization.replace("Bearer ", "")
|
||||
return None
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
解析版本号字符串
|
||||
@@ -29,8 +45,11 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
Returns:
|
||||
(major, minor, patch) 三元组
|
||||
"""
|
||||
# 移除 snapshot 等后缀
|
||||
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
||||
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
||||
import re
|
||||
|
||||
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
@@ -47,6 +66,95 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
return (0, 0, 0)
|
||||
|
||||
|
||||
# ============ 工具函数(避免在请求内重复定义) ============
|
||||
|
||||
|
||||
def _deep_merge(dst: Dict[str, Any], src: Dict[str, Any]) -> None:
|
||||
"""深度合并两个字典,src 的值会覆盖或合并到 dst 中。"""
|
||||
for k, v in src.items():
|
||||
if k in dst and isinstance(dst[k], dict) and isinstance(v, dict):
|
||||
_deep_merge(dst[k], v)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
|
||||
def normalize_dotted_keys(obj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
将形如 {'a.b': 1} 的键展开为嵌套结构 {'a': {'b': 1}}。
|
||||
若遇到中间节点已存在且非字典,记录日志并覆盖为字典。
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
dotted_items = []
|
||||
|
||||
# 先处理非点号键,避免后续展开覆盖已有结构
|
||||
for k, v in obj.items():
|
||||
if "." in k:
|
||||
dotted_items.append((k, v))
|
||||
else:
|
||||
result[k] = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||
|
||||
# 再处理点号键
|
||||
for dotted_key, v in dotted_items:
|
||||
value = normalize_dotted_keys(v) if isinstance(v, dict) else v
|
||||
parts = dotted_key.split(".")
|
||||
if "" in parts:
|
||||
logger.warning(f"键路径包含空段: '{dotted_key}'")
|
||||
parts = [p for p in parts if p]
|
||||
if not parts:
|
||||
logger.warning(f"忽略空键路径: '{dotted_key}'")
|
||||
continue
|
||||
current = result
|
||||
# 中间层
|
||||
for idx, part in enumerate(parts[:-1]):
|
||||
if part in current and not isinstance(current[part], dict):
|
||||
path_ctx = ".".join(parts[: idx + 1])
|
||||
logger.warning(f"键冲突:{part} 已存在且非字典,覆盖为字典以展开 {dotted_key} (路径 {path_ctx})")
|
||||
current[part] = {}
|
||||
current = current.setdefault(part, {})
|
||||
# 最后一层
|
||||
last_part = parts[-1]
|
||||
if last_part in current and isinstance(current[last_part], dict) and isinstance(value, dict):
|
||||
_deep_merge(current[last_part], value)
|
||||
else:
|
||||
current[last_part] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def coerce_types(schema_part: Dict[str, Any], config_part: Dict[str, Any]) -> None:
|
||||
"""
|
||||
根据 schema 将配置中的类型纠正(目前只纠正 list-from-str)。
|
||||
"""
|
||||
|
||||
def _is_list_type(tp: Any) -> bool:
|
||||
origin = get_origin(tp)
|
||||
return tp is list or origin is list
|
||||
|
||||
for key, schema_val in schema_part.items():
|
||||
if key not in config_part:
|
||||
continue
|
||||
value = config_part[key]
|
||||
if isinstance(schema_val, ConfigField):
|
||||
if _is_list_type(schema_val.type) and isinstance(value, str):
|
||||
config_part[key] = [item.strip() for item in value.split(",") if item.strip()]
|
||||
elif isinstance(schema_val, dict) and isinstance(value, dict):
|
||||
coerce_types(schema_val, value)
|
||||
|
||||
|
||||
def find_plugin_instance(plugin_id: str) -> Optional[Any]:
|
||||
"""
|
||||
按 plugin_id 或 plugin_name 查找已加载的插件实例。
|
||||
局部导入 plugin_manager 以规避循环依赖。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||
if instance and (instance.plugin_name == plugin_id or instance.get_manifest_info("id", "") == plugin_id):
|
||||
return instance
|
||||
return None
|
||||
|
||||
|
||||
# ============ 请求/响应模型 ============
|
||||
|
||||
|
||||
@@ -206,12 +314,14 @@ async def check_git_status() -> GitStatusResponse:
|
||||
|
||||
|
||||
@router.get("/mirrors", response_model=AvailableMirrorsResponse)
|
||||
async def get_available_mirrors(authorization: Optional[str] = Header(None)) -> AvailableMirrorsResponse:
|
||||
async def get_available_mirrors(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> AvailableMirrorsResponse:
|
||||
"""
|
||||
获取所有可用的镜像源配置
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -236,12 +346,14 @@ async def get_available_mirrors(authorization: Optional[str] = Header(None)) ->
|
||||
|
||||
|
||||
@router.post("/mirrors", response_model=MirrorConfigResponse)
|
||||
async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = Header(None)) -> MirrorConfigResponse:
|
||||
async def add_mirror(
|
||||
request: AddMirrorRequest, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
添加新的镜像源
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -276,13 +388,16 @@ async def add_mirror(request: AddMirrorRequest, authorization: Optional[str] = H
|
||||
|
||||
@router.put("/mirrors/{mirror_id}", response_model=MirrorConfigResponse)
|
||||
async def update_mirror(
|
||||
mirror_id: str, request: UpdateMirrorRequest, authorization: Optional[str] = Header(None)
|
||||
mirror_id: str,
|
||||
request: UpdateMirrorRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> MirrorConfigResponse:
|
||||
"""
|
||||
更新镜像源配置
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -319,12 +434,14 @@ async def update_mirror(
|
||||
|
||||
|
||||
@router.delete("/mirrors/{mirror_id}")
|
||||
async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def delete_mirror(
|
||||
mirror_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
删除镜像源
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -342,7 +459,9 @@ async def delete_mirror(mirror_id: str, authorization: Optional[str] = Header(No
|
||||
|
||||
@router.post("/fetch-raw", response_model=FetchRawFileResponse)
|
||||
async def fetch_raw_file(
|
||||
request: FetchRawFileRequest, authorization: Optional[str] = Header(None)
|
||||
request: FetchRawFileRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> FetchRawFileResponse:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
@@ -352,7 +471,7 @@ async def fetch_raw_file(
|
||||
注意:此接口可公开访问,用于获取插件仓库等公开资源
|
||||
"""
|
||||
# Token 验证(可选,用于日志记录)
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
is_authenticated = token and token_manager.verify_token(token)
|
||||
|
||||
@@ -427,7 +546,9 @@ async def fetch_raw_file(
|
||||
|
||||
@router.post("/clone", response_model=CloneRepositoryResponse)
|
||||
async def clone_repository(
|
||||
request: CloneRepositoryRequest, authorization: Optional[str] = Header(None)
|
||||
request: CloneRepositoryRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> CloneRepositoryResponse:
|
||||
"""
|
||||
克隆 GitHub 仓库到本地
|
||||
@@ -435,7 +556,7 @@ async def clone_repository(
|
||||
支持多镜像源自动切换和错误重试
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -467,14 +588,18 @@ async def clone_repository(
|
||||
|
||||
|
||||
@router.post("/install")
|
||||
async def install_plugin(request: InstallPluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def install_plugin(
|
||||
request: InstallPluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
安装插件
|
||||
|
||||
从 Git 仓库克隆插件到本地插件目录
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -611,7 +736,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
||||
for field in required_fields:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
||||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
@@ -671,7 +796,9 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
||||
|
||||
@router.post("/uninstall")
|
||||
async def uninstall_plugin(
|
||||
request: UninstallPluginRequest, authorization: Optional[str] = Header(None)
|
||||
request: UninstallPluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
卸载插件
|
||||
@@ -679,7 +806,7 @@ async def uninstall_plugin(
|
||||
删除插件目录及其所有文件
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -703,7 +830,7 @@ async def uninstall_plugin(
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -806,14 +933,18 @@ async def uninstall_plugin(
|
||||
|
||||
|
||||
@router.post("/update")
|
||||
async def update_plugin(request: UpdatePluginRequest, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def update_plugin(
|
||||
request: UpdatePluginRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件
|
||||
|
||||
删除旧版本,重新克隆新版本
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -837,7 +968,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -1025,14 +1156,16 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
||||
|
||||
|
||||
@router.get("/installed")
|
||||
async def get_installed_plugins(authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
async def get_installed_plugins(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取已安装的插件列表
|
||||
|
||||
扫描 plugins 目录,返回所有已安装插件的 ID 和基本信息
|
||||
"""
|
||||
# Token 验证
|
||||
token = authorization.replace("Bearer ", "") if authorization else None
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
@@ -1090,21 +1223,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
# 尝试从 author.name 和 repository_url 构建标准 ID
|
||||
author_name = None
|
||||
repo_name = None
|
||||
|
||||
|
||||
# 获取作者名
|
||||
if "author" in manifest:
|
||||
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
|
||||
author_name = manifest["author"]["name"]
|
||||
elif isinstance(manifest["author"], str):
|
||||
author_name = manifest["author"]
|
||||
|
||||
|
||||
# 从 repository_url 获取仓库名
|
||||
if "repository_url" in manifest:
|
||||
repo_url = manifest["repository_url"].rstrip("/")
|
||||
if repo_url.endswith(".git"):
|
||||
repo_url = repo_url[:-4]
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
|
||||
|
||||
# 构建 ID
|
||||
if author_name and repo_name:
|
||||
# 标准格式: Author.RepoName
|
||||
@@ -1120,7 +1253,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
else:
|
||||
# 直接使用文件夹名
|
||||
plugin_id = folder_name
|
||||
|
||||
|
||||
# 将推断的 ID 写入 manifest(方便下次识别)
|
||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||
manifest["id"] = plugin_id
|
||||
@@ -1153,3 +1286,459 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
except Exception as e:
|
||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
# ============ 插件配置管理 API ============
|
||||
|
||||
|
||||
class UpdatePluginConfigRequest(BaseModel):
|
||||
"""更新插件配置请求"""
|
||||
|
||||
config: Dict[str, Any] = Field(..., description="配置数据")
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件配置 Schema
|
||||
|
||||
返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。
|
||||
用于前端动态生成配置表单。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 尝试从已加载的插件中获取
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
# 查找插件实例
|
||||
plugin_instance = None
|
||||
|
||||
# 遍历所有已加载的插件
|
||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||
if instance:
|
||||
# 匹配 plugin_name 或 manifest 中的 id
|
||||
if instance.plugin_name == plugin_id:
|
||||
plugin_instance = instance
|
||||
break
|
||||
# 也尝试匹配 manifest 中的 id
|
||||
manifest_id = instance.get_manifest_info("id", "")
|
||||
if manifest_id == plugin_id:
|
||||
plugin_instance = instance
|
||||
break
|
||||
|
||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||
# 从插件实例获取 schema
|
||||
schema = plugin_instance.get_webui_config_schema()
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
# 如果插件未加载,尝试从文件系统读取
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
# 读取配置文件获取当前配置
|
||||
config_path = plugin_path / "config.toml"
|
||||
current_config = {}
|
||||
if config_path.exists():
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
current_config = tomlkit.load(f)
|
||||
|
||||
# 构建基础 schema(无法获取完整的 ConfigField 信息)
|
||||
schema = {
|
||||
"plugin_id": plugin_id,
|
||||
"plugin_info": {
|
||||
"name": plugin_id,
|
||||
"version": "",
|
||||
"description": "",
|
||||
"author": "",
|
||||
},
|
||||
"sections": {},
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
"_note": "插件未加载,仅返回当前配置结构",
|
||||
}
|
||||
|
||||
# 从当前配置推断 schema
|
||||
for section_name, section_data in current_config.items():
|
||||
if isinstance(section_data, dict):
|
||||
schema["sections"][section_name] = {
|
||||
"name": section_name,
|
||||
"title": section_name,
|
||||
"description": None,
|
||||
"icon": None,
|
||||
"collapsed": False,
|
||||
"order": 0,
|
||||
"fields": {},
|
||||
}
|
||||
for field_name, field_value in section_data.items():
|
||||
# 推断字段类型
|
||||
field_type = type(field_value).__name__
|
||||
ui_type = "text"
|
||||
item_type = None
|
||||
item_fields = None
|
||||
|
||||
if isinstance(field_value, bool):
|
||||
ui_type = "switch"
|
||||
elif isinstance(field_value, (int, float)):
|
||||
ui_type = "number"
|
||||
elif isinstance(field_value, list):
|
||||
ui_type = "list"
|
||||
# 推断数组元素类型
|
||||
if field_value:
|
||||
first_item = field_value[0]
|
||||
if isinstance(first_item, dict):
|
||||
item_type = "object"
|
||||
# 从第一个元素推断字段结构
|
||||
item_fields = {}
|
||||
for k, v in first_item.items():
|
||||
item_fields[k] = {
|
||||
"type": "number" if isinstance(v, (int, float)) else "string",
|
||||
"label": k,
|
||||
"default": "" if isinstance(v, str) else 0,
|
||||
}
|
||||
elif isinstance(first_item, (int, float)):
|
||||
item_type = "number"
|
||||
else:
|
||||
item_type = "string"
|
||||
else:
|
||||
item_type = "string"
|
||||
elif isinstance(field_value, dict):
|
||||
ui_type = "json"
|
||||
|
||||
schema["sections"][section_name]["fields"][field_name] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
"default": field_value,
|
||||
"description": field_name,
|
||||
"label": field_name,
|
||||
"ui_type": ui_type,
|
||||
"required": False,
|
||||
"hidden": False,
|
||||
"disabled": False,
|
||||
"order": 0,
|
||||
"item_type": item_type,
|
||||
"item_fields": item_fields,
|
||||
"min_items": None,
|
||||
"max_items": None,
|
||||
# 补充缺失的字段
|
||||
"placeholder": None,
|
||||
"hint": None,
|
||||
"icon": None,
|
||||
"example": None,
|
||||
"choices": None,
|
||||
"min": None,
|
||||
"max": None,
|
||||
"step": None,
|
||||
"pattern": None,
|
||||
"max_length": None,
|
||||
"input_type": None,
|
||||
"rows": 3,
|
||||
"group": None,
|
||||
"depends_on": None,
|
||||
"depends_value": None,
|
||||
}
|
||||
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件当前配置值
|
||||
|
||||
返回插件的当前配置值。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"获取插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
# 读取配置文件
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": dict(config)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.put("/config/{plugin_id}")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginConfigRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件配置
|
||||
|
||||
保存新的配置值到插件的配置文件。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"更新插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
plugin_instance = find_plugin_instance(plugin_id)
|
||||
|
||||
# 纠正 WebUI 提交的数据结构(扁平键与字符串列表)
|
||||
if plugin_instance and isinstance(request.config, dict):
|
||||
request.config = normalize_dotted_keys(request.config)
|
||||
if isinstance(plugin_instance.config_schema, dict):
|
||||
coerce_types(plugin_instance.config_schema, request.config)
|
||||
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
# 备份旧配置
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
if config_path.exists():
|
||||
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.copy(config_path, backup_path)
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
|
||||
# 写入新配置(自动保留注释和格式)
|
||||
save_toml_with_format(request.config, str(config_path))
|
||||
|
||||
logger.info(f"已更新插件配置: {plugin_id}")
|
||||
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
重置插件配置为默认值
|
||||
|
||||
删除当前配置文件,下次加载插件时将使用默认配置。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"重置插件配置: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
if not config_path.exists():
|
||||
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||
|
||||
# 备份并删除
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.move(config_path, backup_path)
|
||||
|
||||
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||
|
||||
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置插件配置失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(
|
||||
plugin_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
切换插件启用状态
|
||||
|
||||
切换插件配置中的 enabled 字段。
|
||||
"""
|
||||
# Token 验证
|
||||
token = get_token_from_cookie_or_header(maibot_session, authorization)
|
||||
token_manager = get_token_manager()
|
||||
if not token or not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||
|
||||
logger.info(f"切换插件状态: {plugin_id}")
|
||||
|
||||
try:
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
if manifest_path.exists():
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
manifest = json.load(f)
|
||||
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||
plugin_path = p
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
import tomlkit
|
||||
|
||||
# 读取当前配置(保留注释和格式)
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 切换 enabled 状态
|
||||
if "plugin" not in config:
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
current_enabled = config["plugin"].get("enabled", True)
|
||||
new_enabled = not current_enabled
|
||||
config["plugin"]["enabled"] = new_enabled
|
||||
|
||||
# 写入配置(保留注释,格式化数组)
|
||||
save_toml_with_format(config, str(config_path))
|
||||
|
||||
status = "启用" if new_enabled else "禁用"
|
||||
logger.info(f"已{status}插件: {plugin_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||
|
||||
@@ -5,14 +5,15 @@
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
@@ -39,22 +40,23 @@ async def restart_maibot():
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
使用 os.execv 重启当前进程,配置更改将在重启后生效。
|
||||
请求重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
try:
|
||||
# 记录重启操作
|
||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||
logger.info("WebUI 触发重启操作")
|
||||
|
||||
# 定义延迟重启的异步任务
|
||||
async def delayed_restart():
|
||||
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
|
||||
python = sys.executable
|
||||
args = [python] + sys.argv
|
||||
os.execv(python, args)
|
||||
|
||||
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
|
||||
# 42 是约定的重启状态码
|
||||
logger.info("WebUI 请求重启,退出代码 42")
|
||||
os._exit(42)
|
||||
|
||||
# 创建后台任务执行重启
|
||||
asyncio.create_task(delayed_restart())
|
||||
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
from .expression_routes import router as expression_router
|
||||
from .jargon_routes import router as jargon_router
|
||||
from .emoji_routes import router as emoji_router
|
||||
from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
@@ -28,6 +30,8 @@ router.include_router(statistics_router)
|
||||
router.include_router(person_router)
|
||||
# 注册表达方式管理路由
|
||||
router.include_router(expression_router)
|
||||
# 注册黑话管理路由
|
||||
router.include_router(jargon_router)
|
||||
# 注册表情包管理路由
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
@@ -51,6 +55,7 @@ class TokenVerifyResponse(BaseModel):
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
is_first_setup: bool = Field(False, description="是否为首次设置")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
@@ -102,22 +107,27 @@ async def health_check():
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(request: TokenVerifyRequest):
|
||||
async def verify_token(request: TokenVerifyRequest, response: Response):
|
||||
"""
|
||||
验证访问令牌
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request: 包含 token 的验证请求
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
验证结果
|
||||
验证结果(包含首次配置状态)
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(request.token)
|
||||
|
||||
if is_valid:
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功")
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, request.token)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
|
||||
except Exception as e:
|
||||
@@ -125,24 +135,86 @@ async def verify_token(request: TokenVerifyRequest):
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
clear_auth_cookie(response)
|
||||
return {"success": True, "message": "已成功登出"}
|
||||
|
||||
|
||||
@router.get("/auth/check")
|
||||
async def check_auth_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
return {"authenticated": False}
|
||||
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
except Exception:
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
response: Response,
|
||||
req: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -151,6 +223,10 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str]
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
# 如果更新成功,清除 Cookie,要求用户重新登录
|
||||
if success:
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -160,22 +236,34 @@ async def update_token(request: TokenUpdateRequest, authorization: Optional[str]
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
async def regenerate_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -184,6 +272,9 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
# 清除 Cookie,要求用户重新登录
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -193,22 +284,32 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
async def get_setup_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -226,22 +327,32 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
async def complete_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
@@ -259,22 +370,32 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(authorization: Optional[str] = Header(None)):
|
||||
async def reset_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
|
||||
@@ -160,13 +160,29 @@ class TokenManager:
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token
|
||||
重新生成 token(保留 first_setup_completed 状态)
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
return self._create_new_token()
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from src.common.logger import get_logger
|
||||
@@ -20,25 +21,39 @@ class WebUIServer:
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
# 配置防爬虫中间件(需要在CORS之前注册)
|
||||
self._setup_anti_crawler()
|
||||
|
||||
|
||||
# 配置 CORS(支持开发环境跨域请求)
|
||||
self._setup_cors()
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
# 注册robots.txt路由
|
||||
self._setup_robots_txt()
|
||||
|
||||
def _setup_cors(self):
|
||||
"""配置 CORS 中间件"""
|
||||
# 开发环境需要允许前端开发服务器的跨域请求
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
def _show_access_token(self):
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
@@ -75,7 +90,7 @@ class WebUIServer:
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
@@ -88,62 +103,22 @@ class WebUIServer:
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
def _setup_anti_crawler(self):
|
||||
"""配置防爬虫中间件"""
|
||||
try:
|
||||
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||
|
||||
# 从环境变量读取防爬虫模式(false/strict/loose/standard)
|
||||
anti_crawler_mode = os.getenv("WEBUI_ANTI_CRAWLER_MODE", "standard").lower()
|
||||
|
||||
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||
self.app.add_middleware(
|
||||
AntiCrawlerMiddleware,
|
||||
mode=anti_crawler_mode
|
||||
)
|
||||
|
||||
mode_descriptions = {
|
||||
"false": "已禁用",
|
||||
"strict": "严格模式",
|
||||
"loose": "宽松模式",
|
||||
"standard": "标准模式"
|
||||
}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, "标准模式")
|
||||
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
|
||||
|
||||
def _setup_robots_txt(self):
|
||||
"""设置robots.txt路由"""
|
||||
try:
|
||||
from src.webui.anti_crawler import create_robots_txt_response
|
||||
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回robots.txt,禁止所有爬虫"""
|
||||
return create_robots_txt_response()
|
||||
|
||||
logger.debug("✅ robots.txt 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||
|
||||
def _register_api_routes(self):
|
||||
"""注册所有 WebUI API 路由"""
|
||||
try:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
|
||||
logger.info("开始导入 knowledge_routes...")
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
logger.info("knowledge_routes 导入成功")
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
self.app.include_router(knowledge_router)
|
||||
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
|
||||
self.app.include_router(chat_router)
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
@@ -151,6 +126,16 @@ class WebUIServer:
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
# 预先检查端口是否可用
|
||||
if not self._check_port_available():
|
||||
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以通过环境变量 WEBUI_PORT 修改 WebUI 端口")
|
||||
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
|
||||
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
|
||||
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
|
||||
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self.host,
|
||||
@@ -162,12 +147,36 @@ class WebUIServer:
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"本机访问请使用 http://localhost:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}")
|
||||
except OSError as e:
|
||||
# 处理端口绑定相关的错误
|
||||
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
|
||||
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以通过环境变量 WEBUI_PORT 修改 WebUI 端口")
|
||||
else:
|
||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _check_port_available(self) -> bool:
|
||||
"""检查端口是否可用"""
|
||||
import socket
|
||||
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
# 尝试绑定端口
|
||||
s.bind((self.host if self.host != "0.0.0.0" else "127.0.0.1", self.port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
|
||||
Reference in New Issue
Block a user