This commit is contained in:
tcmofashi
2026-01-16 07:34:11 +00:00
281 changed files with 51502 additions and 3157 deletions

View File

@@ -28,7 +28,7 @@ class ExpressionReflector:
try:
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
if not global_config.expression.expression_self_reflect:
if not global_config.expression.expression_manual_reflect:
logger.debug("[Expression Reflection] 表达反思功能未启用,跳过")
return False

View File

@@ -2,7 +2,7 @@ import json
import asyncio
import random
from collections import OrderedDict
from typing import List, Dict, Optional, Any, Callable
from typing import List, Dict, Optional, Callable
from json_repair import repair_json
from peewee import fn
@@ -11,14 +11,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.bw_learner.learner_utils import (
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
update_chat_id_list,

View File

@@ -1,14 +1,14 @@
import time
from typing import Tuple, Optional # 增加了 Optional
from src.common.logger_manager import get_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from src.chat.utils.chat_message_builder import build_readable_messages
logger = get_logger("pfc_action_planner")
@@ -113,12 +113,27 @@ class ActionPlanner:
max_tokens=1500,
request_type="action_planning",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.personality_info = self._get_personality_prompt()
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.BOT_NICKNAME
return f"你的名字是{bot_name},你{prompt_personality};"
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
async def plan(
self,

View File

@@ -3,22 +3,22 @@ import asyncio
import datetime
# from .message_storage import MongoDBMessageStorage
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
# from ...config.config import global_config
# from src.config.config import global_config
from typing import Dict, Any, Optional
from ..chat.message import Message
from src.chat.message_receive.message import Message
from .pfc_types import ConversationState
from .pfc import ChatObserver, GoalAnalyzer
from .message_sender import DirectMessageSender
from src.common.logger_manager import get_logger
from src.common.logger import get_logger
from .action_planner import ActionPlanner
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
from .reply_generator import ReplyGenerator
from ..chat.chat_stream import ChatStream
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import UserInfo
from src.plugins.chat.chat_stream import chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter
@@ -60,7 +60,7 @@ class Conversation:
self.direct_sender = DirectMessageSender(self.private_name)
# 获取聊天流信息
self.chat_stream = chat_manager.get_stream(self.stream_id)
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
self.stop_action_planner = False
except Exception as e:
@@ -248,14 +248,14 @@ class Conversation:
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
# 尝试从 msg_dict 直接获取 chat_stream如果失败则从全局 chat_manager 获取
# 尝试从 msg_dict 直接获取 chat_stream如果失败则从全局 get_chat_manager 获取
chat_info = msg_dict.get("chat_info")
if chat_info and isinstance(chat_info, dict):
chat_stream = ChatStream.from_dict(chat_info)
elif self.chat_stream: # 使用实例变量中的 chat_stream
chat_stream = self.chat_stream
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
chat_stream = chat_manager.get_stream(self.stream_id)
chat_stream = get_chat_manager().get_stream(self.stream_id)
if not chat_stream:
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")

View File

@@ -1,13 +1,11 @@
import time
from typing import Optional
from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream
from ..chat.message import Message
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import Message, MessageSending
from maim_message import UserInfo, Seg
from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager
from ..storage.storage import MessageStorage
from ...config.config import global_config
from src.chat.message_receive.storage import MessageStorage
from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
@@ -66,15 +64,18 @@ class DirectMessageSender:
# 处理消息
await message.process()
# 不知道有什么用先留下来了和之前那套sender一样
_message_json = message.to_dict()
# 发送消息
message_set = MessageSet(chat_stream, message_id)
message_set.add_message(message)
await message_manager.add_message(message_set)
await self.storage.store_message(message, chat_stream)
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
# 发送消息(直接调用底层 API
from src.chat.message_receive.uni_message_sender import _send_message
sent = await _send_message(message, show_log=True)
if sent:
# 存储消息
await self.storage.store_message(message, chat_stream)
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
else:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
raise RuntimeError("消息发送失败")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")

View File

@@ -4,7 +4,7 @@ import time
from src.common.logger import get_module_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification
from src.plugins.utils.chat_message_builder import build_readable_messages
from src.chat.utils.chat_message_builder import build_readable_messages
import traceback # 导入 traceback 用于调试
logger = get_module_logger("observation_info")

View File

@@ -1,13 +1,13 @@
from typing import List, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
from .conversation_info import ConversationInfo
from .observation_info import ObservationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from src.chat.utils.chat_message_builder import build_readable_messages
from rich.traceback import install
install(extra_lines=3)
@@ -46,7 +46,7 @@ class GoalAnalyzer:
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.personality_info = self._get_personality_prompt()
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.private_name = private_name
@@ -57,6 +57,21 @@ class GoalAnalyzer:
self.max_goals = 3 # 同时保持的最大目标数量
self.current_goal_and_reason = None
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.bot.nickname
return f"你的名字是{bot_name},你{prompt_personality};"
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
"""分析对话历史并设定目标

View File

@@ -1,11 +1,11 @@
from typing import List, Tuple
from src.common.logger import get_module_logger
from src.plugins.memory_system.Hippocampus import HippocampusManager
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from ..chat.message import Message
from ..knowledge.knowledge_lib import qa_manager
from ..utils.chat_message_builder import build_readable_messages
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.message_receive.message import Message
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.utils.chat_message_builder import build_readable_messages
logger = get_module_logger("knowledge_fetcher")

View File

@@ -1,8 +1,8 @@
import json
from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from .chat_observer import ChatObserver
from maim_message import UserInfo

View File

@@ -1,15 +1,15 @@
from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import random
from .chat_observer import ChatObserver
from .reply_checker import ReplyChecker
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from src.chat.utils.chat_message_builder import build_readable_messages
logger = get_module_logger("reply_generator")
logger = get_logger("reply_generator")
# --- 定义 Prompt 模板 ---
@@ -92,12 +92,27 @@ class ReplyGenerator:
max_tokens=300,
request_type="reply_generation",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.personality_info = self._get_personality_prompt()
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.reply_checker = ReplyChecker(stream_id, private_name)
def _get_personality_prompt(self) -> str:
"""获取个性提示信息"""
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
if (
global_config.personality.states
and global_config.personality.state_probability > 0
and random.random() < global_config.personality.state_probability
):
prompt_personality = random.choice(global_config.personality.states)
bot_name = global_config.BOT_NICKNAME
return f"你的名字是{bot_name},你{prompt_personality};"
# 修改 generate 方法签名,增加 action_type 参数
async def generate(
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str

View File

@@ -3,7 +3,7 @@ from .chat_observer import ChatObserver
from .conversation_info import ConversationInfo
# from src.individuality.individuality import Individuality # 不再需要
from ...config.config import global_config
from src.config.config import global_config
import time
import asyncio

View File

@@ -87,6 +87,7 @@ class BrainChatting:
# 循环控制内部状态
self.running: bool = False
self._loop_task: Optional[asyncio.Task] = None # 主循环任务
self._new_message_event = asyncio.Event() # 新消息事件,用于打断 wait
# 添加循环信息管理相关的属性
self.history_loop: List[CycleDetail] = []
@@ -173,9 +174,10 @@ class BrainChatting:
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time
# 如果有新消息,更新 last_read_time 并触发事件以打断正在进行的 wait
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
self._new_message_event.set() # 触发新消息事件,打断 wait
# 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理
@@ -434,6 +436,9 @@ class BrainChatting:
last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次
# 清除事件状态,准备等待新消息
self._new_message_event.clear()
while self.running:
# 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat(
@@ -453,8 +458,15 @@ class BrainChatting:
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return
# 等待一段时间后再次检查
await asyncio.sleep(check_interval)
# 等待新消息事件或超时后再次检查
try:
await asyncio.wait_for(self._new_message_event.wait(), timeout=check_interval)
# 事件被触发,说明有新消息
logger.info(f"{self.log_prefix} 检测到新消息事件,恢复循环")
return
except asyncio.TimeoutError:
# 超时后继续检查
continue
async def _handle_action(
self,
@@ -674,7 +686,10 @@ class BrainChatting:
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}(可被新消息打断)")
# 清除事件状态,准备等待新消息
self._new_message_event.clear()
# 记录动作信息
await database_api.store_action_info(
@@ -687,8 +702,17 @@ class BrainChatting:
action_name="wait",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} wait 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:
# 超时正常完成
pass
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
@@ -707,7 +731,10 @@ class BrainChatting:
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}(可被新消息打断)")
# 清除事件状态,准备等待新消息
self._new_message_event.clear()
# 记录动作信息
await database_api.store_action_info(
@@ -720,8 +747,17 @@ class BrainChatting:
action_name="listening",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
# 等待指定时间,但可被新消息打断
try:
await asyncio.wait_for(
self._new_message_event.wait(),
timeout=wait_seconds
)
# 如果事件被触发,说明有新消息到达
logger.info(f"{self.log_prefix} listening 动作被新消息打断,提前结束等待")
except asyncio.TimeoutError:
# 超时正常完成
pass
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")

View File

@@ -402,7 +402,7 @@ class BrainPlanner:
moderation_prompt=moderation_prompt_block,
name_block=name_block,
interest=interest,
plan_style=global_config.personality.private_plan_style,
plan_style=global_config.experimental.private_plan_style,
)
return prompt, message_id_list

View File

@@ -244,12 +244,14 @@ class HeartFChatting:
thinking_id,
actions,
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
reply_set=response_set,
message_data=action_message,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
@@ -526,15 +528,26 @@ class HeartFChatting:
reply_set: "ReplySetModel",
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
quote_message: Optional[bool] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 3)
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息,使用引用回复")
# 根据 llm_quote 配置决定是否使用 quote_message 参数
if global_config.chat.llm_quote:
# 如果配置为 true使用 llm_quote 参数决定是否引用回复
if quote_message is None:
logger.warning(f"{self.log_prefix} quote_message 参数为空,不引用")
need_reply = False
else:
need_reply = quote_message
if need_reply:
logger.info(f"{self.log_prefix} LLM 决定使用引用回复")
else:
# 如果配置为 false使用原来的模式
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
)
need_reply = new_message_count >= random.randint(2, 3) or time.time() - self.last_read_time > 90
if need_reply:
logger.info(f"{self.log_prefix} 从思考到回复,共有{new_message_count}条新消息使用引用回复或者上次回复时间超过90秒")
reply_text = ""
first_replied = False
@@ -640,6 +653,7 @@ class HeartFChatting:
# 从 Planner 的 action_data 中提取未知词语列表(仅在 reply 时使用)
unknown_words = None
quote_message = None
if isinstance(action_planner_info.action_data, dict):
uw = action_planner_info.action_data.get("unknown_words")
if isinstance(uw, list):
@@ -651,6 +665,19 @@ class HeartFChatting:
cleaned_uw.append(s)
if cleaned_uw:
unknown_words = cleaned_uw
# 从 Planner 的 action_data 中提取 quote_message 参数
qm = action_planner_info.action_data.get("quote")
if qm is not None:
# 支持多种格式true/false, "true"/"false", 1/0
if isinstance(qm, bool):
quote_message = qm
elif isinstance(qm, str):
quote_message = qm.lower() in ("true", "1", "yes")
elif isinstance(qm, (int, float)):
quote_message = bool(qm)
logger.info(f"{self.log_prefix} {qm}引用回复设置: {quote_message}")
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
@@ -682,6 +709,7 @@ class HeartFChatting:
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
quote_message=quote_message,
)
self.last_active_time = time.time()
return {

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import List, Union
from typing import List, Union, Dict, Any
from .global_logger import logger
from . import prompt_template
@@ -173,3 +173,50 @@ def info_extract_from_str(
return None, None
return entity_extract_result, rdf_triple_extract_result
class IEProcess:
"""
信息抽取处理器类,提供更方便的批次处理接口。
"""
def __init__(self, llm_ner: LLMRequest, llm_rdf: LLMRequest = None):
self.llm_ner = llm_ner
self.llm_rdf = llm_rdf or llm_ner
async def process_paragraphs(self, paragraphs: List[str]) -> List[dict]:
"""
异步处理多个段落。
"""
from .utils.hash import get_sha256
results = []
total = len(paragraphs)
for i, pg in enumerate(paragraphs, start=1):
# 打印进度日志,让用户知道没有卡死
logger.info(f"[IEProcess] 正在处理第 {i}/{total} 段文本 (长度: {len(pg)})...")
# 使用 asyncio.to_thread 包装同步阻塞调用,防止死锁
# 这样 info_extract_from_str 内部的 asyncio.run 会在独立线程的新 loop 中运行
try:
entities, triples = await asyncio.to_thread(
info_extract_from_str, self.llm_ner, self.llm_rdf, pg
)
if entities is not None:
results.append(
{
"idx": get_sha256(pg),
"passage": pg,
"extracted_entities": entities,
"extracted_triples": triples,
}
)
logger.info(f"[IEProcess] 第 {i}/{total} 段处理完成,提取到 {len(entities)} 个实体")
else:
logger.warning(f"[IEProcess] 第 {i}/{total} 段提取失败(返回为空)")
except Exception as e:
logger.error(f"[IEProcess] 处理第 {i}/{total} 段时发生异常: {e}")
return results

View File

@@ -0,0 +1,388 @@
import asyncio
import os
from functools import partial
from typing import List, Callable, Any
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.qa_manager import QAManager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge import get_qa_manager, lpmm_start_up
logger = get_logger("LPMM-Plugin-API")
class LPMMOperations:
"""
LPMM 内部操作接口。
封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。
"""
def __init__(self):
self._initialized = False
async def _run_cancellable_executor(
self, func: Callable, *args, **kwargs
) -> Any:
"""
在线程池中执行可取消的同步操作。
当任务被取消时(如 Ctrl+C会立即响应并抛出 CancelledError。
注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。
Args:
func: 要执行的同步函数
*args: 函数的位置参数
**kwargs: 函数的关键字参数
Returns:
函数的返回值
Raises:
asyncio.CancelledError: 当任务被取消时
"""
loop = asyncio.get_event_loop()
# 在线程池中执行,当协程被取消时会立即响应
# 虽然线程池中的操作可能仍在运行,但协程不会阻塞
return await loop.run_in_executor(None, func, *args, **kwargs)
async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]:
"""获取并确保 LPMM 管理器已初始化"""
qa_mgr = get_qa_manager()
if qa_mgr is None:
# 如果全局没初始化,尝试初始化
if not global_config.lpmm_knowledge.enable:
logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。")
lpmm_start_up()
qa_mgr = get_qa_manager()
if qa_mgr is None:
raise RuntimeError("无法获取 LPMM QAManager请检查 LPMM 是否已正确安装和配置。")
return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr
async def add_content(self, text: str, auto_split: bool = True) -> dict:
"""
向知识库添加新内容。
Args:
text: 原始文本。
auto_split: 是否自动按双换行符分割段落。
- True: 自动分割(默认),支持多段文本(用双换行分隔)
- False: 不分割,将整个文本作为完整一段处理
Returns:
dict: {"status": "success/error", "count": 导入段落数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 分段处理
if auto_split:
# 自动按双换行符分割
paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
else:
# 不分割,作为完整一段
text_stripped = text.strip()
if not text_stripped:
return {"status": "error", "message": "文本内容为空"}
paragraphs = [text_stripped]
if not paragraphs:
return {"status": "error", "message": "文本内容为空"}
# 2. 实体与三元组抽取 (内部调用大模型)
from src.chat.knowledge.ie_process import IEProcess
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
llm_ner = LLMRequest(model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract")
llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
ie_process = IEProcess(llm_ner, llm_rdf)
logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...")
extracted_docs = await ie_process.process_paragraphs(paragraphs)
# 3. 构造并导入数据
# 这里我们手动实现导入逻辑,不依赖外部脚本
# a. 准备段落
raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs}
# b. 准备三元组
triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs}
# 向量化并入库
# 注意:此处模仿 import_openie.py 的核心逻辑
# 1. 先进行去重检查,只处理新段落
# store_new_data_set 期望的格式raw_paragraphs 的键是段落hash不带前缀值是段落文本
new_raw_paragraphs = {}
new_triple_list_data = {}
for pg_hash, passage in raw_paragraphs.items():
key = f"paragraph-{pg_hash}"
if key not in embed_mgr.stored_pg_hashes:
new_raw_paragraphs[pg_hash] = passage
new_triple_list_data[pg_hash] = triple_list_data[pg_hash]
if not new_raw_paragraphs:
return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"}
# 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入
# store_new_data_set 会自动处理嵌入生成和存储
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
await self._run_cancellable_executor(
embed_mgr.store_new_data_set,
new_raw_paragraphs,
new_triple_list_data
)
# 3. 构建知识图谱只需要三元组数据和embedding_manager
await self._run_cancellable_executor(
kg_mgr.build_kg,
new_triple_list_data,
embed_mgr
)
# 4. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
return {"status": "success", "count": len(new_raw_paragraphs), "message": f"成功导入 {len(new_raw_paragraphs)} 条知识"}
except asyncio.CancelledError:
logger.warning("[Plugin API] 导入操作被用户中断")
return {"status": "cancelled", "message": "导入操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def search(self, query: str, top_k: int = 3) -> List[str]:
"""
检索知识库。
Args:
query: 查询问题。
top_k: 返回最相关的条目数。
Returns:
List[str]: 相关文段列表。
"""
try:
_, _, qa_mgr = await self._get_managers()
# 直接调用 QAManager 的检索接口
knowledge = qa_mgr.get_knowledge(query, top_k=top_k)
# 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回
return [knowledge] if knowledge else []
except Exception as e:
logger.error(f"[Plugin API] 检索知识失败: {e}")
return []
async def delete(self, keyword: str, exact_match: bool = False) -> dict:
"""
根据关键词或完整文段删除知识库内容。
Args:
keyword: 匹配关键词或完整文段。
exact_match: 是否使用完整文段匹配True=完全匹配False=关键词模糊匹配)。
Returns:
dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 1. 查找匹配的段落
to_delete_keys = []
to_delete_hashes = []
for key, item in embed_mgr.paragraphs_embedding_store.store.items():
if exact_match:
# 完整文段匹配
if item.str.strip() == keyword.strip():
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
else:
# 关键词模糊匹配
if keyword in item.str:
to_delete_keys.append(key)
to_delete_hashes.append(key.replace("paragraph-", "", 1))
if not to_delete_keys:
match_type = "完整文段" if exact_match else "关键词"
return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"}
# 2. 执行删除
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# a. 从向量库删除
deleted_count, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
to_delete_keys
)
embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys())
# b. 从知识图谱删除
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
to_delete_hashes,
ent_hashes=None,
remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 3. 持久化
await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index)
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
match_type = "完整文段" if exact_match else "关键词"
return {"status": "success", "deleted_count": deleted_count, "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)"}
except asyncio.CancelledError:
logger.warning("[Plugin API] 删除操作被用户中断")
return {"status": "cancelled", "message": "删除操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
async def clear_all(self) -> dict:
"""
清空整个LPMM知识库删除所有段落、实体、关系和知识图谱数据
Returns:
dict: {"status": "success/error", "message": "描述", "stats": {...}}
"""
try:
embed_mgr, kg_mgr, _ = await self._get_managers()
# 记录清空前的统计信息
before_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
# 将同步阻塞操作放到线程池中执行,避免阻塞事件循环
# 1. 清空所有向量库
# 获取所有keys
para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys())
ent_keys = list(embed_mgr.entities_embedding_store.store.keys())
rel_keys = list(embed_mgr.relation_embedding_store.store.keys())
# 删除所有段落向量
para_deleted, _ = await self._run_cancellable_executor(
embed_mgr.paragraphs_embedding_store.delete_items,
para_keys
)
embed_mgr.stored_pg_hashes.clear()
# 删除所有实体向量
if ent_keys:
ent_deleted, _ = await self._run_cancellable_executor(
embed_mgr.entities_embedding_store.delete_items,
ent_keys
)
else:
ent_deleted = 0
# 删除所有关系向量
if rel_keys:
rel_deleted, _ = await self._run_cancellable_executor(
embed_mgr.relation_embedding_store.delete_items,
rel_keys
)
else:
rel_deleted = 0
# 2. 清空所有 embedding store 的索引和映射
# 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件
def _clear_embedding_indices():
# 清空段落索引
embed_mgr.paragraphs_embedding_store.faiss_index = None
embed_mgr.paragraphs_embedding_store.idx2hash = None
embed_mgr.paragraphs_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.index_file_path)
if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path)
# 清空实体索引
embed_mgr.entities_embedding_store.faiss_index = None
embed_mgr.entities_embedding_store.idx2hash = None
embed_mgr.entities_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.entities_embedding_store.index_file_path):
os.remove(embed_mgr.entities_embedding_store.index_file_path)
if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path)
# 清空关系索引
embed_mgr.relation_embedding_store.faiss_index = None
embed_mgr.relation_embedding_store.idx2hash = None
embed_mgr.relation_embedding_store.dirty = False
# 删除旧的索引文件
if os.path.exists(embed_mgr.relation_embedding_store.index_file_path):
os.remove(embed_mgr.relation_embedding_store.index_file_path)
if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path):
os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path)
await self._run_cancellable_executor(_clear_embedding_indices)
# 3. 清空知识图谱
# 获取所有段落hash
all_pg_hashes = list(kg_mgr.stored_paragraph_hashes)
if all_pg_hashes:
# 删除所有段落节点(这会自动清理相关的边和孤立实体)
# 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数
# 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs
delete_func = partial(
kg_mgr.delete_paragraphs,
all_pg_hashes,
ent_hashes=None,
remove_orphan_entities=True
)
await self._run_cancellable_executor(delete_func)
# 完全清空KG创建新的空图无论是否有段落hash都要执行
from quick_algo import di_graph
kg_mgr.graph = di_graph.DiGraph()
kg_mgr.stored_paragraph_hashes.clear()
kg_mgr.ent_appear_cnt.clear()
# 4. 保存所有数据此时所有store都是空的索引也是None
# 注意即使store为空save_to_file也会保存空的DataFrame这是正确的
await self._run_cancellable_executor(embed_mgr.save_to_file)
await self._run_cancellable_executor(kg_mgr.save_to_file)
after_stats = {
"paragraphs": len(embed_mgr.paragraphs_embedding_store.store),
"entities": len(embed_mgr.entities_embedding_store.store),
"relations": len(embed_mgr.relation_embedding_store.store),
"kg_nodes": len(kg_mgr.graph.get_node_list()),
"kg_edges": len(kg_mgr.graph.get_edge_list()),
}
return {
"status": "success",
"message": f"已成功清空LPMM知识库删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)",
"stats": {
"before": before_stats,
"after": after_stats,
}
}
except asyncio.CancelledError:
logger.warning("[Plugin API] 清空操作被用户中断")
return {"status": "cancelled", "message": "清空操作已被用户中断"}
except Exception as e:
logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True)
return {"status": "error", "message": str(e)}
# 内部使用的单例
lpmm_ops = LPMMOperations()

View File

@@ -7,9 +7,8 @@ from .kg_manager import KGManager
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config, model_config
from src.config.config import global_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度

View File

@@ -4,6 +4,8 @@ from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4
from src.config.config import global_config
class PlanReplyLogger:
"""独立的Plan/Reply日志记录器负责落盘和容量控制。"""
@@ -11,9 +13,13 @@ class PlanReplyLogger:
_BASE_DIR = Path("logs")
_PLAN_DIR = _BASE_DIR / "plan"
_REPLY_DIR = _BASE_DIR / "reply"
_MAX_PER_CHAT = 1000
_TRIM_COUNT = 100
@classmethod
def _get_max_per_chat(cls) -> int:
"""从配置中获取每个聊天流最大保存的日志数量"""
return getattr(global_config.chat, "plan_reply_log_max_per_chat", 1000)
@classmethod
def log_plan(
cls,
@@ -85,7 +91,8 @@ class PlanReplyLogger:
def _trim_overflow(cls, chat_dir: Path) -> None:
"""超过阈值时删除最老的若干文件,避免目录无限增长。"""
files = sorted(chat_dir.glob("*.json"), key=lambda p: p.stat().st_mtime)
if len(files) <= cls._MAX_PER_CHAT:
max_per_chat = cls._get_max_per_chat()
if len(files) <= max_per_chat:
return
# 删除最老的 TRIM_COUNT 条
for old_file in files[: cls._TRIM_COUNT]:

View File

@@ -1,4 +1,5 @@
import time
import asyncio
import urllib3
from abc import abstractmethod
@@ -20,6 +21,9 @@ logger = get_logger("chat_message")
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# VLM 处理并发限制(避免同时处理太多图片导致卡死)
_vlm_semaphore = asyncio.Semaphore(3)
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@@ -73,20 +77,35 @@ class Message(MessageBase):
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
# 处理消息段列表 - 使用并行处理提升性能
tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore
results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg) # type: ignore
if processed:
segments_text.append(processed)
for result in results:
if isinstance(result, Exception):
logger.error(f"处理消息段时出错: {result}")
continue
if result:
segments_text.append(result)
return " ".join(segments_text)
elif segment.type == "forward":
segments_text = []
for node_dict in segment.data:
# 处理转发消息 - 使用并行处理
async def process_forward_node(node_dict):
message = MessageBase.from_dict(node_dict) # type: ignore
processed_text = await self._process_message_segments(message.message_segment)
if processed_text:
segments_text.append(f"{global_config.bot.nickname}: {processed_text}")
return f"{global_config.bot.nickname}: {processed_text}"
return None
tasks = [process_forward_node(node_dict) for node_dict in segment.data]
results = await asyncio.gather(*tasks, return_exceptions=True)
segments_text = []
for result in results:
if isinstance(result, Exception):
logger.error(f"处理转发节点时出错: {result}")
continue
if result:
segments_text.append(result)
return "[合并消息]: " + "\n-- ".join(segments_text)
else:
# 处理单个消息段
@@ -173,8 +192,9 @@ class MessageRecv(Message):
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
# 使用 semaphore 限制 VLM 并发,避免同时处理太多图片
async with _vlm_semaphore:
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
@@ -183,7 +203,9 @@ class MessageRecv(Message):
self.is_picid = False
self.is_voice = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
# 使用 semaphore 限制 VLM 并发
async with _vlm_semaphore:
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "voice":
self.is_picid = False

View File

@@ -4,6 +4,7 @@ import traceback
import random
import re
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from collections import OrderedDict
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
@@ -53,7 +54,6 @@ reply
4.不要选择回复你自己发送的消息
5.不要单独对表情包进行回复
6.将上下文中所有含义不明的疑似黑话的缩写词均写入unknown_words中
7.如果你对上下文存在疑问有需要查询的问题写入question中
{reply_action_example}
no_reply
@@ -111,6 +111,10 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0
self.plan_log: List[Tuple[str, float, Union[List[ActionPlannerInfo], str]]] = []
# 黑话缓存:使用 OrderedDict 实现 LRU最多缓存10个
self.unknown_words_cache: OrderedDict[str, None] = OrderedDict()
self.unknown_words_cache_limit = 10
def find_message_by_id(
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
@@ -225,24 +229,6 @@ class ActionPlanner:
reasoning = "未提供原因"
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
# 验证和清理 question
if "question" in action_data:
q = action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
action_data["question"] = cleaned_q
else:
# 如果清理后为空字符串,移除该字段
action_data.pop("question", None)
elif q is None:
# 如果为 None移除该字段
action_data.pop("question", None)
else:
# 如果不是字符串类型,记录警告并移除
logger.warning(f"{self.log_prefix}question 格式不正确,应为字符串类型,已忽略")
action_data.pop("question", None)
# 非no_reply动作需要target_message_id
target_message = None
@@ -318,6 +304,136 @@ class ActionPlanner:
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
return False
def _update_unknown_words_cache(self, new_words: List[str]) -> None:
"""
更新黑话缓存,将新的黑话加入缓存
Args:
new_words: 新提取的黑话列表
"""
for word in new_words:
if not isinstance(word, str):
continue
word = word.strip()
if not word:
continue
# 如果已存在移到末尾LRU
if word in self.unknown_words_cache:
self.unknown_words_cache.move_to_end(word)
else:
# 添加新词
self.unknown_words_cache[word] = None
# 如果超过限制,移除最老的
if len(self.unknown_words_cache) > self.unknown_words_cache_limit:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}黑话缓存已满,移除最老的黑话")
def _merge_unknown_words_with_cache(self, new_words: Optional[List[str]]) -> List[str]:
"""
合并新提取的黑话和缓存中的黑话
Args:
new_words: 新提取的黑话列表可能为None
Returns:
合并后的黑话列表(去重)
"""
# 清理新提取的黑话
cleaned_new_words: List[str] = []
if new_words:
for word in new_words:
if isinstance(word, str):
word = word.strip()
if word:
cleaned_new_words.append(word)
# 获取缓存中的黑话列表
cached_words = list(self.unknown_words_cache.keys())
# 合并并去重(保留顺序:新提取的在前,缓存的在后)
merged_words: List[str] = []
seen = set()
# 先添加新提取的
for word in cleaned_new_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
# 再添加缓存的(如果不在新提取的列表中)
for word in cached_words:
if word not in seen:
merged_words.append(word)
seen.add(word)
return merged_words
def _process_unknown_words_cache(
self, actions: List[ActionPlannerInfo]
) -> None:
"""
处理黑话缓存逻辑:
1. 检查是否有 reply action 提取了 unknown_words
2. 如果没有提取移除最老的1个
3. 如果缓存数量大于5移除最老的2个
4. 对于每个 reply action合并缓存和新提取的黑话
5. 更新缓存
Args:
actions: 解析后的动作列表
"""
# 先检查缓存数量如果大于5移除最老的2个
if len(self.unknown_words_cache) > 5:
# 移除最老的2个
removed_count = 0
for _ in range(2):
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
removed_count += 1
if removed_count > 0:
logger.debug(f"{self.log_prefix}缓存数量大于5移除最老的{removed_count}个缓存")
# 检查是否有 reply action 提取了 unknown_words
has_extracted_unknown_words = False
for action in actions:
if action.action_type == "reply":
action_data = action.action_data or {}
unknown_words = action_data.get("unknown_words")
if unknown_words and isinstance(unknown_words, list) and len(unknown_words) > 0:
has_extracted_unknown_words = True
break
# 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words:
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话
for action in actions:
if action.action_type == "reply":
action_data = action.action_data or {}
new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words)
# 更新 action_data
if merged_words:
action_data["unknown_words"] = merged_words
logger.debug(
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
f"缓存 {len(self.unknown_words_cache)} 个,合并后 {len(merged_words)}"
)
else:
# 如果没有合并后的黑话,移除 unknown_words 字段
action_data.pop("unknown_words", None)
# 更新缓存(将新提取的黑话加入缓存)
if new_words:
self._update_unknown_words_cache(new_words)
async def plan(
self,
available_actions: Dict[str, ActionInfo],
@@ -520,21 +636,32 @@ class ActionPlanner:
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
# 根据 think_mode 配置决定 reply action 的示例 JSON
# 在 JSON 中直接作为 action 参数携带 unknown_words 和 question
# 在 JSON 中直接作为 action 参数携带 unknown_words
if global_config.chat.think_mode == "classic":
reply_action_example = (
reply_action_example = ""
if global_config.chat.llm_quote:
reply_action_example += "5.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
'{{"action":"reply", "target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"], '
'"question":"需要查询的问题"}'
'"unknown_words":["词语1","词语2"]'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
reply_action_example += "}"
else:
reply_action_example = (
"5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n"
+ '{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"], '
'"question":"需要查询的问题"}'
)
if global_config.chat.llm_quote:
reply_action_example += "6.如果要明确回复消息使用quote如果消息不多不需要明确回复设置quote为false\n"
reply_action_example += (
'{{"action":"reply", "think_level":数值等级(0或1), '
'"target_message_id":"消息id(m+数字)", '
'"unknown_words":["词语1","词语2"]'
)
if global_config.chat.llm_quote:
reply_action_example += ', "quote":"如果需要引用该message设置为true"'
reply_action_example += "}"
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
@@ -730,6 +857,9 @@ class ActionPlanner:
random.shuffle(shuffled)
actions = list({a.action_type: a for a in shuffled}.values())
# 处理黑话缓存逻辑
self._process_unknown_words_cache(actions)
logger.debug(f"{self.log_prefix}规划器选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
return extracted_reasoning, actions, llm_content, llm_reasoning, llm_duration_ms

View File

@@ -845,18 +845,6 @@ class DefaultReplyer:
chat_id, message_list_before_short, chat_talking_prompt_short, unknown_words
)
# 从 chosen_actions 中提取 question仅在 reply 动作中)
question = None
if chosen_actions:
for action_info in chosen_actions:
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
q = action_info.action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
question = cleaned_q
break
# 并行执行构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -871,7 +859,7 @@ class DefaultReplyer:
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level, unknown_words=unknown_words, question=question
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level, unknown_words=unknown_words
),
"memory_retrieval",
),

View File

@@ -710,18 +710,6 @@ class PrivateReplyer:
else:
jargon_coroutine = self._build_disabled_jargon_explanation()
# 从 chosen_actions 中提取 question仅在 reply 动作中)
question = None
if chosen_actions:
for action_info in chosen_actions:
if action_info.action_type == "reply" and isinstance(action_info.action_data, dict):
q = action_info.action_data.get("question")
if isinstance(q, str):
cleaned_q = q.strip()
if cleaned_q:
question = cleaned_q
break
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
@@ -736,7 +724,7 @@ class PrivateReplyer:
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=1, unknown_words=unknown_words, question=question
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=1, unknown_words=unknown_words
),
"memory_retrieval",
),

View File

@@ -368,7 +368,7 @@ class ChatHistory(BaseModel):
theme = TextField() # 主题:这段对话的主要内容,一个简短的标题
keywords = TextField() # 关键词这段对话的关键词JSON格式存储
summary = TextField() # 概括:对这段话的平文本概括
key_point = TextField(null=True) # 关键信息话题中的关键信息点JSON格式存储
# key_point = TextField(null=True) # 关键信息话题中的关键信息点JSON格式存储
count = IntegerField(default=0) # 被检索次数
forget_times = IntegerField(default=0) # 被遗忘检查的次数

View File

@@ -50,7 +50,15 @@ class Server:
async def run(self):
"""启动服务器"""
# 禁用 uvicorn 默认日志和访问日志
config = Config(app=self.app, host=self._host, port=self._port, log_config=None, access_log=False)
# 设置 ws_max_size 为 100MB支持大消息如包含多张图片的转发消息
config = Config(
app=self.app,
host=self._host,
port=self._port,
log_config=None,
access_log=False,
ws_max_size=104_857_600, # 100MB
)
self._server = UvicornServer(config=config)
try:
await self._server.serve()

View File

@@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.12.1"
MMC_VERSION = "0.13.0-snapshot.1"
def get_key_comment(toml_table, key):

View File

@@ -57,9 +57,6 @@ class PersonalityConfig(ConfigBase):
visual_style: str = ""
"""图片提示词"""
private_plan_style: str = ""
"""私聊说话规则,行为风格"""
states: list[str] = field(default_factory=lambda: [])
"""状态列表用于随机替换personality"""
@@ -122,6 +119,12 @@ class ChatConfig(ConfigBase):
- dynamic: think_level由planner动态给出根据planner返回的think_level决定
"""
plan_reply_log_max_per_chat: int = 1024
"""每个聊天流最大保存的Plan/Reply日志数量超过此数量时会自动删除最老的日志"""
llm_quote: bool = False
"""是否在 reply action 中启用 quote 参数,启用后 LLM 可以控制是否引用消息"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
try:
@@ -279,12 +282,20 @@ class MemoryConfig(ConfigBase):
- 当在黑名单中的聊天流进行查询时,仅使用该聊天流的本地记忆
"""
planner_question: bool = True
"""
是否使用 Planner 提供的 question 作为记忆检索问题
- True: 当 Planner 在 reply 动作中提供了 question 时,直接使用该问题进行记忆检索,跳过 LLM 生成问题的步骤
- False: 沿用旧模式,使用 LLM 生成问题
"""
chat_history_topic_check_message_threshold: int = 80
"""聊天历史话题检查的消息数量阈值,当累积消息数达到此值时触发话题检查"""
chat_history_topic_check_time_hours: float = 8.0
"""聊天历史话题检查的时间阈值(小时),当距离上次检查超过此时间且消息数达到最小阈值时触发话题检查"""
chat_history_topic_check_min_messages: int = 20
"""聊天历史话题检查的时间触发模式下的最小消息数阈值"""
chat_history_finalize_no_update_checks: int = 3
"""聊天历史话题打包存储的连续无更新检查次数阈值当话题连续N次检查无新增内容时触发打包存储"""
chat_history_finalize_message_count: int = 5
"""聊天历史话题打包存储的消息条数阈值,当话题的消息条数超过此值时触发打包存储"""
def __post_init__(self):
"""验证配置值"""
@@ -292,6 +303,16 @@ class MemoryConfig(ConfigBase):
raise ValueError(f"max_agent_iterations 必须至少为1当前值: {self.max_agent_iterations}")
if self.agent_timeout_seconds <= 0:
raise ValueError(f"agent_timeout_seconds 必须大于0当前值: {self.agent_timeout_seconds}")
if self.chat_history_topic_check_message_threshold < 1:
raise ValueError(f"chat_history_topic_check_message_threshold 必须至少为1当前值: {self.chat_history_topic_check_message_threshold}")
if self.chat_history_topic_check_time_hours <= 0:
raise ValueError(f"chat_history_topic_check_time_hours 必须大于0当前值: {self.chat_history_topic_check_time_hours}")
if self.chat_history_topic_check_min_messages < 1:
raise ValueError(f"chat_history_topic_check_min_messages 必须至少为1当前值: {self.chat_history_topic_check_min_messages}")
if self.chat_history_finalize_no_update_checks < 1:
raise ValueError(f"chat_history_finalize_no_update_checks 必须至少为1当前值: {self.chat_history_finalize_no_update_checks}")
if self.chat_history_finalize_message_count < 1:
raise ValueError(f"chat_history_finalize_message_count 必须至少为1当前值: {self.chat_history_finalize_message_count}")
@dataclass
@@ -676,6 +697,9 @@ class WebUIConfig(ConfigBase):
secure_cookie: bool = False
"""是否启用安全Cookie仅通过HTTPS传输默认false"""
enable_paragraph_content: bool = False
"""是否在知识图谱中加载段落完整内容需要加载embedding store会占用额外内存"""
@dataclass
class DebugConfig(ConfigBase):
@@ -707,8 +731,8 @@ class DebugConfig(ConfigBase):
class ExperimentalConfig(ConfigBase):
"""实验功能配置类"""
enable_friend_chat: bool = False
"""是否启用好友聊天"""
private_plan_style: str = ""
"""私聊说话规则,行为风格(实验性功能)"""
chat_prompts: list[str] = field(default_factory=lambda: [])
"""
@@ -729,6 +753,9 @@ class ExperimentalConfig(ConfigBase):
- prompt内容: 要添加的额外prompt文本
"""
lpmm_memory: bool = False
"""是否将聊天历史总结导入到LPMM知识库。开启后chat_history_summarizer总结出的历史记录会同时导入到知识库"""
@dataclass
class MaimMessageConfig(ConfigBase):
@@ -898,6 +925,13 @@ class DreamConfig(ConfigBase):
return False
dream_visible: bool = False
"""
做梦结果是否存储到上下文
- True: 将梦境发送给配置的用户后,也会存储到聊天上下文中,在后续对话中可见
- False: 仅发送梦境但不存储,不在后续对话上下文中出现
"""
def __post_init__(self):
"""验证配置值"""
if self.interval_minutes < 1:

View File

@@ -192,7 +192,6 @@ def init_dream_tools(chat_id: str) -> None:
("theme", ToolParamType.STRING, "新的主题标题,如果不需要修改可不填。", False, None),
("summary", ToolParamType.STRING, "新的概括内容,如果不需要修改可不填。", False, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2']。", False, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2']。", False, None),
],
update_chat_history,
)
@@ -201,7 +200,7 @@ def init_dream_tools(chat_id: str) -> None:
_dream_tool_registry.register_tool(
DreamTool(
"create_chat_history",
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)。",
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词等)。",
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
@@ -212,10 +211,11 @@ def init_dream_tools(chat_id: str) -> None:
True,
None,
),
("original_text", ToolParamType.STRING, "对话原文内容(必填)。", True, None),
(
"key_point",
"participants",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
"参与人的 JSON 字符串,如 ['用户1','用户2'](必填)。",
True,
None,
),
@@ -313,8 +313,7 @@ async def run_dream_agent_once(
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
f"概括={record.summary or ''}"
)
logger.debug(

View File

@@ -44,18 +44,6 @@ def get_random_dream_styles(count: int = 2) -> List[str]:
"""从梦境风格列表中随机选择指定数量的风格"""
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
def get_dream_summary_model() -> LLMRequest:
"""获取用于生成梦境总结的 utils 模型实例"""
global _dream_summary_model
if _dream_summary_model is None:
_dream_summary_model = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="dream.summary",
)
return _dream_summary_model
def init_dream_summary_prompt() -> None:
"""初始化梦境总结的提示词"""
Prompt(
@@ -186,10 +174,12 @@ async def generate_dream_summary(
)
# 调用 utils 模型生成梦境
summary_model = get_dream_summary_model()
summary_model = LLMRequest(
model_set=model_config.model_task_config.replyer,
request_type="dream.summary",
)
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
dream_prompt,
max_tokens=512,
temperature=0.8,
)
@@ -225,11 +215,12 @@ async def generate_dream_summary(
f"platform={platform!r}, user_id={user_id!r}"
)
else:
dream_visible = global_config.dream.dream_visible
ok = await send_api.text_to_stream(
dream_content,
stream_id=stream_id,
typing=False,
storage_message=True,
storage_message=dream_visible,
)
if ok:
logger.info(

View File

@@ -11,7 +11,8 @@ def make_create_chat_history(chat_id: str):
theme: str,
summary: str,
keywords: str,
key_point: str,
original_text: str,
participants: str,
start_time: float,
end_time: float,
) -> str:
@@ -20,7 +21,8 @@ def make_create_chat_history(chat_id: str):
logger.info(
f"[dream][tool] 调用 create_chat_history("
f"theme={bool(theme)}, summary={bool(summary)}, "
f"keywords={bool(keywords)}, key_point={bool(key_point)}, "
f"keywords={bool(keywords)}, original_text={bool(original_text)}, "
f"participants={bool(participants)}, "
f"start_time={start_time}, end_time={end_time}) (chat_id={chat_id})"
)
@@ -43,7 +45,8 @@ def make_create_chat_history(chat_id: str):
theme=theme,
summary=summary,
keywords=keywords,
key_point=key_point,
original_text=original_text,
participants=participants,
# 对于由 dream 整理产生的新概括,时间范围优先使用工具提供的时间,否则使用当前时间占位
start_time=start_ts,
end_time=end_ts,

View File

@@ -32,8 +32,7 @@ def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
f"概括={record.summary or ''}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result

View File

@@ -13,13 +13,12 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
theme: Optional[str] = None,
summary: Optional[str] = None,
keywords: Optional[str] = None,
key_point: Optional[str] = None,
) -> str:
"""按字段更新 chat_history字符串字段要求 JSON 的字段须传入已序列化的字符串)"""
try:
logger.info(
f"[dream][tool] 调用 update_chat_history(memory_id={memory_id}, "
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)}, key_point={bool(key_point)})"
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)})"
)
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
if not record:
@@ -34,8 +33,6 @@ def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,
data["summary"] = summary
if keywords is not None:
data["keywords"] = keywords
if key_point is not None:
data["key_point"] = key_point
if not data:
return "未提供任何需要更新的字段。"

