Merge branch 'dev' of https://github.com/Mai-with-u/MaiBot into dev
This commit is contained in:
7
bot.py
7
bot.py
@@ -78,6 +78,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
|||||||
# 关闭 WebUI 服务器
|
# 关闭 WebUI 服务器
|
||||||
try:
|
try:
|
||||||
from src.webui.webui_server import get_webui_server
|
from src.webui.webui_server import get_webui_server
|
||||||
|
|
||||||
webui_server = get_webui_server()
|
webui_server = get_webui_server()
|
||||||
if webui_server and webui_server._server:
|
if webui_server and webui_server._server:
|
||||||
await webui_server.shutdown()
|
await webui_server.shutdown()
|
||||||
@@ -236,15 +237,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
|
|
||||||
# 取消主任务
|
# 取消主任务
|
||||||
if 'main_tasks' in locals() and main_tasks and not main_tasks.done():
|
if "main_tasks" in locals() and main_tasks and not main_tasks.done():
|
||||||
main_tasks.cancel()
|
main_tasks.cancel()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(main_tasks)
|
loop.run_until_complete(main_tasks)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 执行优雅关闭
|
# 执行优雅关闭
|
||||||
if loop and not loop.is_closed():
|
if loop and not loop.is_closed():
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ pyarrow>=20.0.0
|
|||||||
pydantic>=2.11.7
|
pydantic>=2.11.7
|
||||||
pypinyin>=0.54.0
|
pypinyin>=0.54.0
|
||||||
python-dotenv>=1.1.1
|
python-dotenv>=1.1.1
|
||||||
|
python-multipart>=0.0.20
|
||||||
quick-algo>=0.1.3
|
quick-algo>=0.1.3
|
||||||
rich>=14.0.0
|
rich>=14.0.0
|
||||||
ruff>=0.12.2
|
ruff>=0.12.2
|
||||||
|
|||||||
@@ -235,13 +235,13 @@ class BrainChatting:
|
|||||||
if recent_messages_list is None:
|
if recent_messages_list is None:
|
||||||
recent_messages_list = []
|
recent_messages_list = []
|
||||||
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
_reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# ReflectTracker Check
|
# ReflectTracker Check
|
||||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
from src.express.reflect_tracker import reflect_tracker_manager
|
from src.express.reflect_tracker import reflect_tracker_manager
|
||||||
|
|
||||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||||
if tracker:
|
if tracker:
|
||||||
resolved = await tracker.trigger_tracker()
|
resolved = await tracker.trigger_tracker()
|
||||||
@@ -254,6 +254,7 @@ class BrainChatting:
|
|||||||
# 检查是否需要提问表达反思
|
# 检查是否需要提问表达反思
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
from src.express.expression_reflector import expression_reflector_manager
|
from src.express.expression_reflector import expression_reflector_manager
|
||||||
|
|
||||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||||
asyncio.create_task(reflector.check_and_ask())
|
asyncio.create_task(reflector.check_and_ask())
|
||||||
|
|
||||||
|
|||||||
@@ -400,7 +400,7 @@ class HeartFChatting:
|
|||||||
# ReflectTracker Check
|
# ReflectTracker Check
|
||||||
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||||
await reflector.check_and_ask()
|
await reflector.check_and_ask()
|
||||||
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
|
||||||
@@ -410,7 +410,6 @@ class HeartFChatting:
|
|||||||
reflect_tracker_manager.remove_tracker(self.stream_id)
|
reflect_tracker_manager.remove_tracker(self.stream_id)
|
||||||
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
|
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
|
||||||
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||||
@@ -427,7 +426,9 @@ class HeartFChatting:
|
|||||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
||||||
|
)
|
||||||
|
|
||||||
# 第一步:动作检查
|
# 第一步:动作检查
|
||||||
available_actions: Dict[str, ActionInfo] = {}
|
available_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ class HeartFCMessageReceiver:
|
|||||||
message_data: 原始消息字符串
|
message_data: 原始消息字符串
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# 通知消息不处理
|
||||||
|
if message.is_notify:
|
||||||
|
logger.debug("通知消息,跳过处理")
|
||||||
|
return
|
||||||
|
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
chat = message.chat_stream
|
chat = message.chat_stream
|
||||||
|
|||||||
@@ -33,6 +33,11 @@ class MessageStorage:
|
|||||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
|
# 通知消息不存储
|
||||||
|
if isinstance(message, MessageRecv) and message.is_notify:
|
||||||
|
logger.debug("通知消息,跳过存储")
|
||||||
|
return
|
||||||
|
|
||||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||||
|
|
||||||
# print(message)
|
# print(message)
|
||||||
|
|||||||
@@ -15,12 +15,57 @@ install(extra_lines=3)
|
|||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|
||||||
|
# WebUI 聊天室的消息广播器(延迟导入避免循环依赖)
|
||||||
|
_webui_chat_broadcaster = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_webui_chat_broadcaster():
|
||||||
|
"""获取 WebUI 聊天室广播器"""
|
||||||
|
global _webui_chat_broadcaster
|
||||||
|
if _webui_chat_broadcaster is None:
|
||||||
|
try:
|
||||||
|
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
||||||
|
|
||||||
|
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||||
|
except ImportError:
|
||||||
|
_webui_chat_broadcaster = (None, None)
|
||||||
|
return _webui_chat_broadcaster
|
||||||
|
|
||||||
|
|
||||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||||
|
platform = message.message_info.platform
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 检查是否是 WebUI 平台的消息
|
||||||
|
chat_manager, webui_platform = get_webui_chat_broadcaster()
|
||||||
|
if platform == webui_platform and chat_manager is not None:
|
||||||
|
# WebUI 聊天室消息,通过 WebSocket 广播
|
||||||
|
import time
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
|
"type": "bot_message",
|
||||||
|
"content": message.processed_plain_text,
|
||||||
|
"message_type": "text",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"sender": {
|
||||||
|
"name": global_config.bot.nickname,
|
||||||
|
"avatar": None,
|
||||||
|
"is_bot": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||||
|
# 无需手动保存
|
||||||
|
|
||||||
|
if show_log:
|
||||||
|
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
|
||||||
|
return True
|
||||||
|
|
||||||
# 直接调用API发送消息
|
# 直接调用API发送消息
|
||||||
await get_global_api().send_message(message)
|
await get_global_api().send_message(message)
|
||||||
if show_log:
|
if show_log:
|
||||||
|
|||||||
@@ -181,8 +181,12 @@ class ActionPlanner:
|
|||||||
found_ids = set(matches)
|
found_ids = set(matches)
|
||||||
missing_ids = found_ids - available_ids
|
missing_ids = found_ids - available_ids
|
||||||
if missing_ids:
|
if missing_ids:
|
||||||
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
logger.info(
|
||||||
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中"
|
||||||
|
)
|
||||||
|
|
||||||
def _replace(match: re.Match[str]) -> str:
|
def _replace(match: re.Match[str]) -> str:
|
||||||
msg_id = match.group(0)
|
msg_id = match.group(0)
|
||||||
@@ -234,17 +238,11 @@ class ActionPlanner:
|
|||||||
target_message = message_id_list[-1][1]
|
target_message = message_id_list[-1][1]
|
||||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||||
|
|
||||||
if (
|
if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
|
||||||
action != "no_reply"
|
|
||||||
and target_message is not None
|
|
||||||
and self._is_message_from_self(target_message)
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
||||||
)
|
)
|
||||||
reasoning = (
|
reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
||||||
f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
|
||||||
)
|
|
||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
target_message = None
|
target_message = None
|
||||||
|
|
||||||
@@ -295,10 +293,9 @@ class ActionPlanner:
|
|||||||
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
||||||
"""判断消息是否由机器人自身发送"""
|
"""判断消息是否由机器人自身发送"""
|
||||||
try:
|
try:
|
||||||
return (
|
return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
|
||||||
str(message.user_info.user_id) == str(global_config.bot.qq_account)
|
message.user_info.platform or ""
|
||||||
and (message.user_info.platform or "") == (global_config.bot.platform or "")
|
) == (global_config.bot.platform or "")
|
||||||
)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||||
return False
|
return False
|
||||||
@@ -780,20 +777,20 @@ class ActionPlanner:
|
|||||||
json_content_start = json_start_pos + 7 # ```json的长度
|
json_content_start = json_start_pos + 7 # ```json的长度
|
||||||
# 提取从```json之后到内容结尾的所有内容
|
# 提取从```json之后到内容结尾的所有内容
|
||||||
incomplete_json_str = content[json_content_start:].strip()
|
incomplete_json_str = content[json_content_start:].strip()
|
||||||
|
|
||||||
# 提取JSON之前的内容作为推理文本
|
# 提取JSON之前的内容作为推理文本
|
||||||
if json_start_pos > 0:
|
if json_start_pos > 0:
|
||||||
reasoning_content = content[:json_start_pos].strip()
|
reasoning_content = content[:json_start_pos].strip()
|
||||||
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
|
||||||
reasoning_content = reasoning_content.strip()
|
reasoning_content = reasoning_content.strip()
|
||||||
|
|
||||||
if incomplete_json_str:
|
if incomplete_json_str:
|
||||||
try:
|
try:
|
||||||
# 清理可能的注释和格式问题
|
# 清理可能的注释和格式问题
|
||||||
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
|
||||||
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
|
||||||
json_str = json_str.strip()
|
json_str = json_str.strip()
|
||||||
|
|
||||||
if json_str:
|
if json_str:
|
||||||
# 尝试按行分割,每行可能是一个JSON对象
|
# 尝试按行分割,每行可能是一个JSON对象
|
||||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||||
@@ -808,7 +805,7 @@ class ActionPlanner:
|
|||||||
json_objects.append(item)
|
json_objects.append(item)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
# 如果按行解析没有成功,尝试将整个块作为一个JSON对象或数组
|
||||||
if not json_objects:
|
if not json_objects:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -959,7 +959,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: b
|
|||||||
header = f"[{i + 1}] {anon_name}说 "
|
header = f"[{i + 1}] {anon_name}说 "
|
||||||
else:
|
else:
|
||||||
header = f"{anon_name}说 "
|
header = f"{anon_name}说 "
|
||||||
|
|
||||||
output_lines.append(header)
|
output_lines.append(header)
|
||||||
stripped_line = content.strip()
|
stripped_line = content.strip()
|
||||||
if stripped_line:
|
if stripped_line:
|
||||||
|
|||||||
67
src/common/toml_utils.py
Normal file
67
src/common/toml_utils.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
TOML 工具函数
|
||||||
|
|
||||||
|
提供 TOML 文件的格式化保存功能,确保数组等元素以美观的多行格式输出。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
import tomlkit
|
||||||
|
from tomlkit.items import AoT, Table, Array
|
||||||
|
|
||||||
|
|
||||||
|
def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
|
||||||
|
"""递归格式化 TOML 值,将数组转换为多行格式"""
|
||||||
|
# 处理 AoT (Array of Tables) - 保持原样,递归处理内部
|
||||||
|
if isinstance(obj, AoT):
|
||||||
|
for item in obj:
|
||||||
|
_format_toml_value(item, threshold, depth)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# 处理字典类型 (dict 或 Table)
|
||||||
|
if isinstance(obj, (dict, Table)):
|
||||||
|
for k, v in obj.items():
|
||||||
|
obj[k] = _format_toml_value(v, threshold, depth)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# 处理列表类型 (list 或 Array)
|
||||||
|
if isinstance(obj, (list, Array)):
|
||||||
|
# 如果是纯 list (非 tomlkit Array) 且包含字典/表,视为 AoT 的列表形式
|
||||||
|
# 保持结构递归处理,避免转换为 Inline Table Array (因为 Inline Table 必须单行,复杂对象不友好)
|
||||||
|
if isinstance(obj, list) and not isinstance(obj, Array) and obj and isinstance(obj[0], (dict, Table)):
|
||||||
|
for i, item in enumerate(obj):
|
||||||
|
obj[i] = _format_toml_value(item, threshold, depth)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# 决定是否多行:仅在顶层且长度超过阈值时
|
||||||
|
should_multiline = (depth == 0 and len(obj) > threshold)
|
||||||
|
|
||||||
|
# 如果已经是 tomlkit Array,原地修改以保留注释
|
||||||
|
if isinstance(obj, Array):
|
||||||
|
obj.multiline(should_multiline)
|
||||||
|
for i, item in enumerate(obj):
|
||||||
|
obj[i] = _format_toml_value(item, threshold, depth + 1)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# 普通 list:转换为 tomlkit 数组
|
||||||
|
arr = tomlkit.array()
|
||||||
|
arr.multiline(should_multiline)
|
||||||
|
|
||||||
|
for item in obj:
|
||||||
|
arr.append(_format_toml_value(item, threshold, depth + 1))
|
||||||
|
return arr
|
||||||
|
|
||||||
|
# 其他基本类型直接返回
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
|
||||||
|
"""格式化 TOML 数据并保存到文件"""
|
||||||
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
tomlkit.dump(formatted, f)
|
||||||
|
|
||||||
|
|
||||||
|
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
|
||||||
|
"""格式化 TOML 数据并返回字符串"""
|
||||||
|
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
|
||||||
|
return tomlkit.dumps(formatted)
|
||||||
@@ -11,6 +11,7 @@ from rich.traceback import install
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.toml_utils import format_toml_string
|
||||||
from src.config.config_base import ConfigBase
|
from src.config.config_base import ConfigBase
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
@@ -252,7 +253,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
|||||||
# 如果配置有更新,立即保存到文件
|
# 如果配置有更新,立即保存到文件
|
||||||
if config_updated:
|
if config_updated:
|
||||||
with open(old_config_path, "w", encoding="utf-8") as f:
|
with open(old_config_path, "w", encoding="utf-8") as f:
|
||||||
f.write(tomlkit.dumps(old_config))
|
f.write(format_toml_string(old_config))
|
||||||
logger.info(f"已保存更新后的{config_name}配置文件")
|
logger.info(f"已保存更新后的{config_name}配置文件")
|
||||||
else:
|
else:
|
||||||
logger.info(f"未检测到{config_name}模板默认值变动")
|
logger.info(f"未检测到{config_name}模板默认值变动")
|
||||||
@@ -313,9 +314,9 @@ def _update_config_generic(config_name: str, template_name: str):
|
|||||||
logger.info(f"开始合并{config_name}新旧配置...")
|
logger.info(f"开始合并{config_name}新旧配置...")
|
||||||
_update_dict(new_config, old_config)
|
_update_dict(new_config, old_config)
|
||||||
|
|
||||||
# 保存更新后的配置(保留注释和格式)
|
# 保存更新后的配置(保留注释和格式,数组多行格式化)
|
||||||
with open(new_config_path, "w", encoding="utf-8") as f:
|
with open(new_config_path, "w", encoding="utf-8") as f:
|
||||||
f.write(tomlkit.dumps(new_config))
|
f.write(format_toml_string(new_config))
|
||||||
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def _compute_weights(population: List[Dict]) -> List[float]:
|
|||||||
|
|
||||||
# 如果checked,权重乘以3
|
# 如果checked,权重乘以3
|
||||||
weights = []
|
weights = []
|
||||||
for base_weight, checked in zip(base_weights, checked_flags):
|
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
|
||||||
if checked:
|
if checked:
|
||||||
weights.append(base_weight * 3.0)
|
weights.append(base_weight * 3.0)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -226,19 +226,19 @@ class ExpressionLearner:
|
|||||||
match_responses = []
|
match_responses = []
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
|
|
||||||
# 尝试提取JSON代码块(如果存在)
|
# 尝试提取JSON代码块(如果存在)
|
||||||
json_pattern = r"```json\s*(.*?)\s*```"
|
json_pattern = r"```json\s*(.*?)\s*```"
|
||||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||||
if matches:
|
if matches:
|
||||||
response = matches[0].strip()
|
response = matches[0].strip()
|
||||||
|
|
||||||
# 移除可能的markdown代码块标记(如果没有找到```json,但可能有```)
|
# 移除可能的markdown代码块标记(如果没有找到```json,但可能有```)
|
||||||
if not matches:
|
if not matches:
|
||||||
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
|
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
|
||||||
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
|
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
|
|
||||||
# 检查是否已经是标准JSON数组格式
|
# 检查是否已经是标准JSON数组格式
|
||||||
if response.startswith("[") and response.endswith("]"):
|
if response.startswith("[") and response.endswith("]"):
|
||||||
match_responses = json.loads(response)
|
match_responses = json.loads(response)
|
||||||
|
|||||||
@@ -13,28 +13,28 @@ logger = get_logger("expression_reflector")
|
|||||||
|
|
||||||
class ExpressionReflector:
|
class ExpressionReflector:
|
||||||
"""表达反思器,管理单个聊天流的表达反思提问"""
|
"""表达反思器,管理单个聊天流的表达反思提问"""
|
||||||
|
|
||||||
def __init__(self, chat_id: str):
|
def __init__(self, chat_id: str):
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.last_ask_time: float = 0.0
|
self.last_ask_time: float = 0.0
|
||||||
|
|
||||||
async def check_and_ask(self) -> bool:
|
async def check_and_ask(self) -> bool:
|
||||||
"""
|
"""
|
||||||
检查是否需要提问表达反思,如果需要则提问
|
检查是否需要提问表达反思,如果需要则提问
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否执行了提问
|
bool: 是否执行了提问
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
|
||||||
|
|
||||||
if not global_config.expression.reflect:
|
if not global_config.expression.reflect:
|
||||||
logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过")
|
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
operator_config = global_config.expression.reflect_operator_id
|
operator_config = global_config.expression.reflect_operator_id
|
||||||
if not operator_config:
|
if not operator_config:
|
||||||
logger.debug(f"[Expression Reflection] Operator ID 未配置,跳过")
|
logger.debug("[Expression Reflection] Operator ID 未配置,跳过")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否在允许列表中
|
# 检查是否在允许列表中
|
||||||
@@ -48,7 +48,7 @@ class ExpressionReflector:
|
|||||||
allow_reflect_chat_ids.append(parsed_chat_id)
|
allow_reflect_chat_ids.append(parsed_chat_id)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}")
|
logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}")
|
||||||
|
|
||||||
if self.chat_id not in allow_reflect_chat_ids:
|
if self.chat_id not in allow_reflect_chat_ids:
|
||||||
logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过")
|
logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过")
|
||||||
return False
|
return False
|
||||||
@@ -56,17 +56,21 @@ class ExpressionReflector:
|
|||||||
# 检查上一次提问时间
|
# 检查上一次提问时间
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
time_since_last_ask = current_time - self.last_ask_time
|
time_since_last_ask = current_time - self.last_ask_time
|
||||||
|
|
||||||
# 5-10分钟间隔,随机选择
|
# 5-10分钟间隔,随机选择
|
||||||
min_interval = 10 * 60 # 5分钟
|
min_interval = 10 * 60 # 5分钟
|
||||||
max_interval = 15 * 60 # 10分钟
|
max_interval = 15 * 60 # 10分钟
|
||||||
interval = random.uniform(min_interval, max_interval)
|
interval = random.uniform(min_interval, max_interval)
|
||||||
|
|
||||||
logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
|
||||||
|
)
|
||||||
|
|
||||||
if time_since_last_ask < interval:
|
if time_since_last_ask < interval:
|
||||||
remaining_time = interval - time_since_last_ask
|
remaining_time = interval - time_since_last_ask
|
||||||
logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
||||||
@@ -77,56 +81,59 @@ class ExpressionReflector:
|
|||||||
|
|
||||||
# 获取未检查的表达
|
# 获取未检查的表达
|
||||||
try:
|
try:
|
||||||
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
|
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||||
expressions = (Expression
|
expressions = (
|
||||||
.select()
|
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
|
||||||
.where((Expression.checked == False) & (Expression.rejected == False))
|
)
|
||||||
.limit(50))
|
|
||||||
|
|
||||||
expr_list = list(expressions)
|
expr_list = list(expressions)
|
||||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||||
|
|
||||||
if not expr_list:
|
if not expr_list:
|
||||||
logger.info(f"[Expression Reflection] 没有可用的表达,跳过")
|
logger.info("[Expression Reflection] 没有可用的表达,跳过")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
target_expr: Expression = random.choice(expr_list)
|
target_expr: Expression = random.choice(expr_list)
|
||||||
logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
|
||||||
|
)
|
||||||
|
|
||||||
# 生成询问文本
|
# 生成询问文本
|
||||||
ask_text = _generate_ask_text(target_expr)
|
ask_text = _generate_ask_text(target_expr)
|
||||||
if not ask_text:
|
if not ask_text:
|
||||||
logger.warning(f"[Expression Reflection] 生成询问文本失败,跳过")
|
logger.warning("[Expression Reflection] 生成询问文本失败,跳过")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问")
|
logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问")
|
||||||
# 发送给 Operator
|
# 发送给 Operator
|
||||||
await _send_to_operator(operator_config, ask_text, target_expr)
|
await _send_to_operator(operator_config, ask_text, target_expr)
|
||||||
|
|
||||||
# 更新上一次提问时间
|
# 更新上一次提问时间
|
||||||
self.last_ask_time = current_time
|
self.last_ask_time = current_time
|
||||||
logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}")
|
logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ExpressionReflectorManager:
|
class ExpressionReflectorManager:
|
||||||
"""表达反思管理器,管理多个聊天流的表达反思实例"""
|
"""表达反思管理器,管理多个聊天流的表达反思实例"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reflectors: Dict[str, ExpressionReflector] = {}
|
self.reflectors: Dict[str, ExpressionReflector] = {}
|
||||||
|
|
||||||
def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector:
|
def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector:
|
||||||
"""获取或创建指定聊天流的表达反思实例"""
|
"""获取或创建指定聊天流的表达反思实例"""
|
||||||
if chat_id not in self.reflectors:
|
if chat_id not in self.reflectors:
|
||||||
@@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
|||||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||||
from src.express.reflect_tracker import reflect_tracker_manager
|
from src.express.reflect_tracker import reflect_tracker_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = None
|
chat_stream = None
|
||||||
|
|
||||||
@@ -150,12 +158,12 @@ async def _check_tracker_exists(operator_config: str) -> bool:
|
|||||||
platform = parts[0]
|
platform = parts[0]
|
||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
|
|
||||||
user_info = None
|
user_info = None
|
||||||
group_info = None
|
group_info = None
|
||||||
|
|
||||||
from maim_message import UserInfo, GroupInfo
|
from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
if stream_type == "group":
|
if stream_type == "group":
|
||||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||||
@@ -203,12 +211,12 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
|||||||
platform = parts[0]
|
platform = parts[0]
|
||||||
id_str = parts[1]
|
id_str = parts[1]
|
||||||
stream_type = parts[2]
|
stream_type = parts[2]
|
||||||
|
|
||||||
user_info = None
|
user_info = None
|
||||||
group_info = None
|
group_info = None
|
||||||
|
|
||||||
from maim_message import UserInfo, GroupInfo
|
from maim_message import UserInfo, GroupInfo
|
||||||
|
|
||||||
if stream_type == "group":
|
if stream_type == "group":
|
||||||
group_info = GroupInfo(group_id=id_str, platform=platform)
|
group_info = GroupInfo(group_id=id_str, platform=platform)
|
||||||
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
|
||||||
@@ -232,20 +240,13 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
|||||||
return
|
return
|
||||||
|
|
||||||
stream_id = chat_stream.stream_id
|
stream_id = chat_stream.stream_id
|
||||||
|
|
||||||
# 注册 Tracker
|
# 注册 Tracker
|
||||||
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
|
||||||
|
|
||||||
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
|
||||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
|
||||||
text=text,
|
|
||||||
stream_id=stream_id,
|
|
||||||
typing=True
|
|
||||||
)
|
|
||||||
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class ExpressionSelector:
|
|||||||
|
|
||||||
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
# 优化:一次性查询所有相关chat_id的表达方式,排除 rejected=1 的表达
|
||||||
style_query = Expression.select().where(
|
style_query = Expression.select().where(
|
||||||
(Expression.chat_id.in_(related_chat_ids)) & (Expression.rejected == False)
|
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
|
||||||
)
|
)
|
||||||
|
|
||||||
style_exprs = [
|
style_exprs = [
|
||||||
|
|||||||
@@ -4,34 +4,32 @@ from src.common.logger import get_logger
|
|||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
get_raw_msg_by_timestamp_with_chat,
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
)
|
)
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
pass
|
||||||
|
|
||||||
logger = get_logger("reflect_tracker")
|
logger = get_logger("reflect_tracker")
|
||||||
|
|
||||||
|
|
||||||
class ReflectTracker:
|
class ReflectTracker:
|
||||||
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.expression = expression
|
self.expression = expression
|
||||||
self.created_time = created_time
|
self.created_time = created_time
|
||||||
# self.message_count = 0 # Replaced by checking message list length
|
# self.message_count = 0 # Replaced by checking message list length
|
||||||
self.last_check_msg_count = 0
|
self.last_check_msg_count = 0
|
||||||
self.max_message_count = 30
|
self.max_message_count = 30
|
||||||
self.max_duration = 15 * 60 # 15 minutes
|
self.max_duration = 15 * 60 # 15 minutes
|
||||||
|
|
||||||
# LLM for judging response
|
# LLM for judging response
|
||||||
self.judge_model = LLMRequest(
|
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
|
||||||
model_set=model_config.model_task_config.utils, request_type="reflect.tracker"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._init_prompts()
|
self._init_prompts()
|
||||||
|
|
||||||
def _init_prompts(self):
|
def _init_prompts(self):
|
||||||
@@ -72,16 +70,16 @@ class ReflectTracker:
|
|||||||
if time.time() - self.created_time > self.max_duration:
|
if time.time() - self.created_time > self.max_duration:
|
||||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).")
|
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Fetch messages since creation
|
# Fetch messages since creation
|
||||||
msg_list = get_raw_msg_by_timestamp_with_chat(
|
msg_list = get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=self.chat_stream.stream_id,
|
chat_id=self.chat_stream.stream_id,
|
||||||
timestamp_start=self.created_time,
|
timestamp_start=self.created_time,
|
||||||
timestamp_end=time.time(),
|
timestamp_end=time.time(),
|
||||||
)
|
)
|
||||||
|
|
||||||
current_msg_count = len(msg_list)
|
current_msg_count = len(msg_list)
|
||||||
|
|
||||||
# Check message limit
|
# Check message limit
|
||||||
if current_msg_count > self.max_message_count:
|
if current_msg_count > self.max_message_count:
|
||||||
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).")
|
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).")
|
||||||
@@ -90,9 +88,9 @@ class ReflectTracker:
|
|||||||
# If no new messages since last check, skip
|
# If no new messages since last check, skip
|
||||||
if current_msg_count <= self.last_check_msg_count:
|
if current_msg_count <= self.last_check_msg_count:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.last_check_msg_count = current_msg_count
|
self.last_check_msg_count = current_msg_count
|
||||||
|
|
||||||
# Build context block
|
# Build context block
|
||||||
# Use simple readable format
|
# Use simple readable format
|
||||||
context_block = build_readable_messages(
|
context_block = build_readable_messages(
|
||||||
@@ -109,78 +107,83 @@ class ReflectTracker:
|
|||||||
"reflect_judge_prompt",
|
"reflect_judge_prompt",
|
||||||
situation=self.expression.situation,
|
situation=self.expression.situation,
|
||||||
style=self.expression.style,
|
style=self.expression.style,
|
||||||
context_block=context_block
|
context_block=context_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
||||||
|
|
||||||
response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1)
|
response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1)
|
||||||
|
|
||||||
logger.info(f"ReflectTracker LLM Response: {response}")
|
logger.info(f"ReflectTracker LLM Response: {response}")
|
||||||
|
|
||||||
# Parse JSON
|
# Parse JSON
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
json_pattern = r"```json\s*(.*?)\s*```"
|
json_pattern = r"```json\s*(.*?)\s*```"
|
||||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||||
if not matches:
|
if not matches:
|
||||||
# Try to parse raw response if no code block
|
# Try to parse raw response if no code block
|
||||||
matches = [response]
|
matches = [response]
|
||||||
|
|
||||||
json_obj = json.loads(repair_json(matches[0]))
|
json_obj = json.loads(repair_json(matches[0]))
|
||||||
|
|
||||||
judgment = json_obj.get("judgment")
|
judgment = json_obj.get("judgment")
|
||||||
|
|
||||||
if judgment == "Approve":
|
if judgment == "Approve":
|
||||||
self.expression.checked = True
|
self.expression.checked = True
|
||||||
self.expression.rejected = False
|
self.expression.rejected = False
|
||||||
self.expression.save()
|
self.expression.save()
|
||||||
logger.info(f"Expression {self.expression.id} approved by operator.")
|
logger.info(f"Expression {self.expression.id} approved by operator.")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif judgment == "Reject":
|
elif judgment == "Reject":
|
||||||
self.expression.checked = True
|
self.expression.checked = True
|
||||||
corrected_situation = json_obj.get("corrected_situation")
|
corrected_situation = json_obj.get("corrected_situation")
|
||||||
corrected_style = json_obj.get("corrected_style")
|
corrected_style = json_obj.get("corrected_style")
|
||||||
|
|
||||||
# 检查是否有更新
|
# 检查是否有更新
|
||||||
has_update = bool(corrected_situation or corrected_style)
|
has_update = bool(corrected_situation or corrected_style)
|
||||||
|
|
||||||
if corrected_situation:
|
if corrected_situation:
|
||||||
self.expression.situation = corrected_situation
|
self.expression.situation = corrected_situation
|
||||||
if corrected_style:
|
if corrected_style:
|
||||||
self.expression.style = corrected_style
|
self.expression.style = corrected_style
|
||||||
|
|
||||||
# 如果拒绝但未更新,标记为 rejected=1
|
# 如果拒绝但未更新,标记为 rejected=1
|
||||||
if not has_update:
|
if not has_update:
|
||||||
self.expression.rejected = True
|
self.expression.rejected = True
|
||||||
else:
|
else:
|
||||||
self.expression.rejected = False
|
self.expression.rejected = False
|
||||||
|
|
||||||
self.expression.save()
|
self.expression.save()
|
||||||
|
|
||||||
if has_update:
|
if has_update:
|
||||||
logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}")
|
logger.info(
|
||||||
|
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.")
|
logger.info(
|
||||||
|
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif judgment == "Ignore":
|
elif judgment == "Ignore":
|
||||||
logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.")
|
logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in ReflectTracker check: {e}")
|
logger.error(f"Error in ReflectTracker check: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Global manager for trackers
|
# Global manager for trackers
|
||||||
class ReflectTrackerManager:
|
class ReflectTrackerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
|
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
|
||||||
|
|
||||||
def add_tracker(self, chat_id: str, tracker: ReflectTracker):
|
def add_tracker(self, chat_id: str, tracker: ReflectTracker):
|
||||||
self.trackers[chat_id] = tracker
|
self.trackers[chat_id] = tracker
|
||||||
@@ -192,5 +195,5 @@ class ReflectTrackerManager:
|
|||||||
if chat_id in self.trackers:
|
if chat_id in self.trackers:
|
||||||
del self.trackers[chat_id]
|
del self.trackers[chat_id]
|
||||||
|
|
||||||
reflect_tracker_manager = ReflectTrackerManager()
|
|
||||||
|
|
||||||
|
reflect_tracker_manager = ReflectTrackerManager()
|
||||||
|
|||||||
@@ -44,9 +44,7 @@ class JargonExplainer:
|
|||||||
request_type="jargon.explain",
|
request_type="jargon.explain",
|
||||||
)
|
)
|
||||||
|
|
||||||
def match_jargon_from_messages(
|
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
||||||
self, messages: List[Any]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""
|
"""
|
||||||
通过直接匹配数据库中的jargon字符串来提取黑话
|
通过直接匹配数据库中的jargon字符串来提取黑话
|
||||||
|
|
||||||
@@ -57,7 +55,7 @@ class JargonExplainer:
|
|||||||
List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content
|
List[Dict[str, str]]: 提取到的黑话列表,每个元素包含content
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -67,8 +65,10 @@ class JargonExplainer:
|
|||||||
# 跳过机器人自己的消息
|
# 跳过机器人自己的消息
|
||||||
if is_bot_message(msg):
|
if is_bot_message(msg):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
msg_text = (
|
||||||
|
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||||
|
).strip()
|
||||||
if msg_text:
|
if msg_text:
|
||||||
message_texts.append(msg_text)
|
message_texts.append(msg_text)
|
||||||
|
|
||||||
@@ -79,9 +79,7 @@ class JargonExplainer:
|
|||||||
combined_text = " ".join(message_texts)
|
combined_text = " ".join(message_texts)
|
||||||
|
|
||||||
# 查询所有有meaning的jargon记录
|
# 查询所有有meaning的jargon记录
|
||||||
query = Jargon.select().where(
|
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据all_global配置决定查询逻辑
|
# 根据all_global配置决定查询逻辑
|
||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
@@ -98,7 +96,7 @@ class JargonExplainer:
|
|||||||
# 执行查询并匹配
|
# 执行查询并匹配
|
||||||
matched_jargon: Dict[str, Dict[str, str]] = {}
|
matched_jargon: Dict[str, Dict[str, str]] = {}
|
||||||
query_time = time.time()
|
query_time = time.time()
|
||||||
|
|
||||||
for jargon in query:
|
for jargon in query:
|
||||||
content = jargon.content or ""
|
content = jargon.content or ""
|
||||||
if not content or not content.strip():
|
if not content or not content.strip():
|
||||||
@@ -123,13 +121,13 @@ class JargonExplainer:
|
|||||||
pattern = re.escape(content)
|
pattern = re.escape(content)
|
||||||
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
||||||
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
||||||
if re.search(r'[\u4e00-\u9fff]', content):
|
if re.search(r"[\u4e00-\u9fff]", content):
|
||||||
# 包含中文,使用更宽松的匹配
|
# 包含中文,使用更宽松的匹配
|
||||||
search_pattern = pattern
|
search_pattern = pattern
|
||||||
else:
|
else:
|
||||||
# 纯英文/数字,使用单词边界
|
# 纯英文/数字,使用单词边界
|
||||||
search_pattern = r'\b' + pattern + r'\b'
|
search_pattern = r"\b" + pattern + r"\b"
|
||||||
|
|
||||||
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
||||||
# 找到匹配,记录(去重)
|
# 找到匹配,记录(去重)
|
||||||
if content not in matched_jargon:
|
if content not in matched_jargon:
|
||||||
@@ -139,7 +137,7 @@ class JargonExplainer:
|
|||||||
total_time = match_time - start_time
|
total_time = match_time - start_time
|
||||||
query_duration = query_time - start_time
|
query_duration = query_time - start_time
|
||||||
match_duration = match_time - query_time
|
match_duration = match_time - query_time
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
|
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
|
||||||
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
|
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
|
||||||
@@ -147,9 +145,7 @@ class JargonExplainer:
|
|||||||
|
|
||||||
return list(matched_jargon.values())
|
return list(matched_jargon.values())
|
||||||
|
|
||||||
async def explain_jargon(
|
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||||
self, messages: List[Any], chat_context: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
解释上下文中的黑话
|
解释上下文中的黑话
|
||||||
|
|
||||||
@@ -183,7 +179,7 @@ class JargonExplainer:
|
|||||||
jargon_explanations: List[str] = []
|
jargon_explanations: List[str] = []
|
||||||
for entry in jargon_list:
|
for entry in jargon_list:
|
||||||
content = entry["content"]
|
content = entry["content"]
|
||||||
|
|
||||||
# 根据是否开启全局黑话,决定查询方式
|
# 根据是否开启全局黑话,决定查询方式
|
||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
# 开启全局黑话:查询所有is_global=True的记录
|
# 开启全局黑话:查询所有is_global=True的记录
|
||||||
@@ -239,9 +235,7 @@ class JargonExplainer:
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
async def explain_jargon_in_context(
|
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||||
chat_id: str, messages: List[Any], chat_context: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
解释上下文中的黑话(便捷函数)
|
解释上下文中的黑话(便捷函数)
|
||||||
|
|
||||||
@@ -255,4 +249,3 @@ async def explain_jargon_in_context(
|
|||||||
"""
|
"""
|
||||||
explainer = JargonExplainer(chat_id)
|
explainer = JargonExplainer(chat_id)
|
||||||
return await explainer.explain_jargon(messages, chat_context)
|
return await explainer.explain_jargon(messages, chat_context)
|
||||||
|
|
||||||
|
|||||||
@@ -17,20 +17,18 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.jargon.jargon_utils import (
|
from src.jargon.jargon_utils import (
|
||||||
is_bot_message,
|
is_bot_message,
|
||||||
build_context_paragraph,
|
build_context_paragraph,
|
||||||
contains_bot_self_name,
|
contains_bot_self_name,
|
||||||
parse_chat_id_list,
|
parse_chat_id_list,
|
||||||
chat_id_list_contains,
|
chat_id_list_contains,
|
||||||
update_chat_id_list
|
update_chat_id_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_prompt() -> None:
|
def _init_prompt() -> None:
|
||||||
prompt_str = """
|
prompt_str = """
|
||||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||||
@@ -126,7 +124,6 @@ _init_prompt()
|
|||||||
_init_inference_prompts()
|
_init_inference_prompts()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||||
"""
|
"""
|
||||||
判断是否需要进行含义推断
|
判断是否需要进行含义推断
|
||||||
@@ -211,7 +208,9 @@ class JargonMiner:
|
|||||||
processed_pairs = set()
|
processed_pairs = set()
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
msg_text = (
|
||||||
|
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||||
|
).strip()
|
||||||
if not msg_text or is_bot_message(msg):
|
if not msg_text or is_bot_message(msg):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -270,7 +269,7 @@ class JargonMiner:
|
|||||||
prompt1 = await global_prompt_manager.format_prompt(
|
prompt1 = await global_prompt_manager.format_prompt(
|
||||||
"jargon_inference_with_context_prompt",
|
"jargon_inference_with_context_prompt",
|
||||||
content=content,
|
content=content,
|
||||||
bot_name = global_config.bot.nickname,
|
bot_name=global_config.bot.nickname,
|
||||||
raw_content_list=raw_content_text,
|
raw_content_list=raw_content_text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -588,7 +587,6 @@ class JargonMiner:
|
|||||||
content = entry["content"]
|
content = entry["content"]
|
||||||
raw_content_list = entry["raw_content"] # 已经是列表
|
raw_content_list = entry["raw_content"] # 已经是列表
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询所有content匹配的记录
|
# 查询所有content匹配的记录
|
||||||
query = Jargon.select().where(Jargon.content == content)
|
query = Jargon.select().where(Jargon.content == content)
|
||||||
@@ -782,15 +780,15 @@ def search_jargon(
|
|||||||
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
# 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含
|
||||||
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 只返回有meaning的记录
|
# 只返回有meaning的记录
|
||||||
if not jargon.meaning or jargon.meaning.strip() == "":
|
if not jargon.meaning or jargon.meaning.strip() == "":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
|
||||||
|
|
||||||
# 达到限制数量后停止
|
# 达到限制数量后停止
|
||||||
if len(results) >= limit:
|
if len(results) >= limit:
|
||||||
break
|
break
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
@@ -2,30 +2,29 @@ import json
|
|||||||
from typing import List, Dict, Optional, Any
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Jargon
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
build_readable_messages_with_id,
|
|
||||||
)
|
)
|
||||||
from src.chat.utils.utils import parse_platform_accounts
|
from src.chat.utils.utils import parse_platform_accounts
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||||
"""
|
"""
|
||||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
chat_id_value: 可能是字符串(旧格式)或JSON字符串(新格式)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
|
||||||
"""
|
"""
|
||||||
if not chat_id_value:
|
if not chat_id_value:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 如果是字符串,尝试解析为JSON
|
# 如果是字符串,尝试解析为JSON
|
||||||
if isinstance(chat_id_value, str):
|
if isinstance(chat_id_value, str):
|
||||||
# 尝试解析JSON
|
# 尝试解析JSON
|
||||||
@@ -54,12 +53,12 @@ def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
|||||||
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
|
||||||
"""
|
"""
|
||||||
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
更新chat_id列表,如果target_chat_id已存在则增加计数,否则添加新条目
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
chat_id_list: 当前的chat_id列表,格式为 [[chat_id, count], ...]
|
||||||
target_chat_id: 要更新或添加的chat_id
|
target_chat_id: 要更新或添加的chat_id
|
||||||
increment: 增加的计数,默认为1
|
increment: 增加的计数,默认为1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Any]]: 更新后的chat_id列表
|
List[List[Any]]: 更新后的chat_id列表
|
||||||
"""
|
"""
|
||||||
@@ -74,22 +73,22 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr
|
|||||||
item.append(increment)
|
item.append(increment)
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
|
|
||||||
if not found:
|
if not found:
|
||||||
# 未找到,添加新条目
|
# 未找到,添加新条目
|
||||||
chat_id_list.append([target_chat_id, increment])
|
chat_id_list.append([target_chat_id, increment])
|
||||||
|
|
||||||
return chat_id_list
|
return chat_id_list
|
||||||
|
|
||||||
|
|
||||||
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
检查chat_id列表中是否包含指定的chat_id
|
检查chat_id列表中是否包含指定的chat_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
chat_id_list: chat_id列表,格式为 [[chat_id, count], ...]
|
||||||
target_chat_id: 要查找的chat_id
|
target_chat_id: 要查找的chat_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果包含则返回True
|
bool: 如果包含则返回True
|
||||||
"""
|
"""
|
||||||
@@ -168,10 +167,7 @@ def is_bot_message(msg: Any) -> bool:
|
|||||||
.strip()
|
.strip()
|
||||||
.lower()
|
.lower()
|
||||||
)
|
)
|
||||||
user_id = (
|
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||||
str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not platform or not user_id:
|
if not platform or not user_id:
|
||||||
return False
|
return False
|
||||||
@@ -196,4 +192,4 @@ def is_bot_message(msg: Any) -> bool:
|
|||||||
bot_accounts[plat] = account
|
bot_accounts[plat] = account
|
||||||
|
|
||||||
bot_account = bot_accounts.get(platform)
|
bot_account = bot_accounts.get(platform)
|
||||||
return bool(bot_account and user_id == bot_account)
|
return bool(bot_account and user_id == bot_account)
|
||||||
|
|||||||
@@ -338,8 +338,10 @@ class LLMRequest:
|
|||||||
if e.__cause__:
|
if e.__cause__:
|
||||||
original_error_type = type(e.__cause__).__name__
|
original_error_type = type(e.__cause__).__name__
|
||||||
original_error_msg = str(e.__cause__)
|
original_error_msg = str(e.__cause__)
|
||||||
original_error_info = f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
original_error_info = (
|
||||||
|
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
if retry_remain <= 0:
|
if retry_remain <= 0:
|
||||||
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}")
|
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class MainSystem:
|
|||||||
from src.webui.webui_server import get_webui_server
|
from src.webui.webui_server import get_webui_server
|
||||||
|
|
||||||
self.webui_server = get_webui_server()
|
self.webui_server = get_webui_server()
|
||||||
|
|
||||||
if webui_mode == "development":
|
if webui_mode == "development":
|
||||||
logger.info("📝 WebUI 开发模式已启用")
|
logger.info("📝 WebUI 开发模式已启用")
|
||||||
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
|
||||||
@@ -64,9 +64,9 @@ class MainSystem:
|
|||||||
logger.info("💡 前端将运行在 http://localhost:7999")
|
logger.info("💡 前端将运行在 http://localhost:7999")
|
||||||
else:
|
else:
|
||||||
logger.info("✅ WebUI 生产模式已启用")
|
logger.info("✅ WebUI 生产模式已启用")
|
||||||
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
logger.info("🌐 WebUI 将运行在 http://0.0.0.0:8001")
|
||||||
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -296,7 +296,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
|||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
if not global_config.jargon.all_global and not jargon.is_global:
|
if not global_config.jargon.all_global and not jargon.is_global:
|
||||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||||
@@ -586,9 +585,7 @@ async def _react_agent_solve_question(
|
|||||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
||||||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}")
|
||||||
f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}"
|
|
||||||
)
|
|
||||||
return True, found_answer_content, thinking_steps, False
|
return True, found_answer_content, thinking_steps, False
|
||||||
|
|
||||||
if not_enough_info_reason:
|
if not_enough_info_reason:
|
||||||
@@ -1016,9 +1013,7 @@ async def build_memory_retrieval_prompt(
|
|||||||
|
|
||||||
if question_results:
|
if question_results:
|
||||||
retrieved_memory = "\n\n".join(question_results)
|
retrieved_memory = "\n\n".join(question_results)
|
||||||
logger.info(
|
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆")
|
||||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆"
|
|
||||||
)
|
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
logger.debug("所有问题均未找到答案")
|
logger.debug("所有问题均未找到答案")
|
||||||
|
|||||||
@@ -54,7 +54,9 @@ async def search_chat_history(
|
|||||||
if record.participants:
|
if record.participants:
|
||||||
try:
|
try:
|
||||||
participants_data = (
|
participants_data = (
|
||||||
json.loads(record.participants) if isinstance(record.participants, str) else record.participants
|
json.loads(record.participants)
|
||||||
|
if isinstance(record.participants, str)
|
||||||
|
else record.participants
|
||||||
)
|
)
|
||||||
if isinstance(participants_data, list):
|
if isinstance(participants_data, list):
|
||||||
participants_list = [str(p).lower() for p in participants_data]
|
participants_list = [str(p).lower() for p in participants_data]
|
||||||
@@ -156,9 +158,7 @@ async def search_chat_history(
|
|||||||
# 添加关键词
|
# 添加关键词
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
|
||||||
)
|
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
result_parts.append(f"关键词:{keywords_str}")
|
result_parts.append(f"关键词:{keywords_str}")
|
||||||
@@ -208,9 +208,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
|
|||||||
return "未提供有效的记忆ID"
|
return "未提供有效的记忆ID"
|
||||||
|
|
||||||
# 查询记录
|
# 查询记录
|
||||||
query = ChatHistory.select().where(
|
query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)))
|
||||||
(ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))
|
|
||||||
)
|
|
||||||
records = list(query.order_by(ChatHistory.start_time.desc()))
|
records = list(query.order_by(ChatHistory.start_time.desc()))
|
||||||
|
|
||||||
if not records:
|
if not records:
|
||||||
@@ -256,9 +254,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
|
|||||||
# 添加关键词
|
# 添加关键词
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
|
||||||
)
|
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
result_parts.append(f"关键词:{keywords_str}")
|
result_parts.append(f"关键词:{keywords_str}")
|
||||||
|
|||||||
@@ -1,18 +1,263 @@
|
|||||||
"""
|
"""
|
||||||
插件系统配置类型定义
|
插件系统配置类型定义
|
||||||
|
|
||||||
|
提供插件配置的类型定义,支持 WebUI 可视化配置编辑。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Optional, List
|
from typing import Any, Optional, List, Dict, Union
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConfigField:
|
class ConfigField:
|
||||||
"""配置字段定义"""
|
"""
|
||||||
|
配置字段定义
|
||||||
|
|
||||||
type: type # 字段类型
|
用于定义插件配置项的元数据,支持类型验证、UI 渲染等功能。
|
||||||
|
|
||||||
|
基础示例:
|
||||||
|
ConfigField(type=str, default="", description="API密钥")
|
||||||
|
|
||||||
|
完整示例:
|
||||||
|
ConfigField(
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
description="API密钥",
|
||||||
|
input_type="password",
|
||||||
|
placeholder="请输入API密钥",
|
||||||
|
required=True,
|
||||||
|
hint="从服务商控制台获取",
|
||||||
|
order=1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# === 基础字段(必需) ===
|
||||||
|
type: type # 字段类型: str, int, float, bool, list, dict
|
||||||
default: Any # 默认值
|
default: Any # 默认值
|
||||||
description: str # 字段描述
|
description: str # 字段描述(也用作默认标签)
|
||||||
example: Optional[str] = None # 示例值
|
|
||||||
|
# === 验证相关 ===
|
||||||
|
example: Optional[str] = None # 示例值(用于生成配置文件注释)
|
||||||
required: bool = False # 是否必需
|
required: bool = False # 是否必需
|
||||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
|
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表(用于下拉选择)
|
||||||
|
min: Optional[float] = None # 最小值(数字类型)
|
||||||
|
max: Optional[float] = None # 最大值(数字类型)
|
||||||
|
step: Optional[float] = None # 步进值(数字类型)
|
||||||
|
pattern: Optional[str] = None # 正则验证(字符串类型)
|
||||||
|
max_length: Optional[int] = None # 最大长度(字符串类型)
|
||||||
|
|
||||||
|
# === UI 显示控制 ===
|
||||||
|
label: Optional[str] = None # 显示标签(默认使用 description)
|
||||||
|
placeholder: Optional[str] = None # 输入框占位符
|
||||||
|
hint: Optional[str] = None # 字段下方的提示文字
|
||||||
|
icon: Optional[str] = None # 字段图标名称
|
||||||
|
hidden: bool = False # 是否在 UI 中隐藏
|
||||||
|
disabled: bool = False # 是否禁用编辑
|
||||||
|
order: int = 0 # 排序权重(数字越小越靠前)
|
||||||
|
|
||||||
|
# === 输入控件类型 ===
|
||||||
|
# 可选值: text, password, textarea, number, color, code, file, json
|
||||||
|
# 不指定时根据 type 和 choices 自动推断
|
||||||
|
input_type: Optional[str] = None
|
||||||
|
|
||||||
|
# === textarea 专用 ===
|
||||||
|
rows: int = 3 # 文本域行数
|
||||||
|
|
||||||
|
# === 分组与布局 ===
|
||||||
|
group: Optional[str] = None # 字段分组(在 section 内再细分)
|
||||||
|
|
||||||
|
# === 条件显示 ===
|
||||||
|
depends_on: Optional[str] = None # 依赖的字段路径,如 "section.field"
|
||||||
|
depends_value: Any = None # 依赖字段需要的值(当依赖字段等于此值时显示)
|
||||||
|
|
||||||
|
def get_ui_type(self) -> str:
|
||||||
|
"""
|
||||||
|
获取 UI 控件类型
|
||||||
|
|
||||||
|
如果指定了 input_type 则直接返回,否则根据 type 和 choices 自动推断。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
控件类型字符串
|
||||||
|
"""
|
||||||
|
if self.input_type:
|
||||||
|
return self.input_type
|
||||||
|
|
||||||
|
# 根据 type 和 choices 自动推断
|
||||||
|
if self.type is bool:
|
||||||
|
return "switch"
|
||||||
|
elif self.type in (int, float):
|
||||||
|
if self.min is not None and self.max is not None:
|
||||||
|
return "slider"
|
||||||
|
return "number"
|
||||||
|
elif self.type is str:
|
||||||
|
if self.choices:
|
||||||
|
return "select"
|
||||||
|
return "text"
|
||||||
|
elif self.type is list:
|
||||||
|
return "list"
|
||||||
|
elif self.type is dict:
|
||||||
|
return "json"
|
||||||
|
else:
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
转换为可序列化的字典(用于 API 传输)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含所有配置信息的字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"type": self.type.__name__ if isinstance(self.type, type) else str(self.type),
|
||||||
|
"default": self.default,
|
||||||
|
"description": self.description,
|
||||||
|
"example": self.example,
|
||||||
|
"required": self.required,
|
||||||
|
"choices": self.choices if self.choices else None,
|
||||||
|
"min": self.min,
|
||||||
|
"max": self.max,
|
||||||
|
"step": self.step,
|
||||||
|
"pattern": self.pattern,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"label": self.label or self.description,
|
||||||
|
"placeholder": self.placeholder,
|
||||||
|
"hint": self.hint,
|
||||||
|
"icon": self.icon,
|
||||||
|
"hidden": self.hidden,
|
||||||
|
"disabled": self.disabled,
|
||||||
|
"order": self.order,
|
||||||
|
"input_type": self.input_type,
|
||||||
|
"ui_type": self.get_ui_type(),
|
||||||
|
"rows": self.rows,
|
||||||
|
"group": self.group,
|
||||||
|
"depends_on": self.depends_on,
|
||||||
|
"depends_value": self.depends_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigSection:
|
||||||
|
"""
|
||||||
|
配置节定义
|
||||||
|
|
||||||
|
用于描述配置文件中一个 section 的元数据。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
ConfigSection(
|
||||||
|
title="API配置",
|
||||||
|
description="外部API连接参数",
|
||||||
|
icon="cloud",
|
||||||
|
order=1
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
title: str # 显示标题
|
||||||
|
description: Optional[str] = None # 详细描述
|
||||||
|
icon: Optional[str] = None # 图标名称
|
||||||
|
collapsed: bool = False # 默认是否折叠
|
||||||
|
order: int = 0 # 排序权重
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为可序列化的字典"""
|
||||||
|
return {
|
||||||
|
"title": self.title,
|
||||||
|
"description": self.description,
|
||||||
|
"icon": self.icon,
|
||||||
|
"collapsed": self.collapsed,
|
||||||
|
"order": self.order,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigTab:
|
||||||
|
"""
|
||||||
|
配置标签页定义
|
||||||
|
|
||||||
|
用于将多个 section 组织到一个标签页中。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
ConfigTab(
|
||||||
|
id="general",
|
||||||
|
title="通用设置",
|
||||||
|
icon="settings",
|
||||||
|
sections=["plugin", "api"]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str # 标签页 ID
|
||||||
|
title: str # 显示标题
|
||||||
|
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
|
||||||
|
icon: Optional[str] = None # 图标名称
|
||||||
|
order: int = 0 # 排序权重
|
||||||
|
badge: Optional[str] = None # 角标文字(如 "Beta", "New")
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为可序列化的字典"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"title": self.title,
|
||||||
|
"sections": self.sections,
|
||||||
|
"icon": self.icon,
|
||||||
|
"order": self.order,
|
||||||
|
"badge": self.badge,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigLayout:
|
||||||
|
"""
|
||||||
|
配置页面布局定义
|
||||||
|
|
||||||
|
用于定义插件配置页面的整体布局结构。
|
||||||
|
|
||||||
|
布局类型:
|
||||||
|
- "auto": 自动布局,sections 作为折叠面板显示
|
||||||
|
- "tabs": 标签页布局
|
||||||
|
- "pages": 分页布局(左侧导航 + 右侧内容)
|
||||||
|
|
||||||
|
简单示例(标签页布局):
|
||||||
|
ConfigLayout(
|
||||||
|
type="tabs",
|
||||||
|
tabs=[
|
||||||
|
ConfigTab(id="basic", title="基础", sections=["plugin", "api"]),
|
||||||
|
ConfigTab(id="advanced", title="高级", sections=["debug"]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "auto" # 布局类型: auto, tabs, pages
|
||||||
|
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为可序列化的字典"""
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"tabs": [tab.to_dict() for tab in self.tabs],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def section_meta(
|
||||||
|
title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0
|
||||||
|
) -> Union[str, ConfigSection]:
|
||||||
|
"""
|
||||||
|
便捷函数:创建 section 元数据
|
||||||
|
|
||||||
|
可以在 config_section_descriptions 中使用,提供比纯字符串更丰富的信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: 显示标题
|
||||||
|
description: 详细描述
|
||||||
|
icon: 图标名称
|
||||||
|
collapsed: 默认是否折叠
|
||||||
|
order: 排序权重
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfigSection 实例
|
||||||
|
|
||||||
|
示例:
|
||||||
|
config_section_descriptions = {
|
||||||
|
"api": section_meta("API配置", icon="cloud", order=1),
|
||||||
|
"debug": section_meta("调试设置", collapsed=True, order=99),
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order)
|
||||||
|
|||||||
@@ -12,7 +12,11 @@ from src.plugin_system.base.component_types import (
|
|||||||
PluginInfo,
|
PluginInfo,
|
||||||
PythonDependency,
|
PythonDependency,
|
||||||
)
|
)
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import (
|
||||||
|
ConfigField,
|
||||||
|
ConfigSection,
|
||||||
|
ConfigLayout,
|
||||||
|
)
|
||||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||||
|
|
||||||
logger = get_logger("plugin_base")
|
logger = get_logger("plugin_base")
|
||||||
@@ -60,7 +64,10 @@ class PluginBase(ABC):
|
|||||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
config_section_descriptions: Dict[str, str] = {}
|
config_section_descriptions: Dict[str, Union[str, ConfigSection]] = {}
|
||||||
|
|
||||||
|
# 布局配置(可选,不定义则使用自动布局)
|
||||||
|
config_layout: ConfigLayout = None
|
||||||
|
|
||||||
def __init__(self, plugin_dir: str):
|
def __init__(self, plugin_dir: str):
|
||||||
"""初始化插件
|
"""初始化插件
|
||||||
@@ -563,6 +570,93 @@ class PluginBase(ABC):
|
|||||||
|
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
def get_webui_config_schema(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取 WebUI 配置 Schema
|
||||||
|
|
||||||
|
返回完整的配置 schema,包含:
|
||||||
|
- 插件基本信息
|
||||||
|
- 所有 section 及其字段定义
|
||||||
|
- 布局配置
|
||||||
|
|
||||||
|
用于 WebUI 动态生成配置表单。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 完整的配置 schema
|
||||||
|
"""
|
||||||
|
schema = {
|
||||||
|
"plugin_id": self.plugin_name,
|
||||||
|
"plugin_info": {
|
||||||
|
"name": self.display_name,
|
||||||
|
"version": self.plugin_version,
|
||||||
|
"description": self.plugin_description,
|
||||||
|
"author": self.plugin_author,
|
||||||
|
},
|
||||||
|
"sections": {},
|
||||||
|
"layout": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理 sections
|
||||||
|
for section_name, fields in self.config_schema.items():
|
||||||
|
if not isinstance(fields, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
section_data = {
|
||||||
|
"name": section_name,
|
||||||
|
"title": section_name,
|
||||||
|
"description": None,
|
||||||
|
"icon": None,
|
||||||
|
"collapsed": False,
|
||||||
|
"order": 0,
|
||||||
|
"fields": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取 section 元数据
|
||||||
|
section_meta = self.config_section_descriptions.get(section_name)
|
||||||
|
if section_meta:
|
||||||
|
if isinstance(section_meta, str):
|
||||||
|
section_data["title"] = section_meta
|
||||||
|
elif isinstance(section_meta, ConfigSection):
|
||||||
|
section_data["title"] = section_meta.title
|
||||||
|
section_data["description"] = section_meta.description
|
||||||
|
section_data["icon"] = section_meta.icon
|
||||||
|
section_data["collapsed"] = section_meta.collapsed
|
||||||
|
section_data["order"] = section_meta.order
|
||||||
|
elif isinstance(section_meta, dict):
|
||||||
|
section_data.update(section_meta)
|
||||||
|
|
||||||
|
# 处理字段
|
||||||
|
for field_name, field_def in fields.items():
|
||||||
|
if isinstance(field_def, ConfigField):
|
||||||
|
field_data = field_def.to_dict()
|
||||||
|
field_data["name"] = field_name
|
||||||
|
section_data["fields"][field_name] = field_data
|
||||||
|
|
||||||
|
schema["sections"][section_name] = section_data
|
||||||
|
|
||||||
|
# 处理布局
|
||||||
|
if self.config_layout:
|
||||||
|
schema["layout"] = self.config_layout.to_dict()
|
||||||
|
else:
|
||||||
|
# 自动布局:按 section order 排序
|
||||||
|
schema["layout"] = {
|
||||||
|
"type": "auto",
|
||||||
|
"tabs": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def get_current_config_values(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取当前配置值
|
||||||
|
|
||||||
|
返回插件当前的配置值(已从配置文件加载)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 当前配置值
|
||||||
|
"""
|
||||||
|
return self.config.copy()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_plugin(self) -> bool:
|
def register_plugin(self) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
380
src/webui/chat_routes.py
Normal file
380
src/webui/chat_routes.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
"""本地聊天室路由 - WebUI 与麦麦直接对话"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database_model import Messages
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
|
||||||
|
logger = get_logger("webui.chat")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||||
|
|
||||||
|
# WebUI 聊天的虚拟群组 ID
|
||||||
|
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||||
|
WEBUI_CHAT_PLATFORM = "webui"
|
||||||
|
|
||||||
|
# 固定的 WebUI 用户 ID 前缀
|
||||||
|
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryMessage(BaseModel):
|
||||||
|
"""聊天历史消息"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str # 'user' | 'bot' | 'system'
|
||||||
|
content: str
|
||||||
|
timestamp: float
|
||||||
|
sender_name: str
|
||||||
|
sender_id: Optional[str] = None
|
||||||
|
is_bot: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ChatHistoryManager:
|
||||||
|
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||||
|
|
||||||
|
def __init__(self, max_messages: int = 200):
|
||||||
|
self.max_messages = max_messages
|
||||||
|
|
||||||
|
def _message_to_dict(self, msg: Messages) -> Dict[str, Any]:
|
||||||
|
"""将数据库消息转换为前端格式"""
|
||||||
|
# 判断是否是机器人消息
|
||||||
|
# WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息
|
||||||
|
user_id = msg.user_id or ""
|
||||||
|
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": msg.message_id,
|
||||||
|
"type": "bot" if is_bot else "user",
|
||||||
|
"content": msg.processed_plain_text or msg.display_message or "",
|
||||||
|
"timestamp": msg.time,
|
||||||
|
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||||
|
"sender_id": "bot" if is_bot else user_id,
|
||||||
|
"is_bot": is_bot,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||||||
|
"""从数据库获取最近的历史记录"""
|
||||||
|
try:
|
||||||
|
# 查询 WebUI 平台的消息,按时间排序
|
||||||
|
messages = (
|
||||||
|
Messages.select()
|
||||||
|
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
|
||||||
|
.order_by(Messages.time.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为列表并反转(使最旧的消息在前)
|
||||||
|
result = [self._message_to_dict(msg) for msg in messages]
|
||||||
|
result.reverse()
|
||||||
|
|
||||||
|
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clear_history(self) -> int:
|
||||||
|
"""清空 WebUI 聊天历史记录"""
|
||||||
|
try:
|
||||||
|
deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
|
||||||
|
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
|
||||||
|
return deleted
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"清空聊天记录失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
# 全局聊天历史管理器
|
||||||
|
chat_history = ChatHistoryManager()
|
||||||
|
|
||||||
|
|
||||||
|
# 存储 WebSocket 连接
|
||||||
|
class ChatConnectionManager:
|
||||||
|
"""聊天连接管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.active_connections: Dict[str, WebSocket] = {}
|
||||||
|
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||||
|
|
||||||
|
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||||
|
await websocket.accept()
|
||||||
|
self.active_connections[session_id] = websocket
|
||||||
|
self.user_sessions[user_id] = session_id
|
||||||
|
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||||
|
|
||||||
|
def disconnect(self, session_id: str, user_id: str):
|
||||||
|
if session_id in self.active_connections:
|
||||||
|
del self.active_connections[session_id]
|
||||||
|
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||||
|
del self.user_sessions[user_id]
|
||||||
|
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||||
|
|
||||||
|
async def send_message(self, session_id: str, message: dict):
|
||||||
|
if session_id in self.active_connections:
|
||||||
|
try:
|
||||||
|
await self.active_connections[session_id].send_json(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送消息失败: {e}")
|
||||||
|
|
||||||
|
async def broadcast(self, message: dict):
|
||||||
|
"""广播消息给所有连接"""
|
||||||
|
for session_id in list(self.active_connections.keys()):
|
||||||
|
await self.send_message(session_id, message)
|
||||||
|
|
||||||
|
|
||||||
|
chat_manager = ChatConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
def create_message_data(
|
||||||
|
content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""创建符合麦麦消息格式的消息数据"""
|
||||||
|
if message_id is None:
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message_info": {
|
||||||
|
"platform": WEBUI_CHAT_PLATFORM,
|
||||||
|
"message_id": message_id,
|
||||||
|
"time": time.time(),
|
||||||
|
"group_info": {
|
||||||
|
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||||
|
"group_name": "WebUI本地聊天室",
|
||||||
|
"platform": WEBUI_CHAT_PLATFORM,
|
||||||
|
},
|
||||||
|
"user_info": {
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_nickname": user_name,
|
||||||
|
"user_cardname": user_name,
|
||||||
|
"platform": WEBUI_CHAT_PLATFORM,
|
||||||
|
},
|
||||||
|
"additional_config": {
|
||||||
|
"at_bot": is_at_bot,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"message_segment": {
|
||||||
|
"type": "seglist",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"data": content,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "mention_bot",
|
||||||
|
"data": "1.0",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"raw_message": content,
|
||||||
|
"processed_plain_text": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history")
|
||||||
|
async def get_chat_history(
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||||
|
):
|
||||||
|
"""获取聊天历史记录
|
||||||
|
|
||||||
|
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||||
|
"""
|
||||||
|
history = chat_history.get_history(limit)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"messages": history,
|
||||||
|
"total": len(history),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/history")
|
||||||
|
async def clear_chat_history():
|
||||||
|
"""清空聊天历史记录"""
|
||||||
|
deleted = chat_history.clear_history()
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"已清空 {deleted} 条聊天记录",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/ws")
|
||||||
|
async def websocket_chat(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: Optional[str] = Query(default=None),
|
||||||
|
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||||
|
):
|
||||||
|
"""WebSocket 聊天端点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户唯一标识(由前端生成并持久化)
|
||||||
|
user_name: 用户显示昵称(可修改)
|
||||||
|
"""
|
||||||
|
# 生成会话 ID(每次连接都是新的)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 如果没有提供 user_id,生成一个新的
|
||||||
|
if not user_id:
|
||||||
|
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||||
|
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||||
|
# 确保 user_id 有正确的前缀
|
||||||
|
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||||
|
|
||||||
|
await chat_manager.connect(websocket, session_id, user_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "session_info",
|
||||||
|
"session_id": session_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"user_name": user_name,
|
||||||
|
"bot_name": global_config.bot.nickname,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送历史记录
|
||||||
|
history = chat_history.get_history(50)
|
||||||
|
if history:
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "history",
|
||||||
|
"messages": history,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 发送欢迎消息(不保存到历史)
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "system",
|
||||||
|
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
data = await websocket.receive_json()
|
||||||
|
|
||||||
|
if data.get("type") == "message":
|
||||||
|
content = data.get("content", "").strip()
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 用户可以更新昵称
|
||||||
|
current_user_name = data.get("user_name", user_name)
|
||||||
|
|
||||||
|
message_id = str(uuid.uuid4())
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
# 广播用户消息给所有连接(包括发送者)
|
||||||
|
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||||
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
|
"type": "user_message",
|
||||||
|
"content": content,
|
||||||
|
"message_id": message_id,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"sender": {
|
||||||
|
"name": current_user_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"is_bot": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建麦麦消息格式
|
||||||
|
message_data = create_message_data(
|
||||||
|
content=content,
|
||||||
|
user_id=user_id,
|
||||||
|
user_name=current_user_name,
|
||||||
|
message_id=message_id,
|
||||||
|
is_at_bot=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 显示正在输入状态
|
||||||
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
|
"type": "typing",
|
||||||
|
"is_typing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用麦麦的消息处理
|
||||||
|
await chat_bot.message_process(message_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理消息时出错: {e}")
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"content": f"处理消息时出错: {str(e)}",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
|
"type": "typing",
|
||||||
|
"is_typing": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif data.get("type") == "ping":
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "pong",
|
||||||
|
"timestamp": time.time(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
elif data.get("type") == "update_nickname":
|
||||||
|
# 允许用户更新昵称
|
||||||
|
if new_name := data.get("user_name", "").strip():
|
||||||
|
current_user_name = new_name
|
||||||
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"type": "nickname_updated",
|
||||||
|
"user_name": current_user_name,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket 错误: {e}")
|
||||||
|
finally:
|
||||||
|
chat_manager.disconnect(session_id, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/info")
|
||||||
|
async def get_chat_info():
|
||||||
|
"""获取聊天室信息"""
|
||||||
|
return {
|
||||||
|
"bot_name": global_config.bot.nickname,
|
||||||
|
"platform": WEBUI_CHAT_PLATFORM,
|
||||||
|
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||||
|
"active_sessions": len(chat_manager.active_connections),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_webui_chat_broadcaster() -> tuple:
|
||||||
|
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||||
|
"""
|
||||||
|
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||||
@@ -5,9 +5,10 @@
|
|||||||
import os
|
import os
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, HTTPException, Body
|
from fastapi import APIRouter, HTTPException, Body
|
||||||
from typing import Any
|
from typing import Any, Annotated
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
from src.config.official_configs import (
|
from src.config.official_configs import (
|
||||||
BotConfig,
|
BotConfig,
|
||||||
@@ -41,6 +42,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
|
|||||||
|
|
||||||
logger = get_logger("webui")
|
logger = get_logger("webui")
|
||||||
|
|
||||||
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
|
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||||
|
SectionBody = Annotated[Any, Body()]
|
||||||
|
RawContentBody = Annotated[str, Body(embed=True)]
|
||||||
|
PathBody = Annotated[dict[str, str], Body()]
|
||||||
|
|
||||||
router = APIRouter(prefix="/config", tags=["config"])
|
router = APIRouter(prefix="/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
@@ -90,7 +97,7 @@ async def get_bot_config_schema():
|
|||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取配置架构失败: {e}")
|
logger.error(f"获取配置架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/model")
|
@router.get("/schema/model")
|
||||||
@@ -101,7 +108,7 @@ async def get_model_config_schema():
|
|||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模型配置架构失败: {e}")
|
logger.error(f"获取模型配置架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 子配置架构获取接口 =====
|
# ===== 子配置架构获取接口 =====
|
||||||
@@ -174,7 +181,7 @@ async def get_config_section_schema(section_name: str):
|
|||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取配置节架构失败: {e}")
|
logger.error(f"获取配置节架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置读取接口 =====
|
# ===== 配置读取接口 =====
|
||||||
@@ -196,7 +203,7 @@ async def get_bot_config():
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model")
|
@router.get("/model")
|
||||||
@@ -215,26 +222,25 @@ async def get_model_config():
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置更新接口 =====
|
# ===== 配置更新接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot")
|
@router.post("/bot")
|
||||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
async def update_bot_config(config_data: ConfigBody):
|
||||||
"""更新麦麦主程序配置"""
|
"""更新麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件(格式化数组为多行)
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
save_toml_with_format(config_data, config_path)
|
||||||
tomlkit.dump(config_data, f)
|
|
||||||
|
|
||||||
logger.info("麦麦主程序配置已更新")
|
logger.info("麦麦主程序配置已更新")
|
||||||
return {"success": True, "message": "配置已保存"}
|
return {"success": True, "message": "配置已保存"}
|
||||||
@@ -242,23 +248,22 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model")
|
@router.post("/model")
|
||||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
async def update_model_config(config_data: ConfigBody):
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
APIAdapterConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件(格式化数组为多行)
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
save_toml_with_format(config_data, config_path)
|
||||||
tomlkit.dump(config_data, f)
|
|
||||||
|
|
||||||
logger.info("模型配置已更新")
|
logger.info("模型配置已更新")
|
||||||
return {"success": True, "message": "配置已保存"}
|
return {"success": True, "message": "配置已保存"}
|
||||||
@@ -266,14 +271,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置节更新接口 =====
|
# ===== 配置节更新接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/section/{section_name}")
|
@router.post("/bot/section/{section_name}")
|
||||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
@@ -304,11 +309,10 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
|||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置(tomlkit.dump 会保留注释)
|
# 保存配置(格式化数组为多行,保留注释)
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
save_toml_with_format(config_data, config_path)
|
||||||
tomlkit.dump(config_data, f)
|
|
||||||
|
|
||||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||||
@@ -316,7 +320,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新配置节失败: {e}")
|
logger.error(f"更新配置节失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 原始 TOML 文件操作接口 =====
|
# ===== 原始 TOML 文件操作接口 =====
|
||||||
@@ -338,24 +342,24 @@ async def get_bot_config_raw():
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/raw")
|
@router.post("/bot/raw")
|
||||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||||
try:
|
try:
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
try:
|
try:
|
||||||
config_data = tomlkit.loads(raw_content)
|
config_data = tomlkit.loads(raw_content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||||
|
|
||||||
# 验证配置数据结构
|
# 验证配置数据结构
|
||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
@@ -368,11 +372,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/section/{section_name}")
|
@router.post("/model/section/{section_name}")
|
||||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||||
"""更新模型配置的指定节(保留注释和格式)"""
|
"""更新模型配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
@@ -403,11 +407,10 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
|||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
APIAdapterConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置(tomlkit.dump 会保留注释)
|
# 保存配置(格式化数组为多行,保留注释)
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
save_toml_with_format(config_data, config_path)
|
||||||
tomlkit.dump(config_data, f)
|
|
||||||
|
|
||||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||||
@@ -415,7 +418,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新配置节失败: {e}")
|
logger.error(f"更新配置节失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 适配器配置管理接口 =====
|
# ===== 适配器配置管理接口 =====
|
||||||
@@ -425,11 +428,11 @@ def _normalize_adapter_path(path: str) -> str:
|
|||||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||||
if not path:
|
if not path:
|
||||||
return path
|
return path
|
||||||
|
|
||||||
# 如果已经是绝对路径,直接返回
|
# 如果已经是绝对路径,直接返回
|
||||||
if os.path.isabs(path):
|
if os.path.isabs(path):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||||
|
|
||||||
@@ -438,17 +441,17 @@ def _to_relative_path(path: str) -> str:
|
|||||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||||
if not path or not os.path.isabs(path):
|
if not path or not os.path.isabs(path):
|
||||||
return path
|
return path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试获取相对路径
|
# 尝试获取相对路径
|
||||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||||
if not rel_path.startswith('..'):
|
if not rel_path.startswith(".."):
|
||||||
return rel_path
|
return rel_path
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 无法转换为相对路径,返回绝对路径
|
# 无法转换为相对路径,返回绝对路径
|
||||||
return path
|
return path
|
||||||
|
|
||||||
@@ -463,6 +466,7 @@ async def get_adapter_config_path():
|
|||||||
return {"success": True, "path": None}
|
return {"success": True, "path": None}
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||||
webui_data = json.load(f)
|
webui_data = json.load(f)
|
||||||
|
|
||||||
@@ -472,10 +476,11 @@ async def get_adapter_config_path():
|
|||||||
|
|
||||||
# 将路径规范化为绝对路径
|
# 将路径规范化为绝对路径
|
||||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||||
|
|
||||||
# 检查文件是否存在并返回最后修改时间
|
# 检查文件是否存在并返回最后修改时间
|
||||||
if os.path.exists(abs_path):
|
if os.path.exists(abs_path):
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
mtime = os.path.getmtime(abs_path)
|
mtime = os.path.getmtime(abs_path)
|
||||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||||
# 返回相对路径(如果可能)
|
# 返回相对路径(如果可能)
|
||||||
@@ -487,11 +492,11 @@ async def get_adapter_config_path():
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取适配器配置路径失败: {e}")
|
logger.error(f"获取适配器配置路径失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config/path")
|
@router.post("/adapter-config/path")
|
||||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
async def save_adapter_config_path(data: PathBody):
|
||||||
"""保存适配器配置文件路径偏好"""
|
"""保存适配器配置文件路径偏好"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
@@ -511,10 +516,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
|||||||
|
|
||||||
# 将路径规范化为绝对路径
|
# 将路径规范化为绝对路径
|
||||||
abs_path = _normalize_adapter_path(path)
|
abs_path = _normalize_adapter_path(path)
|
||||||
|
|
||||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||||
save_path = _to_relative_path(abs_path)
|
save_path = _to_relative_path(abs_path)
|
||||||
|
|
||||||
# 更新路径
|
# 更新路径
|
||||||
webui_data["adapter_config_path"] = save_path
|
webui_data["adapter_config_path"] = save_path
|
||||||
|
|
||||||
@@ -530,7 +535,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存适配器配置路径失败: {e}")
|
logger.error(f"保存适配器配置路径失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config")
|
@router.get("/adapter-config")
|
||||||
@@ -542,7 +547,7 @@ async def get_adapter_config(path: str):
|
|||||||
|
|
||||||
# 将路径规范化为绝对路径
|
# 将路径规范化为绝对路径
|
||||||
abs_path = _normalize_adapter_path(path)
|
abs_path = _normalize_adapter_path(path)
|
||||||
|
|
||||||
# 检查文件是否存在
|
# 检查文件是否存在
|
||||||
if not os.path.exists(abs_path):
|
if not os.path.exists(abs_path):
|
||||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||||
@@ -562,11 +567,11 @@ async def get_adapter_config(path: str):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取适配器配置失败: {e}")
|
logger.error(f"读取适配器配置失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config")
|
@router.post("/adapter-config")
|
||||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
async def save_adapter_config(data: PathBody):
|
||||||
"""保存适配器配置到指定路径"""
|
"""保存适配器配置到指定路径"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
@@ -579,17 +584,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
|||||||
|
|
||||||
# 将路径规范化为绝对路径
|
# 将路径规范化为绝对路径
|
||||||
abs_path = _normalize_adapter_path(path)
|
abs_path = _normalize_adapter_path(path)
|
||||||
|
|
||||||
# 检查文件扩展名
|
# 检查文件扩展名
|
||||||
if not abs_path.endswith(".toml"):
|
if not abs_path.endswith(".toml"):
|
||||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||||
|
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
try:
|
try:
|
||||||
import toml
|
tomlkit.loads(content)
|
||||||
toml.loads(content)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
dir_path = os.path.dirname(abs_path)
|
dir_path = os.path.dirname(abs_path)
|
||||||
@@ -607,5 +611,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存适配器配置失败: {e}")
|
logger.error(f"保存适配器配置失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
|
|||||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||||
# 单行文档字符串
|
# 单行文档字符串
|
||||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||||
else:
|
else:
|
||||||
# 多行文档字符串
|
# 多行文档字符串
|
||||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||||
@@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
|
|||||||
next_line = lines[i + 1].strip()
|
next_line = lines[i + 1].strip()
|
||||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||||
else:
|
else:
|
||||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||||
description_lines.append(next_line.strip(quote).strip())
|
description_lines.append(next_line.strip(quote).strip())
|
||||||
@@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
|
|||||||
return FieldType.ARRAY, None, items
|
return FieldType.ARRAY, None, items
|
||||||
|
|
||||||
# 处理基本类型
|
# 处理基本类型
|
||||||
if field_type is bool or field_type == bool:
|
if field_type is bool:
|
||||||
return FieldType.BOOLEAN, None, None
|
return FieldType.BOOLEAN, None, None
|
||||||
elif field_type is int or field_type == int:
|
elif field_type is int:
|
||||||
return FieldType.INTEGER, None, None
|
return FieldType.INTEGER, None, None
|
||||||
elif field_type is float or field_type == float:
|
elif field_type is float:
|
||||||
return FieldType.NUMBER, None, None
|
return FieldType.NUMBER, None, None
|
||||||
elif field_type is str or field_type == str:
|
elif field_type is str:
|
||||||
return FieldType.STRING, None, None
|
return FieldType.STRING, None, None
|
||||||
elif field_type is dict or origin is dict:
|
elif field_type is dict or origin is dict:
|
||||||
return FieldType.OBJECT, None, None
|
return FieldType.OBJECT, None, None
|
||||||
|
|||||||
@@ -1,18 +1,27 @@
|
|||||||
"""表情包管理 API 路由"""
|
"""表情包管理 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query
|
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Annotated
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
|
||||||
logger = get_logger("webui.emoji")
|
logger = get_logger("webui.emoji")
|
||||||
|
|
||||||
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
|
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||||
|
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||||
|
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||||
|
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||||
|
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||||
|
|
||||||
# 创建路由器
|
# 创建路由器
|
||||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||||
|
|
||||||
@@ -572,3 +581,290 @@ async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Option
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"批量删除表情包失败: {e}")
|
logger.exception(f"批量删除表情包失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# 表情包存储目录
|
||||||
|
EMOJI_REGISTERED_DIR = os.path.join("data", "emoji_registed")
|
||||||
|
|
||||||
|
|
||||||
|
class EmojiUploadResponse(BaseModel):
|
||||||
|
"""表情包上传响应"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Optional[EmojiResponse] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/upload", response_model=EmojiUploadResponse)
|
||||||
|
async def upload_emoji(
|
||||||
|
file: EmojiFile,
|
||||||
|
description: DescriptionForm = "",
|
||||||
|
emotion: EmotionForm = "",
|
||||||
|
is_registered: IsRegisteredForm = True,
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传并注册表情包
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 表情包图片文件 (支持 jpg, jpeg, png, gif, webp)
|
||||||
|
description: 表情包描述
|
||||||
|
emotion: 情感标签,多个用逗号分隔
|
||||||
|
is_registered: 是否直接注册,默认为 True
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
上传结果和表情包信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
# 验证文件类型
|
||||||
|
if not file.content_type:
|
||||||
|
raise HTTPException(status_code=400, detail="无法识别文件类型")
|
||||||
|
|
||||||
|
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||||
|
if file.content_type not in allowed_types:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"不支持的文件类型: {file.content_type},支持: {', '.join(allowed_types)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
file_content = await file.read()
|
||||||
|
|
||||||
|
if not file_content:
|
||||||
|
raise HTTPException(status_code=400, detail="文件内容为空")
|
||||||
|
|
||||||
|
# 验证图片并获取格式
|
||||||
|
try:
|
||||||
|
with Image.open(io.BytesIO(file_content)) as img:
|
||||||
|
img_format = img.format.lower() if img.format else "png"
|
||||||
|
# 验证图片可以正常打开
|
||||||
|
img.verify()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"无效的图片文件: {str(e)}") from e
|
||||||
|
|
||||||
|
# 重新打开图片(verify后需要重新打开)
|
||||||
|
with Image.open(io.BytesIO(file_content)) as img:
|
||||||
|
img_format = img.format.lower() if img.format else "png"
|
||||||
|
|
||||||
|
# 计算文件哈希
|
||||||
|
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||||
|
|
||||||
|
# 检查是否已存在相同哈希的表情包
|
||||||
|
existing_emoji = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||||
|
if existing_emoji:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail=f"已存在相同的表情包 (ID: {existing_emoji.id})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
timestamp = int(time.time())
|
||||||
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||||
|
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||||
|
|
||||||
|
# 如果文件已存在,添加随机后缀
|
||||||
|
counter = 1
|
||||||
|
while os.path.exists(full_path):
|
||||||
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||||
|
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
with open(full_path, "wb") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
logger.info(f"表情包文件已保存: {full_path}")
|
||||||
|
|
||||||
|
# 处理情感标签
|
||||||
|
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||||
|
|
||||||
|
# 创建数据库记录
|
||||||
|
current_time = time.time()
|
||||||
|
emoji = Emoji.create(
|
||||||
|
full_path=full_path,
|
||||||
|
format=img_format,
|
||||||
|
emoji_hash=emoji_hash,
|
||||||
|
description=description,
|
||||||
|
emotion=emotion_str,
|
||||||
|
query_count=0,
|
||||||
|
is_registered=is_registered,
|
||||||
|
is_banned=False,
|
||||||
|
record_time=current_time,
|
||||||
|
register_time=current_time if is_registered else None,
|
||||||
|
usage_count=0,
|
||||||
|
last_used_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"表情包已上传并注册: ID={emoji.id}, hash={emoji_hash}")
|
||||||
|
|
||||||
|
return EmojiUploadResponse(
|
||||||
|
success=True,
|
||||||
|
message="表情包上传成功" + ("并已注册" if is_registered else ""),
|
||||||
|
data=emoji_to_response(emoji),
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"上传表情包失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch/upload")
|
||||||
|
async def batch_upload_emoji(
|
||||||
|
files: EmojiFiles,
|
||||||
|
emotion: EmotionForm = "",
|
||||||
|
is_registered: IsRegisteredForm = True,
|
||||||
|
authorization: Optional[str] = Header(None),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
批量上传表情包
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: 多个表情包图片文件
|
||||||
|
emotion: 共用的情感标签
|
||||||
|
is_registered: 是否直接注册
|
||||||
|
authorization: Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量上传结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
verify_auth_token(authorization)
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"success": True,
|
||||||
|
"total": len(files),
|
||||||
|
"uploaded": 0,
|
||||||
|
"failed": 0,
|
||||||
|
"details": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed_types = ["image/jpeg", "image/png", "image/gif", "image/webp"]
|
||||||
|
os.makedirs(EMOJI_REGISTERED_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
try:
|
||||||
|
# 验证文件类型
|
||||||
|
if file.content_type not in allowed_types:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": False,
|
||||||
|
"error": f"不支持的文件类型: {file.content_type}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
file_content = await file.read()
|
||||||
|
|
||||||
|
if not file_content:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": False,
|
||||||
|
"error": "文件内容为空",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 验证图片
|
||||||
|
try:
|
||||||
|
with Image.open(io.BytesIO(file_content)) as img:
|
||||||
|
img_format = img.format.lower() if img.format else "png"
|
||||||
|
except Exception as e:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": False,
|
||||||
|
"error": f"无效的图片: {str(e)}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算哈希
|
||||||
|
emoji_hash = hashlib.md5(file_content).hexdigest()
|
||||||
|
|
||||||
|
# 检查重复
|
||||||
|
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||||
|
results["failed"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": False,
|
||||||
|
"error": "已存在相同的表情包",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成文件名并保存
|
||||||
|
timestamp = int(time.time())
|
||||||
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}.{img_format}"
|
||||||
|
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||||
|
|
||||||
|
counter = 1
|
||||||
|
while os.path.exists(full_path):
|
||||||
|
filename = f"emoji_{timestamp}_{emoji_hash[:8]}_{counter}.{img_format}"
|
||||||
|
full_path = os.path.join(EMOJI_REGISTERED_DIR, filename)
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
with open(full_path, "wb") as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
# 处理情感标签
|
||||||
|
emotion_str = ",".join(e.strip() for e in emotion.split(",") if e.strip()) if emotion else ""
|
||||||
|
|
||||||
|
# 创建数据库记录
|
||||||
|
current_time = time.time()
|
||||||
|
emoji = Emoji.create(
|
||||||
|
full_path=full_path,
|
||||||
|
format=img_format,
|
||||||
|
emoji_hash=emoji_hash,
|
||||||
|
description="", # 批量上传暂不设置描述
|
||||||
|
emotion=emotion_str,
|
||||||
|
query_count=0,
|
||||||
|
is_registered=is_registered,
|
||||||
|
is_banned=False,
|
||||||
|
record_time=current_time,
|
||||||
|
register_time=current_time if is_registered else None,
|
||||||
|
usage_count=0,
|
||||||
|
last_used_time=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
results["uploaded"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": True,
|
||||||
|
"id": emoji.id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results["failed"] += 1
|
||||||
|
results["details"].append(
|
||||||
|
{
|
||||||
|
"filename": file.filename,
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||||
|
return results
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"批量上传表情包失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e
|
||||||
|
|||||||
@@ -602,9 +602,9 @@ class GitMirrorService:
|
|||||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def run_git_clone():
|
def run_git_clone(clone_cmd=cmd):
|
||||||
return subprocess.run(
|
return subprocess.run(
|
||||||
cmd,
|
clone_cmd,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300, # 5分钟超时
|
timeout=300, # 5分钟超时
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""知识库图谱可视化 API 路由"""
|
"""知识库图谱可视化 API 路由"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
|||||||
|
|
||||||
class KnowledgeNode(BaseModel):
|
class KnowledgeNode(BaseModel):
|
||||||
"""知识节点"""
|
"""知识节点"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
type: str # 'entity' or 'paragraph'
|
type: str # 'entity' or 'paragraph'
|
||||||
content: str
|
content: str
|
||||||
@@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
|
|||||||
|
|
||||||
class KnowledgeEdge(BaseModel):
|
class KnowledgeEdge(BaseModel):
|
||||||
"""知识边"""
|
"""知识边"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
weight: float
|
weight: float
|
||||||
@@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
|
|||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
"""知识图谱"""
|
"""知识图谱"""
|
||||||
|
|
||||||
nodes: List[KnowledgeNode]
|
nodes: List[KnowledgeNode]
|
||||||
edges: List[KnowledgeEdge]
|
edges: List[KnowledgeEdge]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStats(BaseModel):
|
class KnowledgeStats(BaseModel):
|
||||||
"""知识库统计信息"""
|
"""知识库统计信息"""
|
||||||
|
|
||||||
total_nodes: int
|
total_nodes: int
|
||||||
total_edges: int
|
total_edges: int
|
||||||
entity_nodes: int
|
entity_nodes: int
|
||||||
@@ -45,7 +50,7 @@ def _load_kg_manager():
|
|||||||
"""延迟加载 KGManager"""
|
"""延迟加载 KGManager"""
|
||||||
try:
|
try:
|
||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
|
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
return kg_manager
|
return kg_manager
|
||||||
@@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||||||
"""将 DiGraph 转换为 JSON 格式"""
|
"""将 DiGraph 转换为 JSON 格式"""
|
||||||
if kg_manager is None or kg_manager.graph is None:
|
if kg_manager is None or kg_manager.graph is None:
|
||||||
return KnowledgeGraph(nodes=[], edges=[])
|
return KnowledgeGraph(nodes=[], edges=[])
|
||||||
|
|
||||||
graph = kg_manager.graph
|
graph = kg_manager.graph
|
||||||
nodes = []
|
nodes = []
|
||||||
edges = []
|
edges = []
|
||||||
|
|
||||||
# 转换节点
|
# 转换节点
|
||||||
node_list = graph.get_node_list()
|
node_list = graph.get_node_list()
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
|
|
||||||
nodes.append(KnowledgeNode(
|
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 转换边
|
# 转换边
|
||||||
edge_list = graph.get_edge_list()
|
edge_list = graph.get_edge_list()
|
||||||
for edge_tuple in edge_list:
|
for edge_tuple in edge_list:
|
||||||
@@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
|||||||
source, target = edge_tuple[0], edge_tuple[1]
|
source, target = edge_tuple[0], edge_tuple[1]
|
||||||
# 通过 graph[source, target] 获取边的属性数据
|
# 通过 graph[source, target] 获取边的属性数据
|
||||||
edge_data = graph[source, target]
|
edge_data = graph[source, target]
|
||||||
|
|
||||||
# edge_data 支持 [] 操作符但不支持 .get()
|
# edge_data 支持 [] 操作符但不支持 .get()
|
||||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||||
|
|
||||||
edges.append(KnowledgeEdge(
|
edges.append(
|
||||||
source=source,
|
KnowledgeEdge(
|
||||||
target=target,
|
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||||
weight=weight,
|
)
|
||||||
create_time=create_time,
|
)
|
||||||
update_time=update_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/graph", response_model=KnowledgeGraph)
|
@router.get("/graph", response_model=KnowledgeGraph)
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
|
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||||
):
|
):
|
||||||
"""获取知识图谱(限制节点数量)
|
"""获取知识图谱(限制节点数量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
limit: 返回的最大节点数,默认 100,最大 10000
|
limit: 返回的最大节点数,默认 100,最大 10000
|
||||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||||
"""
|
"""
|
||||||
@@ -130,46 +128,43 @@ async def get_knowledge_graph(
|
|||||||
if kg_manager is None:
|
if kg_manager is None:
|
||||||
logger.warning("KGManager 未初始化,返回空图谱")
|
logger.warning("KGManager 未初始化,返回空图谱")
|
||||||
return KnowledgeGraph(nodes=[], edges=[])
|
return KnowledgeGraph(nodes=[], edges=[])
|
||||||
|
|
||||||
graph = kg_manager.graph
|
graph = kg_manager.graph
|
||||||
all_node_list = graph.get_node_list()
|
all_node_list = graph.get_node_list()
|
||||||
|
|
||||||
# 按类型过滤节点
|
# 按类型过滤节点
|
||||||
if node_type == "entity":
|
if node_type == "entity":
|
||||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
|
all_node_list = [
|
||||||
|
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||||
|
]
|
||||||
elif node_type == "paragraph":
|
elif node_type == "paragraph":
|
||||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
|
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||||
|
|
||||||
# 限制节点数量
|
# 限制节点数量
|
||||||
total_nodes = len(all_node_list)
|
total_nodes = len(all_node_list)
|
||||||
if len(all_node_list) > limit:
|
if len(all_node_list) > limit:
|
||||||
node_list = all_node_list[:limit]
|
node_list = all_node_list[:limit]
|
||||||
else:
|
else:
|
||||||
node_list = all_node_list
|
node_list = all_node_list
|
||||||
|
|
||||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||||
|
|
||||||
# 转换节点
|
# 转换节点
|
||||||
nodes = []
|
nodes = []
|
||||||
node_ids = set()
|
node_ids = set()
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
|
|
||||||
nodes.append(KnowledgeNode(
|
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type_val,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
node_ids.add(node_id)
|
node_ids.add(node_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||||
edges = []
|
edges = []
|
||||||
edge_list = graph.get_edge_list()
|
edge_list = graph.get_edge_list()
|
||||||
@@ -179,27 +174,25 @@ async def get_knowledge_graph(
|
|||||||
# 只包含两端都在当前节点集中的边
|
# 只包含两端都在当前节点集中的边
|
||||||
if source not in node_ids or target not in node_ids:
|
if source not in node_ids or target not in node_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
edge_data = graph[source, target]
|
edge_data = graph[source, target]
|
||||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||||
|
|
||||||
edges.append(KnowledgeEdge(
|
edges.append(
|
||||||
source=source,
|
KnowledgeEdge(
|
||||||
target=target,
|
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||||
weight=weight,
|
)
|
||||||
create_time=create_time,
|
)
|
||||||
update_time=update_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||||
return graph_data
|
return graph_data
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||||
return KnowledgeGraph(nodes=[], edges=[])
|
return KnowledgeGraph(nodes=[], edges=[])
|
||||||
@@ -208,71 +201,59 @@ async def get_knowledge_graph(
|
|||||||
@router.get("/stats", response_model=KnowledgeStats)
|
@router.get("/stats", response_model=KnowledgeStats)
|
||||||
async def get_knowledge_stats():
|
async def get_knowledge_stats():
|
||||||
"""获取知识库统计信息
|
"""获取知识库统计信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
KnowledgeStats: 统计信息
|
KnowledgeStats: 统计信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
kg_manager = _load_kg_manager()
|
kg_manager = _load_kg_manager()
|
||||||
if kg_manager is None or kg_manager.graph is None:
|
if kg_manager is None or kg_manager.graph is None:
|
||||||
return KnowledgeStats(
|
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||||
total_nodes=0,
|
|
||||||
total_edges=0,
|
|
||||||
entity_nodes=0,
|
|
||||||
paragraph_nodes=0,
|
|
||||||
avg_connections=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = kg_manager.graph
|
graph = kg_manager.graph
|
||||||
node_list = graph.get_node_list()
|
node_list = graph.get_node_list()
|
||||||
edge_list = graph.get_edge_list()
|
edge_list = graph.get_edge_list()
|
||||||
|
|
||||||
total_nodes = len(node_list)
|
total_nodes = len(node_list)
|
||||||
total_edges = len(edge_list)
|
total_edges = len(edge_list)
|
||||||
|
|
||||||
# 统计节点类型
|
# 统计节点类型
|
||||||
entity_nodes = 0
|
entity_nodes = 0
|
||||||
paragraph_nodes = 0
|
paragraph_nodes = 0
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
node_type = node_data['type'] if 'type' in node_data else 'ent'
|
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||||
if node_type == 'ent':
|
if node_type == "ent":
|
||||||
entity_nodes += 1
|
entity_nodes += 1
|
||||||
elif node_type == 'pg':
|
elif node_type == "pg":
|
||||||
paragraph_nodes += 1
|
paragraph_nodes += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算平均连接数
|
# 计算平均连接数
|
||||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||||
|
|
||||||
return KnowledgeStats(
|
return KnowledgeStats(
|
||||||
total_nodes=total_nodes,
|
total_nodes=total_nodes,
|
||||||
total_edges=total_edges,
|
total_edges=total_edges,
|
||||||
entity_nodes=entity_nodes,
|
entity_nodes=entity_nodes,
|
||||||
paragraph_nodes=paragraph_nodes,
|
paragraph_nodes=paragraph_nodes,
|
||||||
avg_connections=round(avg_connections, 2)
|
avg_connections=round(avg_connections, 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||||
return KnowledgeStats(
|
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||||
total_nodes=0,
|
|
||||||
total_edges=0,
|
|
||||||
entity_nodes=0,
|
|
||||||
paragraph_nodes=0,
|
|
||||||
avg_connections=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=List[KnowledgeNode])
|
@router.get("/search", response_model=List[KnowledgeNode])
|
||||||
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||||
"""搜索知识节点
|
"""搜索知识节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 搜索关键词
|
query: 搜索关键词
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[KnowledgeNode]: 匹配的节点列表
|
List[KnowledgeNode]: 匹配的节点列表
|
||||||
"""
|
"""
|
||||||
@@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
|||||||
kg_manager = _load_kg_manager()
|
kg_manager = _load_kg_manager()
|
||||||
if kg_manager is None or kg_manager.graph is None:
|
if kg_manager is None or kg_manager.graph is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
graph = kg_manager.graph
|
graph = kg_manager.graph
|
||||||
node_list = graph.get_node_list()
|
node_list = graph.get_node_list()
|
||||||
results = []
|
results = []
|
||||||
query_lower = query.lower()
|
query_lower = query.lower()
|
||||||
|
|
||||||
# 在节点内容中搜索
|
# 在节点内容中搜索
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
|
|
||||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
results.append(KnowledgeNode(
|
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||||
return results[:50] # 限制返回数量
|
return results[:50] # 限制返回数量
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
|
|||||||
def _parse_openai_response(data: dict) -> list[dict]:
|
def _parse_openai_response(data: dict) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
解析 OpenAI 格式的模型列表响应
|
解析 OpenAI 格式的模型列表响应
|
||||||
|
|
||||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||||
"""
|
"""
|
||||||
models = []
|
models = []
|
||||||
if "data" in data and isinstance(data["data"], list):
|
if "data" in data and isinstance(data["data"], list):
|
||||||
for model in data["data"]:
|
for model in data["data"]:
|
||||||
if isinstance(model, dict) and "id" in model:
|
if isinstance(model, dict) and "id" in model:
|
||||||
models.append({
|
models.append(
|
||||||
"id": model["id"],
|
{
|
||||||
"name": model.get("name") or model["id"],
|
"id": model["id"],
|
||||||
"owned_by": model.get("owned_by", ""),
|
"name": model.get("name") or model["id"],
|
||||||
})
|
"owned_by": model.get("owned_by", ""),
|
||||||
|
}
|
||||||
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
解析 Gemini 格式的模型列表响应
|
解析 Gemini 格式的模型列表响应
|
||||||
|
|
||||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||||
"""
|
"""
|
||||||
models = []
|
models = []
|
||||||
@@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
|
|||||||
model_id = model["name"]
|
model_id = model["name"]
|
||||||
if model_id.startswith("models/"):
|
if model_id.startswith("models/"):
|
||||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||||
models.append({
|
models.append(
|
||||||
"id": model_id,
|
{
|
||||||
"name": model.get("displayName") or model_id,
|
"id": model_id,
|
||||||
"owned_by": "google",
|
"name": model.get("displayName") or model_id,
|
||||||
})
|
"owned_by": "google",
|
||||||
|
}
|
||||||
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
@@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
|
|||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
从提供商 API 获取模型列表
|
从提供商 API 获取模型列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_url: 提供商的基础 URL
|
base_url: 提供商的基础 URL
|
||||||
api_key: API 密钥
|
api_key: API 密钥
|
||||||
endpoint: 获取模型列表的端点
|
endpoint: 获取模型列表的端点
|
||||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||||
client_type: 客户端类型 ('openai' | 'gemini')
|
client_type: 客户端类型 ('openai' | 'gemini')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型列表
|
模型列表
|
||||||
"""
|
"""
|
||||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||||
|
|
||||||
# 根据客户端类型设置请求头
|
# 根据客户端类型设置请求头
|
||||||
headers = {}
|
headers = {}
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
if client_type == "gemini":
|
if client_type == "gemini":
|
||||||
# Gemini 使用 URL 参数传递 API Key
|
# Gemini 使用 URL 参数传递 API Key
|
||||||
params["key"] = api_key
|
params["key"] = api_key
|
||||||
else:
|
else:
|
||||||
# OpenAI 兼容格式使用 Authorization 头
|
# OpenAI 兼容格式使用 Authorization 头
|
||||||
headers["Authorization"] = f"Bearer {api_key}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
response = await client.get(url, headers=headers, params=params)
|
response = await client.get(url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException as e:
|
||||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
|
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||||
if e.response.status_code == 401:
|
if e.response.status_code == 401:
|
||||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
|
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||||
elif e.response.status_code == 403:
|
elif e.response.status_code == 403:
|
||||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
|
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||||
elif e.response.status_code == 404:
|
elif e.response.status_code == 404:
|
||||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
|
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=502,
|
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||||
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
) from e
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模型列表失败: {e}")
|
logger.error(f"获取模型列表失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||||
|
|
||||||
# 根据解析器类型解析响应
|
# 根据解析器类型解析响应
|
||||||
if parser == "openai":
|
if parser == "openai":
|
||||||
return _parse_openai_response(data)
|
return _parse_openai_response(data)
|
||||||
@@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
|
|||||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
从 model_config.toml 获取指定提供商的配置
|
从 model_config.toml 获取指定提供商的配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
provider_name: 提供商名称
|
provider_name: 提供商名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提供商配置,如果未找到则返回 None
|
提供商配置,如果未找到则返回 None
|
||||||
"""
|
"""
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
config_data = tomlkit.load(f)
|
config_data = tomlkit.load(f)
|
||||||
|
|
||||||
providers = config_data.get("api_providers", [])
|
providers = config_data.get("api_providers", [])
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.get("name") == provider_name:
|
if provider.get("name") == provider_name:
|
||||||
return dict(provider)
|
return dict(provider)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取提供商配置失败: {e}")
|
logger.error(f"读取提供商配置失败: {e}")
|
||||||
@@ -184,23 +187,23 @@ async def get_provider_models(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定提供商的可用模型列表
|
获取指定提供商的可用模型列表
|
||||||
|
|
||||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||||
"""
|
"""
|
||||||
# 获取提供商配置
|
# 获取提供商配置
|
||||||
provider_config = _get_provider_config(provider_name)
|
provider_config = _get_provider_config(provider_name)
|
||||||
if not provider_config:
|
if not provider_config:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||||
|
|
||||||
base_url = provider_config.get("base_url")
|
base_url = provider_config.get("base_url")
|
||||||
api_key = provider_config.get("api_key")
|
api_key = provider_config.get("api_key")
|
||||||
client_type = provider_config.get("client_type", "openai")
|
client_type = provider_config.get("client_type", "openai")
|
||||||
|
|
||||||
if not base_url:
|
if not base_url:
|
||||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||||
|
|
||||||
# 获取模型列表
|
# 获取模型列表
|
||||||
models = await _fetch_models_from_provider(
|
models = await _fetch_models_from_provider(
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
@@ -209,7 +212,7 @@ async def get_provider_models(
|
|||||||
parser=parser,
|
parser=parser,
|
||||||
client_type=client_type,
|
client_type=client_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"models": models,
|
"models": models,
|
||||||
@@ -236,9 +239,132 @@ async def get_models_by_url(
|
|||||||
parser=parser,
|
parser=parser,
|
||||||
client_type=client_type,
|
client_type=client_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"models": models,
|
"models": models,
|
||||||
"count": len(models),
|
"count": len(models),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/test-connection")
|
||||||
|
async def test_provider_connection(
|
||||||
|
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||||
|
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
测试提供商连接状态
|
||||||
|
|
||||||
|
分两步测试:
|
||||||
|
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||||
|
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- network_ok: 网络是否连通
|
||||||
|
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||||
|
- latency_ms: 响应延迟(毫秒)
|
||||||
|
- error: 错误信息(如果有)
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
base_url = _normalize_url(base_url)
|
||||||
|
if not base_url:
|
||||||
|
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"network_ok": False,
|
||||||
|
"api_key_valid": None,
|
||||||
|
"latency_ms": None,
|
||||||
|
"error": None,
|
||||||
|
"http_status": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 第一步:测试网络连通性
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||||
|
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||||
|
response = await client.get(base_url)
|
||||||
|
latency = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
result["network_ok"] = True
|
||||||
|
result["latency_ms"] = round(latency, 2)
|
||||||
|
result["http_status"] = response.status_code
|
||||||
|
|
||||||
|
except httpx.ConnectError as e:
|
||||||
|
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||||
|
return result
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
result["error"] = "连接超时:服务器响应时间过长"
|
||||||
|
return result
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
result["error"] = f"请求错误:{str(e)}"
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = f"未知错误:{str(e)}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 第二步:如果提供了 API Key,验证其有效性
|
||||||
|
if api_key:
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
# 尝试获取模型列表
|
||||||
|
models_url = f"{base_url}/models"
|
||||||
|
response = await client.get(models_url, headers=headers)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
result["api_key_valid"] = True
|
||||||
|
elif response.status_code in (401, 403):
|
||||||
|
result["api_key_valid"] = False
|
||||||
|
result["error"] = "API Key 无效或已过期"
|
||||||
|
else:
|
||||||
|
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||||
|
result["api_key_valid"] = None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# API Key 验证失败不影响网络连通性结果
|
||||||
|
logger.warning(f"API Key 验证失败: {e}")
|
||||||
|
result["api_key_valid"] = None
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/test-connection-by-name")
|
||||||
|
async def test_provider_connection_by_name(
|
||||||
|
provider_name: str = Query(..., description="提供商名称"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
通过提供商名称测试连接(从配置文件读取信息)
|
||||||
|
"""
|
||||||
|
# 读取配置文件
|
||||||
|
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
|
if not os.path.exists(model_config_path):
|
||||||
|
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||||
|
|
||||||
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = tomlkit.load(f)
|
||||||
|
|
||||||
|
# 查找提供商
|
||||||
|
providers = config.get("api_providers", [])
|
||||||
|
provider = None
|
||||||
|
for p in providers:
|
||||||
|
if p.get("name") == provider_name:
|
||||||
|
provider = p
|
||||||
|
break
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||||
|
|
||||||
|
base_url = provider.get("base_url", "")
|
||||||
|
api_key = provider.get("api_key", "")
|
||||||
|
|
||||||
|
if not base_url:
|
||||||
|
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||||
|
|
||||||
|
# 调用测试接口
|
||||||
|
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Optional, List, Dict, Any
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
import json
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.toml_utils import save_toml_with_format
|
||||||
from src.config.config import MMC_VERSION
|
from src.config.config import MMC_VERSION
|
||||||
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
@@ -29,8 +30,11 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
|||||||
Returns:
|
Returns:
|
||||||
(major, minor, patch) 三元组
|
(major, minor, patch) 三元组
|
||||||
"""
|
"""
|
||||||
# 移除 snapshot 等后缀
|
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
||||||
base_version = version_str.split(".snapshot")[0].split(".dev")[0].split(".alpha")[0].split(".beta")[0]
|
import re
|
||||||
|
|
||||||
|
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
||||||
|
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||||
|
|
||||||
parts = base_version.split(".")
|
parts = base_version.split(".")
|
||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
@@ -611,7 +615,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
|
|||||||
for field in required_fields:
|
for field in required_fields:
|
||||||
if field not in manifest:
|
if field not in manifest:
|
||||||
raise ValueError(f"缺少必需字段: {field}")
|
raise ValueError(f"缺少必需字段: {field}")
|
||||||
|
|
||||||
# 将插件 ID 写入 manifest(用于后续准确识别)
|
# 将插件 ID 写入 manifest(用于后续准确识别)
|
||||||
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
|
||||||
manifest["id"] = request.plugin_id
|
manifest["id"] = request.plugin_id
|
||||||
@@ -703,7 +707,7 @@ async def uninstall_plugin(
|
|||||||
plugin_path = plugins_dir / folder_name
|
plugin_path = plugins_dir / folder_name
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = plugins_dir / request.plugin_id
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
if old_format_path.exists():
|
if old_format_path.exists():
|
||||||
@@ -837,7 +841,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
|
|||||||
plugin_path = plugins_dir / folder_name
|
plugin_path = plugins_dir / folder_name
|
||||||
# 旧格式:点
|
# 旧格式:点
|
||||||
old_format_path = plugins_dir / request.plugin_id
|
old_format_path = plugins_dir / request.plugin_id
|
||||||
|
|
||||||
# 优先使用新格式,如果不存在则尝试旧格式
|
# 优先使用新格式,如果不存在则尝试旧格式
|
||||||
if not plugin_path.exists():
|
if not plugin_path.exists():
|
||||||
if old_format_path.exists():
|
if old_format_path.exists():
|
||||||
@@ -1090,21 +1094,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
|||||||
# 尝试从 author.name 和 repository_url 构建标准 ID
|
# 尝试从 author.name 和 repository_url 构建标准 ID
|
||||||
author_name = None
|
author_name = None
|
||||||
repo_name = None
|
repo_name = None
|
||||||
|
|
||||||
# 获取作者名
|
# 获取作者名
|
||||||
if "author" in manifest:
|
if "author" in manifest:
|
||||||
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
|
if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
|
||||||
author_name = manifest["author"]["name"]
|
author_name = manifest["author"]["name"]
|
||||||
elif isinstance(manifest["author"], str):
|
elif isinstance(manifest["author"], str):
|
||||||
author_name = manifest["author"]
|
author_name = manifest["author"]
|
||||||
|
|
||||||
# 从 repository_url 获取仓库名
|
# 从 repository_url 获取仓库名
|
||||||
if "repository_url" in manifest:
|
if "repository_url" in manifest:
|
||||||
repo_url = manifest["repository_url"].rstrip("/")
|
repo_url = manifest["repository_url"].rstrip("/")
|
||||||
if repo_url.endswith(".git"):
|
if repo_url.endswith(".git"):
|
||||||
repo_url = repo_url[:-4]
|
repo_url = repo_url[:-4]
|
||||||
repo_name = repo_url.split("/")[-1]
|
repo_name = repo_url.split("/")[-1]
|
||||||
|
|
||||||
# 构建 ID
|
# 构建 ID
|
||||||
if author_name and repo_name:
|
if author_name and repo_name:
|
||||||
# 标准格式: Author.RepoName
|
# 标准格式: Author.RepoName
|
||||||
@@ -1120,7 +1124,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
|||||||
else:
|
else:
|
||||||
# 直接使用文件夹名
|
# 直接使用文件夹名
|
||||||
plugin_id = folder_name
|
plugin_id = folder_name
|
||||||
|
|
||||||
# 将推断的 ID 写入 manifest(方便下次识别)
|
# 将推断的 ID 写入 manifest(方便下次识别)
|
||||||
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
|
||||||
manifest["id"] = plugin_id
|
manifest["id"] = plugin_id
|
||||||
@@ -1153,3 +1157,408 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
logger.error(f"获取已安装插件列表失败: {e}", exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
# ============ 插件配置管理 API ============
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatePluginConfigRequest(BaseModel):
|
||||||
|
"""更新插件配置请求"""
|
||||||
|
|
||||||
|
config: Dict[str, Any] = Field(..., description="配置数据")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/{plugin_id}/schema")
|
||||||
|
async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取插件配置 Schema
|
||||||
|
|
||||||
|
返回插件的完整配置 schema,包含所有 section、字段定义和布局信息。
|
||||||
|
用于前端动态生成配置表单。
|
||||||
|
"""
|
||||||
|
# Token 验证
|
||||||
|
token = authorization.replace("Bearer ", "") if authorization else None
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
|
logger.info(f"获取插件配置 Schema: {plugin_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 尝试从已加载的插件中获取
|
||||||
|
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||||
|
|
||||||
|
# 查找插件实例
|
||||||
|
plugin_instance = None
|
||||||
|
|
||||||
|
# 遍历所有已加载的插件
|
||||||
|
for loaded_plugin_name in plugin_manager.list_loaded_plugins():
|
||||||
|
instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
|
||||||
|
if instance:
|
||||||
|
# 匹配 plugin_name 或 manifest 中的 id
|
||||||
|
if instance.plugin_name == plugin_id:
|
||||||
|
plugin_instance = instance
|
||||||
|
break
|
||||||
|
# 也尝试匹配 manifest 中的 id
|
||||||
|
manifest_id = instance.get_manifest_info("id", "")
|
||||||
|
if manifest_id == plugin_id:
|
||||||
|
plugin_instance = instance
|
||||||
|
break
|
||||||
|
|
||||||
|
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||||
|
# 从插件实例获取 schema
|
||||||
|
schema = plugin_instance.get_webui_config_schema()
|
||||||
|
return {"success": True, "schema": schema}
|
||||||
|
|
||||||
|
# 如果插件未加载,尝试从文件系统读取
|
||||||
|
# 查找插件目录
|
||||||
|
plugins_dir = Path("plugins")
|
||||||
|
plugin_path = None
|
||||||
|
|
||||||
|
for p in plugins_dir.iterdir():
|
||||||
|
if p.is_dir():
|
||||||
|
manifest_path = p / "_manifest.json"
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||||
|
plugin_path = p
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plugin_path:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||||
|
|
||||||
|
# 读取配置文件获取当前配置
|
||||||
|
config_path = plugin_path / "config.toml"
|
||||||
|
current_config = {}
|
||||||
|
if config_path.exists():
|
||||||
|
import tomlkit
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
current_config = tomlkit.load(f)
|
||||||
|
|
||||||
|
# 构建基础 schema(无法获取完整的 ConfigField 信息)
|
||||||
|
schema = {
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"plugin_info": {
|
||||||
|
"name": plugin_id,
|
||||||
|
"version": "",
|
||||||
|
"description": "",
|
||||||
|
"author": "",
|
||||||
|
},
|
||||||
|
"sections": {},
|
||||||
|
"layout": {"type": "auto", "tabs": []},
|
||||||
|
"_note": "插件未加载,仅返回当前配置结构",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 从当前配置推断 schema
|
||||||
|
for section_name, section_data in current_config.items():
|
||||||
|
if isinstance(section_data, dict):
|
||||||
|
schema["sections"][section_name] = {
|
||||||
|
"name": section_name,
|
||||||
|
"title": section_name,
|
||||||
|
"description": None,
|
||||||
|
"icon": None,
|
||||||
|
"collapsed": False,
|
||||||
|
"order": 0,
|
||||||
|
"fields": {},
|
||||||
|
}
|
||||||
|
for field_name, field_value in section_data.items():
|
||||||
|
# 推断字段类型
|
||||||
|
field_type = type(field_value).__name__
|
||||||
|
ui_type = "text"
|
||||||
|
if isinstance(field_value, bool):
|
||||||
|
ui_type = "switch"
|
||||||
|
elif isinstance(field_value, (int, float)):
|
||||||
|
ui_type = "number"
|
||||||
|
elif isinstance(field_value, list):
|
||||||
|
ui_type = "list"
|
||||||
|
elif isinstance(field_value, dict):
|
||||||
|
ui_type = "json"
|
||||||
|
|
||||||
|
schema["sections"][section_name]["fields"][field_name] = {
|
||||||
|
"name": field_name,
|
||||||
|
"type": field_type,
|
||||||
|
"default": field_value,
|
||||||
|
"description": field_name,
|
||||||
|
"label": field_name,
|
||||||
|
"ui_type": ui_type,
|
||||||
|
"required": False,
|
||||||
|
"hidden": False,
|
||||||
|
"disabled": False,
|
||||||
|
"order": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"success": True, "schema": schema}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取插件配置 Schema 失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/{plugin_id}")
|
||||||
|
async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取插件当前配置值
|
||||||
|
|
||||||
|
返回插件的当前配置值。
|
||||||
|
"""
|
||||||
|
# Token 验证
|
||||||
|
token = authorization.replace("Bearer ", "") if authorization else None
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
|
logger.info(f"获取插件配置: {plugin_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 查找插件目录
|
||||||
|
plugins_dir = Path("plugins")
|
||||||
|
plugin_path = None
|
||||||
|
|
||||||
|
for p in plugins_dir.iterdir():
|
||||||
|
if p.is_dir():
|
||||||
|
manifest_path = p / "_manifest.json"
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||||
|
plugin_path = p
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plugin_path:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||||
|
|
||||||
|
# 读取配置文件
|
||||||
|
config_path = plugin_path / "config.toml"
|
||||||
|
if not config_path.exists():
|
||||||
|
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||||
|
|
||||||
|
import tomlkit
|
||||||
|
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = tomlkit.load(f)
|
||||||
|
|
||||||
|
return {"success": True, "config": dict(config)}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取插件配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/config/{plugin_id}")
|
||||||
|
async def update_plugin_config(
|
||||||
|
plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
更新插件配置
|
||||||
|
|
||||||
|
保存新的配置值到插件的配置文件。
|
||||||
|
"""
|
||||||
|
# Token 验证
|
||||||
|
token = authorization.replace("Bearer ", "") if authorization else None
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
|
logger.info(f"更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 查找插件目录
|
||||||
|
plugins_dir = Path("plugins")
|
||||||
|
plugin_path = None
|
||||||
|
|
||||||
|
for p in plugins_dir.iterdir():
|
||||||
|
if p.is_dir():
|
||||||
|
manifest_path = p / "_manifest.json"
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||||
|
plugin_path = p
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plugin_path:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||||
|
|
||||||
|
config_path = plugin_path / "config.toml"
|
||||||
|
|
||||||
|
# 备份旧配置
|
||||||
|
import shutil
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
if config_path.exists():
|
||||||
|
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
|
backup_path = plugin_path / backup_name
|
||||||
|
shutil.copy(config_path, backup_path)
|
||||||
|
logger.info(f"已备份配置文件: {backup_path}")
|
||||||
|
|
||||||
|
# 写入新配置(使用 tomlkit 保留注释)
|
||||||
|
import tomlkit
|
||||||
|
|
||||||
|
# 先读取原配置以保留注释和格式
|
||||||
|
existing_doc = tomlkit.document()
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
existing_doc = tomlkit.load(f)
|
||||||
|
# 更新值
|
||||||
|
for key, value in request.config.items():
|
||||||
|
existing_doc[key] = value
|
||||||
|
save_toml_with_format(existing_doc, str(config_path))
|
||||||
|
|
||||||
|
logger.info(f"已更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
|
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新插件配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/{plugin_id}/reset")
|
||||||
|
async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
重置插件配置为默认值
|
||||||
|
|
||||||
|
删除当前配置文件,下次加载插件时将使用默认配置。
|
||||||
|
"""
|
||||||
|
# Token 验证
|
||||||
|
token = authorization.replace("Bearer ", "") if authorization else None
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
|
logger.info(f"重置插件配置: {plugin_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 查找插件目录
|
||||||
|
plugins_dir = Path("plugins")
|
||||||
|
plugin_path = None
|
||||||
|
|
||||||
|
for p in plugins_dir.iterdir():
|
||||||
|
if p.is_dir():
|
||||||
|
manifest_path = p / "_manifest.json"
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||||
|
plugin_path = p
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plugin_path:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||||
|
|
||||||
|
config_path = plugin_path / "config.toml"
|
||||||
|
|
||||||
|
if not config_path.exists():
|
||||||
|
return {"success": True, "message": "配置文件不存在,无需重置"}
|
||||||
|
|
||||||
|
# 备份并删除
|
||||||
|
import shutil
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
|
backup_path = plugin_path / backup_name
|
||||||
|
shutil.move(config_path, backup_path)
|
||||||
|
|
||||||
|
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||||
|
|
||||||
|
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重置插件配置失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/{plugin_id}/toggle")
|
||||||
|
async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
切换插件启用状态
|
||||||
|
|
||||||
|
切换插件配置中的 enabled 字段。
|
||||||
|
"""
|
||||||
|
# Token 验证
|
||||||
|
token = authorization.replace("Bearer ", "") if authorization else None
|
||||||
|
token_manager = get_token_manager()
|
||||||
|
if not token or not token_manager.verify_token(token):
|
||||||
|
raise HTTPException(status_code=401, detail="未授权:无效的访问令牌")
|
||||||
|
|
||||||
|
logger.info(f"切换插件状态: {plugin_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 查找插件目录
|
||||||
|
plugins_dir = Path("plugins")
|
||||||
|
plugin_path = None
|
||||||
|
|
||||||
|
for p in plugins_dir.iterdir():
|
||||||
|
if p.is_dir():
|
||||||
|
manifest_path = p / "_manifest.json"
|
||||||
|
if manifest_path.exists():
|
||||||
|
try:
|
||||||
|
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||||
|
manifest = json.load(f)
|
||||||
|
if manifest.get("id") == plugin_id or p.name == plugin_id:
|
||||||
|
plugin_path = p
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not plugin_path:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
|
||||||
|
|
||||||
|
config_path = plugin_path / "config.toml"
|
||||||
|
|
||||||
|
import tomlkit
|
||||||
|
|
||||||
|
# 读取当前配置(保留注释和格式)
|
||||||
|
config = tomlkit.document()
|
||||||
|
if config_path.exists():
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = tomlkit.load(f)
|
||||||
|
|
||||||
|
# 切换 enabled 状态
|
||||||
|
if "plugin" not in config:
|
||||||
|
config["plugin"] = tomlkit.table()
|
||||||
|
|
||||||
|
current_enabled = config["plugin"].get("enabled", True)
|
||||||
|
new_enabled = not current_enabled
|
||||||
|
config["plugin"]["enabled"] = new_enabled
|
||||||
|
|
||||||
|
# 写入配置(保留注释,格式化数组)
|
||||||
|
save_toml_with_format(config, str(config_path))
|
||||||
|
|
||||||
|
status = "启用" if new_enabled else "禁用"
|
||||||
|
logger.info(f"已{status}插件: {plugin_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"enabled": new_enabled,
|
||||||
|
"message": f"插件已{status}",
|
||||||
|
"note": "状态更改将在下次加载插件时生效",
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"切换插件状态失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"服务器错误: {str(e)}") from e
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ async def restart_maibot():
|
|||||||
注意:此操作会使麦麦暂时离线。
|
注意:此操作会使麦麦暂时离线。
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 记录重启操作
|
# 记录重启操作
|
||||||
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
print(f"[{datetime.now()}] WebUI 触发重启操作")
|
||||||
@@ -54,7 +54,7 @@ async def restart_maibot():
|
|||||||
python = sys.executable
|
python = sys.executable
|
||||||
args = [python] + sys.argv
|
args = [python] + sys.argv
|
||||||
os.execv(python, args)
|
os.execv(python, args)
|
||||||
|
|
||||||
# 创建后台任务执行重启
|
# 创建后台任务执行重启
|
||||||
asyncio.create_task(delayed_restart())
|
asyncio.create_task(delayed_restart())
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,10 @@ class WebUIServer:
|
|||||||
self.port = port
|
self.port = port
|
||||||
self.app = FastAPI(title="MaiBot WebUI")
|
self.app = FastAPI(title="MaiBot WebUI")
|
||||||
self._server = None
|
self._server = None
|
||||||
|
|
||||||
# 显示 Access Token
|
# 显示 Access Token
|
||||||
self._show_access_token()
|
self._show_access_token()
|
||||||
|
|
||||||
# 重要:先注册 API 路由,再设置静态文件
|
# 重要:先注册 API 路由,再设置静态文件
|
||||||
self._register_api_routes()
|
self._register_api_routes()
|
||||||
self._setup_static_files()
|
self._setup_static_files()
|
||||||
@@ -32,7 +32,7 @@ class WebUIServer:
|
|||||||
"""显示 WebUI Access Token"""
|
"""显示 WebUI Access Token"""
|
||||||
try:
|
try:
|
||||||
from src.webui.token_manager import get_token_manager
|
from src.webui.token_manager import get_token_manager
|
||||||
|
|
||||||
token_manager = get_token_manager()
|
token_manager = get_token_manager()
|
||||||
current_token = token_manager.get_token()
|
current_token = token_manager.get_token()
|
||||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||||
@@ -69,7 +69,7 @@ class WebUIServer:
|
|||||||
# 如果是根路径,直接返回 index.html
|
# 如果是根路径,直接返回 index.html
|
||||||
if not full_path or full_path == "/":
|
if not full_path or full_path == "/":
|
||||||
return FileResponse(static_path / "index.html", media_type="text/html")
|
return FileResponse(static_path / "index.html", media_type="text/html")
|
||||||
|
|
||||||
# 检查是否是静态文件
|
# 检查是否是静态文件
|
||||||
file_path = static_path / full_path
|
file_path = static_path / full_path
|
||||||
if file_path.is_file() and file_path.exists():
|
if file_path.is_file() and file_path.exists():
|
||||||
@@ -88,15 +88,22 @@ class WebUIServer:
|
|||||||
# 导入所有 WebUI 路由
|
# 导入所有 WebUI 路由
|
||||||
from src.webui.routes import router as webui_router
|
from src.webui.routes import router as webui_router
|
||||||
from src.webui.logs_ws import router as logs_router
|
from src.webui.logs_ws import router as logs_router
|
||||||
|
|
||||||
logger.info("开始导入 knowledge_routes...")
|
logger.info("开始导入 knowledge_routes...")
|
||||||
from src.webui.knowledge_routes import router as knowledge_router
|
from src.webui.knowledge_routes import router as knowledge_router
|
||||||
|
|
||||||
logger.info("knowledge_routes 导入成功")
|
logger.info("knowledge_routes 导入成功")
|
||||||
|
|
||||||
|
# 导入本地聊天室路由
|
||||||
|
from src.webui.chat_routes import router as chat_router
|
||||||
|
|
||||||
|
logger.info("chat_routes 导入成功")
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
self.app.include_router(webui_router)
|
self.app.include_router(webui_router)
|
||||||
self.app.include_router(logs_router)
|
self.app.include_router(logs_router)
|
||||||
self.app.include_router(knowledge_router)
|
self.app.include_router(knowledge_router)
|
||||||
|
self.app.include_router(chat_router)
|
||||||
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
|
logger.info(f"knowledge_router 路由前缀: {knowledge_router.prefix}")
|
||||||
|
|
||||||
logger.info("✅ WebUI API 路由已注册")
|
logger.info("✅ WebUI API 路由已注册")
|
||||||
@@ -116,6 +123,8 @@ class WebUIServer:
|
|||||||
|
|
||||||
logger.info("🌐 WebUI 服务器启动中...")
|
logger.info("🌐 WebUI 服务器启动中...")
|
||||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||||
|
if self.host == "0.0.0.0":
|
||||||
|
logger.info(f"本机访问请使用 http://localhost:{self.port}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._server.serve()
|
await self._server.serve()
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ api_provider = "SiliconFlow"
|
|||||||
price_in = 2.0
|
price_in = 2.0
|
||||||
price_out = 3.0
|
price_out = 3.0
|
||||||
[models.extra_params] # 可选的额外参数配置
|
[models.extra_params] # 可选的额外参数配置
|
||||||
enable_thinking = true # 不启用思考
|
enable_thinking = true # 启用思考
|
||||||
|
|
||||||
[[models]]
|
[[models]]
|
||||||
model_identifier = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
model_identifier = "Qwen/Qwen3-Next-80B-A3B-Instruct"
|
||||||
@@ -89,8 +89,7 @@ api_provider = "SiliconFlow"
|
|||||||
price_in = 3.5
|
price_in = 3.5
|
||||||
price_out = 14.0
|
price_out = 14.0
|
||||||
[models.extra_params] # 可选的额外参数配置
|
[models.extra_params] # 可选的额外参数配置
|
||||||
enable_thinking = true # 不启用思考
|
enable_thinking = true # 启用思考
|
||||||
|
|
||||||
|
|
||||||
[[models]]
|
[[models]]
|
||||||
model_identifier = "deepseek-ai/DeepSeek-R1"
|
model_identifier = "deepseek-ai/DeepSeek-R1"
|
||||||
|
|||||||
@@ -8,23 +8,23 @@ if edges:
|
|||||||
e = edges[0]
|
e = edges[0]
|
||||||
print(f"Edge tuple: {e}")
|
print(f"Edge tuple: {e}")
|
||||||
print(f"Edge tuple type: {type(e)}")
|
print(f"Edge tuple type: {type(e)}")
|
||||||
|
|
||||||
edge_data = kg.graph[e[0], e[1]]
|
edge_data = kg.graph[e[0], e[1]]
|
||||||
print(f"\nEdge data type: {type(edge_data)}")
|
print(f"\nEdge data type: {type(edge_data)}")
|
||||||
print(f"Edge data: {edge_data}")
|
print(f"Edge data: {edge_data}")
|
||||||
print(f"Has 'get' method: {hasattr(edge_data, 'get')}")
|
print(f"Has 'get' method: {hasattr(edge_data, 'get')}")
|
||||||
print(f"Is dict: {isinstance(edge_data, dict)}")
|
print(f"Is dict: {isinstance(edge_data, dict)}")
|
||||||
|
|
||||||
# 尝试不同的访问方式
|
# 尝试不同的访问方式
|
||||||
try:
|
try:
|
||||||
print(f"\nUsing []: {edge_data['weight']}")
|
print(f"\nUsing []: {edge_data['weight']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Using [] failed: {e}")
|
print(f"Using [] failed: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f"Using .get(): {edge_data.get('weight')}")
|
print(f"Using .get(): {edge_data.get('weight')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Using .get() failed: {e}")
|
print(f"Using .get() failed: {e}")
|
||||||
|
|
||||||
# 查看所有属性
|
# 查看所有属性
|
||||||
print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}")
|
print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}")
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
16
webui/dist/assets/codemirror-BHeANvwm.js
vendored
Normal file
16
webui/dist/assets/codemirror-BHeANvwm.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
webui/dist/assets/dnd-Dyi3CnuX.js
vendored
Normal file
5
webui/dist/assets/dnd-Dyi3CnuX.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-Bw5y5Hqz.js
vendored
Normal file
1
webui/dist/assets/icons-Bw5y5Hqz.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
webui/dist/assets/icons-DMlhlQyz.js
vendored
1
webui/dist/assets/icons-DMlhlQyz.js
vendored
File diff suppressed because one or more lines are too long
407
webui/dist/assets/index--0Z4-njD.js
vendored
407
webui/dist/assets/index--0Z4-njD.js
vendored
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-Bzl8QBn9.css
vendored
1
webui/dist/assets/index-Bzl8QBn9.css
vendored
File diff suppressed because one or more lines are too long
1
webui/dist/assets/index-CUrrfy9B.css
vendored
Normal file
1
webui/dist/assets/index-CUrrfy9B.css
vendored
Normal file
File diff suppressed because one or more lines are too long
52
webui/dist/assets/index-DuV8F13p.js
vendored
Normal file
52
webui/dist/assets/index-DuV8F13p.js
vendored
Normal file
File diff suppressed because one or more lines are too long
295
webui/dist/assets/markdown-A1ShuLvG.js
vendored
Normal file
295
webui/dist/assets/markdown-A1ShuLvG.js
vendored
Normal file
File diff suppressed because one or more lines are too long
27
webui/dist/assets/misc-Ii-X5qWA.js
vendored
Normal file
27
webui/dist/assets/misc-Ii-X5qWA.js
vendored
Normal file
File diff suppressed because one or more lines are too long
45
webui/dist/assets/radix-core-BlBHu_Lw.js
vendored
Normal file
45
webui/dist/assets/radix-core-BlBHu_Lw.js
vendored
Normal file
File diff suppressed because one or more lines are too long
12
webui/dist/assets/radix-extra-Cw1azsjZ.js
vendored
Normal file
12
webui/dist/assets/radix-extra-Cw1azsjZ.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
webui/dist/assets/reactflow-B3n3_Vkw.js
vendored
Normal file
2
webui/dist/assets/reactflow-B3n3_Vkw.js
vendored
Normal file
File diff suppressed because one or more lines are too long
5
webui/dist/assets/router-CWhjJi2n.js
vendored
Normal file
5
webui/dist/assets/router-CWhjJi2n.js
vendored
Normal file
File diff suppressed because one or more lines are too long
2
webui/dist/assets/router-SinpzM5S.js
vendored
2
webui/dist/assets/router-SinpzM5S.js
vendored
File diff suppressed because one or more lines are too long
45
webui/dist/assets/ui-vendor-BLBhIcJ8.js
vendored
45
webui/dist/assets/ui-vendor-BLBhIcJ8.js
vendored
File diff suppressed because one or more lines are too long
11
webui/dist/assets/uppy-DSH7n_-V.js
vendored
Normal file
11
webui/dist/assets/uppy-DSH7n_-V.js
vendored
Normal file
File diff suppressed because one or more lines are too long
6
webui/dist/assets/utils-CCeOswSm.js
vendored
Normal file
6
webui/dist/assets/utils-CCeOswSm.js
vendored
Normal file
File diff suppressed because one or more lines are too long
20
webui/dist/index.html
vendored
20
webui/dist/index.html
vendored
@@ -7,13 +7,21 @@
|
|||||||
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>MaiBot Dashboard</title>
|
<title>MaiBot Dashboard</title>
|
||||||
<script type="module" crossorigin src="/assets/index--0Z4-njD.js"></script>
|
<script type="module" crossorigin src="/assets/index-DuV8F13p.js"></script>
|
||||||
<link rel="modulepreload" crossorigin href="/assets/react-vendor-Dtc2IqVY.js">
|
<link rel="modulepreload" crossorigin href="/assets/react-vendor-Dtc2IqVY.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/router-SinpzM5S.js">
|
<link rel="modulepreload" crossorigin href="/assets/router-CWhjJi2n.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/charts-0z-hIQr-.js">
|
<link rel="modulepreload" crossorigin href="/assets/utils-CCeOswSm.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/ui-vendor-BLBhIcJ8.js">
|
<link rel="modulepreload" crossorigin href="/assets/radix-core-BlBHu_Lw.js">
|
||||||
<link rel="modulepreload" crossorigin href="/assets/icons-DMlhlQyz.js">
|
<link rel="modulepreload" crossorigin href="/assets/radix-extra-Cw1azsjZ.js">
|
||||||
<link rel="stylesheet" crossorigin href="/assets/index-Bzl8QBn9.css">
|
<link rel="modulepreload" crossorigin href="/assets/charts-Dhri-zxi.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/icons-Bw5y5Hqz.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/codemirror-BHeANvwm.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/misc-Ii-X5qWA.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/dnd-Dyi3CnuX.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/uppy-DSH7n_-V.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/markdown-A1ShuLvG.js">
|
||||||
|
<link rel="modulepreload" crossorigin href="/assets/reactflow-B3n3_Vkw.js">
|
||||||
|
<link rel="stylesheet" crossorigin href="/assets/index-CUrrfy9B.css">
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="root" class="notranslate"></div>
|
<div id="root" class="notranslate"></div>
|
||||||
|
|||||||
Reference in New Issue
Block a user