Ruff Fix & format
This commit is contained in:
@@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
@@ -36,17 +37,17 @@ class ChatHistoryMessage(BaseModel):
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
|
||||
def _message_to_dict(self, msg: Messages) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式"""
|
||||
# 判断是否是机器人消息
|
||||
# WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
@@ -56,7 +57,7 @@ class ChatHistoryManager:
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
|
||||
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录"""
|
||||
try:
|
||||
@@ -67,25 +68,21 @@ class ChatHistoryManager:
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
result = [self._message_to_dict(msg) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def clear_history(self) -> int:
|
||||
"""清空 WebUI 聊天历史记录"""
|
||||
try:
|
||||
deleted = (
|
||||
Messages.delete()
|
||||
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
|
||||
.execute()
|
||||
)
|
||||
deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
|
||||
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
@@ -100,31 +97,31 @@ chat_history = ChatHistoryManager()
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
|
||||
async def send_message(self, session_id: str, message: dict):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
@@ -135,16 +132,12 @@ chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True
|
||||
content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
@@ -163,7 +156,7 @@ def create_message_data(
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
}
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
@@ -175,8 +168,8 @@ def create_message_data(
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
}
|
||||
]
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
@@ -186,10 +179,10 @@ def create_message_data(
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
"""
|
||||
history = chat_history.get_history(limit)
|
||||
@@ -217,76 +210,87 @@ async def websocket_chat(
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
"""
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
|
||||
try:
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
})
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送历史记录
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
})
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "system",
|
||||
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast({
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": current_user_name,
|
||||
"user_id": user_id,
|
||||
"is_bot": False,
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": current_user_name,
|
||||
"user_id": user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
@@ -295,46 +299,59 @@ async def websocket_chat(
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast({
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
})
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast({
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
})
|
||||
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(session_id, {
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
@@ -356,7 +373,7 @@ async def get_chat_info():
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from typing import Any
|
||||
from typing import Any, Annotated
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
@@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@@ -90,7 +96,7 @@ async def get_bot_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
@@ -101,7 +107,7 @@ async def get_model_config_schema():
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
@@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str):
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
@@ -196,7 +202,7 @@ async def get_bot_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
@@ -215,21 +221,21 @@ async def get_model_config():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_bot_config(config_data: ConfigBody):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
async def update_model_config(config_data: ConfigBody):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
@@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
@@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
@@ -338,24 +344,24 @@ async def get_bot_config_raw():
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
@@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
||||
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
@@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(tomlkit.dump 会保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
@@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
@@ -425,11 +431,11 @@ def _normalize_adapter_path(path: str) -> str:
|
||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
|
||||
# 如果已经是绝对路径,直接返回
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
@@ -438,17 +444,17 @@ def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
return path
|
||||
|
||||
|
||||
try:
|
||||
# 尝试获取相对路径
|
||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||
if not rel_path.startswith('..'):
|
||||
if not rel_path.startswith(".."):
|
||||
return rel_path
|
||||
except (ValueError, TypeError):
|
||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||
pass
|
||||
|
||||
|
||||
# 无法转换为相对路径,返回绝对路径
|
||||
return path
|
||||
|
||||
@@ -463,6 +469,7 @@ async def get_adapter_config_path():
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
@@ -472,10 +479,11 @@ async def get_adapter_config_path():
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
import datetime
|
||||
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
# 返回相对路径(如果可能)
|
||||
@@ -487,11 +495,11 @@ async def get_adapter_config_path():
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config_path(data: PathBody):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -511,10 +519,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = save_path
|
||||
|
||||
@@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
@@ -542,7 +550,7 @@ async def get_adapter_config(path: str):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
@@ -562,11 +570,11 @@ async def get_adapter_config(path: str):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
async def save_adapter_config(data: PathBody):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
@@ -579,17 +587,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
import tomlkit
|
||||
tomlkit.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(abs_path)
|
||||
@@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||
|
||||
@@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
@@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
@@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool or field_type == bool:
|
||||
if field_type is bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int or field_type == int:
|
||||
elif field_type is int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float or field_type == float:
|
||||
elif field_type is float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str or field_type == str:
|
||||
elif field_type is str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
|
||||
@@ -3,20 +3,25 @@
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Annotated
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Emoji
|
||||
from .token_manager import get_token_manager
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import hashlib
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
logger = get_logger("webui.emoji")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||
|
||||
@@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel):
|
||||
|
||||
@router.post("/upload", response_model=EmojiUploadResponse)
|
||||
async def upload_emoji(
|
||||
file: UploadFile = File(..., description="表情包图片文件"),
|
||||
description: str = Form("", description="表情包描述"),
|
||||
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
|
||||
is_registered: bool = Form(True, description="是否直接注册"),
|
||||
file: EmojiFile,
|
||||
description: DescriptionForm = "",
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -713,9 +718,9 @@ async def upload_emoji(
|
||||
|
||||
@router.post("/batch/upload")
|
||||
async def batch_upload_emoji(
|
||||
files: List[UploadFile] = File(..., description="多个表情包图片文件"),
|
||||
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
|
||||
is_registered: bool = Form(True, description="是否直接注册"),
|
||||
files: EmojiFiles,
|
||||
emotion: EmotionForm = "",
|
||||
is_registered: IsRegisteredForm = True,
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
@@ -749,11 +754,13 @@ async def batch_upload_emoji(
|
||||
# 验证文件类型
|
||||
if file.content_type not in allowed_types:
|
||||
results["failed"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"不支持的文件类型: {file.content_type}",
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"不支持的文件类型: {file.content_type}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 读取文件内容
|
||||
@@ -761,11 +768,13 @@ async def batch_upload_emoji(
|
||||
|
||||
if not file_content:
|
||||
results["failed"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "文件内容为空",
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "文件内容为空",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 验证图片
|
||||
@@ -774,11 +783,13 @@ async def batch_upload_emoji(
|
||||
img_format = img.format.lower() if img.format else "png"
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"无效的图片: {str(e)}",
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": f"无效的图片: {str(e)}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 计算哈希
|
||||
@@ -787,11 +798,13 @@ async def batch_upload_emoji(
|
||||
# 检查重复
|
||||
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||
results["failed"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": "已存在相同的表情包",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成文件名并保存
|
||||
@@ -829,19 +842,23 @@ async def batch_upload_emoji(
|
||||
)
|
||||
|
||||
results["uploaded"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": True,
|
||||
"id": emoji.id,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
results["failed"] += 1
|
||||
results["details"].append({
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
})
|
||||
results["details"].append(
|
||||
{
|
||||
"filename": file.filename,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||
return results
|
||||
@@ -850,4 +867,4 @@ async def batch_upload_emoji(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量上传表情包失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
|
||||
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
|
||||
|
||||
@@ -602,9 +602,9 @@ class GitMirrorService:
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone():
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query
|
||||
from pydantic import BaseModel
|
||||
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
id: str
|
||||
type: str # 'entity' or 'paragraph'
|
||||
content: str
|
||||
@@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
|
||||
|
||||
class KnowledgeEdge(BaseModel):
|
||||
"""知识边"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
weight: float
|
||||
@@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""知识图谱"""
|
||||
|
||||
nodes: List[KnowledgeNode]
|
||||
edges: List[KnowledgeEdge]
|
||||
|
||||
|
||||
class KnowledgeStats(BaseModel):
|
||||
"""知识库统计信息"""
|
||||
|
||||
total_nodes: int
|
||||
total_edges: int
|
||||
entity_nodes: int
|
||||
@@ -45,7 +50,7 @@ def _load_kg_manager():
|
||||
"""延迟加载 KGManager"""
|
||||
try:
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
|
||||
|
||||
kg_manager = KGManager()
|
||||
kg_manager.load_from_file()
|
||||
return kg_manager
|
||||
@@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
"""将 DiGraph 转换为 JSON 格式"""
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
|
||||
# 转换节点
|
||||
node_list = graph.get_node_list()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 转换边
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
@@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 通过 graph[source, target] 获取边的属性数据
|
||||
edge_data = graph[source, target]
|
||||
|
||||
|
||||
# edge_data 支持 [] 操作符但不支持 .get()
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.get("/graph", response_model=KnowledgeGraph)
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
|
||||
Args:
|
||||
limit: 返回的最大节点数,默认 100,最大 10000
|
||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||
"""
|
||||
@@ -130,46 +128,43 @@ async def get_knowledge_graph(
|
||||
if kg_manager is None:
|
||||
logger.warning("KGManager 未初始化,返回空图谱")
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
all_node_list = graph.get_node_list()
|
||||
|
||||
|
||||
# 按类型过滤节点
|
||||
if node_type == "entity":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
|
||||
all_node_list = [
|
||||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||
]
|
||||
elif node_type == "paragraph":
|
||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
|
||||
|
||||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||
|
||||
# 限制节点数量
|
||||
total_nodes = len(all_node_list)
|
||||
if len(all_node_list) > limit:
|
||||
node_list = all_node_list[:limit]
|
||||
else:
|
||||
node_list = all_node_list
|
||||
|
||||
|
||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||
|
||||
|
||||
# 转换节点
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type_val,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
node_ids.add(node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||
edges = []
|
||||
edge_list = graph.get_edge_list()
|
||||
@@ -179,27 +174,25 @@ async def get_knowledge_graph(
|
||||
# 只包含两端都在当前节点集中的边
|
||||
if source not in node_ids or target not in node_ids:
|
||||
continue
|
||||
|
||||
|
||||
edge_data = graph[source, target]
|
||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
||||
|
||||
edges.append(KnowledgeEdge(
|
||||
source=source,
|
||||
target=target,
|
||||
weight=weight,
|
||||
create_time=create_time,
|
||||
update_time=update_time
|
||||
))
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
return graph_data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
@@ -208,71 +201,59 @@ async def get_knowledge_graph(
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats():
|
||||
"""获取知识库统计信息
|
||||
|
||||
|
||||
Returns:
|
||||
KnowledgeStats: 统计信息
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
edge_list = graph.get_edge_list()
|
||||
|
||||
|
||||
total_nodes = len(node_list)
|
||||
total_edges = len(edge_list)
|
||||
|
||||
|
||||
# 统计节点类型
|
||||
entity_nodes = 0
|
||||
paragraph_nodes = 0
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = node_data['type'] if 'type' in node_data else 'ent'
|
||||
if node_type == 'ent':
|
||||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||
if node_type == "ent":
|
||||
entity_nodes += 1
|
||||
elif node_type == 'pg':
|
||||
elif node_type == "pg":
|
||||
paragraph_nodes += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
# 计算平均连接数
|
||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||
|
||||
|
||||
return KnowledgeStats(
|
||||
total_nodes=total_nodes,
|
||||
total_edges=total_edges,
|
||||
entity_nodes=entity_nodes,
|
||||
paragraph_nodes=paragraph_nodes,
|
||||
avg_connections=round(avg_connections, 2)
|
||||
avg_connections=round(avg_connections, 2),
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||
return KnowledgeStats(
|
||||
total_nodes=0,
|
||||
total_edges=0,
|
||||
entity_nodes=0,
|
||||
paragraph_nodes=0,
|
||||
avg_connections=0.0
|
||||
)
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
"""搜索知识节点
|
||||
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
|
||||
Returns:
|
||||
List[KnowledgeNode]: 匹配的节点列表
|
||||
"""
|
||||
@@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return []
|
||||
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
|
||||
# 在节点内容中搜索
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
content = node_data['content'] if 'content' in node_data else node_id
|
||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
||||
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
||||
results.append(KnowledgeNode(
|
||||
id=node_id,
|
||||
type=node_type,
|
||||
content=content,
|
||||
create_time=create_time
|
||||
))
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||
return results[:50] # 限制返回数量
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
@@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append({
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
|
||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||
"""
|
||||
models = []
|
||||
@@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
model_id = model["name"]
|
||||
if model_id.startswith("models/"):
|
||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||
models.append({
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
})
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
@@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.TimeoutException:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||
if e.response.status_code == 401:
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||
elif e.response.status_code == 403:
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||
elif e.response.status_code == 404:
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
)
|
||||
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||
|
||||
# 根据解析器类型解析响应
|
||||
if parser == "openai":
|
||||
return _parse_openai_response(data)
|
||||
@@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
|
||||
Returns:
|
||||
提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
@@ -184,23 +187,23 @@ async def get_provider_models(
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
if not provider_config:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key")
|
||||
client_type = provider_config.get("client_type", "openai")
|
||||
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
@@ -209,7 +212,7 @@ async def get_provider_models(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
@@ -236,7 +239,7 @@ async def get_models_by_url(
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
@@ -251,11 +254,11 @@ async def test_provider_connection(
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
||||
|
||||
分两步测试:
|
||||
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||
|
||||
|
||||
返回:
|
||||
- network_ok: 网络是否连通
|
||||
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||
@@ -263,11 +266,11 @@ async def test_provider_connection(
|
||||
- error: 错误信息(如果有)
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
base_url = _normalize_url(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
@@ -275,7 +278,7 @@ async def test_provider_connection(
|
||||
"error": None,
|
||||
"http_status": None,
|
||||
}
|
||||
|
||||
|
||||
# 第一步:测试网络连通性
|
||||
try:
|
||||
start_time = time.time()
|
||||
@@ -283,11 +286,11 @@ async def test_provider_connection(
|
||||
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||
response = await client.get(base_url)
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
result["network_ok"] = True
|
||||
result["latency_ms"] = round(latency, 2)
|
||||
result["http_status"] = response.status_code
|
||||
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||
return result
|
||||
@@ -300,7 +303,7 @@ async def test_provider_connection(
|
||||
except Exception as e:
|
||||
result["error"] = f"未知错误:{str(e)}"
|
||||
return result
|
||||
|
||||
|
||||
# 第二步:如果提供了 API Key,验证其有效性
|
||||
if api_key:
|
||||
try:
|
||||
@@ -313,7 +316,7 @@ async def test_provider_connection(
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
elif response.status_code in (401, 403):
|
||||
@@ -322,12 +325,12 @@ async def test_provider_connection(
|
||||
else:
|
||||
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||
result["api_key_valid"] = None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# API Key 验证失败不影响网络连通性结果
|
||||
logger.warning(f"API Key 验证失败: {e}")
|
||||
result["api_key_valid"] = None
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -342,10 +345,10 @@ async def test_provider_connection_by_name(
|
||||
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(model_config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
@@ -353,15 +356,15 @@ async def test_provider_connection_by_name(
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
|
||||
@@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
"""
|
||||
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
||||
import re
|
||||
|
||||
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
||||
base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0]
|
||||
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||
|
||||
parts = base_version.split(".")
|
||||
if len(parts) < 3:
|
||||
@@ -613,7 +614,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
||||
for field in required_fields:
|
||||
if field not in manifest:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
||||
|
||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||
manifest["id"] = request.plugin_id
|
||||
@@ -705,7 +706,7 @@ async def uninstall_plugin(
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -839,7 +840,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
||||
plugin_path = plugins_dir / folder_name
|
||||
# 旧格式:点
|
||||
old_format_path = plugins_dir / request.plugin_id
|
||||
|
||||
|
||||
# 优先使用新格式,如果不存在则尝试旧格式
|
||||
if not plugin_path.exists():
|
||||
if old_format_path.exists():
|
||||
@@ -1092,21 +1093,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
# 尝试从 author.name 和 repository_url 构建标准 ID
|
||||
author_name = None
|
||||
repo_name = None
|
||||
|
||||
|
||||
# 获取作者名
|
||||
if "author" in manifest:
|
||||
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
|
||||
author_name = manifest["author"]["name"]
|
||||
elif isinstance(manifest["author"], str):
|
||||
author_name = manifest["author"]
|
||||
|
||||
|
||||
# 从 repository_url 获取仓库名
|
||||
if "repository_url" in manifest:
|
||||
repo_url = manifest["repository_url"].rstrip("/")
|
||||
if repo_url.endswith(".git"):
|
||||
repo_url = repo_url[:-4]
|
||||
repo_name = repo_url.split("/")[-1]
|
||||
|
||||
|
||||
# 构建 ID
|
||||
if author_name and repo_name:
|
||||
# 标准格式: Author.RepoName
|
||||
@@ -1122,7 +1123,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
||||
else:
|
||||
# 直接使用文件夹名
|
||||
plugin_id = folder_name
|
||||
|
||||
|
||||
# 将推断的 ID 写入 manifest(方便下次识别)
|
||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||
manifest["id"] = plugin_id
|
||||
@@ -1167,12 +1168,10 @@ class UpdatePluginConfigRequest(BaseModel):
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}/schema")
|
||||
async def get_plugin_config_schema(
|
||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件配置 Schema
|
||||
|
||||
|
||||
返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。
|
||||
用于前端动态生成配置表单。
|
||||
"""
|
||||
@@ -1187,10 +1186,10 @@ async def get_plugin_config_schema(
|
||||
try:
|
||||
# 尝试从已加载的插件中获取
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
|
||||
# 查找插件实例
|
||||
plugin_instance = None
|
||||
|
||||
|
||||
# 遍历所有已加载的插件
|
||||
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||
@@ -1204,17 +1203,17 @@ async def get_plugin_config_schema(
|
||||
if manifest_id == plugin_id:
|
||||
plugin_instance = instance
|
||||
break
|
||||
|
||||
if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'):
|
||||
|
||||
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||
# 从插件实例获取 schema
|
||||
schema = plugin_instance.get_webui_config_schema()
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
|
||||
# 如果插件未加载,尝试从文件系统读取
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
@@ -1227,18 +1226,19 @@ async def get_plugin_config_schema(
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
|
||||
# 读取配置文件获取当前配置
|
||||
config_path = plugin_path / "config.toml"
|
||||
current_config = {}
|
||||
if config_path.exists():
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
current_config = tomlkit.load(f)
|
||||
|
||||
|
||||
# 构建基础 schema(无法获取完整的 ConfigField 信息)
|
||||
schema = {
|
||||
"plugin_id": plugin_id,
|
||||
@@ -1252,7 +1252,7 @@ async def get_plugin_config_schema(
|
||||
"layout": {"type": "auto", "tabs": []},
|
||||
"_note": "插件未加载,仅返回当前配置结构",
|
||||
}
|
||||
|
||||
|
||||
# 从当前配置推断 schema
|
||||
for section_name, section_data in current_config.items():
|
||||
if isinstance(section_data, dict):
|
||||
@@ -1277,7 +1277,7 @@ async def get_plugin_config_schema(
|
||||
ui_type = "list"
|
||||
elif isinstance(field_value, dict):
|
||||
ui_type = "json"
|
||||
|
||||
|
||||
schema["sections"][section_name]["fields"][field_name] = {
|
||||
"name": field_name,
|
||||
"type": field_type,
|
||||
@@ -1290,7 +1290,7 @@ async def get_plugin_config_schema(
|
||||
"disabled": False,
|
||||
"order": 0,
|
||||
}
|
||||
|
||||
|
||||
return {"success": True, "schema": schema}
|
||||
|
||||
except HTTPException:
|
||||
@@ -1301,12 +1301,10 @@ async def get_plugin_config_schema(
|
||||
|
||||
|
||||
@router.get("/config/{plugin_id}")
|
||||
async def get_plugin_config(
|
||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
获取插件当前配置值
|
||||
|
||||
|
||||
返回插件的当前配置值。
|
||||
"""
|
||||
# Token 验证
|
||||
@@ -1321,7 +1319,7 @@ async def get_plugin_config(
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
@@ -1334,19 +1332,20 @@ async def get_plugin_config(
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
|
||||
# 读取配置文件
|
||||
config_path = plugin_path / "config.toml"
|
||||
if not config_path.exists():
|
||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||
|
||||
|
||||
import tomlkit
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
|
||||
return {"success": True, "config": dict(config)}
|
||||
|
||||
except HTTPException:
|
||||
@@ -1358,13 +1357,11 @@ async def get_plugin_config(
|
||||
|
||||
@router.put("/config/{plugin_id}")
|
||||
async def update_plugin_config(
|
||||
plugin_id: str,
|
||||
request: UpdatePluginConfigRequest,
|
||||
authorization: Optional[str] = Header(None)
|
||||
plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新插件配置
|
||||
|
||||
|
||||
保存新的配置值到插件的配置文件。
|
||||
"""
|
||||
# Token 验证
|
||||
@@ -1379,7 +1376,7 @@ async def update_plugin_config(
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
@@ -1392,23 +1389,25 @@ async def update_plugin_config(
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
|
||||
# 备份旧配置
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
if config_path.exists():
|
||||
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.copy(config_path, backup_path)
|
||||
logger.info(f"已备份配置文件: {backup_path}")
|
||||
|
||||
|
||||
# 写入新配置(使用 tomlkit 保留注释)
|
||||
import tomlkit
|
||||
|
||||
# 先读取原配置以保留注释和格式
|
||||
existing_doc = tomlkit.document()
|
||||
if config_path.exists():
|
||||
@@ -1419,14 +1418,10 @@ async def update_plugin_config(
|
||||
existing_doc[key] = value
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(existing_doc, f)
|
||||
|
||||
|
||||
logger.info(f"已更新插件配置: {plugin_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已保存",
|
||||
"note": "配置更改将在插件重新加载后生效"
|
||||
}
|
||||
|
||||
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1436,12 +1431,10 @@ async def update_plugin_config(
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/reset")
|
||||
async def reset_plugin_config(
|
||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
重置插件配置为默认值
|
||||
|
||||
|
||||
删除当前配置文件,下次加载插件时将使用默认配置。
|
||||
"""
|
||||
# Token 验证
|
||||
@@ -1456,7 +1449,7 @@ async def reset_plugin_config(
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
@@ -1469,29 +1462,26 @@ async def reset_plugin_config(
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
|
||||
if not config_path.exists():
|
||||
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||
|
||||
|
||||
# 备份并删除
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||
backup_path = plugin_path / backup_name
|
||||
shutil.move(config_path, backup_path)
|
||||
|
||||
|
||||
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "配置已重置,下次加载插件时将使用默认配置",
|
||||
"backup": str(backup_path)
|
||||
}
|
||||
|
||||
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1501,12 +1491,10 @@ async def reset_plugin_config(
|
||||
|
||||
|
||||
@router.post("/config/{plugin_id}/toggle")
|
||||
async def toggle_plugin(
|
||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
||||
) -> Dict[str, Any]:
|
||||
async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||
"""
|
||||
切换插件启用状态
|
||||
|
||||
|
||||
切换插件配置中的 enabled 字段。
|
||||
"""
|
||||
# Token 验证
|
||||
@@ -1521,7 +1509,7 @@ async def toggle_plugin(
|
||||
# 查找插件目录
|
||||
plugins_dir = Path("plugins")
|
||||
plugin_path = None
|
||||
|
||||
|
||||
for p in plugins_dir.iterdir():
|
||||
if p.is_dir():
|
||||
manifest_path = p / "_manifest.json"
|
||||
@@ -1534,40 +1522,40 @@ async def toggle_plugin(
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if not plugin_path:
|
||||
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||
|
||||
|
||||
config_path = plugin_path / "config.toml"
|
||||
|
||||
|
||||
import tomlkit
|
||||
|
||||
|
||||
# 读取当前配置(保留注释和格式)
|
||||
config = tomlkit.document()
|
||||
if config_path.exists():
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
|
||||
# 切换 enabled 状态
|
||||
if "plugin" not in config:
|
||||
config["plugin"] = tomlkit.table()
|
||||
|
||||
|
||||
current_enabled = config["plugin"].get("enabled", True)
|
||||
new_enabled = not current_enabled
|
||||
config["plugin"]["enabled"] = new_enabled
|
||||
|
||||
|
||||
# 写入配置(保留注释)
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
tomlkit.dump(config, f)
|
||||
|
||||
|
||||
status = "启用" if new_enabled else "禁用"
|
||||
logger.info(f"已{status}插件: {plugin_id}")
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"enabled": new_enabled,
|
||||
"message": f"插件已{status}",
|
||||
"note": "状态更改将在下次加载插件时生效"
|
||||
"note": "状态更改将在下次加载插件时生效",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@@ -43,7 +43,7 @@ async def restart_maibot():
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
try:
|
||||
# 记录重启操作
|
||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||
@@ -54,7 +54,7 @@ async def restart_maibot():
|
||||
python = sys.executable
|
||||
args = [python] + sys.argv
|
||||
os.execv(python, args)
|
||||
|
||||
|
||||
# 创建后台任务执行重启
|
||||
asyncio.create_task(delayed_restart())
|
||||
|
||||
|
||||
@@ -20,10 +20,10 @@ class WebUIServer:
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
@@ -32,7 +32,7 @@ class WebUIServer:
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
@@ -69,7 +69,7 @@ class WebUIServer:
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
@@ -88,13 +88,15 @@ class WebUIServer:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
|
||||
|
||||
logger.info("开始导入 knowledge_routes...")
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
logger.info("knowledge_routes 导入成功")
|
||||
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
logger.info("chat_routes 导入成功")
|
||||
|
||||
# 注册路由
|
||||
|
||||
Reference in New Issue
Block a user