diff --git a/Plan.md b/Plan.md new file mode 100644 index 00000000..c05d47b2 --- /dev/null +++ b/Plan.md @@ -0,0 +1,9 @@ +Context 在消息接收的时候就进行解析,不再放到 MaiMessage 里面,由消息注册的时候直接进去注册 +- [ ] 实现`update_chat_context`方法,主要关注`format_info` + + +1. **预计不对发送的时候进行`accept_format`的格式判断**,希望所有消息适配器接收的时候做一下不兼容内容主动丢弃 +2. 在发送消息的时候进行`accept_format`的判断,判断不兼容内容是否存在,如果存在则丢弃掉 + +- [ ] 实现 status_api + diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 19fd22d8..c23921c1 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -1,18 +1,9 @@ from asyncio import Task -from datetime import datetime -from maim_message import ( - MessageBase, - UserInfo as MaimUserInfo, - GroupInfo as MaimGroupInfo, - BaseMessageInfo as MaimBaseMessageInfo, - Seg, -) from rich.traceback import install from sqlmodel import select -from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING +from typing import List, Dict, Tuple, Sequence import asyncio -import time from src.common.logger import get_logger from src.common.database.database import get_db_session @@ -28,10 +19,6 @@ from src.common.data_models.message_component_data_model import ( ForwardNodeComponent, StandardMessageComponents, ) -from src.common.utils.utils_message import MessageUtils - -if TYPE_CHECKING: - from src.chat.message_receive.chat_manager import BotChatSession install(extra_lines=3) @@ -220,166 +207,3 @@ class SessionMessage(MaiMessage): else: processed_texts.append(result) return " ".join(processed_texts) - - -class MessageSending(MaiMessage): - """发送状态的消息类,继承 MaiMessage 基类。 - - 用于构建、处理和发送机器人的回复消息。 - 复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法, - 额外管理发送专属的会话信息和控制字段。 - """ - - def __init__( - self, - message_id: str, - session: "BotChatSession", - bot_user_info: UserInfo, - message_segment: Seg, - sender_info: Optional[UserInfo] = None, - reply: Optional[MaiMessage] = None, - display_message: str = "", - is_head: bool = False, - is_emoji: bool = False, - thinking_start_time: float = 0, - reply_to: Optional[str] = None, - selected_expressions: Optional[List[int]] = None, - ): - # 初始化 MaiMessage 基类 - super().__init__(message_id=message_id, timestamp=datetime.now()) - - # 发送专属字段 - self.session = session - self.sender_info = sender_info - self.message_segment = message_segment - self.reply = reply - self.is_head = is_head - self.thinking_start_time = thinking_start_time - self.selected_expressions = selected_expressions - self.reply_to_message_id: Optional[str] = reply.message_id if reply else None - self.interest_value: float = 0.0 - - # 填充 MaiMessage 标准字段 - self.platform = session.platform - self.session_id = session.session_id - self.is_emoji = is_emoji - self.reply_to = reply_to - self.display_message = display_message - self.processed_plain_text = "" - - # 构建 message_info:DB 存储时 user_info 始终为 bot 信息 - # 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理 - group_info = self._resolve_group_info() - self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info) - - # bot_user_info 单独保存,to_maim_message 覆写时还需要 - self.bot_user_info = bot_user_info - - # 将 Seg 转换为 MessageSequence,供基类的 to_db_instance / to_maim_message 使用 - self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( - MessageBase(message_info=None, message_segment=message_segment) - ) - - self.initialized = True - - def _resolve_group_info(self) -> Optional[GroupInfo]: - """从 session 中解析群信息""" - if not self.session.group_id: - return None - group_name = "" - if ( - self.session.context - and self.session.context.message - and self.session.context.message.message_info.group_info - ): - group_name = self.session.context.message.message_info.group_info.group_name - return GroupInfo(group_id=self.session.group_id, group_name=group_name) - - async def process(self) -> None: - """处理消息段,生成 processed_plain_text(使用 SessionMessage 的组件处理能力)""" - # 同步 message_segment → raw_message(插件可能修改了 message_segment) - self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( - MessageBase(message_info=None, message_segment=self.message_segment) - ) - if self.raw_message and self.raw_message.components: - tasks = [self._process_component(c) for c in self.raw_message.components] - results = await asyncio.gather(*tasks, return_exceptions=True) - texts = [] - for r in results: - if isinstance(r, BaseException): - logger.error(f"处理发送消息组件时发生错误: {r}") - elif r: - texts.append(r) - self.processed_plain_text = " ".join(texts) - - async def _process_component(self, component: StandardMessageComponents) -> str: - """简单处理单个标准组件为纯文本描述""" - if isinstance(component, TextComponent): - return component.text - elif isinstance(component, ImageComponent): - return "[图片]" - elif isinstance(component, EmojiComponent): - return "[表情包]" - elif isinstance(component, VoiceComponent): - return "[语音]" - elif isinstance(component, AtComponent): - return f"[@{component.target_user_id}]" - elif isinstance(component, ReplyComponent): - return "" - else: - return f"[{type(component).__name__}]" - - def build_reply(self) -> None: - """构建回复消息段,在 message_segment 前插入 reply 段""" - if self.reply: - self.reply_to_message_id = self.reply.message_id - self.message_segment = Seg( - type="seglist", - data=[ - Seg(type="reply", data=self.reply.message_id), - self.message_segment, - ], - ) - # 同步更新 raw_message - self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq( - MessageBase(message_info=None, message_segment=self.message_segment) - ) - - async def to_maim_message(self) -> MessageBase: - """覆写基类方法:发送消息需要特殊处理 user_info(私聊/群聊差异)""" - maim_bot_user_info = MaimUserInfo( - user_id=self.bot_user_info.user_id, - user_nickname=self.bot_user_info.user_nickname, - user_cardname=self.bot_user_info.user_cardname, - platform=self.platform, - ) - - maim_group_info = None - if self.message_info.group_info: - maim_group_info = MaimGroupInfo( - group_id=self.message_info.group_info.group_id, - group_name=self.message_info.group_info.group_name, - platform=self.platform, - ) - - # 私聊时 user_info 填接收者信息(sender_info),群聊时填 bot - if maim_group_info is None and self.sender_info: - msg_user_info = MaimUserInfo( - user_id=self.sender_info.user_id, - user_nickname=self.sender_info.user_nickname, - user_cardname=self.sender_info.user_cardname, - platform=self.platform, - ) - else: - msg_user_info = maim_bot_user_info - - maim_msg_info = MaimBaseMessageInfo( - platform=self.platform, - message_id=self.message_id, - time=time.time(), - group_info=maim_group_info, - user_info=msg_user_info, - ) - - msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message) - return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments)) diff --git a/src/chat/message_receive/message_old.py b/src/chat/message_receive/message_old.py deleted file mode 100644 index b9f44d5c..00000000 --- a/src/chat/message_receive/message_old.py +++ /dev/null @@ -1,561 +0,0 @@ -import time -import asyncio -import urllib3 - -from abc import abstractmethod -from dataclasses import dataclass -from rich.traceback import install -from typing import Optional, Any, List -from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.utils.utils_image import get_image_manager -from src.common.utils.utils_voice import get_voice_text -from .chat_stream import ChatStream - -install(extra_lines=3) - -logger = get_logger("chat_message") - -# 禁用SSL警告 -urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - -# VLM 处理并发限制(避免同时处理太多图片导致卡死) -_vlm_semaphore = asyncio.Semaphore(3) - -# 这个类是消息数据类,用于存储和管理消息数据。 -# 它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 -# 它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。 - - -@dataclass -class Message(MessageBase): - chat_stream: "ChatStream" = None # type: ignore - reply: Optional["Message"] = None - processed_plain_text: str = "" - - def __init__( - self, - message_id: str, - chat_stream: "ChatStream", - user_info: UserInfo, - message_segment: Optional[Seg] = None, - timestamp: Optional[float] = None, - reply: Optional["MessageRecv"] = None, - processed_plain_text: str = "", - ): - # 使用传入的时间戳或当前时间 - current_timestamp = timestamp if timestamp is not None else round(time.time(), 3) - # 构造基础消息信息 - message_info = BaseMessageInfo( - platform=chat_stream.platform, - message_id=message_id, - time=current_timestamp, - group_info=chat_stream.group_info, - user_info=user_info, - ) - - # 调用父类初始化 - super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore - - self.chat_stream = chat_stream - # 文本处理相关属性 - self.processed_plain_text = processed_plain_text - - # 回复消息 - self.reply = reply - - # async def _process_message_segments(self, segment: Seg) -> str: - # # sourcery skip: remove-unnecessary-else, swap-if-else-branches - # """递归处理消息段,转换为文字描述 - - # Args: - # segment: 要处理的消息段 - - # Returns: - # str: 处理后的文本 - # """ - # if segment.type == "seglist": - # # 处理消息段列表 - 使用并行处理提升性能 - # tasks = [self._process_message_segments(seg) for seg in segment.data] # type: ignore - # results = await asyncio.gather(*tasks, return_exceptions=True) - # segments_text = [] - # for result in results: - # if isinstance(result, Exception): - # logger.error(f"处理消息段时出错: {result}") - # continue - # if result: - # segments_text.append(result) - # return " ".join(segments_text) - # elif segment.type == "forward": - # # 处理转发消息 - 使用并行处理 - # async def process_forward_node(node_dict): - # message = MessageBase.from_dict(node_dict) # type: ignore - # processed_text = await self._process_message_segments(message.message_segment) - # if processed_text: - # return f"{global_config.bot.nickname}: {processed_text}" - # return None - - # tasks = [process_forward_node(node_dict) for node_dict in segment.data] - # results = await asyncio.gather(*tasks, return_exceptions=True) - # segments_text = [] - # for result in results: - # if isinstance(result, Exception): - # logger.error(f"处理转发节点时出错: {result}") - # continue - # if result: - # segments_text.append(result) - # return "[合并消息]: " + "\n-- ".join(segments_text) - # else: - # # 处理单个消息段 - # return await self._process_single_segment(segment) # type: ignore - - # @abstractmethod - # async def _process_single_segment(self, segment) -> str: - # pass - - -@dataclass -class MessageRecv(Message): - """接收消息类,用于处理从MessageCQ序列化的消息""" - - def __init__(self, message_dict: dict[str, Any]): - """从MessageCQ的字典初始化 - - Args: - message_dict: MessageCQ序列化后的字典 - """ - self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {})) - self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) - self.raw_message = message_dict.get("raw_message") - self.processed_plain_text = message_dict.get("processed_plain_text", "") - self.is_emoji = False - self.has_emoji = False - self.is_picid = False - self.has_picid = False - self.is_voice = False - self.is_mentioned = None - self.is_at = False - self.reply_probability_boost = 0.0 - self.is_notify = False - - self.is_command = False - self.intercept_message_level = 0 - - self.priority_mode = "interest" - self.priority_info = None - self.interest_value: float = None # type: ignore - - self.key_words = [] - self.key_words_lite = [] - - # 兼容适配器通过 additional_config 传入的 @ 标记 - try: - msg_info_dict = message_dict.get("message_info", {}) - add_cfg = msg_info_dict.get("additional_config") or {} - if isinstance(add_cfg, dict) and add_cfg.get("at_bot"): - # 标记为被提及,提高后续回复优先级 - self.is_mentioned = True # type: ignore - except Exception: - pass - - def update_chat_stream(self, chat_stream: "ChatStream"): - self.chat_stream = chat_stream - - # async def process(self) -> None: - # """处理消息内容,生成纯文本和详细文本 - - # 这个方法必须在创建实例后显式调用,因为它包含异步操作。 - # """ - # # print(f"self.message_segment: {self.message_segment}") - # self.processed_plain_text = await self._process_message_segments(self.message_segment) - - # async def _process_single_segment(self, segment: Seg) -> str: - # """处理单个消息段 - - # Args: - # segment: 消息段 - - # Returns: - # str: 处理后的文本 - # """ - # try: - # if segment.type == "text": - # self.is_picid = False - # self.is_emoji = False - # return segment.data # type: ignore - # elif segment.type == "image": - # # 如果是base64图片数据 - # if isinstance(segment.data, str): - # self.has_picid = True - # self.is_picid = True - # self.is_emoji = False - # image_manager = get_image_manager() - # # 使用 semaphore 限制 VLM 并发,避免同时处理太多图片 - # async with _vlm_semaphore: - # _, processed_text = await image_manager.process_image(segment.data) - # return processed_text - # return "[发了一张图片,网卡了加载不出来]" - # elif segment.type == "emoji": - # self.has_emoji = True - # self.is_emoji = True - # self.is_picid = False - # self.is_voice = False - # if isinstance(segment.data, str): - # # 使用 semaphore 限制 VLM 并发 - # async with _vlm_semaphore: - # return await get_image_manager().get_emoji_description(segment.data) - # return "[发了一个表情包,网卡了加载不出来]" - # elif segment.type == "voice": - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = True - # if isinstance(segment.data, str): - # return await get_voice_text(segment.data) - # return "[发了一段语音,网卡了加载不出来]" - # elif segment.type == "mention_bot": - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = False - # self.is_mentioned = float(segment.data) # type: ignore - # return "" - # elif segment.type == "priority_info": - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = False - # if isinstance(segment.data, dict): - # # 处理优先级信息 - # self.priority_mode = "priority" - # self.priority_info = segment.data - # """ - # { - # 'message_type': 'vip', # vip or normal - # 'message_priority': 1.0, # 优先级,大为优先,float - # } - # """ - # return "" - # elif segment.type == "video_card": - # # 处理视频卡片消息 - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = False - # if isinstance(segment.data, dict): - # file_name = segment.data.get("file", "未知视频") - # file_size = segment.data.get("file_size", "") - # url = segment.data.get("url", "") - # text = f"[视频: {file_name}" - # if file_size: - # text += f", 大小: {file_size}字节" - # text += "]" - # if url: - # text += f" 链接: {url}" - # return text - # return "[视频]" - # elif segment.type == "music_card": - # # 处理音乐卡片消息 - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = False - # if isinstance(segment.data, dict): - # title = segment.data.get("title", "未知歌曲") - # singer = segment.data.get("singer", "") - # tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐" - # jump_url = segment.data.get("jump_url", "") - # music_url = segment.data.get("music_url", "") - # text = f"[音乐: {title}" - # if singer: - # text += f" - {singer}" - # if tag: - # text += f" ({tag})" - # text += "]" - # if jump_url: - # text += f" 跳转链接: {jump_url}" - # if music_url: - # text += f" 音乐链接: {music_url}" - # return text - # return "[音乐]" - # elif segment.type == "miniapp_card": - # # 处理小程序分享卡片(如B站视频分享) - # self.is_picid = False - # self.is_emoji = False - # self.is_voice = False - # if isinstance(segment.data, dict): - # title = segment.data.get("title", "") # 小程序名称 - # desc = segment.data.get("desc", "") # 内容描述 - # source_url = segment.data.get("source_url", "") # 原始链接 - # url = segment.data.get("url", "") # 小程序链接 - # text = "[小程序分享" - # if title: - # text += f" - {title}" - # text += "]" - # if desc: - # text += f" {desc}" - # if source_url: - # text += f" 链接: {source_url}" - # elif url: - # text += f" 链接: {url}" - # return text - # return "[小程序分享]" - # else: - # return "" - # except Exception as e: - # logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") - # return f"[处理失败的{segment.type}消息]" - - -@dataclass -class MessageProcessBase(Message): - """消息处理基类,用于处理中和发送中的消息""" - - def __init__( - self, - message_id: str, - chat_stream: "ChatStream", - bot_user_info: UserInfo, - message_segment: Optional[Seg] = None, - reply: Optional["MessageRecv"] = None, - thinking_start_time: float = 0, - timestamp: Optional[float] = None, - ): - # 调用父类初始化,传递时间戳 - super().__init__( - message_id=message_id, - timestamp=timestamp, - chat_stream=chat_stream, - user_info=bot_user_info, - message_segment=message_segment, - reply=reply, - ) - - # 处理状态相关属性 - self.thinking_start_time = thinking_start_time - self.thinking_time = 0 - - # def update_thinking_time(self) -> float: - # """更新思考时间""" - # self.thinking_time = round(time.time() - self.thinking_start_time, 2) - # return self.thinking_time - - # async def _process_single_segment(self, segment: Seg) -> str: - # """处理单个消息段 - - # Args: - # segment: 要处理的消息段 - - # Returns: - # str: 处理后的文本 - # """ - # try: - # if segment.type == "text": - # return segment.data # type: ignore - # elif segment.type == "image": - # # 如果是base64图片数据 - # if isinstance(segment.data, str): - # return await get_image_manager().get_image_description(segment.data) - # return "[图片,网卡了加载不出来]" - # elif segment.type == "emoji": - # if isinstance(segment.data, str): - # return await get_image_manager().get_emoji_tag(segment.data) - # return "[表情,网卡了加载不出来]" - # elif segment.type == "voice": - # if isinstance(segment.data, str): - # return await get_voice_text(segment.data) - # return "[发了一段语音,网卡了加载不出来]" - # elif segment.type == "at": - # return f"[@{segment.data}]" - # elif segment.type == "reply": - # if self.reply and hasattr(self.reply, "processed_plain_text"): - # # print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}") - # # print(f"reply: {self.reply}") - # return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore - # return "" - # else: - # return f"[{segment.type}:{str(segment.data)}]" - # except Exception as e: - # logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") - # return f"[处理失败的{segment.type}消息]" - - # def _generate_detailed_text(self) -> str: - # """生成详细文本,包含时间和用户信息""" - # # time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time)) - # timestamp = self.message_info.time - # user_info = self.message_info.user_info - - # name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore - # return f"[{timestamp}],{name} 说:{self.processed_plain_text}\n" - - -@dataclass -class MessageSending(MessageProcessBase): - """发送状态的消息类""" - - def __init__( - self, - message_id: str, - chat_stream: "ChatStream", - bot_user_info: UserInfo, - sender_info: UserInfo | None, # 用来记录发送者信息 - message_segment: Seg, - display_message: str = "", - reply: Optional["MessageRecv"] = None, - is_head: bool = False, - is_emoji: bool = False, - thinking_start_time: float = 0, - apply_set_reply_logic: bool = False, - reply_to: Optional[str] = None, - selected_expressions: Optional[List[int]] = None, - ): - # 调用父类初始化 - super().__init__( - message_id=message_id, - chat_stream=chat_stream, - bot_user_info=bot_user_info, - message_segment=message_segment, - reply=reply, - thinking_start_time=thinking_start_time, - ) - - # 发送状态特有属性 - self.sender_info = sender_info - self.reply_to_message_id = reply.message_info.message_id if reply else None - self.is_head = is_head - self.is_emoji = is_emoji - self.apply_set_reply_logic = apply_set_reply_logic - - self.reply_to = reply_to - - # 用于显示发送内容与显示不一致的情况 - self.display_message = display_message - - self.interest_value = 0.0 - - self.selected_expressions = selected_expressions - - def build_reply(self): - """设置回复消息""" - if self.reply: - self.reply_to_message_id = self.reply.message_info.message_id - self.message_segment = Seg( - type="seglist", - data=[ - Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore - self.message_segment, - ], - ) - - async def process(self) -> None: - """处理消息内容,生成纯文本和详细文本""" - if self.message_segment: - self.processed_plain_text = await self._process_message_segments(self.message_segment) - - # def to_dict(self): - # ret = super().to_dict() - # ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict() - # return ret - - # def is_private_message(self) -> bool: - # """判断是否为私聊消息""" - # return self.message_info.group_info is None or self.message_info.group_info.group_id is None - - -# @dataclass -# class MessageSet: -# """消息集合类,可以存储多个发送消息""" - -# def __init__(self, chat_stream: "ChatStream", message_id: str): -# self.chat_stream = chat_stream -# self.message_id = message_id -# self.messages: list[MessageSending] = [] -# self.time = round(time.time(), 3) # 保留3位小数 - -# def add_message(self, message: MessageSending) -> None: -# """添加消息到集合""" -# if not isinstance(message, MessageSending): -# raise TypeError("MessageSet只能添加MessageSending类型的消息") -# self.messages.append(message) -# self.messages.sort(key=lambda x: x.message_info.time) # type: ignore - -# def get_message_by_index(self, index: int) -> Optional[MessageSending]: -# """通过索引获取消息""" -# return self.messages[index] if 0 <= index < len(self.messages) else None - -# def get_message_by_time(self, target_time: float) -> Optional[MessageSending]: -# """获取最接近指定时间的消息""" -# if not self.messages: -# return None - -# left, right = 0, len(self.messages) - 1 -# while left < right: -# mid = (left + right) // 2 -# if self.messages[mid].message_info.time < target_time: # type: ignore -# left = mid + 1 -# else: -# right = mid - -# return self.messages[left] - -# def clear_messages(self) -> None: -# """清空所有消息""" -# self.messages.clear() - -# def remove_message(self, message: MessageSending) -> bool: -# """移除指定消息""" -# if message in self.messages: -# self.messages.remove(message) -# return True -# return False - -# def __str__(self) -> str: -# return f"MessageSet(id={self.message_id}, count={len(self.messages)})" - -# def __len__(self) -> int: -# return len(self.messages) - - -# def message_recv_from_dict(message_dict: dict) -> MessageRecv: -# return MessageRecv(message_dict) - - -# def message_from_db_dict(db_dict: dict) -> MessageRecv: -# """从数据库字典创建MessageRecv实例""" -# # 转换扁平的数据库字典为嵌套结构 -# message_info_dict = { -# "platform": db_dict.get("chat_info_platform"), -# "message_id": db_dict.get("message_id"), -# "time": db_dict.get("time"), -# "group_info": { -# "platform": db_dict.get("chat_info_group_platform"), -# "group_id": db_dict.get("chat_info_group_id"), -# "group_name": db_dict.get("chat_info_group_name"), -# }, -# "user_info": { -# "platform": db_dict.get("user_platform"), -# "user_id": db_dict.get("user_id"), -# "user_nickname": db_dict.get("user_nickname"), -# "user_cardname": db_dict.get("user_cardname"), -# }, -# } - -# processed_text = db_dict.get("processed_plain_text", "") - -# # 构建 MessageRecv 需要的字典 -# recv_dict = { -# "message_info": message_info_dict, -# "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 -# "raw_message": None, # 数据库中未存储原始消息 -# "processed_plain_text": processed_text, -# } - -# # 创建 MessageRecv 实例 -# msg = MessageRecv(recv_dict) - -# # 从数据库字典中填充其他可选字段 -# msg.interest_value = db_dict.get("interest_value", 0.0) -# msg.is_mentioned = db_dict.get("is_mentioned") -# msg.priority_mode = db_dict.get("priority_mode", "interest") -# msg.priority_info = db_dict.get("priority_info") -# msg.is_emoji = db_dict.get("is_emoji", False) -# msg.is_picid = db_dict.get("is_picid", False) - -# return msg diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 9346259c..4170d4f5 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -1,13 +1,14 @@ -import asyncio -import traceback - from rich.traceback import install -from maim_message import Seg +from typing import Optional + +import asyncio + from src.common.message_server.api import get_global_api from src.common.logger import get_logger from src.common.database.database import get_db_session -from src.chat.message_receive.message_old import MessageSending +from src.chat.message_receive.message import SessionMessage +from src.common.data_models.message_component_data_model import ReplyComponent from src.chat.utils.utils import truncate_message from src.chat.utils.utils import calculate_typing_time @@ -21,267 +22,267 @@ _webui_chat_broadcaster = None # 虚拟群 ID 前缀(与 chat_routes.py 保持一致) VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_" +# TODO: 重构完成后完成webui相关 +# def get_webui_chat_broadcaster(): +# """获取 WebUI 聊天室广播器""" +# global _webui_chat_broadcaster +# if _webui_chat_broadcaster is None: +# try: +# from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM -def get_webui_chat_broadcaster(): - """获取 WebUI 聊天室广播器""" - global _webui_chat_broadcaster - if _webui_chat_broadcaster is None: - try: - from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM - - _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) - except ImportError: - _webui_chat_broadcaster = (None, None) - return _webui_chat_broadcaster +# _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) +# except ImportError: +# _webui_chat_broadcaster = (None, None) +# return _webui_chat_broadcaster -def is_webui_virtual_group(group_id: str) -> bool: - """检查是否是 WebUI 虚拟群""" - return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) +# def is_webui_virtual_group(group_id: str) -> bool: +# """检查是否是 WebUI 虚拟群""" +# return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX) -def parse_message_segments(segment) -> list: - """解析消息段,转换为 WebUI 可用的格式 +# def parse_message_segments(segment) -> list: +# """解析消息段,转换为 WebUI 可用的格式 - 参考 NapCat 适配器的消息解析逻辑 +# 参考 NapCat 适配器的消息解析逻辑 - Args: - segment: Seg 消息段对象 +# Args: +# segment: Seg 消息段对象 - Returns: - list: 消息段列表,每个元素为 {"type": "...", "data": ...} - """ +# Returns: +# list: 消息段列表,每个元素为 {"type": "...", "data": ...} +# """ - result = [] +# result = [] - if segment is None: - return result +# if segment is None: +# return result - if segment.type == "seglist": - # 处理消息段列表 - if segment.data: - for seg in segment.data: - result.extend(parse_message_segments(seg)) - elif segment.type == "text": - # 文本消息 - if segment.data: - result.append({"type": "text", "data": segment.data}) - elif segment.type == "image": - # 图片消息(base64) - if segment.data: - result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"}) - elif segment.type == "emoji": - # 表情包消息(base64) - if segment.data: - result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"}) - elif segment.type == "imageurl": - # 图片链接消息 - if segment.data: - result.append({"type": "image", "data": segment.data}) - elif segment.type == "face": - # 原生表情 - result.append({"type": "face", "data": segment.data}) - elif segment.type == "voice": - # 语音消息(base64) - if segment.data: - result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"}) - elif segment.type == "voiceurl": - # 语音链接 - if segment.data: - result.append({"type": "voice", "data": segment.data}) - elif segment.type == "video": - # 视频消息(base64) - if segment.data: - result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"}) - elif segment.type == "videourl": - # 视频链接 - if segment.data: - result.append({"type": "video", "data": segment.data}) - elif segment.type == "music": - # 音乐消息 - result.append({"type": "music", "data": segment.data}) - elif segment.type == "file": - # 文件消息 - result.append({"type": "file", "data": segment.data}) - elif segment.type == "reply": - # 回复消息 - result.append({"type": "reply", "data": segment.data}) - elif segment.type == "forward": - # 转发消息 - forward_items = [] - if segment.data: - for item in segment.data: - forward_items.append( - { - "content": parse_message_segments(item.get("message_segment", {})) - if isinstance(item, dict) - else [] - } - ) - result.append({"type": "forward", "data": forward_items}) - else: - # 未知类型,尝试作为文本处理 - if segment.data: - result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) +# if segment.type == "seglist": +# # 处理消息段列表 +# if segment.data: +# for seg in segment.data: +# result.extend(parse_message_segments(seg)) +# elif segment.type == "text": +# # 文本消息 +# if segment.data: +# result.append({"type": "text", "data": segment.data}) +# elif segment.type == "image": +# # 图片消息(base64) +# if segment.data: +# result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"}) +# elif segment.type == "emoji": +# # 表情包消息(base64) +# if segment.data: +# result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"}) +# elif segment.type == "imageurl": +# # 图片链接消息 +# if segment.data: +# result.append({"type": "image", "data": segment.data}) +# elif segment.type == "face": +# # 原生表情 +# result.append({"type": "face", "data": segment.data}) +# elif segment.type == "voice": +# # 语音消息(base64) +# if segment.data: +# result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"}) +# elif segment.type == "voiceurl": +# # 语音链接 +# if segment.data: +# result.append({"type": "voice", "data": segment.data}) +# elif segment.type == "video": +# # 视频消息(base64) +# if segment.data: +# result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"}) +# elif segment.type == "videourl": +# # 视频链接 +# if segment.data: +# result.append({"type": "video", "data": segment.data}) +# elif segment.type == "music": +# # 音乐消息 +# result.append({"type": "music", "data": segment.data}) +# elif segment.type == "file": +# # 文件消息 +# result.append({"type": "file", "data": segment.data}) +# elif segment.type == "reply": +# # 回复消息 +# result.append({"type": "reply", "data": segment.data}) +# elif segment.type == "forward": +# # 转发消息 +# forward_items = [] +# if segment.data: +# for item in segment.data: +# forward_items.append( +# { +# "content": parse_message_segments(item.get("message_segment", {})) +# if isinstance(item, dict) +# else [] +# } +# ) +# result.append({"type": "forward", "data": forward_items}) +# else: +# # 未知类型,尝试作为文本处理 +# if segment.data: +# result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)}) - return result +# return result -async def _send_message(message: MessageSending, show_log=True) -> bool: - """合并后的消息发送函数,包含WS发送和日志记录""" - message_preview = truncate_message(message.processed_plain_text, max_length=200) - platform = message.platform - group_id = message.session.group_id +# async def _send_message(message: MessageSending, show_log=True) -> bool: +# """合并后的消息发送函数,包含WS发送和日志记录""" +# message_preview = truncate_message(message.processed_plain_text, max_length=200) +# platform = message.platform +# group_id = message.session.group_id - try: - # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 - chat_manager, webui_platform = get_webui_chat_broadcaster() - is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id) +# try: +# # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 +# chat_manager, webui_platform = get_webui_chat_broadcaster() +# is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id) - if is_webui_message and chat_manager is not None: - # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播 - import time - from src.config.config import global_config +# if is_webui_message and chat_manager is not None: +# # WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播 +# import time +# from src.config.config import global_config - # 解析消息段,获取富文本内容 - message_segments = parse_message_segments(message.message_segment) +# # 解析消息段,获取富文本内容 +# message_segments = parse_message_segments(message.message_segment) - # 判断消息类型 - # 如果只有一个文本段,使用简单的 text 类型 - # 否则使用 rich 类型,包含完整的消息段 - if len(message_segments) == 1 and message_segments[0].get("type") == "text": - message_type = "text" - segments = None - else: - message_type = "rich" - segments = message_segments +# # 判断消息类型 +# # 如果只有一个文本段,使用简单的 text 类型 +# # 否则使用 rich 类型,包含完整的消息段 +# if len(message_segments) == 1 and message_segments[0].get("type") == "text": +# message_type = "text" +# segments = None +# else: +# message_type = "rich" +# segments = message_segments - await chat_manager.broadcast( - { - "type": "bot_message", - "content": message.processed_plain_text, - "message_type": message_type, - "segments": segments, # 富文本消息段 - "timestamp": time.time(), - "group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签 - "sender": { - "name": global_config.bot.nickname, - "avatar": None, - "is_bot": True, - }, - } - ) +# await chat_manager.broadcast( +# { +# "type": "bot_message", +# "content": message.processed_plain_text, +# "message_type": message_type, +# "segments": segments, # 富文本消息段 +# "timestamp": time.time(), +# "group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签 +# "sender": { +# "name": global_config.bot.nickname, +# "avatar": None, +# "is_bot": True, +# }, +# } +# ) - # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库 - # 无需手动保存 +# # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库 +# # 无需手动保存 - if show_log: - if is_webui_virtual_group(group_id): - logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})") - else: - logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") - return True +# if show_log: +# if is_webui_virtual_group(group_id): +# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 虚拟群 (平台: {platform})") +# else: +# logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") +# return True - # Fallback 逻辑: 尝试通过 API Server 发送 - async def send_with_new_api(legacy_exception=None): - try: - from src.config.config import global_config +# # Fallback 逻辑: 尝试通过 API Server 发送 +# async def send_with_new_api(legacy_exception=None): +# try: +# from src.config.config import global_config - # 如果未开启 API Server,直接跳过 Fallback - if not global_config.maim_message.enable_api_server: - logger.debug("[API Server Fallback] API Server未开启,跳过fallback") - if legacy_exception: - raise legacy_exception - return False +# # 如果未开启 API Server,直接跳过 Fallback +# if not global_config.maim_message.enable_api_server: +# logger.debug("[API Server Fallback] API Server未开启,跳过fallback") +# if legacy_exception: +# raise legacy_exception +# return False - global_api = get_global_api() - extra_server = getattr(global_api, "extra_server", None) +# global_api = get_global_api() +# extra_server = getattr(global_api, "extra_server", None) - if not extra_server: - logger.warning("[API Server Fallback] extra_server不存在") - if legacy_exception: - raise legacy_exception - return False +# if not extra_server: +# logger.warning("[API Server Fallback] extra_server不存在") +# if legacy_exception: +# raise legacy_exception +# return False - if not extra_server.is_running(): - logger.warning("[API Server Fallback] extra_server未运行") - if legacy_exception: - raise legacy_exception - return False +# if not extra_server.is_running(): +# logger.warning("[API Server Fallback] extra_server未运行") +# if legacy_exception: +# raise legacy_exception +# return False - # Fallback: 使用极其简单的 Platform -> API Key 映射 - # 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息 - platform_map = getattr(global_api, "platform_map", {}) - logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'") - target_api_key = platform_map.get(platform) +# # Fallback: 使用极其简单的 Platform -> API Key 映射 +# # 只有收到过该平台的消息,我们才知道该平台的 API Key,才能回传消息 +# platform_map = getattr(global_api, "platform_map", {}) +# logger.debug(f"[API Server Fallback] platform_map: {platform_map}, 目标平台: '{platform}'") +# target_api_key = platform_map.get(platform) - if not target_api_key: - logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射") - if legacy_exception: - raise legacy_exception - return False +# if not target_api_key: +# logger.warning(f"[API Server Fallback] 未找到平台'{platform}'的API Key映射") +# if legacy_exception: +# raise legacy_exception +# return False - # 使用 MessageConverter 转换为 API 消息 - from maim_message import MessageConverter +# # 使用 MessageConverter 转换为 API 消息 +# from maim_message import MessageConverter - # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异 - message_base = await message.to_maim_message() +# # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异 +# message_base = await message.to_maim_message() - api_message = MessageConverter.to_api_send( - message=message_base, - api_key=target_api_key, - platform=platform, - ) +# api_message = MessageConverter.to_api_send( +# message=message_base, +# api_key=target_api_key, +# platform=platform, +# ) - # 直接调用 Server 的 send_message 接口,它会自动处理路由 - logger.debug("[API Server Fallback] 正在通过extra_server发送消息...") - results = await extra_server.send_message(api_message) - logger.debug(f"[API Server Fallback] 发送结果: {results}") +# # 直接调用 Server 的 send_message 接口,它会自动处理路由 +# logger.debug("[API Server Fallback] 正在通过extra_server发送消息...") +# results = await extra_server.send_message(api_message) +# logger.debug(f"[API Server Fallback] 发送结果: {results}") - # 检查是否有任何连接发送成功 - if any(results.values()): - if show_log: - logger.info( - f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})" - ) - return True - else: - logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}") - except Exception as e: - logger.error(f"[API Server Fallback] 发生异常: {e}") - import traceback +# # 检查是否有任何连接发送成功 +# if any(results.values()): +# if show_log: +# logger.info( +# f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})" +# ) +# return True +# else: +# logger.warning(f"[API Server Fallback] 没有连接发送成功, results={results}") +# except Exception as e: +# logger.error(f"[API Server Fallback] 发生异常: {e}") +# import traceback - logger.debug(traceback.format_exc()) +# logger.debug(traceback.format_exc()) - # 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常 - if legacy_exception: - raise legacy_exception - return False +# # 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常 +# if legacy_exception: +# raise legacy_exception +# return False - try: - message_base = await message.to_maim_message() - send_result = await get_global_api().send_message(message_base) - if send_result: - if show_log: - logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'") - return True - else: - # Legacy API 返回 False (发送失败但未报错),尝试 Fallback - fallback_result = await send_with_new_api() - if fallback_result and show_log: - # Fallback成功的日志已在send_with_new_api中打印 - pass - return fallback_result +# try: +# message_base = await message.to_maim_message() +# send_result = await get_global_api().send_message(message_base) +# if send_result: +# if show_log: +# logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'") +# return True +# else: +# # Legacy API 返回 False (发送失败但未报错),尝试 Fallback +# fallback_result = await send_with_new_api() +# if fallback_result and show_log: +# # Fallback成功的日志已在send_with_new_api中打印 +# pass +# return fallback_result - except Exception as legacy_e: - # Legacy API 抛出异常,尝试 Fallback - # 如果 Fallback 也失败,将重新抛出 legacy_e - return await send_with_new_api(legacy_exception=legacy_e) +# except Exception as legacy_e: +# # Legacy API 抛出异常,尝试 Fallback +# # 如果 Fallback 也失败,将重新抛出 legacy_e +# return await send_with_new_api(legacy_exception=legacy_e) - except Exception as e: - logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}") - traceback.print_exc() - raise e # 重新抛出其他异常 +# except Exception as e: +# logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}") +# traceback.print_exc() +# raise e # 重新抛出其他异常 class UniversalMessageSender: @@ -291,21 +292,26 @@ class UniversalMessageSender: pass async def send_message( - self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True + self, + message: "SessionMessage", + typing: bool = False, + set_reply: bool = False, + reply_message_id: Optional[str] = None, + storage_message: bool = True, + show_log: bool = True, ): """ 处理、发送并存储一条消息。 参数: - message: MessageSending 对象,待发送的消息。 + message: MessageSession 对象,待发送的消息。 typing: 是否模拟打字等待。 + set_reply: 是否构建回复引用消息。 + 用法: - typing=True 时,发送前会有打字等待。 """ - if not message.session: - logger.error("消息缺少 session,无法发送") - raise ValueError("消息缺少 session,无法发送") if not message.message_id: logger.error("消息缺少 message_id,无法发送") raise ValueError("消息缺少 message_id,无法发送") @@ -315,66 +321,62 @@ class UniversalMessageSender: try: if set_reply: - message.build_reply() - logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...") + if not reply_message_id: + raise ValueError("set_reply=True 时必须提供 reply_message_id") + message.raw_message.components.insert(0, ReplyComponent(reply_message_id)) - from src.core.event_bus import event_bus - from src.chat.event_helpers import build_event_message - from src.core.types import EventType + # TODO: fix + # from src.core.event_bus import event_bus + # from src.chat.event_helpers import build_event_message + # from src.core.types import EventType - _event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id) - continue_flag, modified_message = await event_bus.emit( - EventType.POST_SEND_PRE_PROCESS, _event_msg - ) - if not continue_flag: - logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") - return False - if modified_message: - if modified_message._modify_flags.modify_message_segments: - message.message_segment = Seg(type="seglist", data=modified_message.message_segments) - if modified_message._modify_flags.modify_plain_text: - logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。") - message.processed_plain_text = modified_message.plain_text + # _event_msg = build_event_message(EventType.POST_SEND_PRE_PROCESS, message=message, stream_id=chat_id) + # continue_flag, modified_message = await event_bus.emit(EventType.POST_SEND_PRE_PROCESS, _event_msg) + # if not continue_flag: + # logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") + # return False + # if modified_message: + # if modified_message._modify_flags.modify_message_segments: + # message.message_segment = Seg(type="seglist", data=modified_message.message_segments) + # if modified_message._modify_flags.modify_plain_text: + # logger.warning(f"[{chat_id}] 插件修改了消息的纯文本内容,可能导致此内容被覆盖。") + # message.processed_plain_text = modified_message.plain_text await message.process() - _event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id) - continue_flag, modified_message = await event_bus.emit( - EventType.POST_SEND, _event_msg - ) - if not continue_flag: - logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") - return False - if modified_message: - if modified_message._modify_flags.modify_message_segments: - message.message_segment = Seg(type="seglist", data=modified_message.message_segments) - if modified_message._modify_flags.modify_plain_text: - message.processed_plain_text = modified_message.plain_text + # TODO: fix + # _event_msg = build_event_message(EventType.POST_SEND, message=message, stream_id=chat_id) + # continue_flag, modified_message = await event_bus.emit(EventType.POST_SEND, _event_msg) + # if not continue_flag: + # logger.info(f"[{chat_id}] 消息发送被插件取消: {str(message.message_segment)[:100]}...") + # return False + # if modified_message: + # if modified_message._modify_flags.modify_message_segments: + # message.message_segment = Seg(type="seglist", data=modified_message.message_segments) + # if modified_message._modify_flags.modify_plain_text: + # message.processed_plain_text = modified_message.plain_text if typing: typing_time = calculate_typing_time( - input_string=message.processed_plain_text, - thinking_start_time=message.thinking_start_time, + input_string=message.processed_plain_text, # type: ignore is_emoji=message.is_emoji, ) await asyncio.sleep(typing_time) - sent_msg = await _send_message(message, show_log=show_log) - if not sent_msg: - return False + # sent_msg = await _send_message(message, show_log=show_log) + # if not sent_msg: + # return False - _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id) - continue_flag, modified_message = await event_bus.emit( - EventType.AFTER_SEND, _event_msg - ) - if not continue_flag: - logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...") - return True - if modified_message: - if modified_message._modify_flags.modify_message_segments: - message.message_segment = Seg(type="seglist", data=modified_message.message_segments) - if modified_message._modify_flags.modify_plain_text: - message.processed_plain_text = modified_message.plain_text + # _event_msg = build_event_message(EventType.AFTER_SEND, message=message, stream_id=chat_id) + # continue_flag, modified_message = await event_bus.emit(EventType.AFTER_SEND, _event_msg) + # if not continue_flag: + # logger.info(f"[{chat_id}] 消息发送后续处理被插件取消: {str(message.message_segment)[:100]}...") + # return True + # if modified_message: + # if modified_message._modify_flags.modify_message_segments: + # message.message_segment = Seg(type="seglist", data=modified_message.message_segments) + # if modified_message._modify_flags.modify_plain_text: + # message.processed_plain_text = modified_message.plain_text if storage_message: with get_db_session() as db_session: diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 86ac7da4..f4d92eeb 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,3 +1,4 @@ +# TODO: 完全删除此文件,将所有方法该合并的合并。 import time import random import re @@ -19,7 +20,7 @@ install(extra_lines=3) logger = get_logger("chat_message_builder") -def replace_user_references( +def replace_user_references( # TODO: 整合此函数 content: Optional[str], platform: str, name_resolver: Optional[Callable[[str, str], str]] = None, @@ -262,102 +263,103 @@ def get_actions_by_timestamp_with_chat_inclusive( return [action.model_dump() for action in actions] -def get_raw_msg_by_timestamp_random( - timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" -) -> List[DatabaseMessages]: - """ - 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 - """ - # 获取所有消息,只取chat_id字段 - all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) - if not all_msgs: - return [] - # 随机选一条 - msg = random.choice(all_msgs) - chat_id = msg.chat_id - timestamp_start = msg.time - # 用 chat_id 获取该聊天在指定时间戳范围内的消息 - return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") +# TODO: 整合为统一函数,由参数控制(仿照build_readable_message) +# def get_raw_msg_by_timestamp_random( +# timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" +# ) -> List[DatabaseMessages]: +# """ +# 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 +# """ +# # 获取所有消息,只取chat_id字段 +# all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) +# if not all_msgs: +# return [] +# # 随机选一条 +# msg = random.choice(all_msgs) +# chat_id = msg.chat_id +# timestamp_start = msg.time +# # 用 chat_id 获取该聊天在指定时间戳范围内的消息 +# return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") -def get_raw_msg_by_timestamp_with_users( - timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> List[DatabaseMessages]: - """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 - limit: 限制返回的消息数量,0为不限制 - limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 - """ - filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} - # 只有当 limit 为 0 时才应用外部 sort - sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) +# def get_raw_msg_by_timestamp_with_users( +# timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" +# ) -> List[DatabaseMessages]: +# """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 +# limit: 限制返回的消息数量,0为不限制 +# limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录。默认为 'latest'。 +# """ +# filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} +# # 只有当 limit 为 0 时才应用外部 sort +# sort_order = [("time", 1)] if limit == 0 else None +# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]: - """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 - limit: 限制返回的消息数量,0为不限制 - """ - filter_query = {"time": {"$lt": timestamp}} - sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) +# def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[DatabaseMessages]: +# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 +# limit: 限制返回的消息数量,0为不限制 +# """ +# filter_query = {"time": {"$lt": timestamp}} +# sort_order = [("time", 1)] +# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_chat( - chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None -) -> List[DatabaseMessages]: - """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 - limit: 限制返回的消息数量,0为不限制 - """ - filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} - sort_order = [("time", 1)] - return find_messages( - message_filter=filter_query, - sort=sort_order, - limit=limit, - filter_intercept_message_level=filter_intercept_message_level, - ) +# def get_raw_msg_before_timestamp_with_chat( +# chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None +# ) -> List[DatabaseMessages]: +# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 +# limit: 限制返回的消息数量,0为不限制 +# """ +# filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}} +# sort_order = [("time", 1)] +# return find_messages( +# message_filter=filter_query, +# sort=sort_order, +# limit=limit, +# filter_intercept_message_level=filter_intercept_message_level, +# ) -def get_raw_msg_before_timestamp_with_users( - timestamp: float, person_ids: List[str], limit: int = 0 -) -> List[DatabaseMessages]: - """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 - limit: 限制返回的消息数量,0为不限制 - """ - filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} - sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) +# def get_raw_msg_before_timestamp_with_users( +# timestamp: float, person_ids: List[str], limit: int = 0 +# ) -> List[DatabaseMessages]: +# """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 +# limit: 限制返回的消息数量,0为不限制 +# """ +# filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} +# sort_order = [("time", 1)] +# return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: - """ - 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 - 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 - """ - # 确定有效的结束时间戳 - _timestamp_end = timestamp_end if timestamp_end is not None else time.time() +# def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +# """ +# 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 +# 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 +# """ +# # 确定有效的结束时间戳 +# _timestamp_end = timestamp_end if timestamp_end is not None else time.time() - # 确保 timestamp_start < _timestamp_end - if timestamp_start >= _timestamp_end: - # logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.") - return 0 # 起始时间大于等于结束时间,没有新消息 +# # 确保 timestamp_start < _timestamp_end +# if timestamp_start >= _timestamp_end: +# # logger.warning(f"timestamp_start ({timestamp_start}) must be less than _timestamp_end ({_timestamp_end}). Returning 0.") +# return 0 # 起始时间大于等于结束时间,没有新消息 - filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} - return count_messages(message_filter=filter_query) +# filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} +# return count_messages(message_filter=filter_query) -def num_new_messages_since_with_users( - chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str] -) -> int: - """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" - if not person_ids: # 保持空列表检查 - return 0 - filter_query = { - "chat_id": chat_id, - "time": {"$gt": timestamp_start, "$lt": timestamp_end}, - "user_id": {"$in": person_ids}, - } - return count_messages(message_filter=filter_query) +# def num_new_messages_since_with_users( +# chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str] +# ) -> int: +# """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" +# if not person_ids: # 保持空列表检查 +# return 0 +# filter_query = { +# "chat_id": chat_id, +# "time": {"$gt": timestamp_start, "$lt": timestamp_end}, +# "user_id": {"$in": person_ids}, +# } +# return count_messages(message_filter=filter_query) def _build_readable_messages_internal( @@ -563,40 +565,41 @@ def _build_readable_messages_internal( ) -def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: - # sourcery skip: use-contextlib-suppress - """ - 构建图片映射信息字符串,显示图片的具体描述内容 +# 由MessageUtils._extract_pictures_from_message替代 +# def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: +# # sourcery skip: use-contextlib-suppress +# """ +# 构建图片映射信息字符串,显示图片的具体描述内容 - Args: - pic_id_mapping: 图片ID到显示名称的映射字典 +# Args: +# pic_id_mapping: 图片ID到显示名称的映射字典 - Returns: - 格式化的映射信息字符串 - """ - if not pic_id_mapping: - return "" +# Returns: +# 格式化的映射信息字符串 +# """ +# if not pic_id_mapping: +# return "" - mapping_lines = [] +# mapping_lines = [] - # 按图片编号排序 - sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", ""))) +# # 按图片编号排序 +# sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", ""))) - for pic_id, display_name in sorted_items: - # 从数据库中获取图片描述 - description = "内容正在阅读,请稍等" - try: - with get_db_session() as session: - image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None - if image and image.description: - description = image.description - except Exception: - # 如果查询失败,保持默认描述 - pass +# for pic_id, display_name in sorted_items: +# # 从数据库中获取图片描述 +# description = "内容正在阅读,请稍等" +# try: +# with get_db_session() as session: +# image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None +# if image and image.description: +# description = image.description +# except Exception: +# # 如果查询失败,保持默认描述 +# pass - mapping_lines.append(f"[{display_name}] 的内容:{description}") +# mapping_lines.append(f"[{display_name}] 的内容:{description}") - return "\n".join(mapping_lines) +# return "\n".join(mapping_lines) def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "relative") -> str: @@ -646,68 +649,69 @@ def build_readable_actions(actions: List[DatabaseActionRecords], mode: str = "re return "\n".join(output_lines) -async def build_readable_messages_with_list( - messages: List[DatabaseMessages], - replace_bot_name: bool = True, - timestamp_mode: str = "relative", - truncate: bool = False, - pic_single: bool = False, -) -> Tuple[str, List[Tuple[float, str, str]]]: - """ - 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 - 允许通过参数控制格式化行为。 - """ - formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( - messages, - replace_bot_name, - timestamp_mode, - truncate, - pic_id_mapping=None, - pic_counter=1, - show_pic=True, - message_id_list=None, - pic_single=pic_single, - long_time_notice=False, - ) +# 由MessageUtils里面的build_readable_message替代 +# async def build_readable_messages_with_list( +# messages: List[DatabaseMessages], +# replace_bot_name: bool = True, +# timestamp_mode: str = "relative", +# truncate: bool = False, +# pic_single: bool = False, +# ) -> Tuple[str, List[Tuple[float, str, str]]]: +# """ +# 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 +# 允许通过参数控制格式化行为。 +# """ +# formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( +# messages, +# replace_bot_name, +# timestamp_mode, +# truncate, +# pic_id_mapping=None, +# pic_counter=1, +# show_pic=True, +# message_id_list=None, +# pic_single=pic_single, +# long_time_notice=False, +# ) - if not pic_single: - if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): - formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" +# if not pic_single: +# if pic_mapping_info := build_pic_mapping_info(pic_id_mapping): +# formatted_string = f"{pic_mapping_info}\n\n{formatted_string}" - return formatted_string, details_list +# return formatted_string, details_list +# 由MessageUtils里面的build_readable_message替代 +# def build_readable_messages_with_id( +# messages: List[DatabaseMessages], +# replace_bot_name: bool = True, +# timestamp_mode: str = "relative", +# read_mark: float = 0.0, +# truncate: bool = False, +# show_actions: bool = False, +# show_pic: bool = True, +# remove_emoji_stickers: bool = False, +# pic_single: bool = False, +# ) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]: +# """ +# 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 +# 允许通过参数控制格式化行为。 +# """ +# message_id_list = assign_message_ids(messages) -def build_readable_messages_with_id( - messages: List[DatabaseMessages], - replace_bot_name: bool = True, - timestamp_mode: str = "relative", - read_mark: float = 0.0, - truncate: bool = False, - show_actions: bool = False, - show_pic: bool = True, - remove_emoji_stickers: bool = False, - pic_single: bool = False, -) -> Tuple[str, List[Tuple[str, DatabaseMessages]]]: - """ - 将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。 - 允许通过参数控制格式化行为。 - """ - message_id_list = assign_message_ids(messages) +# formatted_string = build_readable_messages( +# messages=messages, +# replace_bot_name=replace_bot_name, +# timestamp_mode=timestamp_mode, +# truncate=truncate, +# show_actions=show_actions, +# show_pic=show_pic, +# read_mark=read_mark, +# message_id_list=message_id_list, +# remove_emoji_stickers=remove_emoji_stickers, +# pic_single=pic_single, +# ) - formatted_string = build_readable_messages( - messages=messages, - replace_bot_name=replace_bot_name, - timestamp_mode=timestamp_mode, - truncate=truncate, - show_actions=show_actions, - show_pic=show_pic, - read_mark=read_mark, - message_id_list=message_id_list, - remove_emoji_stickers=remove_emoji_stickers, - pic_single=pic_single, - ) - - return formatted_string, message_id_list +# return formatted_string, message_id_list def build_readable_messages( @@ -903,111 +907,112 @@ def build_readable_messages( return "".join(result_parts) -async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: bool = False) -> str: - """ - 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 - 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 - """ - if not messages: - logger.warning("没有消息,无法构建匿名消息") - return "" +# 由MessageUtils里面的build_readable_message替代 +# async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: bool = False) -> str: +# """ +# 构建匿名可读消息,将不同人的名称转为唯一占位符(A、B、C...),bot自己用SELF。 +# 处理 回复 和 @ 字段,将bbb映射为匿名占位符。 +# """ +# if not messages: +# logger.warning("没有消息,无法构建匿名消息") +# return "" - person_map = {} - current_char = ord("A") - output_lines = [] +# person_map = {} +# current_char = ord("A") +# output_lines = [] - # 图片ID映射字典 - pic_id_mapping = {} - pic_counter = 1 +# # 图片ID映射字典 +# pic_id_mapping = {} +# pic_counter = 1 - def process_pic_ids(content: str) -> str: - """处理内容中的图片ID,将其替换为[图片x]格式""" - nonlocal pic_counter +# def process_pic_ids(content: str) -> str: +# """处理内容中的图片ID,将其替换为[图片x]格式""" +# nonlocal pic_counter - # 匹配 [picid:xxxxx] 格式 - pic_pattern = r"\[picid:([^\]]+)\]" +# # 匹配 [picid:xxxxx] 格式 +# pic_pattern = r"\[picid:([^\]]+)\]" - def replace_pic_id(match): - nonlocal pic_counter - pic_id = match.group(1) +# def replace_pic_id(match): +# nonlocal pic_counter +# pic_id = match.group(1) - if pic_id not in pic_id_mapping: - pic_id_mapping[pic_id] = f"图片{pic_counter}" - pic_counter += 1 +# if pic_id not in pic_id_mapping: +# pic_id_mapping[pic_id] = f"图片{pic_counter}" +# pic_counter += 1 - return f"[{pic_id_mapping[pic_id]}]" +# return f"[{pic_id_mapping[pic_id]}]" - return re.sub(pic_pattern, replace_pic_id, content) +# return re.sub(pic_pattern, replace_pic_id, content) - def get_anon_name(platform, user_id): - # print(f"get_anon_name: platform:{platform}, user_id:{user_id}") - # print(f"global_config.bot.qq_account:{global_config.bot.qq_account}") +# def get_anon_name(platform, user_id): +# # print(f"get_anon_name: platform:{platform}, user_id:{user_id}") +# # print(f"global_config.bot.qq_account:{global_config.bot.qq_account}") - if (platform == "qq" and user_id == global_config.bot.qq_account) or ( - platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "") - ): - # print("SELF11111111111111") - return "SELF" - try: - person_id = get_person_id(platform, user_id) - except Exception as _e: - person_id = None - if not person_id: - return "?" - if person_id not in person_map: - nonlocal current_char - person_map[person_id] = chr(current_char) - current_char += 1 - return person_map[person_id] +# if (platform == "qq" and user_id == global_config.bot.qq_account) or ( +# platform == "telegram" and user_id == getattr(global_config.bot, "telegram_account", "") +# ): +# # print("SELF11111111111111") +# return "SELF" +# try: +# person_id = get_person_id(platform, user_id) +# except Exception as _e: +# person_id = None +# if not person_id: +# return "?" +# if person_id not in person_map: +# nonlocal current_char +# person_map[person_id] = chr(current_char) +# current_char += 1 +# return person_map[person_id] - for i, msg in enumerate(messages): - try: - platform = msg.chat_info.platform - user_id = msg.user_info.user_id - content = msg.display_message or msg.processed_plain_text or "" +# for i, msg in enumerate(messages): +# try: +# platform = msg.chat_info.platform +# user_id = msg.user_info.user_id +# content = msg.display_message or msg.processed_plain_text or "" - # 处理图片ID - content = process_pic_ids(content) +# # 处理图片ID +# content = process_pic_ids(content) - anon_name = get_anon_name(platform, user_id) - # print(f"anon_name:{anon_name}") +# anon_name = get_anon_name(platform, user_id) +# # print(f"anon_name:{anon_name}") - # 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器 - def anon_name_resolver(platform: str, user_id: str) -> str: - try: - return get_anon_name(platform, user_id) - except Exception: - return "?" +# # 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器 +# def anon_name_resolver(platform: str, user_id: str) -> str: +# try: +# return get_anon_name(platform, user_id) +# except Exception: +# return "?" - content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False) +# content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False) - # 构建消息头,如果启用show_ids则添加序号 - if show_ids: - header = f"[{i + 1}] {anon_name}说 " - else: - header = f"{anon_name}说 " +# # 构建消息头,如果启用show_ids则添加序号 +# if show_ids: +# header = f"[{i + 1}] {anon_name}说 " +# else: +# header = f"{anon_name}说 " - output_lines.append(header) - stripped_line = content.strip() - if stripped_line: - if stripped_line.endswith("。"): - stripped_line = stripped_line[:-1] - output_lines.append(f"{stripped_line}") - # print(f"output_lines:{output_lines}") - output_lines.append("\n") - except Exception: - continue +# output_lines.append(header) +# stripped_line = content.strip() +# if stripped_line: +# if stripped_line.endswith("。"): +# stripped_line = stripped_line[:-1] +# output_lines.append(f"{stripped_line}") +# # print(f"output_lines:{output_lines}") +# output_lines.append("\n") +# except Exception: +# continue - # 在最前面添加图片映射信息 - final_output_lines = [] - pic_mapping_info = build_pic_mapping_info(pic_id_mapping) - if pic_mapping_info: - final_output_lines.append(pic_mapping_info) - final_output_lines.append("\n\n") +# # 在最前面添加图片映射信息 +# final_output_lines = [] +# pic_mapping_info = build_pic_mapping_info(pic_id_mapping) +# if pic_mapping_info: +# final_output_lines.append(pic_mapping_info) +# final_output_lines.append("\n\n") - final_output_lines.extend(output_lines) - formatted_string = "".join(final_output_lines).strip() - return formatted_string +# final_output_lines.extend(output_lines) +# formatted_string = "".join(final_output_lines).strip() +# return formatted_string async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 81097184..4f9e0362 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -523,7 +523,7 @@ def process_llm_response(text: str, enable_splitter: bool = True, enable_chinese def calculate_typing_time( input_string: str, - thinking_start_time: float, + # thinking_start_time: float, chinese_time: float = 0.3, english_time: float = 0.15, is_emoji: bool = False, @@ -556,8 +556,8 @@ def calculate_typing_time( if is_emoji: total_time = 1 - if time.time() - thinking_start_time > 10: - total_time = 1 + # if time.time() - thinking_start_time > 10: + # total_time = 1 # print(f"thinking_start_time:{thinking_start_time}") # print(f"nowtime:{time.time()}") diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 63e6bdce..15d422b9 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,7 +1,7 @@ -import json -from dataclasses import dataclass +# import json +# from dataclasses import dataclass -from . import BaseDataModel +# from . import BaseDataModel # @dataclass @@ -208,33 +208,33 @@ from . import BaseDataModel # } -@dataclass(init=False) -class DatabaseActionRecords(BaseDataModel): - def __init__( - self, - action_id: str, - time: float, - action_name: str, - action_data: str, - action_done: bool, - action_build_into_prompt: bool, - action_prompt_display: str, - chat_id: str, - chat_info_stream_id: str, - chat_info_platform: str, - action_reasoning: str, - ): - self.action_id = action_id - self.time = time - self.action_name = action_name - if isinstance(action_data, str): - self.action_data = json.loads(action_data) - else: - raise ValueError("action_data must be a JSON string") - self.action_done = action_done - self.action_build_into_prompt = action_build_into_prompt - self.action_prompt_display = action_prompt_display - self.chat_id = chat_id - self.chat_info_stream_id = chat_info_stream_id - self.chat_info_platform = chat_info_platform - self.action_reasoning = action_reasoning +# @dataclass(init=False) +# class DatabaseActionRecords(BaseDataModel): +# def __init__( +# self, +# action_id: str, +# time: float, +# action_name: str, +# action_data: str, +# action_done: bool, +# action_build_into_prompt: bool, +# action_prompt_display: str, +# chat_id: str, +# chat_info_stream_id: str, +# chat_info_platform: str, +# action_reasoning: str, +# ): +# self.action_id = action_id +# self.time = time +# self.action_name = action_name +# if isinstance(action_data, str): +# self.action_data = json.loads(action_data) +# else: +# raise ValueError("action_data must be a JSON string") +# self.action_done = action_done +# self.action_build_into_prompt = action_build_into_prompt +# self.action_prompt_display = action_prompt_display +# self.chat_id = chat_id +# self.chat_info_stream_id = chat_info_stream_id +# self.chat_info_platform = chat_info_platform +# self.action_reasoning = action_reasoning diff --git a/src/common/data_models/info_data_model.py b/src/common/data_models/info_data_model.py index ce0781e1..672953b0 100644 --- a/src/common/data_models/info_data_model.py +++ b/src/common/data_models/info_data_model.py @@ -1,27 +1,28 @@ -from dataclasses import dataclass, field -from typing import Optional, Dict, TYPE_CHECKING -from . import BaseDataModel +# from dataclasses import dataclass, field +# from typing import Optional, Dict, TYPE_CHECKING +# from . import BaseDataModel -if TYPE_CHECKING: - from .database_data_model import DatabaseMessages - from src.core.types import ActionInfo +# if TYPE_CHECKING: +# from .database_data_model import DatabaseMessages +# from src.core.types import ActionInfo + + +# # @dataclass +# # class TargetPersonInfo(BaseDataModel): +# # platform: str = field(default_factory=str) +# # user_id: str = field(default_factory=str) +# # user_nickname: str = field(default_factory=str) +# # person_id: Optional[str] = None +# # person_name: Optional[str] = None # @dataclass -# class TargetPersonInfo(BaseDataModel): -# platform: str = field(default_factory=str) -# user_id: str = field(default_factory=str) -# user_nickname: str = field(default_factory=str) -# person_id: Optional[str] = None -# person_name: Optional[str] = None - - -@dataclass -class ActionPlannerInfo(BaseDataModel): - action_type: str = field(default_factory=str) - reasoning: Optional[str] = None - action_data: Optional[Dict] = None - action_message: Optional["DatabaseMessages"] = None - available_actions: Optional[Dict[str, "ActionInfo"]] = None - loop_start_time: Optional[float] = None - action_reasoning: Optional[str] = None +# class ActionPlannerInfo(BaseDataModel): +# action_type: str = field(default_factory=str) +# reasoning: Optional[str] = None +# action_data: Optional[Dict] = None +# action_message: Optional["DatabaseMessages"] = None +# available_actions: Optional[Dict[str, "ActionInfo"]] = None +# loop_start_time: Optional[float] = None +# action_reasoning: Optional[str] = None +# TODO: 重构 \ No newline at end of file diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 68068cda..567eefc6 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -1,22 +1,23 @@ -from dataclasses import dataclass -from typing import Optional, List, TYPE_CHECKING, Dict, Any +# from dataclasses import dataclass +# from typing import Optional, List, TYPE_CHECKING, Dict, Any -from . import BaseDataModel +# from . import BaseDataModel -if TYPE_CHECKING: - from src.common.data_models.message_data_model import ReplySetModel - from src.llm_models.payload_content.tool_option import ToolCall +# if TYPE_CHECKING: +# from src.common.data_models.message_data_model import ReplySetModel +# from src.llm_models.payload_content.tool_option import ToolCall -@dataclass -class LLMGenerationDataModel(BaseDataModel): - content: Optional[str] = None - reasoning: Optional[str] = None - model: Optional[str] = None - tool_calls: Optional[List["ToolCall"]] = None - prompt: Optional[str] = None - selected_expressions: Optional[List[int]] = None - reply_set: Optional["ReplySetModel"] = None - timing: Optional[Dict[str, Any]] = None - processed_output: Optional[List[str]] = None - timing_logs: Optional[List[str]] = None +# @dataclass +# class LLMGenerationDataModel(BaseDataModel): +# content: Optional[str] = None +# reasoning: Optional[str] = None +# model: Optional[str] = None +# tool_calls: Optional[List["ToolCall"]] = None +# prompt: Optional[str] = None +# selected_expressions: Optional[List[int]] = None +# reply_set: Optional["ReplySetModel"] = None +# timing: Optional[Dict[str, Any]] = None +# processed_output: Optional[List[str]] = None +# timing_logs: Optional[List[str]] = None +# TODO: 重构 \ No newline at end of file diff --git a/代码备忘.md b/代码备忘.md new file mode 100644 index 00000000..0d806a18 --- /dev/null +++ b/代码备忘.md @@ -0,0 +1,33 @@ +# 代码备忘 +- [ ] 检查EmojiManager的replace_an_emoji_by_llm传入的emoji是否真的是没有注册到db的 +- [ ] According to a comment, MaiMBot's check_types() accesses format_info.accept_format without None check +- [ ] 如果需要更多的消息格式支持,更新列表如下: + - [ ] `src/common/utils/utils_message.py`中的`_parse_maim_message_segment_to_component`函数 + - [ ] `src/common/data_models/message_component_model.py`中: + - [ ] 增加新的消息组件 + - [ ] 看情况修改`StandardMessageComponents`的内容 + - [ ] `MessageSequence`的`_dict_2_item`和`_item_2_dict`函数 +- [ ] **取消了从chat_manager获取ChatSession时候的deepcopy,看看会不会有问题** + +# 迁移脚本备忘 +- [ ] 迁移env到新版的bot_config管理 +- [ ] 对于旧的消息,需要重新计算其Hash(md5 -> sha256),做好映射防止消息丢失 +- [ ] PersonInfo的group_nickname名字改为group_cardname,做好映射防止数据丢失,同时存储的方式从`[{"group_id": str, "group_nick_name": str}]` -> `[{"group_id": str, "group_cardname": str}]` +- [ ] Expression中的`up_content`被移除了 +- [ ] Jargon现在chat_id(session_id_list,格式为`[["session_id", session_count]]`) -> session_id_dict(`{"session_id": session_count}`),做好映射防止数据丢失 + +# 插件开发备忘 +- [ ] 求各位插件开发不要在Dict里面塞一堆乱七八糟的东西,免得数据库存储的时候一团糟 + +# Hack备忘 +- [ ] 对于不符合内容审查要求的表情包,无法注册到数据库内,因此面对相同的非法表情包时,会导致反复识别。有成功注册的可能。 + - [ ] 考虑到数据库记录表情包不合规判定有大模型误判的风险,因此保留现有的无法注册的情况,在再次遇到的时候重新识别。 +- [ ] 目前在匿名化build message的时候,如果一个被回复的消息包含了一个转发消息组件,那么这个转发消息组件中的用户信息是不会被匿名化的,后续需要修复这个问题。(有时候感觉用正则是对的) + - [ ] 可以考虑将消息保存的时候就将消息中的用户信息匿名化,这样在后续的处理过程中就不需要担心匿名化的问题了,同时也可以避免在build message的时候进行复杂的递归处理,同时还要保存匿名映射表。 + +# 计算备忘 +- [ ] emoji的emotion比较是基于编辑距离的,考虑更换为基于语义的比较(比如使用emoji的embedding进行比较),以提高准确性和鲁棒性 +- [ ] expression的相似度比较是基于LCS的(Ratcliff-Obershelp算法),考虑更换为基于语义的比较(比如使用embedding进行比较),以提高准确性和鲁棒性 +- [ ] 为了保持代码的简洁性,HFC无论任何情况都将初始化ExpressionReflector,ExpressionLearner,JargonMiner实例,无论配置文件中是否在此聊天流启用了他们。 + - [ ] 可优化方向:将其置为Optional,在不启用的情况下不进行初始化 + - [ ] 当配置文件重载时,重新分析所有启用判定,所有HFC进行并行检查,将启用的进行实例化。不启用的实例化移除引用,释放内存。 \ No newline at end of file