View File

@@ -378,12 +378,15 @@ class LLMRequest:
# 空回复:通常为临时问题,单独记录并重试
original_error_info = self._get_original_error_info(e)
retry_remain -= 1
task_display = self.request_type or "未知任务"
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
logger.error(
f"任务 '{task_display}' 的模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
f"任务 '{task_display}'模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
@@ -393,12 +396,15 @@ class LLMRequest:
original_error_info = self._get_original_error_info(e)
retry_remain -= 1
task_display = self.request_type or "未知任务"
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}")
logger.error(
f"任务 '{task_display}' 的模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}{original_error_info}\n"
f"任务 '{task_display}'模型 '{model_info.name}' 遇到网络错误(可重试): {str(e)}{original_error_info}\n"
f" 常见原因: 如请求的API正常但APITimeoutError类型错误过多请尝试调整模型配置中对应API Provider的timeout值\n"
f" 其它可能原因: 网络波动、DNS 故障、连接超时、防火墙限制或代理问题\n"
f" 剩余重试次数: {retry_remain}"
@@ -407,42 +413,50 @@ class LLMRequest:
except RespNotOkException as e:
original_error_info = self._get_original_error_info(e)
task_display = self.request_type or "未知任务"
# 可重试的HTTP错误
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
f"任务 '{task_display}'模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
f"模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}{original_error_info}。剩余重试次数: {retry_remain}"
f"任务 '{task_display}'模型 '{model_info.name}' 遇到可重试的HTTP错误: {str(e)}{original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
continue
# 特殊处理413尝试压缩
if e.status_code == 413 and message_list and not compressed_messages:
logger.warning(f"模型 '{model_info.name}' 返回413请求体过大尝试压缩后重试...")
logger.warning(
f"任务 '{task_display}' 的模型 '{model_info.name}' 返回413请求体过大尝试压缩后重试..."
)
# 压缩消息本身不消耗重试次数
compressed_messages = compress_messages(message_list)
continue
# 不可重试的HTTP错误
logger.warning(f"模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}{original_error_info}")
logger.warning(
f"任务 '{task_display}' 的模型 '{model_info.name}' 遇到不可重试的HTTP错误: {str(e)}{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
except Exception as e:
logger.error(traceback.format_exc())
original_error_info = self._get_original_error_info(e)
task_display = self.request_type or "未知任务"
logger.warning(f"模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}{original_error_info}")
logger.warning(
f"任务 '{task_display}' 的模型 '{model_info.name}' 遇到未知的不可重试错误: {str(e)}{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 遇到硬错误", original_exception=e) from e
raise ModelAttemptFailed(f"模型 '{model_info.name}' 未被尝试因为重试次数已配置为0或更少。")
raise ModelAttemptFailed(f"任务 '{self.request_type or '未知任务'}'模型 '{model_info.name}' 未被尝试因为重试次数已配置为0或更少。")
async def _execute_request(
self,

View File

@@ -15,7 +15,7 @@ from json_repair import repair_json
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config
from src.config.config import model_config, global_config
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import message_api
from src.chat.utils.chat_message_builder import build_readable_messages
@@ -71,16 +71,14 @@ def init_prompt():
1. 关键词提取与话题相关的关键词用列表形式返回3-10个关键词
2. 概括对这段话的平文本概括50-200字要求
- 仔细地转述发生的事件和聊天内容;
- 可以适当摘取聊天记录中的原文;
- 重点突出事件的发展过程和结果;
- 围绕话题这个中心进行概括。
3. 关键信息:提取话题中的关键信息点,用列表形式返回3-8个关键信息点每个关键信息点应该简洁明了。
- 提取话题中的关键信息点,关键信息点应该简洁明了。
请以JSON格式返回格式如下
{{
"keywords": ["关键词1", "关键词2", ...],
"summary": "概括内容",
"key_point": ["关键信息1", "关键信息2", ...]
"summary": "概括内容"
}}
聊天记录:
@@ -370,18 +368,24 @@ class ChatHistorySummarizer:
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
# 检查话题检查触发条件
# 检查"话题检查"触发条件
should_check = False
# 条件1: 消息数量 >= 100触发一次检查
if message_count >= 80:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条")
# 从配置中获取阈值
message_threshold = global_config.memory.chat_history_topic_check_message_threshold
time_threshold_hours = global_config.memory.chat_history_topic_check_time_hours
min_messages = global_config.memory.chat_history_topic_check_min_messages
time_threshold_seconds = time_threshold_hours * 3600
# 条件2: 距离上一次检查 > 3600 * 8 秒8小时且消息数量 >= 20 条,触发一次检查
elif time_since_last_check > 3600 * 8 and message_count >= 20:
# 条件1: 消息数量达到阈值,触发一次检查
if message_count >= message_threshold:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 8小时消息数量达到 {message_count} 条(阈值: 20条)")
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: {message_threshold}条)")
# 条件2: 距离上一次检查超过时间阈值且消息数量达到最小阈值,触发一次检查
elif time_since_last_check > time_threshold_seconds and message_count >= min_messages:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: {time_threshold_hours}小时)且消息数量达到 {message_count} 条(阈值: {min_messages}条)")
if should_check:
await self._run_topic_check_and_update_cache(messages)
@@ -528,14 +532,18 @@ class ChatHistorySummarizer:
item.no_update_checks += 1
# 6. 检查是否有话题需要打包存储
# 从配置中获取阈值
no_update_checks_threshold = global_config.memory.chat_history_finalize_no_update_checks
message_count_threshold = global_config.memory.chat_history_finalize_message_count
topics_to_finalize: List[str] = []
for topic, item in self.topic_cache.items():
if item.no_update_checks >= 3:
logger.info(f"{self.log_prefix} 话题[{topic}] 连续 3 次检查无新增内容,触发打包存储")
if item.no_update_checks >= no_update_checks_threshold:
logger.info(f"{self.log_prefix} 话题[{topic}] 连续 {no_update_checks_threshold} 次检查无新增内容,触发打包存储")
topics_to_finalize.append(topic)
continue
if len(item.messages) > 5:
logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 4,触发打包存储")
if len(item.messages) > message_count_threshold:
logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 {message_count_threshold},触发打包存储")
topics_to_finalize.append(topic)
for topic in topics_to_finalize:
@@ -805,12 +813,38 @@ class ChatHistorySummarizer:
original_text = "\n".join(item.messages)
logger.info(
f"{self.log_prefix} 开始打包话题[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
f"{self.log_prefix} 开始将聊天记录构建成记忆:[{topic}] | 消息数: {len(item.messages)} | 时间范围: {start_time:.2f} - {end_time:.2f}"
)
# 使用 LLM 进行总结(基于话题名)
success, keywords, summary, key_point = await self._compress_with_llm(original_text, topic)
if not success:
# 使用 LLM 进行总结(基于话题名),带重试机制
max_retries = 3
attempt = 0
success = False
keywords = []
summary = ""
while attempt < max_retries:
attempt += 1
success, keywords, summary = await self._compress_with_llm(original_text, topic)
if success and keywords and summary:
# 成功获取到有效的 keywords 和 summary
if attempt > 1:
logger.info(
f"{self.log_prefix} 话题[{topic}] LLM 概括在第 {attempt} 次重试后成功"
)
break
if attempt < max_retries:
logger.warning(
f"{self.log_prefix} 话题[{topic}] LLM 概括失败(第 {attempt} 次尝试),准备重试"
)
else:
logger.error(
f"{self.log_prefix} 话题[{topic}] LLM 概括连续 {max_retries} 次失败,放弃存储"
)
if not success or not keywords or not summary:
logger.warning(f"{self.log_prefix} 话题[{topic}] LLM 概括失败,不写入数据库")
return
@@ -824,14 +858,13 @@ class ChatHistorySummarizer:
theme=topic, # 主题直接使用话题名
keywords=keywords,
summary=summary,
key_point=key_point,
)
logger.info(
f"{self.log_prefix} 话题[{topic}] 成功打包并存储 | 消息数: {len(item.messages)} | 参与者数: {len(participants)}"
)
async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str, List[str]]:
async def _compress_with_llm(self, original_text: str, topic: str) -> tuple[bool, List[str], str]:
"""
使用LLM压缩聊天内容用于单个话题的最终总结
@@ -840,7 +873,7 @@ class ChatHistorySummarizer:
topic: 话题名称
Returns:
tuple[bool, List[str], str, List[str]]: (是否成功, 关键词列表, 概括, 关键信息列表)
tuple[bool, List[str], str]: (是否成功, 关键词列表, 概括)
"""
prompt = await global_prompt_manager.format_prompt(
"hippo_topic_summary_prompt",
@@ -910,24 +943,24 @@ class ChatHistorySummarizer:
keywords = result.get("keywords", [])
summary = result.get("summary", "")
key_point = result.get("key_point", [])
if not (keywords and summary) and key_point:
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少字段原文\n{response}")
# 检查必需字段是否为空
if not keywords or not summary:
logger.warning(f"{self.log_prefix} LLM返回的JSON中缺少必需字段原文\n{response}")
# 返回失败,和模型出错一样,让上层进行重试
return False, [], ""
# 确保keywords和key_point是列表
# 确保keywords是列表
if isinstance(keywords, str):
keywords = [keywords]
if isinstance(key_point, str):
key_point = [key_point]
return True, keywords, summary, key_point
return True, keywords, summary
except Exception as e:
logger.error(f"{self.log_prefix} LLM压缩聊天内容时出错: {e}")
logger.error(f"{self.log_prefix} LLM响应: {response if 'response' in locals() else 'N/A'}")
# 返回失败标志和默认值
return False, [], "压缩失败,无法生成概括", []
return False, [], "压缩失败,无法生成概括"
async def _store_to_database(
self,
@@ -938,7 +971,6 @@ class ChatHistorySummarizer:
theme: str,
keywords: List[str],
summary: str,
key_point: Optional[List[str]] = None,
):
"""存储到数据库"""
try:
@@ -958,10 +990,6 @@ class ChatHistorySummarizer:
"count": 0,
}
# 存储 key_point如果存在
if key_point is not None:
data["key_point"] = json.dumps(key_point, ensure_ascii=False)
# 使用db_save存储使用start_time和chat_id作为唯一标识
# 由于可能有多条记录我们使用组合键但peewee不支持所以使用start_time作为唯一标识
# 但为了避免冲突我们使用组合键chat_id + start_time
@@ -976,6 +1004,15 @@ class ChatHistorySummarizer:
else:
logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败")
# 如果配置开启同时导入到LPMM知识库
if global_config.lpmm_knowledge.enable and global_config.experimental.lpmm_memory:
await self._import_to_lpmm_knowledge(
theme=theme,
summary=summary,
participants=participants,
original_text=original_text,
)
except Exception as e:
logger.error(f"{self.log_prefix} 存储到数据库时出错: {e}")
import traceback
@@ -983,6 +1020,77 @@ class ChatHistorySummarizer:
traceback.print_exc()
raise
async def _import_to_lpmm_knowledge(
self,
theme: str,
summary: str,
participants: List[str],
original_text: str,
):
"""
将聊天历史总结导入到LPMM知识库
Args:
theme: 话题主题
summary: 概括内容
participants: 参与者列表
original_text: 原始文本(可能很长,需要截断)
"""
try:
from src.chat.knowledge.lpmm_ops import lpmm_ops
# 构造要导入的文本内容
# 格式:主题 + 概括 + 参与者信息 + 原始内容摘要
# 注意使用单换行符连接确保整个内容作为一段导入不被LPMM分段
content_parts = []
# 1. 话题主题
# if theme:
# content_parts.append(f"话题:{theme}")
# 2. 概括内容
if summary:
content_parts.append(f"概括:{summary}")
# 3. 参与者信息
if participants:
participants_text = "".join(participants)
content_parts.append(f"参与者:{participants_text}")
# 4. 原始文本摘要如果原始文本太长只取前500字
# if original_text:
# # 截断原始文本,避免过长
# max_original_length = 500
# if len(original_text) > max_original_length:
# truncated_text = original_text[:max_original_length] + "..."
# content_parts.append(f"原始内容摘要:{truncated_text}")
# else:
# content_parts.append(f"原始内容:{original_text}")
# 将所有部分合并为一个完整段落使用单换行符避免被LPMM分段
# LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段
content_to_import = "\n".join(content_parts)
if not content_to_import.strip():
logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库")
return
# 调用lpmm_ops导入
result = await lpmm_ops.add_content(text=content_to_import, auto_split=False)
if result["status"] == "success":
logger.info(
f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}"
)
else:
logger.warning(
f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}"
)
except Exception as e:
# 导入失败不应该影响数据库存储,只记录错误
logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True)
async def start(self):
"""启动后台定期检查循环"""
if self._running:

