From 898fab6de9c73e770d515a46dfe96ee9c7b859a2 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Fri, 13 Mar 2026 23:36:17 +0800 Subject: [PATCH] =?UTF-8?q?=E9=83=A8=E5=88=86=E6=A8=A1=E5=9D=97=E7=9A=84?= =?UTF-8?q?=E6=96=B0=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/brain_chat/brain_chat.py | 11 +- src/chat/replyer/group_generator.py | 105 +++--- src/chat/replyer/private_generator.py | 103 ++---- src/common/message_repository.py | 79 ++--- src/services/generator_service.py | 28 +- src/services/message_service.py | 461 +++++++++++++++++++++----- src/services/send_service.py | 192 ++++------- 7 files changed, 580 insertions(+), 399 deletions(-) diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 9c6e5847..b4a43b23 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -8,7 +8,7 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger from src.common.data_models.info_data_model import ActionPlannerInfo -from src.common.data_models.message_data_model import ReplyContentType +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.timer_calculator import Timer @@ -35,7 +35,6 @@ from src.chat.utils.chat_message_builder import ( if TYPE_CHECKING: from src.common.data_models.database_data_model import DatabaseMessages - from src.common.data_models.message_data_model import ReplySetModel ERROR_LOOP_INFO = { @@ -513,7 +512,7 @@ class BrainChatting: async def _send_response( self, - reply_set: "ReplySetModel", + reply_set: MessageSequence, message_data: "DatabaseMessages", selected_expressions: Optional[List[int]] = None, ) -> str: @@ -528,10 +527,10 @@ class BrainChatting: reply_text = "" first_replied = False - for reply_content in reply_set.reply_data: - if reply_content.content_type != ReplyContentType.TEXT: + for component in reply_set.components: + if not isinstance(component, TextComponent): continue - data: str = reply_content.content # type: ignore + data = component.text if not first_replied: await send_api.text_to_stream( text=data, diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 441e8bb8..4bfc78fc 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -7,29 +7,31 @@ import re from typing import List, Optional, Dict, Any, Tuple from datetime import datetime from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from src.chat.message_receive.message_old import UserInfo, Seg, MessageRecv, MessageSending -from src.chat.message_receive.chat_stream import ChatStream +from maim_message import BaseMessageInfo, MessageBase, Seg + +from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo +from src.chat.message_receive.message import SessionMessage +from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager -from src.chat.utils.chat_message_builder import ( +from src.services.message_service import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, replace_user_references, + translate_pid_to_description, ) from src.bw_learner.expression_selector import expression_selector -from src.plugin_system.apis.message_api import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person -from src.plugin_system.base.component_types import ActionInfo, EventType -from src.plugin_system.apis import llm_api +from src.core.types import ActionInfo, EventType +from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt @@ -45,17 +47,17 @@ logger = get_logger("replyer") class DefaultReplyer: def __init__( self, - chat_stream: ChatStream, + chat_stream: BotChatSession, request_type: str = "replyer", ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.chat_stream = chat_stream - self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) + self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id) self.heart_fc_sender = UniversalMessageSender() - from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + from src.chat.tool_executor import ToolExecutor - self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3) async def generate_reply_with_context( self, @@ -66,7 +68,7 @@ class DefaultReplyer: enable_tool: bool = True, from_plugin: bool = True, stream_id: Optional[str] = None, - reply_message: Optional[DatabaseMessages] = None, + reply_message: Optional[SessionMessage] = None, reply_time_point: float = time.time(), think_level: int = 1, unknown_words: Optional[List[str]] = None, @@ -132,7 +134,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt="", output=None, processed_output=None, @@ -146,12 +148,12 @@ class DefaultReplyer: except Exception: logger.exception("记录reply日志失败") return False, llm_response - from src.plugin_system.core.events_manager import events_manager + from src.core.event_bus import event_bus + from src.chat.event_helpers import build_event_message if not from_plugin: - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.POST_LLM, None, prompt, None, stream_id=stream_id - ) + _event_msg = build_event_message(EventType.POST_LLM, llm_prompt=prompt, stream_id=stream_id) + continue_flag, modified_message = await event_bus.emit(EventType.POST_LLM, _event_msg) if not continue_flag: raise UserWarning("插件于请求前中断了内容生成") if modified_message and modified_message._modify_flags.modify_llm_prompt: @@ -202,7 +204,7 @@ class DefaultReplyer: try: if log_reply: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt=prompt, output=content, processed_output=None, @@ -214,9 +216,10 @@ class DefaultReplyer: ) except Exception: logger.exception("记录reply日志失败") - continue_flag, modified_message = await events_manager.handle_mai_events( - EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id + _event_msg = build_event_message( + EventType.AFTER_LLM, llm_prompt=prompt, llm_response=llm_response, stream_id=stream_id ) + continue_flag, modified_message = await event_bus.emit(EventType.AFTER_LLM, _event_msg) if not from_plugin and not continue_flag: raise UserWarning("插件于请求后取消了内容生成") if modified_message: @@ -259,7 +262,7 @@ class DefaultReplyer: if log_reply: try: PlanReplyLogger.log_reply( - chat_id=self.chat_stream.stream_id, + chat_id=self.chat_stream.session_id, prompt=prompt or "", output=None, processed_output=None, @@ -353,14 +356,14 @@ class DefaultReplyer: str: 表达习惯信息字符串 """ # 检查是否允许在此聊天流中使用表达 - use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) + use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id) if not use_expression: return "", [] style_habits = [] # 使用从处理器传来的选中表达方式 # 使用模型预测选择表达方式 selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( - self.chat_stream.stream_id, + self.chat_stream.session_id, chat_history, max_num=8, target_message=target, @@ -594,7 +597,7 @@ class DefaultReplyer: async def _build_jargon_explanation( self, chat_id: str, - messages_short: List[DatabaseMessages], + messages_short: List[SessionMessage], chat_talking_prompt_short: str, unknown_words: Optional[List[str]], ) -> str: @@ -703,9 +706,13 @@ class DefaultReplyer: is_group = stream_type == "group" # 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑 - from src.chat.message_receive.chat_stream import get_chat_manager + from src.common.utils.utils_session import SessionUtils - chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) + chat_id = SessionUtils.calculate_session_id( + platform, + group_id=str(id_str) if is_group else None, + user_id=str(id_str) if not is_group else None, + ) return chat_id, prompt_content except (ValueError, IndexError): @@ -751,7 +758,7 @@ class DefaultReplyer: async def build_prompt_reply_context( self, - reply_message: Optional[DatabaseMessages] = None, + reply_message: Optional[SessionMessage] = None, extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, @@ -778,7 +785,7 @@ class DefaultReplyer: if available_actions is None: available_actions = {} chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id _is_group_chat = bool(chat_stream.group_info) platform = chat_stream.platform @@ -1005,7 +1012,7 @@ class DefaultReplyer: reply_to: str, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream - chat_id = chat_stream.stream_id + chat_id = chat_stream.session_id sender, target = self._parse_reply_target(reply_to) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) @@ -1105,31 +1112,29 @@ class DefaultReplyer: is_emoji: bool, thinking_start_time: float, display_message: str, - anchor_message: Optional[MessageRecv] = None, - ) -> MessageSending: + anchor_message: Optional[MaiMessage] = None, + ) -> SessionMessage: """构建单个发送消息""" - bot_user_info = UserInfo( - user_id=str(global_config.bot.qq_account), - user_nickname=global_config.bot.nickname, - platform=self.chat_stream.platform, - ) - - # await anchor_message.process() - sender_info = anchor_message.message_info.user_info if anchor_message else None - - return MessageSending( - message_id=message_id, # 使用片段的唯一ID - chat_stream=self.chat_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, + maim_message = MessageBase( + message_info=BaseMessageInfo( + platform=self.chat_stream.platform, + message_id=message_id, + time=thinking_start_time, + user_info=UserInfo( + user_id=str(global_config.bot.qq_account), + user_nickname=global_config.bot.nickname, + ), + additional_config={}, + ), message_segment=message_segment, - reply=anchor_message, # 回复原始锚点 - is_head=reply_to, - is_emoji=is_emoji, - thinking_start_time=thinking_start_time, # 传递原始思考开始时间 - display_message=display_message, ) + message = SessionMessage.from_maim_message(maim_message) + message.session_id = self.chat_stream.session_id + message.display_message = display_message + message.reply_to = anchor_message.message_id if reply_to and anchor_message else None + message.is_emoji = is_emoji + return message async def llm_generate_content(self, prompt: str): with Timer("LLM生成", {}): # 内部计时器,可选保留 diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index 39e66e91..13d73018 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -7,33 +7,31 @@ import re from typing import List, Optional, Dict, Any, Tuple from datetime import datetime from src.common.logger import get_logger -from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from maim_message import Seg +from maim_message import BaseMessageInfo, MessageBase, Seg from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo -from src.chat.message_receive.message_old import MessageSending +from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.chat_manager import BotChatSession from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.prompt.prompt_manager import prompt_manager from src.chat.utils.common_utils import TempMethodsExpression -from src.chat.utils.chat_message_builder import ( +from src.services.message_service import ( build_readable_messages, get_raw_msg_before_timestamp_with_chat, replace_user_references, + translate_pid_to_description, ) from src.bw_learner.expression_selector import expression_selector -from src.services.message_service import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator from src.person_info.person_info import Person, is_person_known from src.core.types import ActionInfo, EventType -from src.services import llm_service as llm_api from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt from src.bw_learner.jargon_explainer_old import explain_jargon_in_context @@ -69,7 +67,7 @@ class PrivateReplyer: from_plugin: bool = True, think_level: int = 1, stream_id: Optional[str] = None, - reply_message: Optional[DatabaseMessages] = None, + reply_message: Optional[SessionMessage] = None, reply_time_point: Optional[float] = time.time(), unknown_words: Optional[List[str]] = None, log_reply: bool = True, @@ -604,7 +602,7 @@ class PrivateReplyer: async def build_prompt_reply_context( self, - reply_message: Optional[DatabaseMessages] = None, + reply_message: Optional[SessionMessage] = None, extra_info: str = "", reply_reason: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, @@ -954,28 +952,29 @@ class PrivateReplyer: thinking_start_time: float, display_message: str, anchor_message: Optional[MaiMessage] = None, - ) -> MessageSending: + ) -> SessionMessage: """构建单个发送消息""" - bot_user_info = UserInfo( - user_id=str(global_config.bot.qq_account), - user_nickname=global_config.bot.nickname, - ) - - sender_info = anchor_message.message_info.user_info if anchor_message else None - - return MessageSending( - message_id=message_id, - session=self.chat_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, + maim_message = MessageBase( + message_info=BaseMessageInfo( + platform=self.chat_stream.platform, + message_id=message_id, + time=thinking_start_time, + user_info=UserInfo( + user_id=str(global_config.bot.qq_account), + user_nickname=global_config.bot.nickname, + ), + group_info=None, + additional_config={}, + ), message_segment=message_segment, - reply=anchor_message, - is_head=reply_to, - is_emoji=is_emoji, - thinking_start_time=thinking_start_time, - display_message=display_message, ) + message = SessionMessage.from_maim_message(maim_message) + message.session_id = self.chat_stream.session_id + message.display_message = display_message + message.reply_to = anchor_message.message_id if reply_to and anchor_message else None + message.is_emoji = is_emoji + return message async def llm_generate_content(self, prompt: str): with Timer("LLM生成", {}): # 内部计时器,可选保留 @@ -999,55 +998,9 @@ class PrivateReplyer: return content, reasoning_content, model_name, tool_calls async def get_prompt_info(self, message: str, sender: str, target: str): - related_info = "" - start_time = time.time() - from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool - - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 从LPMM知识库获取知识 - try: - # 检查LPMM知识库是否启用 - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用,跳过获取知识库内容") - return "" - - if global_config.lpmm_knowledge.lpmm_mode == "agent": - return "" - - prompt_template = prompt_manager.get_prompt("lpmm_get_knowledge") - prompt_template.add_context("bot_name", global_config.bot.nickname) - prompt_template.add_context("time_now", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) - prompt_template.add_context("chat_history", message) - prompt_template.add_context("sender", sender) - prompt_template.add_context("target_message", target) - prompt = await prompt_manager.render_prompt(prompt_template) - - _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( - prompt, - model_config=model_config.model_task_config.tool_use, - tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], - ) - if tool_calls: - result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) - end_time = time.time() - if not result or not result.get("content"): - logger.debug("从LPMM知识库获取知识失败,返回空知识...") - return "" - found_knowledge_from_lpmm = result.get("content", "") - logger.debug( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - related_info += found_knowledge_from_lpmm - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" - else: - logger.debug("模型认为不需要使用LPMM知识库") - return "" - except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") - return "" + logger.debug(f"已跳过知识库信息获取,元消息:{message[:30]}...,消息长度: {len(message)}") + del message, sender, target + return "" def weighted_sample_no_replacement(items, weights, k) -> list: diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 8ade577a..7215ffa3 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,5 +1,6 @@ import traceback from datetime import datetime +from types import SimpleNamespace from typing import Any import json @@ -9,7 +10,7 @@ from sqlmodel import col, select from src.common.database.database import get_db_session from src.common.database.database_model import Messages -from src.common.data_models.database_data_model import DatabaseMessages +from src.chat.message_receive.message import SessionMessage from src.common.logger import get_logger from src.config.config import global_config @@ -58,53 +59,37 @@ def _normalize_optional_str(value: object) -> str | None: return str(value) -def _message_to_instance(message: Messages) -> DatabaseMessages: +def _message_to_instance(message: Messages) -> SessionMessage: config = _parse_additional_config(message) - timestamp_value = message.timestamp - if isinstance(timestamp_value, datetime): - time_value = timestamp_value.timestamp() - else: - time_value = float(timestamp_value) - selected_expressions = _normalize_optional_str(config.get("selected_expressions")) - priority_info = _normalize_optional_str(config.get("priority_info")) - return DatabaseMessages( - message_id=message.message_id, - time=time_value, - chat_id=message.session_id, - reply_to=message.reply_to, - interest_value=config.get("interest_value"), - key_words=_normalize_optional_str(config.get("key_words")), - key_words_lite=_normalize_optional_str(config.get("key_words_lite")), - is_mentioned=message.is_mentioned, - is_at=message.is_at, - reply_probability_boost=config.get("reply_probability_boost"), - processed_plain_text=message.processed_plain_text, - display_message=message.display_message, - priority_mode=_normalize_optional_str(config.get("priority_mode")), - priority_info=priority_info, - additional_config=message.additional_config, - is_emoji=message.is_emoji, - is_picid=message.is_picture, - is_command=message.is_command, - intercept_message_level=config.get("intercept_message_level", 0), - is_notify=message.is_notify, - selected_expressions=selected_expressions, - user_id=message.user_id, - user_nickname=message.user_nickname, - user_cardname=message.user_cardname, - user_platform=message.platform, - chat_info_group_id=message.group_id, - chat_info_group_name=message.group_name, - chat_info_group_platform=message.platform, - chat_info_user_id=message.user_id, - chat_info_user_nickname=message.user_nickname, - chat_info_user_cardname=message.user_cardname, - chat_info_user_platform=message.platform, - chat_info_stream_id=message.session_id, - chat_info_platform=message.platform, - chat_info_create_time=0.0, - chat_info_last_active_time=0.0, + instance = SessionMessage.from_db_instance(message) + instance.interest_value = config.get("interest_value") + instance.key_words = _normalize_optional_str(config.get("key_words")) + instance.key_words_lite = _normalize_optional_str(config.get("key_words_lite")) + instance.reply_probability_boost = config.get("reply_probability_boost") + instance.priority_mode = _normalize_optional_str(config.get("priority_mode")) + instance.priority_info = _normalize_optional_str(config.get("priority_info")) + instance.intercept_message_level = config.get("intercept_message_level", 0) + instance.selected_expressions = _normalize_optional_str(config.get("selected_expressions")) + group_info = instance.message_info.group_info + legacy_group_info = None + if group_info: + legacy_group_info = SimpleNamespace( + group_id=group_info.group_id, + group_name=group_info.group_name, + ) + instance.user_info = SimpleNamespace( + user_id=instance.message_info.user_info.user_id, + user_nickname=instance.message_info.user_info.user_nickname, + user_cardname=instance.message_info.user_info.user_cardname, + platform=instance.platform, ) + instance.chat_info = SimpleNamespace( + platform=instance.platform, + stream_id=instance.session_id, + group_info=legacy_group_info, + ) + instance.time = instance.timestamp.timestamp() + return instance def _coerce_datetime(value: Any) -> Any: @@ -147,7 +132,7 @@ def find_messages( filter_bot: bool = False, filter_command: bool = False, filter_intercept_message_level: int | None = None, -) -> list[DatabaseMessages]: +) -> list[SessionMessage]: """ 根据提供的过滤器、排序和限制条件查找消息。 diff --git a/src/services/generator_service.py b/src/services/generator_service.py index 7a84b911..3587a244 100644 --- a/src/services/generator_service.py +++ b/src/services/generator_service.py @@ -16,14 +16,14 @@ from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.private_generator import PrivateReplyer from src.chat.replyer.replyer_manager import replyer_manager from src.chat.utils.utils import process_llm_response -from src.common.data_models.message_data_model import ReplySetModel +from src.common.data_models.message_component_data_model import MessageSequence, TextComponent from src.common.logger import get_logger from src.core.types import ActionInfo if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.llm_data_model import LLMGenerationDataModel + from src.chat.message_receive.message import SessionMessage install(extra_lines=3) @@ -67,7 +67,7 @@ async def generate_reply( chat_stream: Optional[BotChatSession] = None, chat_id: Optional[str] = None, action_data: Optional[Dict[str, Any]] = None, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, think_level: int = 1, extra_info: str = "", reply_reason: str = "", @@ -126,15 +126,17 @@ async def generate_reply( if not success: logger.warning("[GeneratorService] 回复生成失败") return False, None - reply_set: Optional[ReplySetModel] = None + reply_set: Optional[MessageSequence] = None if content := llm_response.content: processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) llm_response.processed_output = processed_response - reply_set = ReplySetModel() + reply_set = MessageSequence(components=[]) for text in processed_response: - reply_set.add_text_content(text) + reply_set.components.append(TextComponent(text)) llm_response.reply_set = reply_set - logger.debug(f"[GeneratorService] 回复生成成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") + logger.debug( + f"[GeneratorService] 回复生成成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项" + ) try: PlanReplyLogger.log_reply( @@ -196,12 +198,14 @@ async def rewrite_reply( reason=reason, reply_to=reply_to, ) - reply_set: Optional[ReplySetModel] = None + reply_set: Optional[MessageSequence] = None if success and llm_response and (content := llm_response.content): reply_set = process_human_text(content, enable_splitter, enable_chinese_typo) llm_response.reply_set = reply_set if success: - logger.info(f"[GeneratorService] 重写回复成功,生成了 {len(reply_set) if reply_set else 0} 个回复项") + logger.info( + f"[GeneratorService] 重写回复成功,生成了 {len(reply_set.components) if reply_set else 0} 个回复项" + ) else: logger.warning("[GeneratorService] 重写回复失败") @@ -215,16 +219,16 @@ async def rewrite_reply( return False, None -def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[ReplySetModel]: +def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> Optional[MessageSequence]: """将文本处理为更拟人化的文本""" if not isinstance(content, str): raise ValueError("content 必须是字符串类型") try: - reply_set = ReplySetModel() + reply_set = MessageSequence(components=[]) processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo) for text in processed_response: - reply_set.add_text_content(text) + reply_set.components.append(TextComponent(text)) return reply_set diff --git a/src/services/message_service.py b/src/services/message_service.py index 7b175dfe..3a19431a 100644 --- a/src/services/message_service.py +++ b/src/services/message_service.py @@ -1,34 +1,21 @@ -""" -消息服务模块 - -提供消息查询和构建成字符串的核心功能。 -""" +"""消息服务模块。""" +import re import time -from typing import Any, Dict, List, Optional, Tuple +from datetime import datetime +from typing import Any, List, Optional, Tuple from sqlmodel import col, select -from src.chat.utils.chat_message_builder import ( - build_readable_messages, - build_readable_messages_with_list, - get_person_id_list, - get_raw_msg_before_timestamp, - get_raw_msg_before_timestamp_with_chat, - get_raw_msg_before_timestamp_with_users, - get_raw_msg_by_timestamp, - get_raw_msg_by_timestamp_random, - get_raw_msg_by_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_by_timestamp_with_chat_users, - get_raw_msg_by_timestamp_with_users, - num_new_messages_since, - num_new_messages_since_with_users, -) -from src.chat.utils.utils import is_bot_self -from src.common.data_models.database_data_model import DatabaseMessages +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.action_record_data_model import MaiActionRecord from src.common.database.database import get_db_session -from src.common.database.database_model import Images, ImageType +from src.common.database.database_model import ActionRecord, Images, ImageType +from src.common.message_repository import count_messages, find_messages +from src.common.utils.math_utils import translate_timestamp_to_human_readable +from src.common.utils.utils_action import ActionUtils +from src.chat.utils.utils import is_bot_self +from src.config.config import global_config # ============================================================================= @@ -36,16 +23,62 @@ from src.common.database.database_model import Images, ImageType # ============================================================================= +def _build_time_range_filter(start_time: float, end_time: float) -> dict[str, Any]: + return { + "time": { + "$gte": start_time, + "$lte": end_time, + } + } + + +def _build_readable_line( + message: SessionMessage, + *, + replace_bot_name: bool, + timestamp_mode: Optional[str], + show_message_id_prefix: bool, +) -> str: + plain_text = (message.processed_plain_text or "").strip() + if replace_bot_name and global_config.bot.nickname: + plain_text = plain_text.replace(global_config.bot.nickname, "你") + user_name = ( + message.message_info.user_info.user_cardname + or message.message_info.user_info.user_nickname + or message.message_info.user_info.user_id + ) + prefix: List[str] = [] + if timestamp_mode: + prefix.append(f"[{translate_timestamp_to_human_readable(message.timestamp.timestamp(), mode=timestamp_mode)}]") + if show_message_id_prefix: + prefix.append(f"[消息ID: {message.message_id}]") + prefix.append(f"{user_name}说:") + return " ".join(prefix) + plain_text + + +def _normalize_messages(messages: List[SessionMessage]) -> List[SessionMessage]: + normalized: List[SessionMessage] = [] + for message in messages: + if not message.processed_plain_text: + message.processed_plain_text = message.display_message or "" + normalized.append(message) + return normalized + + def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) - return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) + messages = find_messages( + message_filter=_build_time_range_filter(start_time, end_time), + limit=limit, + limit_mode=limit_mode, + filter_bot=filter_mai, + ) + return _normalize_messages(messages) def get_messages_by_time_in_chat( @@ -57,7 +90,7 @@ def get_messages_by_time_in_chat( filter_mai: bool = False, filter_command: bool = False, filter_intercept_message_level: Optional[int] = None, -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -66,16 +99,18 @@ def get_messages_by_time_in_chat( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return get_raw_msg_by_timestamp_with_chat( - chat_id=chat_id, - timestamp_start=start_time, - timestamp_end=end_time, + messages = find_messages( + message_filter={ + "chat_id": chat_id, + **_build_time_range_filter(start_time, end_time), + }, limit=limit, limit_mode=limit_mode, filter_bot=filter_mai, filter_command=filter_command, filter_intercept_message_level=filter_intercept_message_level, ) + return _normalize_messages(messages) def get_messages_by_time_in_chat_inclusive( @@ -87,7 +122,7 @@ def get_messages_by_time_in_chat_inclusive( filter_mai: bool = False, filter_command: bool = False, filter_intercept_message_level: Optional[int] = None, -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -96,19 +131,21 @@ def get_messages_by_time_in_chat_inclusive( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=chat_id, - timestamp_start=start_time, - timestamp_end=end_time, + messages = find_messages( + message_filter={ + "chat_id": chat_id, + "time": { + "$gte": start_time, + "$lte": end_time, + }, + }, limit=limit, limit_mode=limit_mode, filter_bot=filter_mai, filter_command=filter_command, filter_intercept_message_level=filter_intercept_message_level, ) - if filter_mai: - return filter_mai_messages(messages) - return messages + return _normalize_messages(messages) def get_messages_by_time_in_chat_for_users( @@ -118,7 +155,7 @@ def get_messages_by_time_in_chat_for_users( person_ids: List[str], limit: int = 0, limit_mode: str = "latest", -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: @@ -127,39 +164,64 @@ def get_messages_by_time_in_chat_for_users( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) + messages = find_messages( + message_filter={ + "chat_id": chat_id, + "time": { + "$gte": start_time, + "$lte": end_time, + }, + "user_id": {"$in": person_ids}, + }, + limit=limit, + limit_mode=limit_mode, + ) + return _normalize_messages(messages) def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)) - return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode) + return get_messages_by_time(start_time, end_time, limit, limit_mode, filter_mai) def get_messages_by_time_for_users( start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)): raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) + messages = find_messages( + message_filter={ + "time": { + "$gte": start_time, + "$lte": end_time, + }, + "user_id": {"$in": person_ids}, + }, + limit=limit, + limit_mode=limit_mode, + ) + return _normalize_messages(messages) -def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[DatabaseMessages]: +def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[SessionMessage]: if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - if filter_mai: - return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit)) - return get_raw_msg_before_timestamp(timestamp, limit) + messages = find_messages( + message_filter={"time": {"$lt": timestamp}}, + limit=limit, + limit_mode="latest", + filter_bot=filter_mai, + ) + return _normalize_messages(messages) def get_messages_before_time_in_chat( @@ -168,7 +230,7 @@ def get_messages_before_time_in_chat( limit: int = 0, filter_mai: bool = False, filter_intercept_message_level: Optional[int] = None, -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: @@ -177,30 +239,40 @@ def get_messages_before_time_in_chat( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - messages = get_raw_msg_before_timestamp_with_chat( - chat_id=chat_id, - timestamp=timestamp, + messages = find_messages( + message_filter={ + "chat_id": chat_id, + "time": {"$lt": timestamp}, + }, limit=limit, + limit_mode="latest", + filter_bot=filter_mai, filter_intercept_message_level=filter_intercept_message_level, ) - if filter_mai: - return filter_mai_messages(messages) - return messages + return _normalize_messages(messages) def get_messages_before_time_for_users( timestamp: float, person_ids: List[str], limit: int = 0 -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(timestamp, (int, float)): raise ValueError("timestamp 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) + messages = find_messages( + message_filter={ + "time": {"$lt": timestamp}, + "user_id": {"$in": person_ids}, + }, + limit=limit, + limit_mode="latest", + ) + return _normalize_messages(messages) def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False -) -> List[DatabaseMessages]: +) -> List[SessionMessage]: if not isinstance(hours, (int, float)) or hours < 0: raise ValueError("hours 不能是负数") if not isinstance(limit, int) or limit < 0: @@ -211,9 +283,7 @@ def get_recent_messages( raise ValueError("chat_id 必须是字符串类型") now = time.time() start_time = now - hours * 3600 - if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)) - return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode) + return get_messages_by_time_in_chat(chat_id, start_time, now, limit, limit_mode, filter_mai) # ============================================================================= @@ -228,7 +298,13 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return num_new_messages_since(chat_id, start_time, end_time) + message_filter: dict[str, Any] = { + "chat_id": chat_id, + "time": {"$gt": start_time}, + } + if end_time is not None: + message_filter["time"]["$lte"] = end_time + return count_messages(message_filter) def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: @@ -238,7 +314,13 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) + return count_messages( + { + "chat_id": chat_id, + "time": {"$gt": start_time, "$lte": end_time}, + "user_id": {"$in": person_ids}, + } + ) # ============================================================================= @@ -246,8 +328,45 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa # ============================================================================= +def build_readable_messages( + messages: List[SessionMessage], + replace_bot_name: bool = True, + timestamp_mode: str = "relative", + read_mark: float = 0.0, + truncate: bool = False, + show_actions: bool = False, +) -> str: + normalized_messages = _normalize_messages(messages) + lines: List[str] = [] + unread_mark_added = False + for message in normalized_messages: + if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark: + lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---") + unread_mark_added = True + line = _build_readable_line( + message, + replace_bot_name=replace_bot_name, + timestamp_mode=timestamp_mode, + show_message_id_prefix=False, + ) + if truncate and len(line) > 200: + line = f"{line[:200]}......(内容太长了)" + lines.append(line) + if show_actions and normalized_messages: + action_lines = build_readable_actions( + get_actions_by_timestamp_with_chat( + normalized_messages[0].session_id, + normalized_messages[0].timestamp.timestamp(), + normalized_messages[-1].timestamp.timestamp(), + ) + ) + if action_lines: + lines.append(action_lines) + return "\n".join(lines) + + def build_readable_messages_to_str( - messages: List[DatabaseMessages], + messages: List[SessionMessage], replace_bot_name: bool = True, timestamp_mode: str = "relative", read_mark: float = 0.0, @@ -257,17 +376,71 @@ def build_readable_messages_to_str( return build_readable_messages(messages, replace_bot_name, timestamp_mode, read_mark, truncate, show_actions) +def build_readable_messages_with_id( + messages: List[SessionMessage], + replace_bot_name: bool = True, + timestamp_mode: str = "relative", + read_mark: float = 0.0, + truncate: bool = False, + show_actions: bool = False, +) -> Tuple[str, List[Tuple[str, SessionMessage]]]: + normalized_messages = _normalize_messages(messages) + lines: List[str] = [] + message_id_list: List[Tuple[str, SessionMessage]] = [] + unread_mark_added = False + for message in normalized_messages: + if read_mark and not unread_mark_added and message.timestamp.timestamp() > read_mark: + lines.append("--- 以上消息是你已经看过,请关注以下未读的新消息 ---") + unread_mark_added = True + line = _build_readable_line( + message, + replace_bot_name=replace_bot_name, + timestamp_mode=timestamp_mode, + show_message_id_prefix=True, + ) + if truncate and len(line) > 200: + line = f"{line[:200]}......(内容太长了)" + lines.append(line) + message_id_list.append((message.message_id, message)) + if show_actions and normalized_messages: + action_lines = build_readable_actions( + get_actions_by_timestamp_with_chat( + normalized_messages[0].session_id, + normalized_messages[0].timestamp.timestamp(), + normalized_messages[-1].timestamp.timestamp(), + ) + ) + if action_lines: + lines.append(action_lines) + return "\n".join(lines), message_id_list + + async def build_readable_messages_with_details( - messages: List[DatabaseMessages], + messages: List[SessionMessage], replace_bot_name: bool = True, timestamp_mode: str = "relative", truncate: bool = False, ) -> Tuple[str, List[Tuple[float, str, str]]]: - return await build_readable_messages_with_list(messages, replace_bot_name, timestamp_mode, truncate) + normalized_messages = _normalize_messages(messages) + message_list = [ + ( + message.timestamp.timestamp(), + message.message_info.user_info.user_id, + message.processed_plain_text or "", + ) + for message in normalized_messages + ] + return build_readable_messages(normalized_messages, replace_bot_name, timestamp_mode, truncate=truncate), message_list -async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]: - return await get_person_id_list(messages) +async def get_person_ids_from_messages(messages: List[Any]) -> List[str]: + person_ids: List[str] = [] + for message in messages: + if isinstance(message, SessionMessage): + person_ids.append(message.message_info.user_info.user_id) + elif isinstance(message, dict) and (user_id := message.get("user_id")): + person_ids.append(str(user_id)) + return person_ids # ============================================================================= @@ -275,9 +448,145 @@ async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[s # ============================================================================= -def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessages]: +def filter_mai_messages(messages: List[SessionMessage]) -> List[SessionMessage]: """从消息列表中移除麦麦的消息""" - return [msg for msg in messages if not is_bot_self(msg.user_info.platform, msg.user_info.user_id)] + return [ + msg + for msg in messages + if not is_bot_self(msg.platform, msg.message_info.user_info.user_id) + ] + + +def get_raw_msg_by_timestamp( + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", +) -> List[SessionMessage]: + return get_messages_by_time(timestamp_start, timestamp_end, limit, limit_mode) + + +def get_raw_msg_by_timestamp_with_chat( + chat_id: str, + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", + filter_bot: bool = False, + filter_command: bool = False, + filter_intercept_message_level: Optional[int] = None, +) -> List[SessionMessage]: + return get_messages_by_time_in_chat( + chat_id, + timestamp_start, + timestamp_end, + limit, + limit_mode, + filter_bot, + filter_command, + filter_intercept_message_level, + ) + + +def get_raw_msg_by_timestamp_with_chat_inclusive( + chat_id: str, + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", + filter_bot: bool = False, + filter_command: bool = False, + filter_intercept_message_level: Optional[int] = None, +) -> List[SessionMessage]: + return get_messages_by_time_in_chat_inclusive( + chat_id, + timestamp_start, + timestamp_end, + limit, + limit_mode, + filter_bot, + filter_command, + filter_intercept_message_level, + ) + + +def get_raw_msg_by_timestamp_with_chat_users( + chat_id: str, + timestamp_start: float, + timestamp_end: float, + person_ids: List[str], + limit: int = 0, + limit_mode: str = "latest", +) -> List[SessionMessage]: + return get_messages_by_time_in_chat_for_users(chat_id, timestamp_start, timestamp_end, person_ids, limit, limit_mode) + + +def get_raw_msg_by_timestamp_with_users( + timestamp_start: float, + timestamp_end: float, + person_ids: List[str], + limit: int = 0, + limit_mode: str = "latest", +) -> List[SessionMessage]: + return get_messages_by_time_for_users(timestamp_start, timestamp_end, person_ids, limit, limit_mode) + + +def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[SessionMessage]: + return get_messages_before_time(timestamp, limit) + + +def get_raw_msg_before_timestamp_with_chat( + chat_id: str, + timestamp: float, + limit: int = 0, + filter_intercept_message_level: Optional[int] = None, +) -> List[SessionMessage]: + return get_messages_before_time_in_chat(chat_id, timestamp, limit, False, filter_intercept_message_level) + + +def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[SessionMessage]: + return get_messages_before_time_for_users(timestamp, person_ids, limit) + + +def get_raw_msg_by_timestamp_random( + timestamp_start: float, + timestamp_end: float, + limit: int = 0, + limit_mode: str = "latest", +) -> List[SessionMessage]: + return get_random_chat_messages(timestamp_start, timestamp_end, limit, limit_mode) + + +def get_actions_by_timestamp_with_chat(chat_id: str, timestamp_start: float, timestamp_end: float) -> List[MaiActionRecord]: + with get_db_session() as session: + statement = ( + select(ActionRecord) + .where(col(ActionRecord.session_id) == chat_id) + .where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start)) + .where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end)) + .order_by(col(ActionRecord.timestamp)) + ) + return [MaiActionRecord.from_db_instance(item) for item in session.exec(statement).all()] + + +def build_readable_actions(actions: List[MaiActionRecord], timestamp_mode: str = "relative") -> str: + return ActionUtils.build_readable_action_records(actions, timestamp_mode) + + +def replace_user_references(text: str, platform: str, replace_bot_name: bool = False) -> str: + del platform + if not text: + return text + + def _replace(match: re.Match[str]) -> str: + prefix = match.group(1) or "" + user_name = match.group(2) + if replace_bot_name and user_name == global_config.bot.nickname: + user_name = "你" + return f"{prefix}{user_name}" + + text = re.sub(r"(回复|@)?<([^:<>]+):[^<>]+>", _replace, text) + return text def translate_pid_to_description(pid: str) -> str: diff --git a/src/services/send_service.py b/src/services/send_service.py index 29a79ced..76521283 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -6,21 +6,20 @@ import traceback import time -from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple +from typing import Optional, Union, Dict, List, TYPE_CHECKING -from maim_message import MessageBase, BaseMessageInfo, Seg +from maim_message import BaseMessageInfo, GroupInfo as MaimGroupInfo, MessageBase, Seg, UserInfo as MaimUserInfo from src.chat.message_receive.chat_manager import chat_manager as _chat_manager -from src.chat.message_receive.message import MessageSending +from src.chat.message_receive.message import SessionMessage from src.chat.message_receive.uni_message_sender import UniversalMessageSender -from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo -from src.common.data_models.message_data_model import ReplyContentType +from src.common.data_models.mai_message_data_model import MaiMessage +from src.common.data_models.message_component_data_model import DictComponent, MessageSequence from src.common.logger import get_logger from src.config.config import global_config if TYPE_CHECKING: - from src.common.data_models.database_data_model import DatabaseMessages - from src.common.data_models.message_data_model import ForwardNode, ReplyContent, ReplySetModel + from src.chat.message_receive.message import SessionMessage logger = get_logger("send_service") @@ -36,7 +35,7 @@ async def _send_to_target( display_message: str = "", typing: bool = False, set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, storage_message: bool = True, show_log: bool = True, selected_expressions: Optional[List[int]] = None, @@ -60,12 +59,6 @@ async def _send_to_target( current_time = time.time() message_id = f"send_api_{int(current_time * 1000)}" - bot_user_info = UserInfo( - user_id=global_config.bot.qq_account, - user_nickname=global_config.bot.nickname, - ) - - reply_to_platform_id = "" anchor_message: Optional[MaiMessage] = None if reply_message: anchor_message = db_message_to_mai_message(reply_message) @@ -73,31 +66,50 @@ async def _send_to_target( logger.debug( f"[SendService] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}" ) - reply_to_platform_id = f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}" - sender_info = None - if target_stream.context and target_stream.context.message: - sender_info = target_stream.context.message.message_info.user_info + group_info = None + if target_stream.group_id: + group_name = "" + if target_stream.context and target_stream.context.message and target_stream.context.message.message_info.group_info: + group_name = target_stream.context.message.message_info.group_info.group_name + group_info = MaimGroupInfo( + group_id=target_stream.group_id, + group_name=group_name, + platform=target_stream.platform, + ) - bot_message = MessageSending( - message_id=message_id, - session=target_stream, - bot_user_info=bot_user_info, - sender_info=sender_info, + additional_config: dict[str, object] = {} + if selected_expressions is not None: + additional_config["selected_expressions"] = selected_expressions + + maim_message = MessageBase( + message_info=BaseMessageInfo( + platform=target_stream.platform, + message_id=message_id, + time=current_time, + user_info=MaimUserInfo( + user_id=str(global_config.bot.qq_account), + user_nickname=global_config.bot.nickname, + platform=target_stream.platform, + ), + group_info=group_info, + additional_config=additional_config, + ), message_segment=message_segment, - display_message=display_message, - reply=anchor_message, - is_head=True, - is_emoji=(message_segment.type == "emoji"), - thinking_start_time=current_time, - reply_to=reply_to_platform_id, - selected_expressions=selected_expressions, ) + bot_message = SessionMessage.from_maim_message(maim_message) + bot_message.session_id = target_stream.session_id + bot_message.display_message = display_message + bot_message.reply_to = anchor_message.message_id if anchor_message else None + bot_message.is_emoji = message_segment.type == "emoji" + bot_message.is_picture = message_segment.type == "image" + bot_message.is_command = message_segment.type == "command" sent_msg = await message_sender.send_message( bot_message, typing=typing, set_reply=set_reply, + reply_message_id=anchor_message.message_id if anchor_message else None, storage_message=storage_message, show_log=show_log, ) @@ -115,37 +127,9 @@ async def _send_to_target( return False -def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]: +def db_message_to_mai_message(message_obj: "SessionMessage") -> Optional[MaiMessage]: """将数据库消息重建为 MaiMessage 对象,用于回复引用。""" - from datetime import datetime - - from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo - from src.common.data_models.message_component_data_model import MessageSequence - - user_info = UserInfo( - user_id=message_obj.user_info.user_id or "", - user_nickname=message_obj.user_info.user_nickname or "", - user_cardname=message_obj.user_info.user_cardname, - ) - - group_info = None - if message_obj.chat_info.group_info: - group_info = GroupInfo( - group_id=message_obj.chat_info.group_info.group_id or "", - group_name=message_obj.chat_info.group_info.group_name or "", - ) - - msg = MaiMessage( - message_id=message_obj.message_id, - timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(), - ) - msg.message_info = MessageInfo(user_info=user_info, group_info=group_info) - msg.platform = message_obj.chat_info.platform or "" - msg.session_id = message_obj.chat_info.stream_id or "" - msg.processed_plain_text = message_obj.processed_plain_text - msg.raw_message = MessageSequence(components=[]) - msg.initialized = True - return msg + return message_obj.deepcopy() # ============================================================================= @@ -158,7 +142,7 @@ async def text_to_stream( stream_id: str, typing: bool = False, set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, storage_message: bool = True, selected_expressions: Optional[List[int]] = None, ) -> bool: @@ -180,7 +164,7 @@ async def emoji_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, ) -> bool: """向指定流发送表情包""" return await _send_to_target( @@ -199,7 +183,7 @@ async def image_to_stream( stream_id: str, storage_message: bool = True, set_reply: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, ) -> bool: """向指定流发送图片""" return await _send_to_target( @@ -236,7 +220,7 @@ async def custom_to_stream( stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, @@ -255,25 +239,27 @@ async def custom_to_stream( async def custom_reply_set_to_stream( - reply_set: "ReplySetModel", + reply_set: MessageSequence, stream_id: str, display_message: str = "", typing: bool = False, - reply_message: Optional["DatabaseMessages"] = None, + reply_message: Optional["SessionMessage"] = None, set_reply: bool = False, storage_message: bool = True, show_log: bool = True, ) -> bool: - """向指定流发送混合型消息集""" + """向指定流发送消息组件序列。""" flag: bool = True - for reply_content in reply_set.reply_data: - status: bool = False - message_seg, need_typing = _parse_content_to_seg(reply_content) + for component in reply_set.components: + if isinstance(component, DictComponent): + message_seg = Seg(type="dict", data=component.data) # type: ignore + else: + message_seg = await component.to_seg() status = await _send_to_target( message_segment=message_seg, stream_id=stream_id, display_message=display_message, - typing=bool(need_typing and typing), + typing=typing, reply_message=reply_message, set_reply=set_reply, storage_message=storage_message, @@ -281,67 +267,7 @@ async def custom_reply_set_to_stream( ) if not status: flag = False - logger.error( - f"[SendService] 发送{repr(reply_content.content_type)}消息失败,消息内容:{str(reply_content.content)[:100]}" - ) + logger.error(f"[SendService] 发送消息组件失败,组件类型:{type(component).__name__}") + set_reply = False return flag - - -def _parse_content_to_seg(reply_content: "ReplyContent") -> Tuple[Seg, bool]: - """把 ReplyContent 转换为 Seg 结构""" - content_type = reply_content.content_type - if content_type == ReplyContentType.TEXT: - text_data: str = reply_content.content # type: ignore - return Seg(type="text", data=text_data), True - elif content_type == ReplyContentType.IMAGE: - return Seg(type="image", data=reply_content.content), False # type: ignore - elif content_type == ReplyContentType.EMOJI: - return Seg(type="emoji", data=reply_content.content), False # type: ignore - elif content_type == ReplyContentType.COMMAND: - return Seg(type="command", data=reply_content.content), False # type: ignore - elif content_type == ReplyContentType.VOICE: - return Seg(type="voice", data=reply_content.content), False # type: ignore - elif content_type == ReplyContentType.HYBRID: - hybrid_message_list_data: List[ReplyContent] = reply_content.content # type: ignore - assert isinstance(hybrid_message_list_data, list), "混合类型内容必须是列表" - sub_seg_list: List[Seg] = [] - for sub_content in hybrid_message_list_data: - sub_content_type = sub_content.content_type - sub_content_data = sub_content.content - - if sub_content_type == ReplyContentType.TEXT: - sub_seg_list.append(Seg(type="text", data=sub_content_data)) # type: ignore - elif sub_content_type == ReplyContentType.IMAGE: - sub_seg_list.append(Seg(type="image", data=sub_content_data)) # type: ignore - elif sub_content_type == ReplyContentType.EMOJI: - sub_seg_list.append(Seg(type="emoji", data=sub_content_data)) # type: ignore - else: - logger.warning(f"[SendService] 混合类型中不支持的子内容类型: {repr(sub_content_type)}") - continue - return Seg(type="seglist", data=sub_seg_list), True - elif content_type == ReplyContentType.FORWARD: - forward_message_list_data: List["ForwardNode"] = reply_content.content # type: ignore - assert isinstance(forward_message_list_data, list), "转发类型内容必须是列表" - forward_message_list: List[Dict] = [] - for forward_node in forward_message_list_data: - message_segment = Seg(type="id", data=forward_node.content) # type: ignore - user_info: Optional[UserInfo] = None - if forward_node.user_id and forward_node.user_nickname: - assert isinstance(forward_node.content, list), "转发节点内容必须是列表" - user_info = UserInfo(user_id=forward_node.user_id, user_nickname=forward_node.user_nickname) - single_node_content: List[Seg] = [] - for sub_content in forward_node.content: - if sub_content.content_type != ReplyContentType.FORWARD: - sub_seg, _ = _parse_content_to_seg(sub_content) - single_node_content.append(sub_seg) - message_segment = Seg(type="seglist", data=single_node_content) - forward_message_list.append( - MessageBase( - message_segment=message_segment, message_info=BaseMessageInfo(user_info=user_info) - ).to_dict() - ) - return Seg(type="forward", data=forward_message_list), False # type: ignore - else: - message_type_in_str = content_type.value if isinstance(content_type, ReplyContentType) else str(content_type) - return Seg(type=message_type_in_str, data=reply_content.content), True # type: ignore