From 2e3dd44ee95c04a4c2b932d922f6ded7fb93befb Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Sat, 7 Mar 2026 00:57:33 +0800 Subject: [PATCH] Refactor chat stream handling to use BotChatSession - Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability. --- scripts/replyer_action_stats.py | 4 +- scripts/test_memory_retrieval.py | 2 +- src/bw_learner/expression_learner.py | 6 +- src/bw_learner/expression_selector.py | 17 +- src/bw_learner/jargon_miner.py | 8 +- src/bw_learner/message_recorder.py | 6 +- src/bw_learner/reflect_tracker.py | 19 +- src/chat/__init__.py | 4 +- src/chat/brain_chat/PFC/conversation.py | 57 +++--- src/chat/brain_chat/PFC/message_sender.py | 39 ++-- src/chat/brain_chat/brain_chat.py | 6 +- src/chat/brain_chat/brain_planner.py | 4 +- src/chat/heart_flow/hfc_utils.py | 23 +-- src/chat/message_receive/bot.py | 17 +- src/chat/message_receive/message.py | 180 +++++++++++++++++- .../message_receive/uni_message_sender.py | 61 +++--- src/chat/planner_actions/action_manager.py | 4 +- src/chat/planner_actions/action_modifier.py | 11 +- src/chat/planner_actions/planner.py | 65 +++---- src/chat/replyer/group_generator.py | 46 ++--- src/chat/replyer/private_generator.py | 40 ++-- src/chat/replyer/replyer_manager.py | 11 +- src/chat/utils/common_utils.py | 7 +- src/chat/utils/statistic.py | 17 +- src/chat/utils/utils.py | 30 +-- src/common/utils/utils_image.py | 29 +++ src/dream/dream_generator.py | 7 +- src/main.py | 6 +- src/memory_system/chat_history_summarizer.py | 4 +- src/memory_system/memory_retrieval.py | 43 ++--- .../retrieval_tools/query_chat_history.py | 14 +- src/person_info/person_info.py | 18 +- src/plugin_system/apis/chat_api.py | 114 ++++++----- src/plugin_system/apis/emoji_api.py | 12 +- src/plugin_system/apis/generator_api.py | 14 +- src/plugin_system/apis/send_api.py | 98 +++++----- src/plugin_system/apis/tool_api.py | 4 +- src/plugin_system/base/base_action.py | 6 +- src/plugin_system/base/base_command.py | 68 +++---- src/plugin_system/base/base_tool.py | 6 +- src/plugin_system/core/events_manager.py | 93 +++++---- src/plugin_system/core/tool_use.py | 6 +- src/webui/routers/expression.py | 43 ++--- 43 files changed, 706 insertions(+), 563 deletions(-) diff --git a/scripts/replyer_action_stats.py b/scripts/replyer_action_stats.py index 8d8904bf..8370affb 100644 --- a/scripts/replyer_action_stats.py +++ b/scripts/replyer_action_stats.py @@ -19,10 +19,10 @@ sys.path.insert(0, project_root) try: from src.common.database.database_model import ChatStreams - from src.chat.message_receive.chat_stream import get_chat_manager + from src.chat.message_receive.chat_manager import chat_manager as _script_chat_manager except ImportError: ChatStreams = None - get_chat_manager = None + _script_chat_manager = None def get_chat_name(chat_id: str) -> str: diff --git a/scripts/test_memory_retrieval.py b/scripts/test_memory_retrieval.py index 7519a306..63bc3bd4 100644 --- a/scripts/test_memory_retrieval.py +++ b/scripts/test_memory_retrieval.py @@ -23,7 +23,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import initialize_logging, get_logger from src.common.database.database import db from src.common.database.database_model import LLMUsage -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession from maim_message import UserInfo, GroupInfo logger = get_logger("test_memory_retrieval") diff --git a/src/bw_learner/expression_learner.py b/src/bw_learner/expression_learner.py index 35cb3501..a0a3fe34 100644 --- a/src/bw_learner/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -12,7 +12,7 @@ from src.chat.utils.chat_message_builder import ( build_anonymous_messages, ) from src.prompt.prompt_manager import prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.bw_learner.learner_utils import ( filter_message_content, is_bot_message, @@ -42,8 +42,8 @@ class ExpressionLearner: ) self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(chat_id) - self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + self.chat_stream = _chat_manager.get_session_by_session_id(chat_id) + self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id # 学习锁,防止并发执行学习任务 self._learning_lock = asyncio.Lock() diff --git a/src/bw_learner/expression_selector.py b/src/bw_learner/expression_selector.py index 863481a3..78c0948d 100644 --- a/src/bw_learner/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -10,7 +10,7 @@ from src.common.logger import get_logger from src.common.database.database_model import Expression from src.prompt.prompt_manager import prompt_manager from src.bw_learner.learner_utils import weighted_sample -from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.utils.utils_session import SessionUtils from src.chat.utils.common_utils import TempMethodsExpression logger = get_logger("expression_selector") @@ -50,8 +50,9 @@ class ExpressionSelector: id_str = parts[1] stream_type = parts[2] is_group = stream_type == "group" - # 统一通过 chat_manager 生成 stream_id,避免各处自行实现哈希逻辑 - return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + return SessionUtils.calculate_session_id( + platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str) + ) except Exception: return None @@ -127,8 +128,7 @@ class ExpressionSelector: logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择") # 完全没有高 count 样本时,退化为全量随机抽样(不进入LLM流程) fallback_num = min(3, max_num) if max_num > 0 else 3 - fallback_selected = self._random_expressions(chat_id, fallback_num) - if fallback_selected: + if fallback_selected := self._random_expressions(chat_id, fallback_num): self.update_expressions_last_active_time(fallback_selected) selected_ids = [expr["id"] for expr in fallback_selected] logger.info( @@ -199,12 +199,7 @@ class ExpressionSelector: ] # 随机抽样 - if style_exprs: - selected_style = weighted_sample(style_exprs, total_num) - else: - selected_style = [] - - return selected_style + return weighted_sample(style_exprs, total_num) if style_exprs else [] except Exception as e: logger.error(f"随机选择表达方式失败: {e}") diff --git a/src/bw_learner/jargon_miner.py b/src/bw_learner/jargon_miner.py index 0d1622a9..e3f86bd1 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -10,7 +10,7 @@ from src.common.logger import get_logger from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.prompt.prompt_manager import prompt_manager from src.bw_learner.learner_utils import ( parse_chat_id_list, @@ -99,9 +99,9 @@ class JargonMiner: ) # 初始化stream_name作为类属性,避免重复提取 - chat_manager = get_chat_manager() - stream_name = chat_manager.get_stream_name(self.chat_id) - self.stream_name = stream_name if stream_name else self.chat_id + chat_manager = _chat_manager + stream_name = chat_manager.get_session_name(self.chat_id) + self.stream_name = stream_name or self.chat_id self.cache_limit = 50 self.cache: OrderedDict[str, None] = OrderedDict() diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py index 39be834f..bdf13fed 100644 --- a/src/bw_learner/message_recorder.py +++ b/src/bw_learner/message_recorder.py @@ -2,7 +2,7 @@ import time import asyncio from typing import List, Any from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.common_utils import TempMethodsExpression from src.bw_learner.expression_learner import expression_learner_manager @@ -18,8 +18,8 @@ class MessageRecorder: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(chat_id) - self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + self.chat_stream = _chat_manager.get_session_by_session_id(chat_id) + self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id # 维护每个chat的上次提取时间 self.last_extraction_time: float = time.time() diff --git a/src/bw_learner/reflect_tracker.py b/src/bw_learner/reflect_tracker.py index 5db5c6e9..d0455245 100644 --- a/src/bw_learner/reflect_tracker.py +++ b/src/bw_learner/reflect_tracker.py @@ -5,20 +5,19 @@ from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest from src.prompt.prompt_manager import prompt_manager from src.config.config import model_config -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp_with_chat, build_readable_messages, ) -if TYPE_CHECKING: - pass + logger = get_logger("reflect_tracker") class ReflectTracker: - def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float): + def __init__(self, chat_stream: BotChatSession, expression: Expression, created_time: float): self.chat_stream = chat_stream self.expression = expression self.created_time = created_time @@ -42,7 +41,7 @@ class ReflectTracker: # Fetch messages since creation msg_list = get_raw_msg_by_timestamp_with_chat( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, timestamp_start=self.created_time, timestamp_end=time.time(), ) @@ -90,10 +89,7 @@ class ReflectTracker: from json_repair import repair_json json_pattern = r"```json\s*(.*?)\s*```" - matches = re.findall(json_pattern, response, re.DOTALL) - if not matches: - # Try to parse raw response if no code block - matches = [response] + matches = re.findall(json_pattern, response, re.DOTALL) or [response] json_obj = json.loads(repair_json(matches[0])) @@ -122,10 +118,7 @@ class ReflectTracker: self.expression.style = corrected_style # 如果拒绝但未更新,标记为 rejected=1 - if not has_update: - self.expression.rejected = True - else: - self.expression.rejected = False + self.expression.rejected = not has_update self.expression.save() diff --git a/src/chat/__init__.py b/src/chat/__init__.py index 35bd5e02..098d19a3 100644 --- a/src/chat/__init__.py +++ b/src/chat/__init__.py @@ -4,10 +4,10 @@ MaiBot模块系统 """ from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager # 导出主要组件供外部使用 __all__ = [ - "get_chat_manager", + "chat_manager", "emoji_manager", ] diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py index 4fe2f168..5b9b60ac 100644 --- a/src/chat/brain_chat/PFC/conversation.py +++ b/src/chat/brain_chat/PFC/conversation.py @@ -7,7 +7,7 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw # from src.config.config import global_config from typing import Dict, Any, Optional -from src.chat.message_receive.message import Message +from src.common.data_models.mai_message_data_model import MaiMessage from .pfc_types import ConversationState from .pfc import ChatObserver, GoalAnalyzer from .message_sender import DirectMessageSender @@ -16,9 +16,8 @@ from .action_planner import ActionPlanner from .observation_info import ObservationInfo from .conversation_info import ConversationInfo # 确保导入 ConversationInfo from .reply_generator import ReplyGenerator -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from maim_message import UserInfo -from src.chat.message_receive.chat_stream import get_chat_manager from .pfc_KnowledgeFetcher import KnowledgeFetcher from .waiter import Waiter @@ -60,7 +59,7 @@ class Conversation: self.direct_sender = DirectMessageSender(self.private_name) # 获取聊天流信息 - self.chat_stream = get_chat_manager().get_stream(self.stream_id) + self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id) self.stop_action_planner = False except Exception as e: @@ -265,34 +264,34 @@ class Conversation: return True return False - def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: - """将消息字典转换为Message对象""" + def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage: + """将消息字典转换为MaiMessage对象""" + from datetime import datetime as dt + from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo + from src.common.data_models.message_component_data_model import MessageSequence + try: - # 尝试从 msg_dict 直接获取 chat_stream,如果失败则从全局 get_chat_manager 获取 - chat_info = msg_dict.get("chat_info") - if chat_info and isinstance(chat_info, dict): - chat_stream = ChatStream.from_dict(chat_info) - elif self.chat_stream: # 使用实例变量中的 chat_stream - chat_stream = self.chat_stream - else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id) - chat_stream = get_chat_manager().get_stream(self.stream_id) - if not chat_stream: - raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}") - - user_info = UserInfo.from_dict(msg_dict.get("user_info", {})) - - return Message( - message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID - chat_stream=chat_stream, # 使用确定的 chat_stream - time=msg_dict.get("time", time.time()), # 提供默认时间 - user_info=user_info, - processed_plain_text=msg_dict.get("processed_plain_text", ""), - detailed_plain_text=msg_dict.get("detailed_plain_text", ""), + user_info_dict = msg_dict.get("user_info", {}) + user_info = MaiUserInfo( + user_id=user_info_dict.get("user_id", ""), + user_nickname=user_info_dict.get("user_nickname", ""), + user_cardname=user_info_dict.get("user_cardname"), ) + + msg = MaiMessage( + message_id=msg_dict.get("message_id", f"gen_{time.time()}"), + timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())), + ) + msg.message_info = MessageInfo(user_info=user_info) + msg.platform = user_info_dict.get("platform", "") + msg.session_id = self.stream_id + msg.processed_plain_text = msg_dict.get("processed_plain_text", "") + msg.raw_message = MessageSequence(components=[]) + msg.initialized = True + return msg except Exception as e: logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}") - # 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题 - raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e + raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e async def _handle_action( self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo @@ -687,7 +686,7 @@ class Conversation: logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。") return if not self.chat_stream: - logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。") + logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。") return await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content) diff --git a/src/chat/brain_chat/PFC/message_sender.py b/src/chat/brain_chat/PFC/message_sender.py index a35576cc..10387319 100644 --- a/src/chat/brain_chat/PFC/message_sender.py +++ b/src/chat/brain_chat/PFC/message_sender.py @@ -1,10 +1,12 @@ import time from typing import Optional from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream -from src.chat.message_receive.message import Message, MessageSending -from maim_message import UserInfo, Seg -from src.chat.message_receive.storage import MessageStorage +from maim_message import Seg + +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.chat_manager import BotChatSession +from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.config.config import global_config from rich.traceback import install @@ -19,18 +21,17 @@ class DirectMessageSender: def __init__(self, private_name: str): self.private_name = private_name - self.storage = MessageStorage() async def send_message( self, - chat_stream: ChatStream, + chat_stream: BotChatSession, content: str, - reply_to_message: Optional[Message] = None, + reply_to_message: Optional[MaiMessage] = None, ) -> None: """发送消息到聊天流 Args: - chat_stream: 聊天流 + chat_stream: 聊天会话 content: 消息内容 reply_to_message: 要回复的消息(可选) """ @@ -42,18 +43,22 @@ class DirectMessageSender: bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, - platform=chat_stream.platform, ) # 用当前时间作为message_id,和之前那套sender一样 message_id = f"dm{round(time.time(), 2)}" + # 构建发送者信息(私聊时为接收者) + sender_info = None + if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info: + sender_info = reply_to_message.message_info.user_info + # 构建消息对象 message = MessageSending( message_id=message_id, - chat_stream=chat_stream, + session=chat_stream, bot_user_info=bot_user_info, - sender_info=reply_to_message.message_info.user_info if reply_to_message else None, + sender_info=sender_info, message_segment=segments, reply=reply_to_message, is_head=True, @@ -61,17 +66,11 @@ class DirectMessageSender: thinking_start_time=time.time(), ) - # 处理消息 - await message.process() - - # 发送消息(直接调用底层 API) - from src.chat.message_receive.uni_message_sender import _send_message - - sent = await _send_message(message, show_log=True) + # 发送消息 + message_sender = UniversalMessageSender() + sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True) if sent: - # 存储消息 - await self.storage.store_message(message, chat_stream) logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") else: logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index d1b3c535..85ccc238 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -9,7 +9,7 @@ from src.config.config import global_config from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.message_data_model import ReplyContentType -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer from src.chat.brain_chat.brain_planner import BrainPlanner @@ -73,10 +73,10 @@ class BrainChatting: """ # 基础属性 self.stream_id: str = chat_id # 聊天流ID - self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore + self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore if not self.chat_stream: raise ValueError(f"无法找到聊天流: {self.stream_id}") - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" + self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]" self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) diff --git a/src/chat/brain_chat/brain_planner.py b/src/chat/brain_chat/brain_planner.py index 59fd2394..44eaf0bc 100644 --- a/src/chat/brain_chat/brain_planner.py +++ b/src/chat/brain_chat/brain_planner.py @@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.plugin_system.core.component_registry import component_registry @@ -38,7 +38,7 @@ install(extra_lines=3) class BrainPlanner: def __init__(self, chat_id: str, action_manager: ActionManager): self.chat_id = chat_id - self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" + self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 self.planner_llm = LLMRequest( diff --git a/src/chat/heart_flow/hfc_utils.py b/src/chat/heart_flow/hfc_utils.py index 36d9d6fb..3ad36ac8 100644 --- a/src/chat/heart_flow/hfc_utils.py +++ b/src/chat/heart_flow/hfc_utils.py @@ -5,9 +5,8 @@ import time from src.config.config import global_config from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.plugin_system.apis import send_api -from maim_message.message_base import GroupInfo from src.common.message_repository import count_messages @@ -121,28 +120,24 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) async def send_typing(): - group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") - - chat = await get_chat_manager().get_or_create_stream( + chat = await _chat_manager.get_or_create_session( platform="amaidesu_default", - user_info=None, - group_info=group_info, + user_id="114514", + group_id="114514", ) await send_api.custom_to_stream( - message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False + message_type="state", content="typing", stream_id=chat.session_id, storage_message=False ) async def stop_typing(): - group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") - - chat = await get_chat_manager().get_or_create_stream( + chat = await _chat_manager.get_or_create_session( platform="amaidesu_default", - user_info=None, - group_info=group_info, + user_id="114514", + group_id="114514", ) await send_api.custom_to_stream( - message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False + message_type="state", content="stop_typing", stream_id=chat.session_id, storage_message=False ) diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index e0b91d31..de1dc7dd 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -8,7 +8,6 @@ from typing import Dict, Any from src.common.logger import get_logger from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_session import SessionUtils -from src.chat.message_receive.message_old import MessageRecv from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager @@ -41,14 +40,14 @@ class ChatBot: self._started = True - async def _create_pfc_chat(self, message: MessageRecv): + async def _create_pfc_chat(self, message: SessionMessage): """创建或获取PFC对话实例 Args: message: 消息对象 """ try: - chat_id = str(message.chat_stream.stream_id) + chat_id = message.session_id private_name = str(message.message_info.user_info.user_nickname) logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}") @@ -177,12 +176,12 @@ class ChatBot: logger.error(f"[新运行时] 执行命令 {matched.full_name} 异常: {e}", exc_info=True) return True, str(e), True - async def handle_notice_message(self, message: MessageRecv): - if message.message_info.message_id == "notice": + async def handle_notice_message(self, message: SessionMessage): + if message.message_id == "notice": message.is_notify = True logger.debug("notice消息") try: - seg = message.message_segment + seg = getattr(message, "message_segment", None) # SessionMessage 没有 message_segment mi = message.message_info sub_type = None scene = None @@ -246,10 +245,8 @@ class ChatBot: return mmc_message_id = message_data.get("echo") actual_message_id = message_data.get("actual_id") - if MessageStorage.update_message(mmc_message_id, actual_message_id): - logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}") - else: - logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}") + # TODO: Implement message ID update in new architecture + logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}") async def message_process(self, message_data: Dict[str, Any]) -> None: """处理转化后的统一格式消息 diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index be2ef026..19fd22d8 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,14 +1,23 @@ from asyncio import Task +from datetime import datetime +from maim_message import ( + MessageBase, + UserInfo as MaimUserInfo, + GroupInfo as MaimGroupInfo, + BaseMessageInfo as MaimBaseMessageInfo, + Seg, +) from rich.traceback import install from sqlmodel import select -from typing import List, Dict, Tuple, Sequence +from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING import asyncio +import time from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Messages -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo from src.common.data_models.message_component_data_model import ( TextComponent, ImageComponent, @@ -19,6 +28,10 @@ from src.common.data_models.message_component_data_model import ( ForwardNodeComponent, StandardMessageComponents, ) +from src.common.utils.utils_message import MessageUtils + +if TYPE_CHECKING: + from src.chat.message_receive.chat_manager import BotChatSession install(extra_lines=3) @@ -207,3 +220,166 @@ class SessionMessage(MaiMessage): else: processed_texts.append(result) return " ".join(processed_texts) + + +class MessageSending(MaiMessage): + """发送状态的消息类,继承 MaiMessage 基类。 + + 用于构建、处理和发送机器人的回复消息。 + 复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法, + 额外管理发送专属的会话信息和控制字段。 + """ + + def __init__( + self, + message_id: str, + session: "BotChatSession", + bot_user_info: UserInfo, + message_segment: Seg, + sender_info: Optional[UserInfo] = None, + reply: Optional[MaiMessage] = None, + display_message: str = "", + is_head: bool = False, + is_emoji: bool = False, + thinking_start_time: float = 0, + reply_to: Optional[str] = None, + selected_expressions: Optional[List[int]] = None, + ): + # 初始化 MaiMessage 基类 + super().__init__(message_id=message_id, timestamp=datetime.now()) + + # 发送专属字段 + self.session = session + self.sender_info = sender_info + self.message_segment = message_segment + self.reply = reply + self.is_head = is_head + self.thinking_start_time = thinking_start_time + self.selected_expressions = selected_expressions + self.reply_to_message_id: Optional[str] = reply.message_id if reply else None + self.interest_value: float = 0.0 + + # 填充 MaiMessage 标准字段 + self.platform = session.platform + self.session_id = session.session_id + self.is_emoji = is_emoji + self.reply_to = reply_to + self.display_message = display_message + self.processed_plain_text = "" + + # 构建 message_info:DB 存储时 user_info 始终为 bot 信息 + # 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理 + group_info = self._resolve_group_info() + self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info) + + # bot_user_info 单独保存,to_maim_message 覆写时还需要 + self.bot_user_info = bot_user_info + + # 将 Seg 转换为 MessageSequence,供基类的 to_db_instance / to_maim_message 使用 + self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( + MessageBase(message_info=None, message_segment=message_segment) + ) + + self.initialized = True + + def _resolve_group_info(self) -> Optional[GroupInfo]: + """从 session 中解析群信息""" + if not self.session.group_id: + return None + group_name = "" + if ( + self.session.context + and self.session.context.message + and self.session.context.message.message_info.group_info + ): + group_name = self.session.context.message.message_info.group_info.group_name + return GroupInfo(group_id=self.session.group_id, group_name=group_name) + + async def process(self) -> None: + """处理消息段,生成 processed_plain_text(使用 SessionMessage 的组件处理能力)""" + # 同步 message_segment → raw_message(插件可能修改了 message_segment) + self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( + MessageBase(message_info=None, message_segment=self.message_segment) + ) + if self.raw_message and self.raw_message.components: + tasks = [self._process_component(c) for c in self.raw_message.components] + results = await asyncio.gather(*tasks, return_exceptions=True) + texts = [] + for r in results: + if isinstance(r, BaseException): + logger.error(f"处理发送消息组件时发生错误: {r}") + elif r: + texts.append(r) + self.processed_plain_text = " ".join(texts) + + async def _process_component(self, component: StandardMessageComponents) -> str: + """简单处理单个标准组件为纯文本描述""" + if isinstance(component, TextComponent): + return component.text + elif isinstance(component, ImageComponent): + return "[图片]" + elif isinstance(component, EmojiComponent): + return "[表情包]" + elif isinstance(component, VoiceComponent): + return "[语音]" + elif isinstance(component, AtComponent): + return f"[@{component.target_user_id}]" + elif isinstance(component, ReplyComponent): + return "" + else: + return f"[{type(component).__name__}]" + + def build_reply(self) -> None: + """构建回复消息段,在 message_segment 前插入 reply 段""" + if self.reply: + self.reply_to_message_id = self.reply.message_id + self.message_segment = Seg( + type="seglist", + data=[ + Seg(type="reply", data=self.reply.message_id), + self.message_segment, + ], + ) + # 同步更新 raw_message + self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( + MessageBase(message_info=None, message_segment=self.message_segment) + ) + + async def to_maim_message(self) -> MessageBase: + """覆写基类方法:发送消息需要特殊处理 user_info(私聊/群聊差异)""" + maim_bot_user_info = MaimUserInfo( + user_id=self.bot_user_info.user_id, + user_nickname=self.bot_user_info.user_nickname, + user_cardname=self.bot_user_info.user_cardname, + platform=self.platform, + ) + + maim_group_info = None + if self.message_info.group_info: + maim_group_info = MaimGroupInfo( + group_id=self.message_info.group_info.group_id, + group_name=self.message_info.group_info.group_name, + platform=self.platform, + ) + + # 私聊时 user_info 填接收者信息(sender_info),群聊时填 bot + if maim_group_info is None and self.sender_info: + msg_user_info = MaimUserInfo( + user_id=self.sender_info.user_id, + user_nickname=self.sender_info.user_nickname, + user_cardname=self.sender_info.user_cardname, + platform=self.platform, + ) + else: + msg_user_info = maim_bot_user_info + + maim_msg_info = MaimBaseMessageInfo( + platform=self.platform, + message_id=self.message_id, + time=time.time(), + group_info=maim_group_info, + user_info=msg_user_info, + ) + + msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message) + return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments)) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index dfc23e4b..eb04109c 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -6,8 +6,8 @@ from maim_message import Seg from src.common.message_server.api import get_global_api from src.common.logger import get_logger +from src.common.database.database import get_db_session from src.chat.message_receive.message import MessageSending -from src.chat.message_receive.storage import MessageStorage from src.chat.utils.utils import truncate_message from src.chat.utils.utils import calculate_typing_time @@ -130,8 +130,8 @@ def parse_message_segments(segment) -> list: async def _send_message(message: MessageSending, show_log=True) -> bool: """合并后的消息发送函数,包含WS发送和日志记录""" message_preview = truncate_message(message.processed_plain_text, max_length=200) - platform = message.message_info.platform - group_id = message.message_info.group_info.group_id if message.message_info.group_info else None + platform = message.platform + group_id = message.session.group_id try: # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 @@ -221,33 +221,14 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: raise legacy_exception return False - # 使用 MessageConverter 转换 Legacy MessageBase 到 APIMessageBase - # 发送场景:MaiMBot 发送回复消息给外部用户 - # group_info/user_info 是消息接收者信息,放入 receiver_info + # 使用 MessageConverter 转换为 API 消息 from maim_message import MessageConverter - # 修复 API Server Fallback 模式下的 user_info 问题 - # 在 Legacy 模式下,MessageSending.to_dict() 的第 454 行会将 user_info 替换为 chat_stream.user_info - # 但在 API Server Fallback 模式下,MessageConverter.to_api_send() 直接访问 message 对象,不调用 to_dict() - # 需要手动应用相同的变通方案:在私聊场景下,user_info 应该是接收者(sender_info) - message_for_conversion = message - if hasattr(message, "message_info") and message.message_info.group_info is None: - # 私聊场景:group_info 为 None - # user_info 应该是接收者,从 chat_stream.user_info 或 sender_info 获取 - temp_dict = message.to_dict() - if ( - hasattr(message, "chat_stream") - and message.chat_stream - and hasattr(message.chat_stream, "user_info") - ): - temp_dict["message_info"]["user_info"] = message.chat_stream.user_info.to_dict() - # 重新构建 MessageBase 对象(不保留 sender_info 等扩展属性) - from maim_message import MessageBase - - message_for_conversion = MessageBase.from_dict(temp_dict) + # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异 + message_base = await message.to_maim_message() api_message = MessageConverter.to_api_send( - message=message_for_conversion, + message=message_base, api_key=target_api_key, platform=platform, ) @@ -278,10 +259,11 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: return False try: - send_result = await get_global_api().send_message(message) + message_base = await message.to_maim_message() + send_result = await get_global_api().send_message(message_base) if send_result: if show_log: - logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'") + logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'") return True else: # Legacy API 返回 False (发送失败但未报错),尝试 Fallback @@ -297,7 +279,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool: return await send_with_new_api(legacy_exception=legacy_e) except Exception as e: - logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}") + logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}") traceback.print_exc() raise e # 重新抛出其他异常 @@ -306,7 +288,7 @@ class UniversalMessageSender: """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" def __init__(self): - self.storage = MessageStorage() + pass async def send_message( self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True @@ -321,15 +303,15 @@ class UniversalMessageSender: 用法: - typing=True 时,发送前会有打字等待。 """ - if not message.chat_stream: - logger.error("消息缺少 chat_stream,无法发送") - raise ValueError("消息缺少 chat_stream,无法发送") - if not message.message_info or not message.message_info.message_id: - logger.error("消息缺少 message_info 或 message_id,无法发送") - raise ValueError("消息缺少 message_info 或 message_id,无法发送") + if not message.session: + logger.error("消息缺少 session,无法发送") + raise ValueError("消息缺少 session,无法发送") + if not message.message_id: + logger.error("消息缺少 message_id,无法发送") + raise ValueError("消息缺少 message_id,无法发送") - chat_id = message.chat_stream.stream_id - message_id = message.message_info.message_id + chat_id = message.session_id + message_id = message.message_id try: if set_reply: @@ -391,7 +373,8 @@ class UniversalMessageSender: message.processed_plain_text = modified_message.plain_text if storage_message: - await self.storage.store_message(message, message.chat_stream) + with get_db_session() as db_session: + db_session.add(message.to_db_instance()) return sent_msg diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 287d4063..2e11474a 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,6 +1,6 @@ from typing import Dict, Optional, Type -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages from src.plugin_system.core.component_registry import component_registry @@ -35,7 +35,7 @@ class ActionManager: action_reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream: ChatStream, + chat_stream: BotChatSession, log_prefix: str, shutting_down: bool = False, action_message: Optional[DatabaseMessages] = None, diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index f85934ad..b7ea858a 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -4,15 +4,12 @@ from typing import List, Dict, TYPE_CHECKING, Tuple from src.common.logger import get_logger from src.config.config import global_config -from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.planner_actions.action_manager import ActionManager from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.plugin_system.base.component_types import ActionInfo, ActionActivationType from src.plugin_system.core.global_announcement_manager import global_announcement_manager -if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream - logger = get_logger("action_manager") @@ -27,8 +24,8 @@ class ActionModifier: def __init__(self, action_manager: ActionManager, chat_id: str): """初始化动作处理器""" self.chat_id = chat_id - self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" + self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore + self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" self.action_manager = action_manager @@ -121,7 +118,7 @@ class ActionModifier: available_actions_text = "、".join(available_actions) if available_actions else "无" logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") - def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): + def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession): type_mismatched_actions: List[Tuple[str, str]] = [] for action_name, action_info in all_actions.items(): if action_info.associated_types and not chat_context.check_types(action_info.associated_types): diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 87c42c84..c9c7f2f8 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -3,6 +3,7 @@ import time import traceback import random import re +import contextlib from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union from collections import OrderedDict from rich.traceback import install @@ -21,7 +22,7 @@ from src.chat.utils.chat_message_builder import ( ) from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.planner_actions.action_manager import ActionManager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.plugin_system.core.component_registry import component_registry from src.plugin_system.apis.message_api import translate_pid_to_description @@ -39,7 +40,7 @@ install(extra_lines=3) class ActionPlanner: def __init__(self, chat_id: str, action_manager: ActionManager): self.chat_id = chat_id - self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" + self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 self.planner_llm = LLMRequest( @@ -80,7 +81,7 @@ class ActionPlanner: if not text: return text - id_to_message = {msg_id: msg for msg_id, msg in message_id_list} + id_to_message = dict(message_id_list) # 匹配m后带2-4位数字,前后不是字母数字下划线 pattern = r"(? 0: - self.unknown_words_cache.popitem(last=False) - logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存") + if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0: + self.unknown_words_cache.popitem(last=False) + logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存") # 对于每个 reply action,合并缓存和新提取的黑话 for action in actions: @@ -363,10 +362,7 @@ class ActionPlanner: new_words = action_data.get("unknown_words") # 合并新提取的和缓存的黑话列表 - merged_words = self._merge_unknown_words_with_cache(new_words) - - # 更新 action_data - if merged_words: + if merged_words := self._merge_unknown_words_with_cache(new_words): action_data["unknown_words"] = merged_words logger.debug( f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个," @@ -449,15 +445,12 @@ class ActionPlanner: # 如果有强制回复消息,确保回复该消息 if force_reply_message: # 检查是否已经有回复该消息的 action - has_reply_to_force_message = False - for action in actions: - if ( - action.action_type == "reply" - and action.action_message - and action.action_message.message_id == force_reply_message.message_id - ): - has_reply_to_force_message = True - break + has_reply_to_force_message = any( + action.action_type == "reply" + and action.action_message + and action.action_message.message_id == force_reply_message.message_id + for action in actions + ) # 如果没有回复该消息,强制添加回复 action if not has_reply_to_force_message: @@ -532,13 +525,10 @@ class ActionPlanner: # 从后往前遍历,收集最新的记录 for reasoning, timestamp, content in reversed(self.plan_log): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): - # 这是action记录 if len(action_records) < max_action_records: action_records.append((reasoning, timestamp, content, "action")) - else: - # 这是执行结果记录 - if len(execution_records) < max_execution_records: - execution_records.append((reasoning, timestamp, content, "execution")) + elif len(execution_records) < max_execution_records: + execution_records.append((reasoning, timestamp, content, "execution")) # 合并所有记录并按时间戳排序 all_records = action_records + execution_records @@ -700,15 +690,9 @@ class ActionPlanner: param_text = param_text.rstrip("\n") # 构建要求文本 - require_text = "" - for require_item in action_info.action_require: - require_text += f"- {require_item}\n" - require_text = require_text.rstrip("\n") + require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require) - if not action_info.parallel_action: - parallel_text = "(当选择这个动作时,请不要选择其他动作)" - else: - parallel_text = "" + parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)" # 获取动作提示模板并填充 using_action_prompt = prompt_manager.get_prompt("action") @@ -864,20 +848,15 @@ class ActionPlanner: # 尝试按行分割,每行可能是一个JSON对象 lines = [line.strip() for line in json_str.split("\n") if line.strip()] for line in lines: - try: - # 尝试解析每一行作为独立的JSON对象 + with contextlib.suppress(json.JSONDecodeError): json_obj = json.loads(repair_json(line)) if isinstance(json_obj, dict): - # 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况 if json_obj: json_objects.append(json_obj) elif isinstance(json_obj, list): for item in json_obj: if isinstance(item, dict) and item: json_objects.append(item) - except json.JSONDecodeError: - # 如果单行解析失败,尝试将整个块作为一个JSON对象或数组 - pass # 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组 if not json_objects: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 98ecfcbb..f49ffda3 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending -from src.chat.message_receive.chat_stream import ChatStream +from maim_message import Seg + +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self @@ -45,17 +48,17 @@ logger = get_logger("replyer") class DefaultReplyer: def __init__( self, - chat_stream: ChatStream, + chat_stream: BotChatSession, request_type: str = "replyer", ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) self.heart_fc_sender = UniversalMessageSender() from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) async def generate_reply_with_context( self, @@ -132,7 +135,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt="", output=None, processed_output=None, @@ -202,7 +205,7 @@ class DefaultReplyer: try: if log_reply: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt=prompt, output=content, processed_output=None, @@ -259,7 +262,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt=prompt or "", output=None, processed_output=None, @@ -353,14 +356,14 @@ class DefaultReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id) if not use_expression: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( - self.chat_stream.stream_id, + self.chat_stream.session_id, chat_history, max_num=8, target_message=target, @@ -702,10 +705,11 @@ class DefaultReplyer: # 判断是否为群聊 is_group = stream_type == "group" - # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑 - from src.chat.message_receive.chat_stream import get_chat_manager + from src.common.utils.utils_session import SessionUtils - chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + chat_id = SessionUtils.calculate_session_id( + platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None + ) return chat_id, prompt_content except (ValueError, IndexError): @@ -778,7 +782,7 @@ class DefaultReplyer: if available_actions is None: available_actions = {} chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id _is_group_chat = bool(chat_stream.group_info) platform = chat_stream.platform @@ -1005,7 +1009,7 @@ class DefaultReplyer: reply_to: str, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -1105,29 +1109,27 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MessageRecv] = None, + anchor_message: Optional[MaiMessage] = None, ) -> MessageSending: """构建单个发送消息""" bot_user_info = UserInfo( user_id=str(global_config.bot.qq_account), user_nickname=global_config.bot.nickname, - platform=self.chat_stream.platform, ) - # await anchor_message.process() sender_info = anchor_message.message_info.user_info if anchor_message else None return MessageSending( - message_id=message_id, # 使用片段的唯一ID - chat_stream=self.chat_stream, + message_id=message_id, + session=self.chat_stream, bot_user_info=bot_user_info, sender_info=sender_info, message_segment=message_segment, - reply=anchor_message, # 回复原始锚点 + reply=anchor_message, is_head=reply_to, is_emoji=is_emoji, - thinking_start_time=thinking_start_time, # 传递原始思考开始时间 + thinking_start_time=thinking_start_time, display_message=display_message, ) diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 30ed14a9..c79b0c54 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending -from src.chat.message_receive.chat_stream import ChatStream +from maim_message import Seg + +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self @@ -43,18 +46,18 @@ logger = get_logger("replyer") class PrivateReplyer: def __init__( self, - chat_stream: ChatStream, + chat_stream: BotChatSession, request_type: str = "replyer", ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) self.heart_fc_sender = UniversalMessageSender() # self.memory_activator = MemoryActivator() from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) async def generate_reply_with_context( self, @@ -253,14 +256,14 @@ class PrivateReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id) if not use_expression: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( - self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason + self.chat_stream.session_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason ) if selected_expressions: @@ -550,10 +553,11 @@ class PrivateReplyer: # 判断是否为群聊 is_group = stream_type == "group" - # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑 - from src.chat.message_receive.chat_stream import get_chat_manager + from src.common.utils.utils_session import SessionUtils - chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + chat_id = SessionUtils.calculate_session_id( + platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None + ) return chat_id, prompt_content except (ValueError, IndexError): @@ -624,7 +628,7 @@ class PrivateReplyer: if available_actions is None: available_actions = {} chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id platform = chat_stream.platform user_id = "用户ID" @@ -843,7 +847,7 @@ class PrivateReplyer: reply_to: str, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -948,29 +952,27 @@ class PrivateReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MessageRecv] = None, + anchor_message: Optional[MaiMessage] = None, ) -> MessageSending: """构建单个发送消息""" bot_user_info = UserInfo( user_id=str(global_config.bot.qq_account), user_nickname=global_config.bot.nickname, - platform=self.chat_stream.platform, ) - # await anchor_message.process() sender_info = anchor_message.message_info.user_info if anchor_message else None return MessageSending( - message_id=message_id, # 使用片段的唯一ID - chat_stream=self.chat_stream, + message_id=message_id, + session=self.chat_stream, bot_user_info=bot_user_info, sender_info=sender_info, message_segment=message_segment, - reply=anchor_message, # 回复原始锚点 + reply=anchor_message, is_head=reply_to, is_emoji=is_emoji, - thinking_start_time=thinking_start_time, # 传递原始思考开始时间 + thinking_start_time=thinking_start_time, display_message=display_message, ) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index c7afddc9..eb430585 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,7 +1,7 @@ from typing import Dict, Optional from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.private_generator import PrivateReplyer @@ -14,7 +14,7 @@ class ReplyerManager: def get_replyer( self, - chat_stream: Optional[ChatStream] = None, + chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer | PrivateReplyer]: @@ -24,7 +24,7 @@ class ReplyerManager: model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。 后续调用将返回已缓存的实例,忽略 model_configs 参数。 """ - stream_id = chat_stream.stream_id if chat_stream else chat_id + stream_id = chat_stream.session_id if chat_stream else chat_id if not stream_id: logger.warning("[ReplyerManager] 缺少 stream_id,无法获取回复器。") return None @@ -39,15 +39,14 @@ class ReplyerManager: target_stream = chat_stream if not target_stream: - if chat_manager := get_chat_manager(): - target_stream = chat_manager.get_stream(stream_id) + target_stream = _chat_manager.get_session_by_session_id(stream_id) if not target_stream: logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。") return None # model_configs 只在此时(初始化时)生效 - if target_stream.group_info: + if target_stream.is_group_session: replyer = DefaultReplyer( chat_stream=target_stream, request_type=request_type, diff --git a/src/chat/utils/common_utils.py b/src/chat/utils/common_utils.py index bffd5557..0692a904 100644 --- a/src/chat/utils/common_utils.py +++ b/src/chat/utils/common_utils.py @@ -61,9 +61,12 @@ class TempMethodsExpression: str: 生成的聊天流ID(哈希值) """ try: - from src.chat.message_receive.chat_stream import get_chat_manager + from src.common.utils.utils_session import SessionUtils - return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + if is_group: + return SessionUtils.calculate_session_id(platform, group_id=str(id_str)) + else: + return SessionUtils.calculate_session_id(platform, user_id=str(id_str)) except Exception as e: logger.error(f"生成聊天流ID失败: {e}") return None diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index 21eef538..ba3b5f7d 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -1051,18 +1051,13 @@ class StatisticOutputTask(AsyncTask): """从chat_id获取显示名称""" try: # 首先尝试从chat_stream获取真实群组名称 - from src.chat.message_receive.chat_stream import get_chat_manager + from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager - chat_manager = get_chat_manager() - - if chat_id in chat_manager.streams: - stream = chat_manager.streams[chat_id] - if stream.group_info and hasattr(stream.group_info, "group_name"): - group_name = stream.group_info.group_name - if group_name and group_name.strip(): - return group_name.strip() - elif stream.user_info and hasattr(stream.user_info, "user_nickname"): - user_name = stream.user_info.user_nickname + if chat_id in _stat_chat_manager.sessions: + session = _stat_chat_manager.sessions[chat_id] + name = _stat_chat_manager.get_session_name(chat_id) + if name and name.strip(): + return name.strip() if user_name and user_name.strip(): return user_name.strip() diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index cd66b919..81097184 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -12,8 +12,8 @@ from typing import Optional, Tuple, List, TYPE_CHECKING from src.common.logger import get_logger from src.common.data_models.database_data_model import DatabaseMessages from src.config.config import global_config, model_config -from src.chat.message_receive.message import MessageRecv -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import SessionMessage +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import Person from .typo_generator import ChineseTypoGenerator @@ -114,10 +114,10 @@ def is_bot_self(platform: str, user_id: str) -> bool: return user_id_str == qq_account -def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]: +def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]: """检查消息是否提到了机器人(统一多平台实现)""" text = message.processed_plain_text or "" - platform = getattr(message.message_info, "platform", "") or "" + platform = message.platform or "" # 获取各平台账号 platforms_list = getattr(global_config.bot, "platforms", []) or [] @@ -696,15 +696,23 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP chat_target_info = None try: - if chat_stream := get_chat_manager().get_stream(chat_id): - if chat_stream.group_info: + if chat_stream := _chat_manager.get_session_by_session_id(chat_id): + if chat_stream.is_group_session: is_group_chat = True chat_target_info = None # Explicitly None for group chat - elif chat_stream.user_info: # It's a private chat + elif chat_stream.user_id: # It's a private chat is_group_chat = False - user_info = chat_stream.user_info platform: str = chat_stream.platform - user_id: str = user_info.user_id # type: ignore + user_id: str = chat_stream.user_id + + # Try to get nickname from context + user_nickname = None + if ( + chat_stream.context + and chat_stream.context.message + and chat_stream.context.message.message_info.user_info + ): + user_nickname = chat_stream.context.message.message_info.user_info.user_nickname from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题 @@ -712,7 +720,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP target_info = TargetPersonInfo( platform=platform, user_id=user_id, - user_nickname=user_info.user_nickname, # type: ignore + user_nickname=user_nickname, # type: ignore person_id=None, person_name=None, ) @@ -721,7 +729,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP try: person = Person(platform=platform, user_id=user_id) if not person.is_known: - logger.warning(f"用户 {user_info.user_nickname} 尚未认识") + logger.warning(f"用户 {user_nickname} 尚未认识") # 如果用户尚未认识,则返回False和None return False, None if person.person_id: diff --git a/src/common/utils/utils_image.py b/src/common/utils/utils_image.py index 944e53a0..6132afd6 100644 --- a/src/common/utils/utils_image.py +++ b/src/common/utils/utils_image.py @@ -1,4 +1,6 @@ +from pathlib import Path from PIL import Image as PILImage, ImageSequence +from typing import Optional, Union import base64 import io @@ -102,3 +104,30 @@ class ImageUtils: logger.error("输入的图片字节数据无效") raise ValueError("输入的图片字节数据无效") return base64.b64encode(image_bytes).decode("utf-8") + + @staticmethod + def image_path_to_base64(image_path: Union[str, Path]) -> Optional[str]: + """读取图片文件并转换为 Base64 编码字符串""" + try: + path = Path(image_path) + if not path.exists(): + logger.error(f"图片文件不存在: {path}") + return None + image_bytes = path.read_bytes() + return base64.b64encode(image_bytes).decode("utf-8") + except Exception as e: + logger.error(f"读取图片文件失败: {e}") + return None + + @staticmethod + def base64_to_image(base64_str: str, save_path: Union[str, Path]) -> bool: + """将 Base64 编码字符串解码并保存为图片文件""" + try: + image_bytes = base64.b64decode(base64_str) + path = Path(save_path) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(image_bytes) + return True + except Exception as e: + logger.error(f"保存图片文件失败: {e}") + return False diff --git a/src/dream/dream_generator.py b/src/dream/dream_generator.py index 316cac99..4934bee1 100644 --- a/src/dream/dream_generator.py +++ b/src/dream/dream_generator.py @@ -7,7 +7,7 @@ from src.config.config import global_config, model_config from src.llm_models.payload_content.message import RoleType, Message from src.prompt.prompt_manager import prompt_manager from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.utils.utils_session import SessionUtils from src.plugin_system.apis import send_api logger = get_logger("dream_generator") @@ -178,10 +178,9 @@ async def generate_dream_summary( logger.warning(f"[dream][梦境总结] dream_send 平台或用户ID为空,当前值: {dream_send_raw!r}") else: # 默认为私聊会话 - stream_id = get_chat_manager().get_stream_id( + stream_id = SessionUtils.calculate_session_id( platform=platform, - id=str(user_id), - is_group=False, + user_id=str(user_id), ) if not stream_id: logger.error( diff --git a/src/main.py b/src/main.py index 29b00682..c3634037 100644 --- a/src/main.py +++ b/src/main.py @@ -8,7 +8,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask # from src.chat.utils.token_statistics import TokenStatisticsTask from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager from src.config.config import config_manager, global_config from src.chat.message_receive.bot import chat_bot from src.common.logger import get_logger @@ -119,8 +119,8 @@ class MainSystem: logger.info("表情包管理器初始化成功") # 初始化聊天管理器 - await get_chat_manager()._initialize() - asyncio.create_task(get_chat_manager()._auto_save_task()) + await chat_manager.initialize() + asyncio.create_task(chat_manager.regularly_save_sessions()) logger.info("聊天管理器初始化成功") diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index d45d52f3..f8ff7adc 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -21,7 +21,7 @@ from src.plugin_system.apis import message_api from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.utils import is_bot_self from src.person_info.person_info import Person -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.prompt.prompt_manager import prompt_manager logger = get_logger("chat_history_summarizer") @@ -100,7 +100,7 @@ class ChatHistorySummarizer: def _get_chat_display_name(self) -> str: """获取聊天显示名称""" try: - chat_name = get_chat_manager().get_stream_name(self.chat_id) + chat_name = _chat_manager.get_session_name(self.chat_id) if chat_name: return chat_name # 如果获取失败,使用简化的chat_id显示 diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 8d85964a..e0d4fcda 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -1,3 +1,4 @@ +import contextlib import time import json import asyncio @@ -12,7 +13,7 @@ from src.common.database.database import get_db_session from src.common.database.database_model import ThinkingQuestion from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") @@ -133,10 +134,10 @@ async def _react_agent_solve_question( Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) """ start_time = time.time() - collected_info = initial_info if initial_info else "" + collected_info = initial_info or "" # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流 try: - chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + chat_name = _chat_manager.get_session_name(chat_id) or chat_id except Exception: chat_name = chat_id react_log_prefix = f"[{chat_name}] " @@ -235,7 +236,7 @@ async def _react_agent_solve_question( # head_prompt应该只构建一次,使用初始的collected_info,后续迭代都复用同一个 if first_head_prompt is None: # 第一次构建,使用初始的collected_info(即initial_info) - initial_collected_info = initial_info if initial_info else "" + initial_collected_info = initial_info or "" # 根据配置选择使用哪个 prompt prompt_name = ( "memory_retrieval_react_prompt_head_lpmm" @@ -362,7 +363,7 @@ async def _react_agent_solve_question( return information except (json.JSONDecodeError, ValueError, TypeError): # 如果JSON解析失败,尝试在文本中查找JSON对象 - try: + with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError): # 查找第一个 { 和最后一个 } 之间的内容(更健壮的JSON提取) first_brace = text.find("{") if first_brace != -1: @@ -384,8 +385,6 @@ async def _react_agent_solve_question( if isinstance(data, dict) and "return_information" in data: information = data.get("information", "") return information - except (json.JSONDecodeError, ValueError, TypeError): - pass return None @@ -679,7 +678,7 @@ async def _react_agent_solve_question( evaluation_prompt_template.add_context("bot_name", bot_name) evaluation_prompt_template.add_context("time_now", time_now) evaluation_prompt_template.add_context("chat_history", chat_history) - evaluation_prompt_template.add_context("collected_info", collected_info if collected_info else "暂无信息") + evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息") evaluation_prompt_template.add_context("current_iteration", str(current_iteration)) evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) @@ -800,8 +799,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) if not records: return "" - history_lines = [] - history_lines.append("最近已查询的问题和结果:") + history_lines = ["最近已查询的问题和结果:"] for record in records: status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" @@ -813,8 +811,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) if len(record.answer) > 100: answer_preview += "..." - history_lines.append(f"- 问题:{record.question}") - history_lines.append(f" 状态:{status}") + history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"]) if answer_preview: history_lines.append(f" 答案:{answer_preview}") history_lines.append("") # 空行分隔 @@ -855,12 +852,11 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) if not records: return [] - found_answers = [] - for record in records: - if record.answer: - found_answers.append(f"问题:{record.question}\n答案:{record.answer}") - - return found_answers + return [ + f"问题:{record.question}\n答案:{record.answer}" + for record in records + if record.answer + ] except Exception as e: logger.error(f"获取最近已找到答案的记录失败: {e}") @@ -892,8 +888,7 @@ def _store_thinking_back( .order_by(col(ThinkingQuestion.updated_timestamp).desc()) .limit(1) ) - record = session.exec(statement).first() - if record: + if record := session.exec(statement).first(): record.context = context record.found_answer = found_answer record.answer = answer @@ -957,10 +952,7 @@ async def _process_memory_retrieval( if is_timeout: logger.info("ReAct Agent超时,不返回结果") - if found_answer and answer: - return answer - - return None + return answer if found_answer and answer else None async def build_memory_retrieval_prompt( @@ -1013,8 +1005,7 @@ async def build_memory_retrieval_prompt( cleaned_concepts = [] for word in unknown_words: if isinstance(word, str): - cleaned = word.strip() - if cleaned: + if cleaned := word.strip(): cleaned_concepts.append(cleaned) if cleaned_concepts: # 对匹配到的概念进行jargon检索,作为初始信息 diff --git a/src/memory_system/retrieval_tools/query_chat_history.py b/src/memory_system/retrieval_tools/query_chat_history.py index 0a9f502f..4f6e114e 100644 --- a/src/memory_system/retrieval_tools/query_chat_history.py +++ b/src/memory_system/retrieval_tools/query_chat_history.py @@ -30,9 +30,8 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]: return chat_ids try: - from src.chat.message_receive.chat_stream import get_chat_manager + from src.common.utils.utils_session import SessionUtils - chat_manager = get_chat_manager() for blacklist_item in blacklist: if not isinstance(blacklist_item, str): continue @@ -51,7 +50,10 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]: is_group = stream_type == "group" # 转换为chat_id - chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group) + if is_group: + chat_id = SessionUtils.calculate_session_id(platform, group_id=str(id_str)) + else: + chat_id = SessionUtils.calculate_session_id(platform, user_id=str(id_str)) if chat_id: chat_ids.add(chat_id) else: @@ -225,9 +227,9 @@ async def search_chat_history( if keyword: keyword_matched = False # 解析多个关键词(支持空格、逗号等分隔符) - keywords_list = parse_keywords_string(keyword) - if not keywords_list: - keywords_list = [keyword.strip()] if keyword.strip() else [] + keywords_list = parse_keywords_string(keyword) or ( + [keyword.strip()] if keyword.strip() else [] + ) # 转换为小写以便匹配 keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()] diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 4e245b67..d200b4ee 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -16,7 +16,7 @@ from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager logger = get_logger("person_info") @@ -818,22 +818,22 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: 聊天ID """ try: - # 从chat_id获取chat_stream - chat_stream = get_chat_manager().get_stream(chat_id) - if not chat_stream: - logger.warning(f"无法获取chat_stream for chat_id: {chat_id}") + # 从 chat_id 获取 session + session = _chat_manager.get_session_by_session_id(chat_id) + if not session: + logger.warning(f"无法获取session for chat_id: {chat_id}") return - platform = chat_stream.platform + platform = session.platform # 尝试从person_name查找person_id # 首先尝试通过person_name查找 person_id = get_person_id_by_person_name(person_name) if not person_id: - # 如果通过person_name找不到,尝试从chat_stream获取user_info - if platform and chat_stream.user_info and chat_stream.user_info.user_id: - user_id = chat_stream.user_info.user_id + # 如果通过person_name找不到,尝试从 session 获取 user_id + if platform and session.user_id: + user_id = session.user_id person_id = get_person_id(platform, user_id) else: logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index 9e995d36..faab54b0 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -16,7 +16,7 @@ from typing import List, Dict, Any, Optional from enum import Enum from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager logger = get_logger("chat_api") @@ -31,7 +31,7 @@ class ChatManager: """聊天管理器 - 专门负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: # sourcery skip: for-append-to-extend """获取所有聊天流 @@ -39,7 +39,7 @@ class ChatManager: platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: - List[ChatStream]: 聊天流列表 + List[BotChatSession]: 聊天流列表 Raises: TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 @@ -48,7 +48,7 @@ class ChatManager: raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: - for _, stream in get_chat_manager().streams.items(): + for _, stream in _chat_manager.sessions.items(): if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流") @@ -57,7 +57,7 @@ class ChatManager: return streams @staticmethod - def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: # sourcery skip: for-append-to-extend """获取所有群聊聊天流 @@ -65,14 +65,14 @@ class ChatManager: platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: - List[ChatStream]: 群聊聊天流列表 + List[BotChatSession]: 群聊聊天流列表 """ if not isinstance(platform, (str, SpecialTypes)): raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: - for _, stream in get_chat_manager().streams.items(): - if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info: + for _, stream in _chat_manager.sessions.items(): + if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流") except Exception as e: @@ -80,7 +80,7 @@ class ChatManager: return streams @staticmethod - def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: + def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: # sourcery skip: for-append-to-extend """获取所有私聊聊天流 @@ -88,7 +88,7 @@ class ChatManager: platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: - List[ChatStream]: 私聊聊天流列表 + List[BotChatSession]: 私聊聊天流列表 Raises: TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 @@ -97,8 +97,10 @@ class ChatManager: raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: - for _, stream in get_chat_manager().streams.items(): - if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info: + for _, stream in _chat_manager.sessions.items(): + if ( + platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform + ) and not stream.is_group_session: streams.append(stream) logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流") except Exception as e: @@ -108,7 +110,7 @@ class ChatManager: @staticmethod def get_group_stream_by_group_id( group_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast """根据群ID获取聊天流 Args: @@ -116,7 +118,7 @@ class ChatManager: platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: - Optional[ChatStream]: 聊天流对象,如果未找到返回None + Optional[BotChatSession]: 聊天流对象,如果未找到返回None Raises: ValueError: 如果 group_id 为空字符串 @@ -129,11 +131,11 @@ class ChatManager: if not group_id: raise ValueError("group_id 不能为空") try: - for _, stream in get_chat_manager().streams.items(): + for _, stream in _chat_manager.sessions.items(): if ( - stream.group_info - and str(stream.group_info.group_id) == str(group_id) - and stream.platform == platform + stream.is_group_session + and str(stream.group_id) == str(group_id) + and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) ): logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流") return stream @@ -145,7 +147,7 @@ class ChatManager: @staticmethod def get_private_stream_by_user_id( user_id: str, platform: Optional[str] | SpecialTypes = "qq" - ) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast + ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast """根据用户ID获取私聊流 Args: @@ -153,7 +155,7 @@ class ChatManager: platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 Returns: - Optional[ChatStream]: 聊天流对象,如果未找到返回None + Optional[BotChatSession]: 聊天流对象,如果未找到返回None Raises: ValueError: 如果 user_id 为空字符串 @@ -166,11 +168,11 @@ class ChatManager: if not user_id: raise ValueError("user_id 不能为空") try: - for _, stream in get_chat_manager().streams.items(): + for _, stream in _chat_manager.sessions.items(): if ( - not stream.group_info - and str(stream.user_info.user_id) == str(user_id) - and stream.platform == platform + not stream.is_group_session + and str(stream.user_id) == str(user_id) + and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) ): logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流") return stream @@ -180,7 +182,7 @@ class ChatManager: return None @staticmethod - def get_stream_type(chat_stream: ChatStream) -> str: + def get_stream_type(chat_stream: BotChatSession) -> str: """获取聊天流类型 Args: @@ -190,20 +192,18 @@ class ChatManager: str: 聊天类型 ("group", "private", "unknown") Raises: - TypeError: 如果 chat_stream 不是 ChatStream 类型 + TypeError: 如果 chat_stream 不是 BotChatSession 类型 ValueError: 如果 chat_stream 为空 """ - if not isinstance(chat_stream, ChatStream): - raise TypeError("chat_stream 必须是 ChatStream 类型") + if not isinstance(chat_stream, BotChatSession): + raise TypeError("chat_stream 必须是 BotChatSession 类型") if not chat_stream: raise ValueError("chat_stream 不能为 None") - if hasattr(chat_stream, "group_info"): - return "group" if chat_stream.group_info else "private" - return "unknown" + return "group" if chat_stream.is_group_session else "private" @staticmethod - def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: + def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: """获取聊天流详细信息 Args: @@ -213,36 +213,34 @@ class ChatManager: Dict ({str: Any}): 聊天流信息字典 Raises: - TypeError: 如果 chat_stream 不是 ChatStream 类型 + TypeError: 如果 chat_stream 不是 BotChatSession 类型 ValueError: 如果 chat_stream 为空 """ if not chat_stream: raise ValueError("chat_stream 不能为 None") - if not isinstance(chat_stream, ChatStream): - raise TypeError("chat_stream 必须是 ChatStream 类型") + if not isinstance(chat_stream, BotChatSession): + raise TypeError("chat_stream 必须是 BotChatSession 类型") try: info: Dict[str, Any] = { - "stream_id": chat_stream.stream_id, + "session_id": chat_stream.session_id, "platform": chat_stream.platform, "type": ChatManager.get_stream_type(chat_stream), } - if chat_stream.group_info: - info.update( - { - "group_id": chat_stream.group_info.group_id, - "group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"), - } - ) - - if chat_stream.user_info: - info.update( - { - "user_id": chat_stream.user_info.user_id, - "user_name": chat_stream.user_info.user_nickname, - } - ) + if chat_stream.is_group_session: + info["group_id"] = chat_stream.group_id + # Try to get group name from context + if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info: + info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊" + else: + info["group_name"] = "未知群聊" + else: + info["user_id"] = chat_stream.user_id + if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info: + info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname + else: + info["user_name"] = "未知用户" return info except Exception as e: @@ -285,37 +283,37 @@ class ChatManager: # ============================================================================= -def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: """获取所有聊天流的便捷函数""" return ChatManager.get_all_streams(platform) -def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: """获取群聊聊天流的便捷函数""" return ChatManager.get_group_streams(platform) -def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: +def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]: """获取私聊聊天流的便捷函数""" return ChatManager.get_private_streams(platform) -def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]: """根据群ID获取聊天流的便捷函数""" return ChatManager.get_group_stream_by_group_id(group_id, platform) -def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: +def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]: """根据用户ID获取私聊流的便捷函数""" return ChatManager.get_private_stream_by_user_id(user_id, platform) -def get_stream_type(chat_stream: ChatStream) -> str: +def get_stream_type(chat_stream: BotChatSession) -> str: """获取聊天流类型的便捷函数""" return ChatManager.get_stream_type(chat_stream) -def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: +def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]: """获取聊天流信息的便捷函数""" return ChatManager.get_stream_info(chat_stream) diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index 2a99d25c..cbc6dc50 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -16,7 +16,7 @@ import uuid from typing import Optional, Tuple, List, Dict, Any from src.common.logger import get_logger from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR -from src.chat.utils.utils_image import image_path_to_base64, base64_to_image +from src.common.utils.utils_image import ImageUtils from src.config.config import global_config logger = get_logger("emoji_api") @@ -56,7 +56,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] emoji_path = str(emoji_obj.full_path) emoji_description = emoji_obj.description matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" - emoji_base64 = image_path_to_base64(emoji_path) + emoji_base64 = ImageUtils.image_path_to_base64(emoji_path) if not emoji_base64: logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}") @@ -115,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: results = [] for selected_emoji in selected_emojis: - emoji_base64 = image_path_to_base64(str(selected_emoji.full_path)) + emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path)) if not emoji_base64: logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") @@ -174,7 +174,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: # 随机选择匹配的表情包 selected_emoji = random.choice(matching_emojis) - emoji_base64 = image_path_to_base64(selected_emoji.full_path) + emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path) if not emoji_base64: logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") @@ -263,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]: if emoji_obj.is_deleted: continue - emoji_base64 = image_path_to_base64(str(emoji_obj.full_path)) + emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path)) if not emoji_base64: logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}") @@ -429,7 +429,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D try: # 解码base64并保存图片 - if not base64_to_image(image_base64, temp_file_path): + if not ImageUtils.base64_to_image(image_base64, temp_file_path): logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}") return { "success": False, diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index af985b96..3217817c 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -16,7 +16,7 @@ from src.common.logger import get_logger from src.common.data_models.message_data_model import ReplySetModel from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.private_generator import PrivateReplyer -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.utils.utils import process_llm_response from src.chat.replyer.replyer_manager import replyer_manager from src.plugin_system.base.component_types import ActionInfo @@ -38,7 +38,7 @@ logger = get_logger("generator_api") def get_replyer( - chat_stream: Optional[ChatStream] = None, + chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer | PrivateReplyer]: @@ -79,7 +79,7 @@ def get_replyer( async def generate_reply( - chat_stream: Optional[ChatStream] = None, + chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, action_data: Optional[Dict[str, Any]] = None, reply_message: Optional["DatabaseMessages"] = None, @@ -161,7 +161,7 @@ async def generate_reply( unknown_words=unknown_words, think_level=think_level, from_plugin=from_plugin, - stream_id=chat_stream.stream_id if chat_stream else chat_id, + stream_id=chat_stream.session_id if chat_stream else chat_id, reply_time_point=reply_time_point, log_reply=False, ) @@ -181,7 +181,7 @@ async def generate_reply( # 统一在这里记录最终回复日志(包含分割后的 processed_output) try: PlanReplyLogger.log_reply( - chat_id=chat_stream.stream_id if chat_stream else (chat_id or ""), + chat_id=chat_stream.session_id if chat_stream else (chat_id or ""), prompt=llm_response.prompt or "", output=llm_response.content, processed_output=llm_response.processed_output, @@ -210,7 +210,7 @@ async def generate_reply( async def rewrite_reply( - chat_stream: Optional[ChatStream] = None, + chat_stream: Optional[BotChatSession] = None, reply_data: Optional[Dict[str, Any]] = None, chat_id: Optional[str] = None, enable_splitter: bool = True, @@ -302,7 +302,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: async def generate_response_custom( - chat_stream: Optional[ChatStream] = None, + chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, request_type: str = "generator_api", prompt: str = "", diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index fd2b723f..b18f8378 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -26,10 +26,12 @@ from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple from src.common.logger import get_logger from src.common.data_models.message_data_model import ReplyContentType from src.config.config import global_config -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.chat.message_receive.message import MessageSending, MessageRecv -from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo +from maim_message import Seg + +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.message import MessageSending if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages @@ -77,7 +79,7 @@ async def _send_to_target( logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}") # 查找目标聊天流 - target_stream = get_chat_manager().get_stream(stream_id) + target_stream = _chat_manager.get_session_by_session_id(stream_id) if not target_stream: logger.error(f"[SendAPI] 未找到聊天流: {stream_id}") return False @@ -93,27 +95,29 @@ async def _send_to_target( bot_user_info = UserInfo( user_id=global_config.bot.qq_account, user_nickname=global_config.bot.nickname, - platform=target_stream.platform, ) reply_to_platform_id = "" - anchor_message: Union["MessageRecv", None] = None + anchor_message: Optional[MaiMessage] = None if reply_message: - anchor_message = db_message_to_message_recv(reply_message) - logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") # type: ignore + anchor_message = db_message_to_mai_message(reply_message) if anchor_message: - anchor_message.update_chat_stream(target_stream) - assert anchor_message.message_info.user_info, "用户信息缺失" + logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") reply_to_platform_id = ( - f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" + f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" ) + # 构建 sender_info(私聊时为接收者信息) + sender_info = None + if target_stream.context and target_stream.context.message: + sender_info = target_stream.context.message.message_info.user_info + # 构建发送消息对象 bot_message = MessageSending( message_id=message_id, - chat_stream=target_stream, + session=target_stream, bot_user_info=bot_user_info, - sender_info=target_stream.user_info, + sender_info=sender_info, message_segment=message_segment, display_message=display_message, reply=anchor_message, @@ -146,51 +150,43 @@ async def _send_to_target( return False -def db_message_to_message_recv(message_obj: "DatabaseMessages") -> MessageRecv: - """将数据库dict重建为MessageRecv对象 +def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]: + """将数据库消息重建为 MaiMessage 对象,用于回复引用。 + Args: - message_dict: 消息字典 + message_obj: 插件系统的 DatabaseMessages 数据对象 Returns: - Optional[MessageRecv]: 找到的消息,如果没找到则返回None + Optional[MaiMessage]: 构建的消息对象,如果信息不足则返回 None """ - # 构建MessageRecv对象 - user_info = { - "platform": message_obj.user_info.platform or "", - "user_id": message_obj.user_info.user_id or "", - "user_nickname": message_obj.user_info.user_nickname or "", - "user_cardname": message_obj.user_info.user_cardname or "", - } + from datetime import datetime + from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo + from src.common.data_models.message_component_data_model import MessageSequence - group_info = {} + user_info = UserInfo( + user_id=message_obj.user_info.user_id or "", + user_nickname=message_obj.user_info.user_nickname or "", + user_cardname=message_obj.user_info.user_cardname, + ) + + group_info = None if message_obj.chat_info.group_info: - group_info = { - "platform": message_obj.chat_info.group_info.group_platform or "", - "group_id": message_obj.chat_info.group_info.group_id or "", - "group_name": message_obj.chat_info.group_info.group_name or "", - } + group_info = GroupInfo( + group_id=message_obj.chat_info.group_info.group_id or "", + group_name=message_obj.chat_info.group_info.group_name or "", + ) - format_info = {"content_format": "", "accept_format": ""} - template_info = {"template_items": {}} - - message_info = { - "platform": message_obj.chat_info.platform or "", - "message_id": message_obj.message_id, - "time": message_obj.time, - "group_info": group_info, - "user_info": user_info, - "additional_config": message_obj.additional_config, - "format_info": format_info, - "template_info": template_info, - } - - message_dict_recv = { - "message_info": message_info, - "raw_message": message_obj.processed_plain_text, - "processed_plain_text": message_obj.processed_plain_text, - } - - return MessageRecv(message_dict_recv) + msg = MaiMessage( + message_id=message_obj.message_id, + timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(), + ) + msg.message_info = MessageInfo(user_info=user_info, group_info=group_info) + msg.platform = message_obj.chat_info.platform or "" + msg.session_id = message_obj.chat_info.stream_id or "" + msg.processed_plain_text = message_obj.processed_plain_text + msg.raw_message = MessageSequence(components=[]) + msg.initialized = True + return msg # ============================================================================= diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index bc0b32f0..00464ea7 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -5,12 +5,12 @@ from src.plugin_system.base.component_types import ComponentType from src.common.logger import get_logger if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream + from src.chat.message_receive.chat_manager import BotChatSession logger = get_logger("tool_api") -def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]: +def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]: """获取公开工具实例 Args: diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 1c962798..1f2af8a3 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -6,7 +6,7 @@ from typing import Tuple, Optional, TYPE_CHECKING, Dict, List from src.common.logger import get_logger from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_manager import BotChatSession from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType from src.plugin_system.apis import send_api, database_api, message_api @@ -36,7 +36,7 @@ class BaseAction(ABC): action_reasoning: str, cycle_timers: dict, thinking_id: str, - chat_stream: ChatStream, + chat_stream: BotChatSession, plugin_config: Optional[dict] = None, action_message: Optional["DatabaseMessages"] = None, **kwargs, @@ -92,7 +92,7 @@ class BaseAction(ABC): # 获取聊天流对象 self.chat_stream = chat_stream or kwargs.get("chat_stream") - self.chat_id = self.chat_stream.stream_id + self.chat_id = self.chat_stream.session_id self.platform = getattr(self.chat_stream, "platform", None) # 初始化基础信息(带类型注解) diff --git a/src/plugin_system/base/base_command.py b/src/plugin_system/base/base_command.py index 0cfdb5d0..ffa1e46b 100644 --- a/src/plugin_system/base/base_command.py +++ b/src/plugin_system/base/base_command.py @@ -3,7 +3,7 @@ from typing import Dict, Tuple, Optional, TYPE_CHECKING, List from src.common.logger import get_logger from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode from src.plugin_system.base.component_types import CommandInfo, ComponentType -from src.chat.message_receive.message import MessageRecv +from src.chat.message_receive.message import SessionMessage from src.plugin_system.apis import send_api if TYPE_CHECKING: @@ -31,7 +31,7 @@ class BaseCommand(ABC): command_pattern: str = r"" """命令匹配的正则表达式""" - def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): + def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None): """初始化Command组件 Args: @@ -107,14 +107,14 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False return await send_api.text_to_stream( text=content, - stream_id=chat_stream.stream_id, + stream_id=session_id, set_reply=set_reply, reply_message=reply_message, storage_message=storage_message, @@ -135,14 +135,14 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False return await send_api.image_to_stream( image_base64, - chat_stream.stream_id, + session_id, set_reply=set_reply, reply_message=reply_message, storage_message=storage_message, @@ -166,13 +166,13 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False return await send_api.emoji_to_stream( - emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message + emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message ) async def send_command( @@ -195,9 +195,9 @@ class BaseCommand(ABC): """ try: # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False # 构造命令数据 @@ -205,7 +205,7 @@ class BaseCommand(ABC): success = await send_api.command_to_stream( command=command_data, - stream_id=chat_stream.stream_id, + stream_id=session_id, storage_message=storage_message, display_message=display_message, ) @@ -229,15 +229,15 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False return await send_api.custom_to_stream( message_type="voice", content=voice_base64, - stream_id=chat_stream.stream_id, + stream_id=session_id, typing=False, set_reply=False, reply_message=None, @@ -262,15 +262,15 @@ class BaseCommand(ABC): reply_message: 回复的消息对象 storage_message: 是否存储消息到数据库 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False reply_set = ReplySetModel() reply_set.add_hybrid_content_by_raw(message_tuple_list) return await send_api.custom_reply_set_to_stream( reply_set=reply_set, - stream_id=chat_stream.stream_id, + stream_id=session_id, typing=typing, set_reply=set_reply, reply_message=reply_message, @@ -293,9 +293,9 @@ class BaseCommand(ABC): Returns: bool: 是否发送成功 """ - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False reply_set = ReplySetModel() forward_message_nodes: List[ForwardNode] = [] @@ -318,7 +318,7 @@ class BaseCommand(ABC): reply_set.add_forward_content(forward_message_nodes) return await send_api.custom_reply_set_to_stream( reply_set=reply_set, - stream_id=chat_stream.stream_id, + stream_id=session_id, storage_message=storage_message, set_reply=False, reply_message=None, @@ -349,15 +349,15 @@ class BaseCommand(ABC): bool: 是否发送成功 """ # 获取聊天流信息 - chat_stream = self.message.chat_stream - if not chat_stream or not hasattr(chat_stream, "stream_id"): - logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") + session_id = self.message.session_id + if not session_id: + logger.error(f"{self.log_prefix} 缺少session_id") return False return await send_api.custom_to_stream( message_type=message_type, content=content, - stream_id=chat_stream.stream_id, + stream_id=session_id, display_message=display_message, typing=typing, set_reply=set_reply, diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py index 71d55101..3938027e 100644 --- a/src/plugin_system/base/base_tool.py +++ b/src/plugin_system/base/base_tool.py @@ -6,7 +6,7 @@ from src.common.logger import get_logger from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream + from src.chat.message_receive.chat_manager import BotChatSession install(extra_lines=3) @@ -32,7 +32,7 @@ class BaseTool(ABC): available_for_llm: bool = False """是否可供LLM使用""" - def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None): + def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None): """初始化工具基类 Args: @@ -47,7 +47,7 @@ class BaseTool(ABC): # 获取聊天流对象 self.chat_stream = chat_stream - self.chat_id = self.chat_stream.stream_id if self.chat_stream else None + self.chat_id = self.chat_stream.session_id if self.chat_stream else None self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None @classmethod diff --git a/src/plugin_system/core/events_manager.py b/src/plugin_system/core/events_manager.py index be435848..fa0a0637 100644 --- a/src/plugin_system/core/events_manager.py +++ b/src/plugin_system/core/events_manager.py @@ -2,8 +2,8 @@ import asyncio import contextlib from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING -from src.chat.message_receive.message import MessageRecv, MessageSending -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.message import MessageSending, SessionMessage +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.logger import get_logger from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult from src.plugin_system.base.base_events_handler import BaseEventHandler @@ -72,7 +72,7 @@ class EventsManager: async def handle_mai_events( self, event_type: EventType | str, - message: Optional[MessageRecv | MessageSending | MaiMessages] = None, + message: Optional[SessionMessage | MessageSending | MaiMessages] = None, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None, stream_id: Optional[str] = None, @@ -87,7 +87,7 @@ class EventsManager: # 1. 准备消息 transformed_message = self._prepare_message( - event_type, message, llm_prompt, llm_response, stream_id, action_usage + event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type] ) if transformed_message: transformed_message = transformed_message.deepcopy() @@ -134,7 +134,7 @@ class EventsManager: async def handle_workflow_message( self, - message: Optional[MessageRecv | MessageSending | MaiMessages] = None, + message: Optional[SessionMessage | MessageSending | MaiMessages] = None, stream_id: Optional[str] = None, action_usage: Optional[List[str]] = None, context: Optional[WorkflowContext] = None, @@ -248,11 +248,13 @@ class EventsManager: def _transform_event_message( self, - message: MessageRecv | MessageSending, + message: SessionMessage | MessageSending, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None, ) -> MaiMessages: """转换事件消息格式""" + from maim_message import Seg + # 直接赋值部分内容 transformed_message = MaiMessages( llm_prompt=llm_prompt, @@ -260,45 +262,62 @@ class EventsManager: llm_response_reasoning=llm_response.reasoning if llm_response else None, llm_response_model=llm_response.model if llm_response else None, llm_response_tool_call=llm_response.tool_calls if llm_response else None, - raw_message=message.raw_message, - additional_data=message.message_info.additional_config or {}, + raw_message=message.processed_plain_text or "", + additional_data={}, ) # 消息段处理 - if message.message_segment.type == "seglist": - transformed_message.message_segments = list(message.message_segment.data) # type: ignore + if isinstance(message, MessageSending): + if message.message_segment.type == "seglist": + transformed_message.message_segments = list(message.message_segment.data) # type: ignore + else: + transformed_message.message_segments = [message.message_segment] else: - transformed_message.message_segments = [message.message_segment] + # SessionMessage: 使用 processed_plain_text 构造简单段 + transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")] # stream_id 处理 - if hasattr(message, "chat_stream") and message.chat_stream: - transformed_message.stream_id = message.chat_stream.stream_id + transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else "" # 处理后文本 transformed_message.plain_text = message.processed_plain_text # 基本信息 - if hasattr(message, "message_info") and message.message_info: - if message.message_info.platform: - transformed_message.message_base_info["platform"] = message.message_info.platform + if isinstance(message, MessageSending): + transformed_message.message_base_info["platform"] = message.platform + if message.session.group_id: + transformed_message.is_group_message = True + group_name = "" + if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info: + group_name = message.session.context.message.message_info.group_info.group_name + transformed_message.message_base_info.update({ + "group_id": message.session.group_id, + "group_name": group_name, + }) + transformed_message.message_base_info.update({ + "user_id": message.bot_user_info.user_id, + "user_cardname": message.bot_user_info.user_cardname, + "user_nickname": message.bot_user_info.user_nickname, + }) + if not transformed_message.is_group_message: + transformed_message.is_private_message = True + elif hasattr(message, "message_info") and message.message_info: + if message.platform: + transformed_message.message_base_info["platform"] = message.platform if message.message_info.group_info: transformed_message.is_group_message = True - transformed_message.message_base_info.update( - { - "group_id": message.message_info.group_info.group_id, - "group_name": message.message_info.group_info.group_name, - } - ) + transformed_message.message_base_info.update({ + "group_id": message.message_info.group_info.group_id, + "group_name": message.message_info.group_info.group_name, + }) if message.message_info.user_info: if not transformed_message.is_group_message: transformed_message.is_private_message = True - transformed_message.message_base_info.update( - { - "user_id": message.message_info.user_info.user_id, - "user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 - "user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) - } - ) + transformed_message.message_base_info.update({ + "user_id": message.message_info.user_info.user_id, + "user_cardname": message.message_info.user_info.user_cardname, + "user_nickname": message.message_info.user_info.user_nickname, + }) return transformed_message @@ -306,9 +325,9 @@ class EventsManager: self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None ) -> MaiMessages: """从流ID构建消息""" - chat_stream = get_chat_manager().get_stream(stream_id) - assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流" - message = chat_stream.context.get_last_message() + session = _chat_manager.get_session_by_session_id(stream_id) + assert session, f"未找到流ID为 {stream_id} 的会话" + message = session.context.message return self._transform_event_message(message, llm_prompt, llm_response) def _transform_event_without_message( @@ -319,8 +338,8 @@ class EventsManager: action_usage: Optional[List[str]] = None, ) -> MaiMessages: """没有message对象时进行转换""" - chat_stream = get_chat_manager().get_stream(stream_id) - assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流" + session = _chat_manager.get_session_by_session_id(stream_id) + assert session, f"未找到流ID为 {stream_id} 的会话" return MaiMessages( stream_id=stream_id, llm_prompt=llm_prompt, @@ -328,8 +347,8 @@ class EventsManager: llm_response_reasoning=(llm_response.reasoning if llm_response else None), llm_response_model=(llm_response.model if llm_response else None), llm_response_tool_call=(llm_response.tool_calls if llm_response else None), - is_group_message=(not (not chat_stream.group_info)), - is_private_message=(not chat_stream.group_info), + is_group_message=session.is_group_session, + is_private_message=not session.is_group_session, action_usage=action_usage, additional_data={"response_is_processed": True}, ) @@ -373,7 +392,7 @@ class EventsManager: def _prepare_message( self, event_type: EventType | str, - message: Optional[MessageRecv | MessageSending | MaiMessages] = None, + message: Optional[SessionMessage | MessageSending | MaiMessages] = None, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None, stream_id: Optional[str] = None, diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 4fa083e4..0cdbb472 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest from src.llm_models.payload_content import ToolCall from src.config.config import global_config, model_config from src.prompt.prompt_manager import prompt_manager -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.logger import get_logger logger = get_logger("tool_use") @@ -28,8 +28,8 @@ class ToolExecutor: cache_ttl: 缓存生存时间(周期数) """ self.chat_id = chat_id - self.chat_stream = get_chat_manager().get_stream(self.chat_id) - self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" + self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id) + self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]" self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index 622ec488..3e1fc187 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -11,7 +11,7 @@ from sqlmodel import col, select, delete from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import Expression -from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui.expression") @@ -118,14 +118,11 @@ def expression_to_response(expression: Expression) -> ExpressionResponse: def get_chat_name(chat_id: str) -> str: """根据 chat_id 获取聊天名称""" try: - chat_stream = get_chat_manager().get_stream(chat_id) - if not chat_stream: + session = _chat_manager.get_session_by_session_id(chat_id) + if not session: return chat_id - if chat_stream.group_info and chat_stream.group_info.group_name: - return chat_stream.group_info.group_name - if chat_stream.user_info and chat_stream.user_info.user_nickname: - return chat_stream.user_info.user_nickname - return chat_id + name = _chat_manager.get_session_name(chat_id) + return name or chat_id except Exception: return chat_id @@ -134,15 +131,9 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]: """批量获取聊天名称""" result = {cid: cid for cid in chat_ids} # 默认值为原始ID try: - chat_manager = get_chat_manager() for chat_id in chat_ids: - chat_stream = chat_manager.get_stream(chat_id) - if not chat_stream: - continue - if chat_stream.group_info and chat_stream.group_info.group_name: - result[chat_id] = chat_stream.group_info.group_name - elif chat_stream.user_info and chat_stream.user_info.user_nickname: - result[chat_id] = chat_stream.user_info.user_nickname + if name := _chat_manager.get_session_name(chat_id): + result[chat_id] = name except Exception as e: logger.warning(f"批量获取聊天名称失败: {e}") return result @@ -179,17 +170,14 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat verify_auth_token(maibot_session, authorization) chat_list = [] - for stream_id, stream in get_chat_manager().streams.items(): - chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None - if not chat_name and stream.user_info and stream.user_info.user_nickname: - chat_name = stream.user_info.user_nickname - chat_name = chat_name or stream_id + for session_id, session in _chat_manager.sessions.items(): + chat_name = _chat_manager.get_session_name(session_id) or session_id chat_list.append( ChatInfo( - chat_id=stream_id, + chat_id=session_id, chat_name=chat_name, - platform=stream.platform, - is_group=bool(stream.group_info and stream.group_info.group_id), + platform=session.platform, + is_group=session.is_group_session, ) ) @@ -495,11 +483,10 @@ async def batch_delete_expressions( # 查找所有要删除的表达方式 with get_db_session() as session: statements = select(Expression.id).where(col(Expression.id).in_(request.ids)) - found_ids = [expr_id for expr_id in session.exec(statements).all()] + found_ids = list(session.exec(statements).all()) # 检查是否有未找到的ID - not_found_ids = set(request.ids) - set(found_ids) - if not_found_ids: + if not_found_ids := set(request.ids) - set(found_ids): logger.warning(f"部分表达方式未找到: {not_found_ids}") # 执行批量删除 @@ -800,7 +787,7 @@ async def batch_review_expressions( session.add(db_expression) results.append( - BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝") + BatchReviewResultItem(id=item.id, success=True, message="拒绝" if item.rejected else "通过") ) succeeded += 1