View File

@@ -1,7 +1,6 @@
import time
import json
import asyncio
import re
from typing import List, Dict, Any, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@@ -9,7 +8,6 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugin_system.apis import llm_api
from src.common.database.database_model import ThinkingBack
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.memory_utils import parse_questions_json
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
@@ -48,53 +46,36 @@ def init_memory_retrieval_prompt():
# 首先注册所有工具
init_all_tools()
# 第步:问题生成prompt
# 第步:ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""
的名字是{bot_name}。现在是{time_now}
群里正在进行的聊天内容
"""你的名字是{bot_name}。现在是{time_now}
正在参与聊天,你需要搜集信息来帮助你进行回复
重要,这是当前聊天记录
{chat_history}
聊天记录结束
{recent_query_history}
已收集的信息:
{collected_info}
现在,{sender}发送了内容:{target_message},你想要回复ta。
请仔细分析聊天内容,考虑以下几点:
1. 对话中是否提到了过去发生的事情、人物、事件或信息
2. 是否有需要回忆的内容(比如"之前说过""上次""以前"等)
3. 是否有需要查找历史信息的问题
4. 是否有问题可以搜集信息帮助你聊天
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
- 思考完毕后,使用工具
重要提示:
- **每次只能提出一个问题**,选择最需要查询的关键问题
- 如果"最近已查询的问题和结果"中已经包含了类似的问题并得到了答案,请避免重复生成相同或相似的问题,不需要重复查询
- 如果之前已经查询过某个问题但未找到答案,可以尝试用不同的方式提问或更具体的问题
如果你认为需要从记忆中检索信息来回答,请根据上下文提出**一个**最关键的问题来帮助你回复目标消息,放入"questions"字段
问题格式示例:
- "xxx在前几天干了什么"
- "xxx是什么在什么时候提到过?"
- "xxxx和xxx的关系是什么"
- "xxx在某个时间点发生了什么"
问题要说明前因后果和上下文,使其全面且精准
输出格式示例:
```json
{{
"questions": ["张三在前几天干了什么"] #问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
}}
```
请只输出JSON对象不要输出其他内容
**工具说明:**
- 如果涉及过往事件或者查询某个过去可能提到过的概念或者某段时间发生的事件。可以使用lpmm知识库查询
- 如果遇到不熟悉的词语、缩写、黑话或网络用语可以使用query_words工具查询其含义
- 你必须使用tool如果需要查询你必须给出使用什么工具进行查询
- 当你决定结束查询时必须调用return_information工具返回总结信息并结束查询
""",
name="memory_retrieval_question_prompt",
name="memory_retrieval_react_prompt_head_lpmm",
)
# 第二步ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""你的名字是{bot_name}。现在是{time_now}
你正在参与聊天,你需要搜集信息来回答问题,帮助你参与聊天
当前需要解答的问题:{question}
你正在参与聊天,你需要搜集信息来帮助你进行回复
当前聊天记录:
{chat_history}
已收集的信息:
{collected_info}
@@ -108,7 +89,7 @@ def init_memory_retrieval_prompt():
- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点
- 先思考当前信息是否足够回答问题
- 如果信息不足则需要使用tool查询信息你必须给出使用什么工具进行查询
- 如果当前已收集的信息足够或信息不足确定无法找到答案你必须调用finish_search工具结束查询
- 当你决定结束查询时必须调用return_information工具返回总结信息并结束查询
""",
name="memory_retrieval_react_prompt_head",
)
@@ -116,23 +97,24 @@ def init_memory_retrieval_prompt():
# 额外如果最后一轮迭代ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""你的名字是{bot_name}。现在是{time_now}
你正在参与聊天,你需要根据搜集到的信息判断问题是否可以回答问题
你正在参与聊天,你需要根据搜集到的信息总结信息
如果搜集到的信息对于参与聊天,回答问题有帮助,请加入总结,如果无关,请不要加入到总结。
当前聊天记录:
{chat_history}
当前问题:{question}
已收集的信息:
{collected_info}
分析:
- 当前信息是否足够回答问题?
- **如果信息足够且能找到明确答案**,在思考中直接给出答案,格式为:found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**,在思考中给出:not_enough_info(reason="信息不足或无法找到答案的原因")
- 基于已收集的信息,总结出对当前聊天有帮助的相关信息
- **如果收集的信息对当前聊天有帮助**,在思考中直接给出总结信息,格式为:return_information(information="你的总结信息")
- **如果信息无关或没有帮助**,在思考中给出:return_information(information="")
**重要规则:**
- 必须严格使用检索到的信息回答问题,不要编造信息
- 答案必须精简,不要过多解释
- **只有在检索到明确、具体的答案时才使用found_answer**
- **如果信息不足、无法确定、找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须给出,格式为 found_answer(answer="...") 或 not_enough_info(reason="...")。
""",
name="memory_retrieval_react_final_prompt",
)
@@ -199,20 +181,20 @@ def _log_conversation_messages(
async def _react_agent_solve_question(
question: str,
chat_id: str,
max_iterations: int = 5,
timeout: float = 30.0,
initial_info: str = "",
chat_history: str = "",
) -> Tuple[bool, str, List[Dict[str, Any]], bool]:
"""使用ReAct架构的Agent来解决问题
Args:
question: 要回答的问题
chat_id: 聊天ID
max_iterations: 最大迭代次数
timeout: 超时时间(秒)
initial_info: 初始信息将作为collected_info的初始值
chat_history: 聊天记录,将传递给 ReAct Agent prompt
Returns:
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
@@ -260,8 +242,8 @@ async def _react_agent_solve_question(
Args:
text: 要搜索的文本
func_name: 函数名,如 'found_answer'
param_name: 参数名,如 'answer'
func_name: 函数名,如 'return_information'
param_name: 参数名,如 'information'
Returns:
提取的参数值如果未找到则返回None
@@ -312,7 +294,7 @@ async def _react_agent_solve_question(
return None
# 正常迭代使用head_prompt决定调用哪些工具包含finish_search工具)
# 正常迭代使用head_prompt决定调用哪些工具包含return_information工具)
tool_definitions = tool_registry.get_tool_definitions()
# tool_names = [tool_def["name"] for tool_def in tool_definitions]
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
@@ -321,11 +303,13 @@ async def _react_agent_solve_question(
if first_head_prompt is None:
# 第一次构建使用初始的collected_info即initial_info
initial_collected_info = initial_info if initial_info else ""
# 根据配置选择使用哪个 prompt
prompt_name = "memory_retrieval_react_prompt_head_lpmm" if global_config.experimental.lpmm_memory else "memory_retrieval_react_prompt_head"
first_head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
prompt_name,
bot_name=bot_name,
time_now=time_now,
question=question,
chat_history=chat_history,
collected_info=initial_collected_info,
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
@@ -373,7 +357,7 @@ async def _react_agent_solve_question(
logger.error(f"ReAct Agent LLM调用失败: {response}")
break
# 注意:这里会检查finish_search工具调用如果检测到finish_search工具会根据found_answer参数决定返回答案或退出查询
# 注意:这里会检查return_information工具调用如果检测到return_information工具会根据information参数决定返回信息或退出查询
assistant_message: Optional[Message] = None
if tool_calls:
@@ -403,115 +387,137 @@ async def _react_agent_solve_question(
# 处理工具调用
if not tool_calls:
# 如果没有工具调用,检查响应文本中是否包含finish_search函数调用格式
# 如果没有工具调用,检查响应文本中是否包含return_information函数调用格式或JSON格式
if response and response.strip():
# 尝试从文本中解析finish_search函数调用
def parse_finish_search_from_text(text: str):
"""从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
# 首先尝试解析JSON格式的return_information
def parse_json_return_information(text: str):
"""从文本中解析JSON格式的return_information返回information字符串如果未找到则返回None"""
if not text:
return None, None
try:
# 尝试提取JSON对象可能包含在代码块中或直接是JSON
json_text = text.strip()
# 如果包含代码块标记提取JSON部分
if "```json" in json_text:
start = json_text.find("```json") + 7
end = json_text.find("```", start)
if end != -1:
json_text = json_text[start:end].strip()
elif "```" in json_text:
start = json_text.find("```") + 3
end = json_text.find("```", start)
if end != -1:
json_text = json_text[start:end].strip()
# 尝试解析JSON
data = json.loads(json_text)
# 检查是否包含return_information字段
if isinstance(data, dict) and "return_information" in data:
information = data.get("information", "")
return information
except (json.JSONDecodeError, ValueError, TypeError):
# 如果JSON解析失败尝试在文本中查找JSON对象
try:
# 查找第一个 { 和最后一个 } 之间的内容更健壮的JSON提取
first_brace = text.find('{')
if first_brace != -1:
# 从第一个 { 开始,找到匹配的 }
brace_count = 0
json_end = -1
for i in range(first_brace, len(text)):
if text[i] == '{':
brace_count += 1
elif text[i] == '}':
brace_count -= 1
if brace_count == 0:
json_end = i + 1
break
if json_end != -1:
json_text = text[first_brace:json_end]
data = json.loads(json_text)
if isinstance(data, dict) and "return_information" in data:
information = data.get("information", "")
return information
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None
# 尝试从文本中解析return_information函数调用
def parse_return_information_from_text(text: str):
"""从文本中解析return_information函数调用返回information字符串如果未找到则返回None"""
if not text:
return None
# 查找finish_search函数调用位置(不区分大小写)
func_pattern = "finish_search"
# 查找return_information函数调用位置(不区分大小写)
func_pattern = "return_information"
text_lower = text.lower()
func_pos = text_lower.find(func_pattern)
if func_pos == -1:
return None, None
return None
# 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos)
if start_pos == -1:
return None, None
# 解析information参数字符串使用extract_quoted_content
information = extract_quoted_content(text, "return_information", "information")
# 如果information存在即使是空字符串也返回它
return information
# 查找匹配的右括号(考虑嵌套)
paren_count = 0
end_pos = start_pos
for i in range(start_pos, len(text)):
if text[i] == "(":
paren_count += 1
elif text[i] == ")":
paren_count -= 1
if paren_count == 0:
end_pos = i
break
else:
# 没有找到匹配的右括号
return None, None
# 首先尝试解析JSON格式
parsed_information_json = parse_json_return_information(response)
is_json_format = parsed_information_json is not None
# 如果JSON解析成功使用JSON结果
if is_json_format:
parsed_information = parsed_information_json
else:
# 如果JSON解析失败尝试解析函数调用格式
parsed_information = parse_return_information_from_text(response)
# 提取函数参数部分
params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None
found_answer_patterns = [
r"found_answer\s*=\s*true",
r"found_answer\s*=\s*True",
r"found_answer\s*=\s*false",
r"found_answer\s*=\s*False",
]
for pattern in found_answer_patterns:
match = re.search(pattern, params_text, re.IGNORECASE)
if match:
found_answer = "true" in match.group(0).lower()
break
# 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer")
return found_answer, answer
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
if parsed_found_answer is not None:
# 检测到finish_search函数调用格式
if parsed_found_answer:
# 找到了答案
if parsed_answer:
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
if parsed_information is not None or is_json_format:
# 检测到return_information格式可能是JSON格式或函数调用格式
format_type = "JSON格式" if is_json_format else "函数调用格式"
# 返回信息(即使为空字符串也返回
step["actions"].append(
{
"action_type": "return_information",
"action_params": {"information": parsed_information or ""},
}
)
if parsed_information and parsed_information.strip():
step["observations"] = [f"检测到return_information{format_type}调用,返回信息"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案"
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_information[:100]}..."
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案",
final_status=f"返回信息:{parsed_information}",
)
return True, parsed_information, thinking_steps, False
else:
# 信息为空,直接退出查询
step["observations"] = [f"检测到return_information{format_type}调用,信息为空"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information{format_type}判断信息为空"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="信息为空通过return_information文本格式判断信息为空",
)
return False, "", thinking_steps, False
# 如果没有检测到finish_search格式,记录思考过程,继续下一轮迭代
# 如果没有检测到return_information格式,记录思考过程,继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 思考完成但未调用工具: {response}"
@@ -525,62 +531,54 @@ async def _react_agent_solve_question(
continue
# 处理工具调用
# 首先检查是否有finish_search工具调用,如果有则立即返回,不再处理其他工具
finish_search_found = None
finish_search_answer = None
# 首先检查是否有return_information工具调用,如果有则立即返回,不再处理其他工具
return_information_info = None
for tool_call in tool_calls:
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
if tool_name == "finish_search":
finish_search_found = tool_args.get("found_answer", False)
finish_search_answer = tool_args.get("answer", "")
if tool_name == "return_information":
return_information_info = tool_args.get("information", "")
if finish_search_found:
# 找到了答案
if finish_search_answer:
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用,找到答案"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}",
)
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(
f"{react_log_prefix}{iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
# 返回信息(即使为空也返回)
step["actions"].append(
{
"action_type": "return_information",
"action_params": {"information": return_information_info},
}
)
if return_information_info and return_information_info.strip():
# 有信息,返回
step["observations"] = ["检测到return_information工具调用,返回信息"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过finish_search工具判断未找到答案"
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information工具返回信息: {return_information_info}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search工具判断未找到答案",
final_status=f"返回信息:{return_information_info}",
)
return True, return_information_info, thinking_steps, False
else:
# 信息为空,直接退出查询
step["observations"] = ["检测到return_information工具调用信息为空"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information工具判断信息为空"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="信息为空通过return_information工具判断信息为空",
)
return False, "", thinking_steps, False
# 如果没有finish_search工具调用,继续处理其他工具
# 如果没有return_information工具调用,继续处理其他工具
tool_tasks = []
for i, tool_call in enumerate(tool_calls):
tool_name = tool_call.func_name
@@ -590,8 +588,8 @@ async def _react_agent_solve_question(
f"{react_log_prefix}{iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 跳过finish_search工具调用(已经在上面处理过了)
if tool_name == "finish_search":
# 跳过return_information工具调用(已经在上面处理过了)
if tool_name == "return_information":
continue
# 记录最后一次使用的工具名称(用于判断是否需要额外迭代)
@@ -689,8 +687,8 @@ async def _react_agent_solve_question(
Args:
text: 要搜索的文本
func_name: 函数名,如 'found_answer'
param_name: 参数名,如 'answer'
func_name: 函数名,如 'return_information'
param_name: 参数名,如 'information'
Returns:
提取的参数值如果未找到则返回None
@@ -746,7 +744,7 @@ async def _react_agent_solve_question(
"memory_retrieval_react_final_prompt",
bot_name=bot_name,
time_now=time_now,
question=question,
chat_history=chat_history,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
@@ -779,64 +777,49 @@ async def _react_agent_solve_question(
logger.info(f"{react_log_prefix}最终评估Prompt: {evaluation_prompt}")
logger.info(f"{react_log_prefix}最终评估响应: {eval_response}")
# 从最终评估响应中提取found_answer或not_enough_info
found_answer_content = None
not_enough_info_reason = None
# 从最终评估响应中提取return_information
return_information_content = None
if eval_response:
found_answer_content = extract_quoted_content(eval_response, "found_answer", "answer")
if not found_answer_content:
not_enough_info_reason = extract_quoted_content(eval_response, "not_enough_info", "reason")
return_information_content = extract_quoted_content(eval_response, "return_information", "information")
# 如果找到答案,返回(找到答案时,无论是否超时,都视为成功完成)
if found_answer_content:
# 如果提取到信息,返回(无论是否超时,都视为成功完成)
if return_information_content is not None:
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"],
"actions": [{"action_type": "return_information", "action_params": {"information": return_information_content}}],
"observations": ["最终评估阶段检测到return_information"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案{found_answer_content}",
)
return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息
if not_enough_info_reason:
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}",
)
return False, "", thinking_steps, is_timeout
if return_information_content and return_information_content.strip():
logger.info(f"ReAct Agent 最终评估阶段返回信息: {return_information_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"返回信息{return_information_content}",
)
return True, return_information_content, thinking_steps, False
else:
logger.info("ReAct Agent 最终评估阶段判断信息为空")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="信息为空:最终评估阶段判断信息为空",
)
return False, "", thinking_steps, False
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
{"action_type": "return_information", "action_params": {"information": ""}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
"observations": ["已到达最大迭代次数,信息为空"],
}
thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
logger.info("ReAct Agent 已到达最大迭代次数,信息为空")
_log_conversation_messages(
conversation_messages,
@@ -1003,66 +986,48 @@ def _store_thinking_back(
logger.error(f"存储思考过程失败: {e}")
async def _process_single_question(
question: str,
async def _process_memory_retrieval(
chat_id: str,
context: str,
initial_info: str = "",
max_iterations: Optional[int] = None,
chat_history: str = "",
) -> Optional[str]:
"""处理单个问题的查询
"""处理记忆检索
Args:
question: 要查询的问题
chat_id: 聊天ID
context: 上下文信息
initial_info: 初始信息将传递给ReAct Agent
max_iterations: 最大迭代次数
chat_history: 聊天记录,将传递给 ReAct Agent
Returns:
Optional[str]: 如果找到答案,返回格式化的结果字符串否则返回None
Optional[str]: 如果找到答案,返回答案内容否则返回None
"""
# 如果question为空或None直接返回None不进行查询
if not question or not question.strip():
logger.debug("问题为空,跳过查询")
return None
# logger.info(f"开始处理问题: {question}")
_cleanup_stale_not_found_thinking_back()
question_initial_info = initial_info or ""
# 直接使用ReAct Agent查询不再从thinking_back获取缓存
# logger.info(f"使用ReAct Agent查询问题: {question[:50]}...")
# 直接使用ReAct Agent进行记忆检索
# 如果未指定max_iterations使用配置的默认值
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=max_iterations,
timeout=global_config.memory.agent_timeout_seconds,
initial_info=question_initial_info,
chat_history=chat_history,
)
# 存储查询历史到数据库(超时时不存储)
if not is_timeout:
_store_thinking_back(
chat_id=chat_id,
question=question,
context=context,
found_answer=found_answer,
answer=answer,
thinking_steps=thinking_steps,
)
else:
logger.info(f"ReAct Agent超时不存储到数据库问题: {question[:50]}...")
# 不再存储到数据库,直接返回答案
if is_timeout:
logger.info("ReAct Agent超时不返回结果")
if found_answer and answer:
return f"问题:{question}\n答案:{answer}"
return answer
return None
@@ -1074,11 +1039,8 @@ async def build_memory_retrieval_prompt(
chat_stream,
think_level: int = 1,
unknown_words: Optional[List[str]] = None,
question: Optional[str] = None,
) -> str:
"""构建记忆检索提示
使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案
Args:
message: 聊天历史记录
sender: 发送者名称
@@ -1086,7 +1048,6 @@ async def build_memory_retrieval_prompt(
chat_stream: 聊天流对象
think_level: 思考深度等级
unknown_words: Planner 提供的未知词语列表,优先使用此列表而不是从聊天记录匹配
question: Planner 提供的问题,当 planner_question 配置开启时,直接使用此问题进行检索
Returns:
str: 记忆检索结果字符串
@@ -1112,62 +1073,8 @@ async def build_memory_retrieval_prompt(
logger.info(f"{log_prefix}检测是否需要回忆,元消息:{message[:30]}...,消息长度: {len(message)}")
try:
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
chat_id = chat_stream.stream_id
# 获取最近查询历史最近10分钟内的查询用于避免重复查询
recent_query_history = _get_recent_query_history(chat_id, time_window_seconds=600.0)
if not recent_query_history:
recent_query_history = "最近没有查询记录。"
# 第一步:生成问题或使用 Planner 提供的问题
single_question: Optional[str] = None
# 如果 planner_question 配置开启,只使用 Planner 提供的问题,不使用旧模式
if global_config.memory.planner_question:
if question and isinstance(question, str) and question.strip():
# 清理和验证 question
single_question = question.strip()
logger.info(f"{log_prefix}使用 Planner 提供的 question: {single_question}")
else:
# planner_question 开启但没有提供 question跳过记忆检索
logger.debug(f"{log_prefix}planner_question 已开启但未提供 question跳过记忆检索")
end_time = time.time()
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
return ""
else:
# planner_question 关闭使用旧模式LLM 生成问题
question_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_question_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=message,
recent_query_history=recent_query_history,
sender=sender,
target_message=target,
)
success, response, reasoning_content, model_name = await llm_api.generate_with_model(
question_prompt,
model_config=model_config.model_task_config.tool_use,
request_type="memory.question",
)
if global_config.debug.show_memory_prompt:
logger.info(f"{log_prefix}记忆检索问题生成提示词: {question_prompt}")
# logger.info(f"记忆检索问题生成响应: {response}")
if not success:
logger.error(f"{log_prefix}LLM生成问题失败: {response}")
return ""
# 解析概念列表和问题列表,只取第一个问题
_, questions = parse_questions_json(response)
if questions and len(questions) > 0:
single_question = questions[0].strip()
logger.info(f"{log_prefix}解析到问题: {single_question}")
# 初始阶段:使用 Planner 提供的 unknown_words 进行检索(如果提供)
initial_info = ""
if unknown_words and len(unknown_words) > 0:
@@ -1189,13 +1096,7 @@ async def build_memory_retrieval_prompt(
else:
logger.debug(f"{log_prefix}unknown_words 检索未找到任何结果")
if not single_question:
logger.debug(f"{log_prefix}模型认为不需要检索记忆或解析失败,不返回任何查询结果")
end_time = time.time()
logger.info(f"{log_prefix}无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
return ""
# 第二步:处理问题(使用配置的最大迭代次数和超时时间)
# 直接使用 ReAct Agent 进行记忆检索(跳过问题生成步骤)
base_max_iterations = global_config.memory.max_agent_iterations
# 根据think_level调整迭代次数think_level=1时不变think_level=0时减半
if think_level == 0:
@@ -1204,60 +1105,31 @@ async def build_memory_retrieval_prompt(
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(
f"{log_prefix}问题: {single_question}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
f"{log_prefix}直接使用 ReAct Agent 进行记忆检索think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 处理单个问题
# 直接调用 ReAct Agent 处理记忆检索
try:
result = await _process_single_question(
question=single_question,
result = await _process_memory_retrieval(
chat_id=chat_id,
context=message,
initial_info=initial_info,
max_iterations=max_iterations,
chat_history=message,
)
except Exception as e:
logger.error(f"{log_prefix}处理问题 '{single_question}' 时发生异常: {e}")
logger.error(f"{log_prefix}处理记忆检索时发生异常: {e}")
result = None
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = []
# 先添加当前查询的结果
current_question = None
if result:
all_results.append(result)
# 提取问题(格式为 "问题xxx\n答案xxx"
if result.startswith("问题:"):
question_end = result.find("\n答案:")
if question_end != -1:
current_question = result[4:question_end]
# 添加缓存答案(排除当前查询的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
question_end = cached_answer.find("\n答案:")
if question_end != -1:
cached_question = cached_answer[4:question_end]
if cached_question != current_question:
all_results.append(cached_answer)
end_time = time.time()
if all_results:
retrieved_memory = "\n\n".join(all_results)
current_count = 1 if result else 0
cached_count = len(all_results) - current_count
if result:
logger.info(
f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}"
f"当前查询 {current_count} 条记忆,缓存 {cached_count} 条记忆,共 {len(all_results)} 条记忆"
f"{log_prefix}记忆检索成功,耗时: {(end_time - start_time):.3f}"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
return f"你回忆起了以下信息:\n{result}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
logger.debug(f"{log_prefix}问题未找到答案,且无缓存答案")
logger.debug(f"{log_prefix}记忆检索未找到相关信息")
return ""
except Exception as e:

View File

@@ -15,16 +15,18 @@ from .query_chat_history import register_tool as register_query_chat_history
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from .query_words import register_tool as register_query_words
from .found_answer import register_tool as register_finish_search
from .return_information import register_tool as register_return_information
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_chat_history()
register_query_person_info()
register_query_words() # 注册query_words工具
register_finish_search() # 注册finish_search工具
# 如果开启了lpmm_memory则不注册query_chat_history工具
if not global_config.experimental.lpmm_memory:
register_query_chat_history()
register_query_person_info()
register_query_words()
register_return_information()
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()

View File

@@ -1,49 +0,0 @@
"""
finish_search工具 - 用于在记忆检索过程中结束查询
"""
from src.common.logger import get_logger
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def finish_search(found_answer: bool, answer: str = "") -> str:
"""结束查询
Args:
found_answer: 是否找到了答案
answer: 如果找到了答案,提供答案内容;如果未找到,可以为空
Returns:
str: 确认信息
"""
if found_answer:
logger.info(f"找到答案: {answer}")
return f"已确认找到答案: {answer}"
else:
logger.info("未找到答案,结束查询")
return "未找到答案,查询结束"
def register_tool():
"""注册finish_search工具"""
register_memory_retrieval_tool(
name="finish_search",
description="当你决定结束查询时调用此工具。如果找到了明确答案设置found_answer为true并在answer中提供答案如果未找到答案设置found_answer为false。只有在检索到明确、具体的答案时才设置found_answer为true不要编造信息。",
parameters=[
{
"name": "found_answer",
"type": "boolean",
"description": "是否找到了答案",
"required": True,
},
{
"name": "answer",
"type": "string",
"description": "如果found_answer为true提供找到的答案内容必须基于已收集的信息不要编造如果found_answer为false可以为空",
"required": False,
},
],
execute_func=finish_search,
)

View File

@@ -82,21 +82,51 @@ def _is_chat_id_in_blacklist(chat_id: str) -> bool:
return chat_id in blacklist_chat_ids
async def search_chat_history(chat_id: str, keyword: Optional[str] = None, participant: Optional[str] = None) -> str:
async def search_chat_history(
chat_id: str,
keyword: Optional[str] = None,
participant: Optional[str] = None,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords
Args:
chat_id: 聊天ID
keyword: 关键词(可选,支持多个关键词,可用空格、逗号等分隔。匹配规则:如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配
participant: 参与人昵称(可选)
start_time: 开始时间(可选,格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供start_time查询该时间点之后的记录
end_time: 结束时间(可选,格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供end_time查询该时间点之前的记录。如果同时提供start_time和end_time查询该时间段内的记录
Returns:
str: 查询结果包含记忆id、theme和keywords
"""
try:
# 检查参数
if not keyword and not participant:
return "未指定查询参数需要提供keywordparticipant之一"
if not keyword and not participant and not start_time and not end_time:
return "未指定查询参数需要提供keywordparticipant、start_time或end_time之一)"
# 解析时间参数
start_timestamp = None
end_timestamp = None
if start_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
start_timestamp = parse_datetime_to_timestamp(start_time)
except ValueError as e:
return f"开始时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
if end_time:
try:
from src.memory_system.memory_utils import parse_datetime_to_timestamp
end_timestamp = parse_datetime_to_timestamp(end_time)
except ValueError as e:
return f"结束时间格式错误: {str(e)},支持格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'"
# 验证时间范围
if start_timestamp and end_timestamp and start_timestamp > end_timestamp:
return "开始时间不能晚于结束时间"
# 构建查询条件
# 检查当前chat_id是否在黑名单中
@@ -128,6 +158,40 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
f"search_chat_history 当前聊天流在黑名单中强制使用本地查询chat_id={chat_id}, keyword={keyword}, participant={participant}"
)
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 添加时间过滤条件
if start_timestamp is not None and end_timestamp is not None:
# 查询指定时间段内的记录(记录的时间范围与查询时间段有交集)
# 记录的开始时间在查询时间段内,或记录的结束时间在查询时间段内,或记录完全包含查询时间段
query = query.where(
(
(ChatHistory.start_time >= start_timestamp)
& (ChatHistory.start_time <= end_timestamp)
) # 记录开始时间在查询时间段内
| (
(ChatHistory.end_time >= start_timestamp)
& (ChatHistory.end_time <= end_timestamp)
) # 记录结束时间在查询时间段内
| (
(ChatHistory.start_time <= start_timestamp)
& (ChatHistory.end_time >= end_timestamp)
) # 记录完全包含查询时间段
)
logger.debug(
f"search_chat_history 添加时间范围过滤: {start_timestamp} - {end_timestamp}, keyword={keyword}, participant={participant}"
)
elif start_timestamp is not None:
# 只提供开始时间,查询该时间点之后的记录(记录的开始时间或结束时间在该时间点之后)
query = query.where(ChatHistory.end_time >= start_timestamp)
logger.debug(
f"search_chat_history 添加开始时间过滤: >= {start_timestamp}, keyword={keyword}, participant={participant}"
)
elif end_timestamp is not None:
# 只提供结束时间,查询该时间点之前的记录(记录的开始时间或结束时间在该时间点之前)
query = query.where(ChatHistory.start_time <= end_timestamp)
logger.debug(
f"search_chat_history 添加结束时间过滤: <= {end_timestamp}, keyword={keyword}, participant={participant}"
)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
@@ -217,21 +281,31 @@ async def search_chat_history(chat_id: str, keyword: Optional[str] = None, parti
filtered_records.append(record)
if not filtered_records:
if keyword and participant:
keywords_str = "".join(parse_keywords_string(keyword) if keyword else [])
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
elif keyword:
# 构建查询条件描述
conditions = []
if keyword:
keywords_str = "".join(parse_keywords_string(keyword))
keywords_list = parse_keywords_string(keyword)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return (
f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
)
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
return f"未找到参与人包含'{participant}'的聊天记录"
conditions.append(f"关键词'{keywords_str}'")
if participant:
conditions.append(f"参与人'{participant}'")
if start_timestamp or end_timestamp:
time_desc = ""
if start_timestamp and end_timestamp:
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间范围'{start_str}''{end_str}'"
elif start_timestamp:
start_str = datetime.fromtimestamp(start_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间>='{start_str}'"
elif end_timestamp:
end_str = datetime.fromtimestamp(end_timestamp).strftime("%Y-%m-%d %H:%M:%S")
time_desc = f"时间<='{end_str}'"
if time_desc:
conditions.append(time_desc)
if conditions:
conditions_str = "".join(conditions)
return f"未找到满足条件({conditions_str})的聊天记录"
else:
return "未找到相关聊天记录"
@@ -389,18 +463,6 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
if record.summary:
result_parts.append(f"概括:{record.summary}")
# 添加关键信息点
if record.key_point:
try:
key_point_data = (
json.loads(record.key_point) if isinstance(record.key_point, str) else record.key_point
)
if isinstance(key_point_data, list) and key_point_data:
key_point_str = "\n".join([f" - {str(kp)}" for kp in key_point_data])
result_parts.append(f"关键信息点:\n{key_point_str}")
except (json.JSONDecodeError, TypeError, ValueError):
pass
results.append("\n".join(result_parts))
if not results:
@@ -419,7 +481,7 @@ def register_tool():
# 注册工具1搜索记忆
register_memory_retrieval_tool(
name="search_chat_history",
description="根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配容错匹配",
description="根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywords。用于快速搜索和定位相关记忆。匹配规则如果关键词数量<=2必须全部匹配如果关键词数量>2允许n-1个关键词匹配容错匹配支持按时间点或时间段进行查询。",
parameters=[
{
"name": "keyword",
@@ -433,6 +495,18 @@ def register_tool():
"description": "参与人昵称(可选),用于查询包含该参与人的记忆",
"required": False,
},
{
"name": "start_time",
"type": "string",
"description": "开始时间(可选),格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供start_time查询该时间点之后的记录。如果同时提供start_time和end_time查询该时间段内的记录",
"required": False,
},
{
"name": "end_time",
"type": "string",
"description": "结束时间(可选),格式如:'2025-01-01''2025-01-01 12:00:00''2025/01/01'。如果只提供end_time查询该时间点之前的记录。如果同时提供start_time和end_time查询该时间段内的记录",
"required": False,
},
],
execute_func=search_chat_history,
)

View File

@@ -56,12 +56,12 @@ def register_tool():
"""注册LPMM知识库查询工具"""
register_memory_retrieval_tool(
name="lpmm_search_knowledge",
description="LPMM知识库中搜索相关信息,适用于需要知识支持的场景。",
description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索",
parameters=[
{
"name": "query",
"type": "string",
"description": "需要查询的关键词或问题",
"description": "需要查询的问题使用一句疑问句提问例如什么是AI",
"required": True,
},
{

View File

@@ -3,7 +3,6 @@
用于在记忆检索过程中主动查询未知词语或黑话的含义
"""
from typing import List, Optional
from src.common.logger import get_logger
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
from .tool_registry import register_memory_retrieval_tool

View File

@@ -0,0 +1,42 @@
"""
return_information工具 - 用于在记忆检索过程中返回总结信息并结束查询
"""
from src.common.logger import get_logger
from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def return_information(information: str) -> str:
"""返回总结信息并结束查询
Args:
information: 基于已收集信息总结出的相关信息,用于帮助回复。如果收集的信息对当前聊天没有帮助,可以返回空字符串。
Returns:
str: 确认信息
"""
if information and information.strip():
logger.info(f"返回总结信息: {information}")
return f"已确认返回信息: {information}"
else:
logger.info("未收集到相关信息,结束查询")
return "未收集到相关信息,查询结束"
def register_tool():
"""注册return_information工具"""
register_memory_retrieval_tool(
name="return_information",
description="当你决定结束查询时调用此工具。基于已收集的信息总结出一段相关信息用于帮助回复。如果收集的信息对当前聊天有帮助在information参数中提供总结信息如果信息无关或没有帮助可以提供空字符串。",
parameters=[
{
"name": "information",
"type": "string",
"description": "基于已收集信息总结出的相关信息,用于帮助回复。必须基于已收集的信息,不要编造。如果信息对当前聊天没有帮助,可以返回空字符串。",
"required": True,
},
],
execute_func=return_information,
)

View File

@@ -12,7 +12,6 @@ import time
from typing import List, Dict, Any, Tuple, Optional
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import Images
from src.config.config import global_config
from src.chat.utils.utils import is_bot_self
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,

161
src/webui/app.py Normal file
View File

@@ -0,0 +1,161 @@
"""FastAPI 应用工厂 - 创建和配置 WebUI 应用实例"""
import mimetypes
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from src.common.logger import get_logger
logger = get_logger("webui.app")
def create_app(
host: str = "0.0.0.0",
port: int = 8001,
enable_static: bool = True,
) -> FastAPI:
"""
创建 WebUI FastAPI 应用实例
Args:
host: 服务器主机地址
port: 服务器端口
enable_static: 是否启用静态文件服务
"""
app = FastAPI(title="MaiBot WebUI")
_setup_anti_crawler(app)
_setup_cors(app, port)
_register_api_routes(app)
_setup_robots_txt(app)
if enable_static:
_setup_static_files(app)
return app
def _setup_cors(app: FastAPI, port: int):
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://127.0.0.1:5173",
"http://localhost:7999",
"http://127.0.0.1:7999",
f"http://localhost:{port}",
f"http://127.0.0.1:{port}",
],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
],
expose_headers=["Content-Length", "Content-Type"],
)
logger.debug("✅ CORS 中间件已配置")
def _setup_anti_crawler(app: FastAPI):
try:
from src.webui.middleware import AntiCrawlerMiddleware
from src.config.config import global_config
anti_crawler_mode = global_config.webui.anti_crawler_mode
app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {
"false": "已禁用",
"strict": "严格模式",
"loose": "宽松模式",
"basic": "基础模式",
}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(app: FastAPI):
try:
from src.webui.middleware import create_robots_txt_response
@app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(app: FastAPI):
try:
from src.webui.routers import get_all_routers
for router in get_all_routers():
app.include_router(router)
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
def _setup_static_files(app: FastAPI):
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return
@app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
if not full_path or full_path == "/":
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type)
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def show_access_token():
"""显示 WebUI Access Token供启动时调用"""
try:
from src.webui.core import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
except Exception as e:
logger.error(f"❌ 获取 Access Token 失败: {e}")

View File

@@ -0,0 +1,30 @@
from .security import TokenManager, get_token_manager
from .rate_limiter import (
RateLimiter,
get_rate_limiter,
check_auth_rate_limit,
check_api_rate_limit,
)
from .auth import (
COOKIE_NAME,
COOKIE_MAX_AGE,
get_current_token,
set_auth_cookie,
clear_auth_cookie,
verify_auth_token_from_cookie_or_header,
)
__all__ = [
"TokenManager",
"get_token_manager",
"RateLimiter",
"get_rate_limiter",
"check_auth_rate_limit",
"check_api_rate_limit",
"COOKIE_NAME",
"COOKIE_MAX_AGE",
"get_current_token",
"set_auth_cookie",
"clear_auth_cookie",
"verify_auth_token_from_cookie_or_header",
]

View File

@@ -7,7 +7,7 @@ from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
from src.config.config import global_config
from .token_manager import get_token_manager
from .security import get_token_manager
logger = get_logger("webui.auth")
@@ -27,7 +27,7 @@ def _is_secure_environment() -> bool:
if global_config.webui.secure_cookie:
logger.info("配置中启用了 secure_cookie")
return True
# 检查是否是生产环境
if global_config.webui.mode == "production":
logger.info("WebUI运行在生产模式启用 secure cookie")
@@ -88,7 +88,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] =
"""
# 根据环境和实际请求协议决定安全设置
is_secure = _is_secure_environment()
# 如果提供了 request检测实际使用的协议
if request:
# 检查 X-Forwarded-Proto header代理/负载均衡器)
@@ -100,7 +100,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] =
# 检查 request.url.scheme
is_https = request.url.scheme == "https"
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
# 如果是 HTTP 连接,强制禁用 secure 标志
if not is_https and is_secure:
logger.warning("=" * 80)
@@ -110,7 +110,7 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] =
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
logger.warning("=" * 80)
is_secure = False
# 设置 Cookie
response.set_cookie(
key=COOKIE_NAME,
@@ -121,8 +121,10 @@ def set_auth_cookie(response: Response, token: str, request: Optional[Request] =
secure=is_secure, # 根据实际协议决定
path="/", # 确保 Cookie 在所有路径下可用
)
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
logger.info(
f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})"
)
logger.debug(f"完整 token 前缀: {token[:20]}...")

View File

@@ -24,8 +24,8 @@ class TokenManager:
config_path: 配置文件路径默认为项目根目录的 data/webui.json
"""
if config_path is None:
# 获取项目根目录 (src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent
# 获取项目根目录 (src/webui/core -> src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent.parent
config_path = project_root / "data" / "webui.json"
self.config_path = config_path

87
src/webui/dependencies.py Normal file
View File

@@ -0,0 +1,87 @@
from typing import Optional
from fastapi import Depends, Cookie, Header, Request, HTTPException
from .core import get_current_token, get_token_manager, check_auth_rate_limit, check_api_rate_limit
async def require_auth(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
FastAPI 依赖:要求有效认证
用于保护需要认证的路由,自动从 Cookie 或 Header 获取并验证 token
Returns:
验证通过的 token
Raises:
HTTPException 401: 认证失败
"""
return get_current_token(request, maibot_session, authorization)
async def require_auth_with_rate_limit(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
_rate_limit: None = Depends(check_auth_rate_limit),
) -> str:
"""
FastAPI 依赖:要求有效认证 + 频率限制
组合了认证检查和频率限制,适用于敏感操作
Returns:
验证通过的 token
Raises:
HTTPException 401: 认证失败
HTTPException 429: 请求过于频繁
"""
return get_current_token(request, maibot_session, authorization)
def get_optional_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> Optional[str]:
"""
FastAPI 依赖:可选获取 token不验证
用于某些需要知道是否有 token 但不强制验证的场景
Returns:
token 字符串或 None
"""
if maibot_session:
return maibot_session
if authorization and authorization.startswith("Bearer "):
return authorization.replace("Bearer ", "")
return None
async def verify_token_optional(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""
FastAPI 依赖:可选验证 token
返回 token 是否有效,不抛出异常
Returns:
True 如果 token 有效,否则 False
"""
token = None
if maibot_session:
token = maibot_session
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
return False
token_manager = get_token_manager()
return token_manager.verify_token(token)

View File

@@ -0,0 +1,17 @@
from .anti_crawler import (
AntiCrawlerMiddleware,
create_robots_txt_response,
ANTI_CRAWLER_MODE,
ALLOWED_IPS,
TRUSTED_PROXIES,
TRUST_XFF,
)
__all__ = [
"AntiCrawlerMiddleware",
"create_robots_txt_response",
"ANTI_CRAWLER_MODE",
"ALLOWED_IPS",
"TRUSTED_PROXIES",
"TRUST_XFF",
]

View File

@@ -3,7 +3,6 @@ WebUI 防爬虫模块
提供爬虫检测和阻止功能保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import os
import time
import ipaddress
import re

View File

@@ -0,0 +1,35 @@
"""WebUI 路由聚合模块 - 提供统一的路由注册接口"""
from fastapi import APIRouter
def get_api_router() -> APIRouter:
"""获取主 API 路由器(包含所有子路由)"""
from src.webui.routes import router as main_router
return main_router
def get_all_routers() -> list[APIRouter]:
"""获取所有需要独立注册的路由器列表"""
from src.webui.routes import router as main_router
from src.webui.routers.websocket.logs import router as logs_router
from src.webui.routers.knowledge import router as knowledge_router
from src.webui.routers.chat import router as chat_router
from src.webui.api.planner import router as planner_router
from src.webui.api.replier import router as replier_router
return [
main_router,
logs_router,
knowledge_router,
chat_router,
planner_router,
replier_router,
]
__all__ = [
"get_api_router",
"get_all_routers",
]

View File

@@ -18,7 +18,7 @@ from src.common.database.database_model import (
ActionRecords,
Jargon,
)
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
@@ -54,9 +54,8 @@ class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
new_friends_count: int = Field(0, description="今年新认识的朋友数")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP3")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP3")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
@@ -71,6 +70,7 @@ class BrainPowerData(BaseModel):
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
@@ -89,13 +89,15 @@ class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP5表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="最常用的表达风格")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
@@ -111,6 +113,7 @@ class AnnualReportData(BaseModel):
"""年度报告完整数据"""
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
@@ -231,25 +234,19 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 加入的群组总数
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
# 2. 今年新认识的朋友数
data.new_friends_count = (
PersonInfo.select()
.where(
(PersonInfo.know_times.is_null(False))
& (PersonInfo.know_times >= start_ts)
& (PersonInfo.know_times <= end_ts)
)
.count()
)
# 3. 话痨群组 TOP3
# 2. 话痨群组 TOP3
top_groups_query = (
Messages.select(
Messages.chat_info_group_id,
@@ -263,18 +260,19 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(3)
.limit(5)
)
data.top_groups = [
{
"group_id": row["chat_info_group_id"],
"group_name": row["chat_info_group_name"] or "未知群组",
"message_count": row["count"],
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
}
for row in top_groups_query.dicts()
]
# 4. 互动最多的用户 TOP3
# 3. 互动最多的用户 TOP5过滤 bot 自身)
top_users_query = (
Messages.select(
Messages.user_id,
@@ -285,21 +283,23 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id.is_null(False))
& (Messages.user_id != bot_qq) # 过滤 bot 自身
)
.group_by(Messages.user_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(3)
.limit(5)
)
data.top_users = [
{
"user_id": row["user_id"],
"user_nickname": row["user_nickname"] or "未知用户",
"message_count": row["count"],
"is_webui": str(row["user_id"]).startswith("webui_"),
}
for row in top_users_query.dicts()
]
# 5. 被@次数
# 4. 被@次数
data.at_count = (
Messages.select()
.where(
@@ -310,7 +310,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
.count()
)
# 6. 被提及次数
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where(
@@ -321,15 +321,17 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
.count()
)
# 7. 最长情陪伴的用户
# 找出跨度时间最长的用户
# 6. 最长情陪伴的用户(过滤 bot 自身)
companion_query = (
ChatStreams.select(
ChatStreams.user_id,
ChatStreams.user_nickname,
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
)
.where(ChatStreams.user_id.is_null(False))
.where(
(ChatStreams.user_id.is_null(False))
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
)
@@ -403,14 +405,19 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户)
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_query = (
LLMUsage.select(
LLMUsage.user_id,
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (LLMUsage.user_id != "system") # 过滤 system 用户
& (LLMUsage.user_id.is_null(False))
)
.group_by(LLMUsage.user_id)
.order_by(fn.SUM(LLMUsage.cost).desc())
.limit(3)
@@ -424,7 +431,32 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
for row in consumer_query.dicts()
]
# 5. 高冷指数 (沉默率) - 基于 ActionRecords
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (
LLMUsage.model_assign_name.contains("replyer")
| LLMUsage.model_assign_name.contains("回复")
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
)
)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where(
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
).count()
@@ -504,8 +536,13 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 表情包之王 - 使用次数最多的表情包
@@ -586,7 +623,12 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
.count()
)
# 6. 动作类型分布 (非 reply 动作)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
]
action_query = (
ActionRecords.select(
ActionRecords.action_name,
@@ -595,8 +637,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name != "reply")
& (ActionRecords.action_name != "no_reply")
& (ActionRecords.action_name.not_in(excluded_actions))
)
.group_by(ActionRecords.action_name)
.order_by(fn.COUNT(ActionRecords.id).desc())
@@ -618,6 +659,128 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
.count()
)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content)
# 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
# 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
Messages.time,
Messages.processed_plain_text,
Messages.display_message,
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = datetime.fromtimestamp(msg.time)
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
})
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
data.late_night_reply = {
"time": selected["datetime_str"],
"content": content,
}
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
reply_records = (
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
)
reply_contents = []
for record in reply_records:
try:
action_data = record.action_data
if action_data:
content = None
# 尝试解析 JSON 格式
try:
parsed = json_lib.loads(action_data)
if isinstance(parsed, dict):
# 优先使用 reply_text其次使用 content
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
if most_common:
fav_content, fav_count = most_common[0]
# 截断过长的内容
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
data.favorite_reply = {
"content": display_content,
"count": fav_count,
}
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
@@ -692,7 +855,12 @@ async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require
完整的年度报告数据
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"
# 并行获取各维度数据
time_footprint = await get_time_footprint(year)
@@ -703,6 +871,7 @@ async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require
report = AnnualReportData(
year=year,
bot_name=bot_name,
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
time_footprint=time_footprint,
social_network=social_network,

View File

@@ -15,9 +15,8 @@ from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")

View File

@@ -8,7 +8,7 @@ from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (

View File

@@ -6,8 +6,7 @@ from pydantic import BaseModel
from typing import Optional, List, Annotated
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from .token_manager import get_token_manager
from .auth import verify_auth_token_from_cookie_or_header
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
import time
import os
import hashlib

View File

@@ -1,11 +1,11 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel, NonNegativeFloat
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from .auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
@@ -224,10 +224,7 @@ async def get_expression_list(
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search))
| (Expression.style.contains(search))
)
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
@@ -363,21 +360,21 @@ async def update_expression(
if request.require_unchecked and expression.checked:
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表"
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段
update_data.pop('require_unchecked', None)
update_data.pop("require_unchecked", None)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if 'checked' in update_data or 'rejected' in update_data:
update_data['modified_by'] = 'user'
if "checked" in update_data or "rejected" in update_data:
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
@@ -542,8 +539,10 @@ async def get_expression_stats(
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
@@ -553,10 +552,7 @@ class ReviewStatsResponse(BaseModel):
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None)
):
async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取审核统计数据
@@ -568,14 +564,10 @@ async def get_review_stats(
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where(
(Expression.checked == True) & (Expression.rejected == False)
).count()
rejected = Expression.select().where(
(Expression.checked == True) & (Expression.rejected == True)
).count()
ai_checked = Expression.select().where(Expression.modified_by == 'ai').count()
user_checked = Expression.select().where(Expression.modified_by == 'user').count()
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
user_checked = Expression.select().where(Expression.modified_by == "user").count()
return ReviewStatsResponse(
total=total,
@@ -583,7 +575,7 @@ async def get_review_stats(
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked
user_checked=user_checked,
)
except HTTPException:
@@ -595,6 +587,7 @@ async def get_review_stats(
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
@@ -641,9 +634,7 @@ async def get_review_list(
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) | (Expression.style.contains(search))
)
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
# 聊天ID过滤
if chat_id:
@@ -651,10 +642,8 @@ async def get_review_list(
# 排序:创建时间倒序
from peewee import Case
query = query.order_by(
Case(None, [(Expression.create_date.is_null(), 1)], 0),
Expression.create_date.desc()
)
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
total = query.count()
offset = (page - 1) * page_size
@@ -665,7 +654,7 @@ async def get_review_list(
total=total,
page=page,
page_size=page_size,
data=[expression_to_response(expr) for expr in expressions]
data=[expression_to_response(expr) for expr in expressions],
)
except HTTPException:
@@ -677,6 +666,7 @@ async def get_review_list(
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
@@ -684,11 +674,13 @@ class BatchReviewItem(BaseModel):
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
@@ -696,6 +688,7 @@ class BatchReviewResultItem(BaseModel):
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
@@ -733,54 +726,44 @@ async def batch_review_expressions(
expression = Expression.get_or_none(Expression.id == item.id)
if not expression:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=f"未找到 ID 为 {item.id} 的表达方式"
))
results.append(
BatchReviewResultItem(id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式")
)
failed += 1
continue
# 冲突检测
if item.require_unchecked and expression.checked:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查"
))
results.append(
BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
)
)
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = 'user'
expression.modified_by = "user"
expression.last_active_time = time.time()
expression.save()
results.append(BatchReviewResultItem(
id=item.id,
success=True,
message="通过" if not item.rejected else "拒绝"
))
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
)
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=str(e)
))
results.append(BatchReviewResultItem(id=item.id, success=False, message=str(e)))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True,
total=len(request.items),
succeeded=succeeded,
failed=failed,
results=results
success=True, total=len(request.items), succeeded=succeeded, failed=failed, results=results
)
except HTTPException:

View File

@@ -4,12 +4,84 @@ from typing import List, Optional
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
from src.config.config import global_config
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
# 延迟初始化的轻量级 embedding store只读仅用于获取段落完整文本
_paragraph_store_cache = None
def _get_paragraph_store():
"""延迟加载段落 embedding store只读模式轻量级
Returns:
EmbeddingStore | None: 如果配置启用则返回store否则返回None
"""
# 检查配置是否启用
if not global_config.webui.enable_paragraph_content:
return None
global _paragraph_store_cache
if _paragraph_store_cache is not None:
return _paragraph_store_cache
try:
from src.chat.knowledge.embedding_store import EmbeddingStore
import os
# 获取数据路径
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", ".."))
embedding_dir = os.path.join(root_path, "data/embedding")
# 只加载段落 embedding store轻量级
paragraph_store = EmbeddingStore(
namespace="paragraph",
dir_path=embedding_dir,
max_workers=1, # 只读不需要多线程
chunk_size=100
)
paragraph_store.load_from_file()
_paragraph_store_cache = paragraph_store
logger.info(f"成功加载段落 embedding store包含 {len(paragraph_store.store)} 个段落")
return paragraph_store
except Exception as e:
logger.warning(f"加载段落 embedding store 失败: {e}")
return None
def _get_paragraph_content(node_id: str) -> tuple[Optional[str], bool]:
"""从 embedding store 获取段落完整内容
Args:
node_id: 段落节点ID格式为 'paragraph-{hash}'
Returns:
tuple[str | None, bool]: (段落完整内容或None, 是否启用了功能)
"""
try:
paragraph_store = _get_paragraph_store()
if paragraph_store is None:
# 功能未启用
return None, False
# 从 store 中获取完整内容
paragraph_item = paragraph_store.store.get(node_id)
if paragraph_item is not None:
# paragraph_item 是 EmbeddingStoreItem其 str 属性包含完整文本
content: str = getattr(paragraph_item, 'str', '')
if content:
return content, True
return None, True
except Exception as e:
logger.debug(f"获取段落内容失败: {e}")
return None, True
def require_auth(
maibot_session: Optional[str] = Cookie(None),
@@ -84,7 +156,14 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
@@ -166,7 +245,14 @@ async def get_knowledge_graph(
try:
node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type_val == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
@@ -281,8 +367,14 @@ async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bo
for node_id in node_list:
try:
node_data = graph[node_id]
content = node_data["content"] if "content" in node_data else node_id
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
# 对于段落节点,尝试从 embedding store 获取完整内容
if node_type == "paragraph":
full_content, _ = _get_paragraph_content(node_id)
content = full_content if full_content is not None else (node_data["content"] if "content" in node_data else node_id)
else:
content = node_data["content"] if "content" in node_data else node_id
if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data["create_time"] if "create_time" in node_data else None

View File

@@ -12,7 +12,7 @@ import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from .auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
import json
import time

View File

@@ -7,9 +7,9 @@ from src.common.logger import get_logger
from src.common.toml_utils import save_toml_with_format
from src.config.config import MMC_VERSION
from src.plugin_system.base.config_types import ConfigField
from .git_mirror_service import get_git_mirror_service, set_update_progress_callback
from .token_manager import get_token_manager
from .plugin_progress_ws import update_progress
from src.webui.git_mirror_service import get_git_mirror_service, set_update_progress_callback
from src.webui.core import get_token_manager
from src.webui.routers.websocket.plugin_progress import update_progress
logger = get_logger("webui.plugin_routes")
@@ -1370,21 +1370,19 @@ async def get_installed_plugins(
seen_ids = {} # 记录 ID -> 路径的映射
unique_plugins = []
duplicates = []
for plugin in installed_plugins:
plugin_id = plugin["id"]
plugin_path = plugin["path"]
if plugin_id not in seen_ids:
seen_ids[plugin_id] = plugin_path
unique_plugins.append(plugin)
else:
duplicates.append(plugin)
first_path = seen_ids[plugin_id]
logger.warning(
f"重复插件 {plugin_id}: 保留 {first_path}, 跳过 {plugin_path}"
)
logger.warning(f"重复插件 {plugin_id}: 保留 {first_path}, 跳过 {plugin_path}")
if duplicates:
logger.warning(f"共检测到 {len(duplicates)} 个重复插件已去重")
@@ -1420,34 +1418,35 @@ async def get_local_plugin_readme(
try:
plugins_dir = Path("plugins")
# 查找插件目录
plugin_path = None
for folder in plugins_dir.iterdir():
if not folder.is_dir():
continue
manifest_path = folder / "_manifest.json"
if manifest_path.exists():
try:
import json as json_module
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json_module.load(f)
# 检查是否匹配 plugin_id
if manifest.get("id") == plugin_id:
plugin_path = folder
break
except Exception:
continue
if not plugin_path:
return {"success": False, "error": "插件未安装"}
# 查找 README 文件(支持多种命名)
readme_files = ["README.md", "readme.md", "Readme.md", "README.MD"]
readme_content = None
for readme_name in readme_files:
readme_path = plugin_path / readme_name
if readme_path.exists():
@@ -1459,12 +1458,12 @@ async def get_local_plugin_readme(
except Exception as e:
logger.warning(f"读取 {readme_path} 失败: {e}")
continue
if readme_content:
return {"success": True, "data": readme_content}
else:
return {"success": False, "error": "本地未找到 README 文件"}
except Exception as e:
logger.error(f"获取本地 README 失败: {e}", exc_info=True)
return {"success": False, "error": str(e)}
@@ -1756,10 +1755,10 @@ async def update_plugin_config_raw(
# 验证 TOML 格式
import tomlkit
if not isinstance(request.config, str):
raise HTTPException(status_code=400, detail="配置必须是字符串格式的 TOML 内容")
try:
tomlkit.loads(request.config)
except Exception as e:

View File

@@ -8,7 +8,7 @@ from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")

View File

@@ -12,7 +12,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.core import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")

View File

@@ -0,0 +1,9 @@
from .logs import router as logs_router
from .plugin_progress import get_progress_router
from .auth import router as ws_auth_router
__all__ = [
"logs_router",
"get_progress_router",
"ws_auth_router",
]

View File

@@ -9,7 +9,7 @@ from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.core import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()

View File

@@ -5,8 +5,8 @@ from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()

View File

@@ -5,8 +5,8 @@ from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
from src.webui.core import get_token_manager
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.plugin_progress")

View File

@@ -4,21 +4,25 @@ from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie,
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
from .expression_routes import router as expression_router
from .jargon_routes import router as jargon_router
from .emoji_routes import router as emoji_router
from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
from .annual_report_routes import router as annual_report_router
from src.webui.core import (
get_token_manager,
set_auth_cookie,
clear_auth_cookie,
get_rate_limiter,
check_auth_rate_limit,
)
from src.webui.routers.config import router as config_router
from src.webui.routers.statistics import router as statistics_router
from src.webui.routers.person import router as person_router
from src.webui.routers.expression import router as expression_router
from src.webui.routers.jargon import router as jargon_router
from src.webui.routers.emoji import router as emoji_router
from src.webui.routers.plugin import router as plugin_router
from src.webui.routers.websocket.plugin_progress import get_progress_router
from src.webui.routers.system import router as system_router
from src.webui.routers.model import router as model_router
from src.webui.routers.websocket.auth import router as ws_auth_router
from src.webui.routers.annual_report import router as annual_report_router
logger = get_logger("webui.api")
@@ -198,9 +202,11 @@ async def check_auth_status(
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
logger.debug(
f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}"
)
# 优先从 Cookie 获取
if maibot_session:
@@ -218,7 +224,7 @@ async def check_auth_status(
token_manager = get_token_manager()
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:

View File

@@ -0,0 +1,109 @@
"""WebUI Schemas - Pydantic models for API requests and responses."""
# Auth schemas
from .auth import (
TokenVerifyRequest,
TokenVerifyResponse,
TokenUpdateRequest,
TokenUpdateResponse,
TokenRegenerateResponse,
FirstSetupStatusResponse,
CompleteSetupResponse,
ResetSetupResponse,
)
# Statistics schemas
from .statistics import (
StatisticsSummary,
ModelStatistics,
TimeSeriesData,
DashboardData,
)
# Emoji schemas
from .emoji import (
EmojiResponse,
EmojiListResponse,
EmojiDetailResponse,
EmojiUpdateRequest,
EmojiUpdateResponse,
EmojiDeleteResponse,
BatchDeleteRequest,
BatchDeleteResponse,
EmojiUploadResponse,
ThumbnailCacheStatsResponse,
ThumbnailCleanupResponse,
ThumbnailPreheatResponse,
)
# Chat schemas
from .chat import (
VirtualIdentityConfig,
ChatHistoryMessage,
)
# Plugin schemas
from .plugin import (
FetchRawFileRequest,
FetchRawFileResponse,
CloneRepositoryRequest,
CloneRepositoryResponse,
MirrorConfigResponse,
AvailableMirrorsResponse,
AddMirrorRequest,
UpdateMirrorRequest,
GitStatusResponse,
InstallPluginRequest,
VersionResponse,
UninstallPluginRequest,
UpdatePluginRequest,
UpdatePluginConfigRequest,
)
__all__ = [
# Auth
"TokenVerifyRequest",
"TokenVerifyResponse",
"TokenUpdateRequest",
"TokenUpdateResponse",
"TokenRegenerateResponse",
"FirstSetupStatusResponse",
"CompleteSetupResponse",
"ResetSetupResponse",
# Statistics
"StatisticsSummary",
"ModelStatistics",
"TimeSeriesData",
"DashboardData",
# Emoji
"EmojiResponse",
"EmojiListResponse",
"EmojiDetailResponse",
"EmojiUpdateRequest",
"EmojiUpdateResponse",
"EmojiDeleteResponse",
"BatchDeleteRequest",
"BatchDeleteResponse",
"EmojiUploadResponse",
"ThumbnailCacheStatsResponse",
"ThumbnailCleanupResponse",
"ThumbnailPreheatResponse",
# Chat
"VirtualIdentityConfig",
"ChatHistoryMessage",
# Plugin
"FetchRawFileRequest",
"FetchRawFileResponse",
"CloneRepositoryRequest",
"CloneRepositoryResponse",
"MirrorConfigResponse",
"AvailableMirrorsResponse",
"AddMirrorRequest",
"UpdateMirrorRequest",
"GitStatusResponse",
"InstallPluginRequest",
"VersionResponse",
"UninstallPluginRequest",
"UpdatePluginRequest",
"UpdatePluginConfigRequest",
]

41
src/webui/schemas/auth.py Normal file
View File

@@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
class TokenVerifyRequest(BaseModel):
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
is_first_setup: bool = Field(False, description="是否为首次设置")
class TokenUpdateRequest(BaseModel):
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
class FirstSetupStatusResponse(BaseModel):
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")

26
src/webui/schemas/chat.py Normal file
View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel
from typing import Optional
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False
platform: Optional[str] = None
person_id: Optional[str] = None
user_id: Optional[str] = None
user_nickname: Optional[str] = None
group_id: Optional[str] = None
group_name: Optional[str] = None
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False

115
src/webui/schemas/emoji.py Normal file
View File

@@ -0,0 +1,115 @@
from pydantic import BaseModel
from typing import Optional, List
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[str]
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[EmojiResponse]
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[str] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
emoji_ids: List[int]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[int] = []
class EmojiUploadResponse(BaseModel):
"""表情包上传响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class ThumbnailCacheStatsResponse(BaseModel):
"""缩略图缓存统计响应"""
success: bool
cache_dir: str
total_count: int
total_size_mb: float
emoji_count: int
coverage_percent: float
class ThumbnailCleanupResponse(BaseModel):
"""缩略图清理响应"""
success: bool
message: str
cleaned_count: int
kept_count: int
class ThumbnailPreheatResponse(BaseModel):
"""缩略图预热响应"""
success: bool
message: str
generated_count: int
skipped_count: int
failed_count: int

135
src/webui/schemas/plugin.py Normal file
View File

@@ -0,0 +1,135 @@
from pydantic import BaseModel, Field
from typing import Optional, List, Dict, Any
class FetchRawFileRequest(BaseModel):
"""获取 Raw 文件请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
branch: str = Field(..., description="分支名称", example="main")
file_path: str = Field(..., description="文件路径", example="plugin_details.json")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义完整 URL")
class FetchRawFileResponse(BaseModel):
"""获取 Raw 文件响应"""
success: bool = Field(..., description="是否成功")
data: Optional[str] = Field(None, description="文件内容")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际请求的 URL")
class CloneRepositoryRequest(BaseModel):
"""克隆仓库请求"""
owner: str = Field(..., description="仓库所有者", example="MaiM-with-u")
repo: str = Field(..., description="仓库名称", example="plugin-repo")
target_path: str = Field(..., description="目标路径(相对于插件目录)")
branch: Optional[str] = Field(None, description="分支名称", example="main")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
custom_url: Optional[str] = Field(None, description="自定义克隆 URL")
depth: Optional[int] = Field(None, description="克隆深度(浅克隆)", ge=1)
class CloneRepositoryResponse(BaseModel):
"""克隆仓库响应"""
success: bool = Field(..., description="是否成功")
path: Optional[str] = Field(None, description="克隆路径")
error: Optional[str] = Field(None, description="错误信息")
mirror_used: Optional[str] = Field(None, description="使用的镜像源")
attempts: int = Field(..., description="尝试次数")
url: Optional[str] = Field(None, description="实际克隆的 URL")
message: Optional[str] = Field(None, description="附加信息")
class MirrorConfigResponse(BaseModel):
"""镜像源配置响应"""
id: str = Field(..., description="镜像源 ID")
name: str = Field(..., description="镜像源名称")
raw_prefix: str = Field(..., description="Raw 文件前缀")
clone_prefix: str = Field(..., description="克隆前缀")
enabled: bool = Field(..., description="是否启用")
priority: int = Field(..., description="优先级(数字越小优先级越高)")
class AvailableMirrorsResponse(BaseModel):
"""可用镜像源列表响应"""
mirrors: List[MirrorConfigResponse] = Field(..., description="镜像源列表")
default_priority: List[str] = Field(..., description="默认优先级顺序ID 列表)")
class AddMirrorRequest(BaseModel):
"""添加镜像源请求"""
id: str = Field(..., description="镜像源 ID", example="custom-mirror")
name: str = Field(..., description="镜像源名称", example="自定义镜像源")
raw_prefix: str = Field(..., description="Raw 文件前缀", example="https://example.com/raw")
clone_prefix: str = Field(..., description="克隆前缀", example="https://example.com/clone")
enabled: bool = Field(True, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class UpdateMirrorRequest(BaseModel):
"""更新镜像源请求"""
name: Optional[str] = Field(None, description="镜像源名称")
raw_prefix: Optional[str] = Field(None, description="Raw 文件前缀")
clone_prefix: Optional[str] = Field(None, description="克隆前缀")
enabled: Optional[bool] = Field(None, description="是否启用")
priority: Optional[int] = Field(None, description="优先级")
class GitStatusResponse(BaseModel):
"""Git 安装状态响应"""
installed: bool = Field(..., description="是否已安装 Git")
version: Optional[str] = Field(None, description="Git 版本号")
path: Optional[str] = Field(None, description="Git 可执行文件路径")
error: Optional[str] = Field(None, description="错误信息")
class InstallPluginRequest(BaseModel):
"""安装插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class VersionResponse(BaseModel):
"""麦麦版本响应"""
version: str = Field(..., description="麦麦版本号")
version_major: int = Field(..., description="主版本号")
version_minor: int = Field(..., description="次版本号")
version_patch: int = Field(..., description="补丁版本号")
class UninstallPluginRequest(BaseModel):
"""卸载插件请求"""
plugin_id: str = Field(..., description="插件 ID")
class UpdatePluginRequest(BaseModel):
"""更新插件请求"""
plugin_id: str = Field(..., description="插件 ID")
repository_url: str = Field(..., description="插件仓库 URL")
branch: Optional[str] = Field("main", description="分支名称")
mirror_id: Optional[str] = Field(None, description="指定镜像源 ID")
class UpdatePluginConfigRequest(BaseModel):
"""更新插件配置请求"""
config: Dict[str, Any] = Field(..., description="配置数据")

View File

@@ -0,0 +1,45 @@
from pydantic import BaseModel, Field
from typing import Dict, Any, List
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -1,13 +1,9 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
import asyncio
import mimetypes
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from uvicorn import Config, Server as UvicornServer
from src.common.logger import get_logger
from src.webui.app import create_app, show_access_token
logger = get_logger("webui_server")
@@ -18,174 +14,10 @@ class WebUIServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
self.host = host
self.port = port
self.app = FastAPI(title="MaiBot WebUI")
self.app = create_app(host=host, port=port, enable_static=True)
self._server = None
# 配置防爬虫中间件需要在CORS之前注册
self._setup_anti_crawler()
# 配置 CORS支持开发环境跨域请求
self._setup_cors()
# 显示 Access Token
self._show_access_token()
# 重要:先注册 API 路由,再设置静态文件
self._register_api_routes()
self._setup_static_files()
# 注册robots.txt路由
self._setup_robots_txt()
def _setup_cors(self):
"""配置 CORS 中间件"""
# 开发环境需要允许前端开发服务器的跨域请求
self.app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")
def _show_access_token(self):
"""显示 WebUI Access Token"""
try:
from src.webui.token_manager import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
except Exception as e:
logger.error(f"❌ 获取 Access Token 失败: {e}")
def _setup_static_files(self):
"""设置静态文件服务"""
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return
# 处理 SPA 路由 - 注意:这个路由优先级最低
@self.app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 其他路径返回 index.htmlSPA 路由)
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def _setup_anti_crawler(self):
"""配置防爬虫中间件"""
try:
from src.webui.anti_crawler import AntiCrawlerMiddleware
from src.config.config import global_config
# 从配置读取防爬虫模式
anti_crawler_mode = global_config.webui.anti_crawler_mode
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(self):
"""设置robots.txt路由"""
try:
from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回robots.txt禁止所有爬虫"""
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(self):
"""注册所有 WebUI API 路由"""
try:
# 导入所有 WebUI 路由
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
# 导入规划器监控路由
from src.webui.api.planner import router as planner_router
# 导入回复器监控路由
from src.webui.api.replier import router as replier_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
self.app.include_router(knowledge_router)
self.app.include_router(chat_router)
self.app.include_router(planner_router)
self.app.include_router(replier_router)
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
show_access_token()
async def start(self):
"""启动服务器"""
@@ -209,9 +41,9 @@ class WebUIServer:
self._server = UvicornServer(config=config)
logger.info("🌐 WebUI 服务器启动中...")
# 根据地址类型显示正确的访问地址
if ':' in self.host:
if ":" in self.host:
# IPv6 地址需要用方括号包裹
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
if self.host == "::":
@@ -245,7 +77,7 @@ class WebUIServer:
import socket
# 判断使用 IPv4 还是 IPv6
if ':' in self.host:
if ":" in self.host:
# IPv6 地址
family = socket.AF_INET6
test_host = self.host if self.host != "::" else "::1"
@@ -289,6 +121,7 @@ def get_webui_server() -> WebUIServer:
if _webui_server is None:
# 从环境变量读取
import os
host = os.getenv("WEBUI_HOST", "127.0.0.1")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)