diff --git a/pytests/message_test/session_message_test.py b/pytests/message_test/session_message_test.py index 14e362b4..a48a3f8b 100644 --- a/pytests/message_test/session_message_test.py +++ b/pytests/message_test/session_message_test.py @@ -19,7 +19,6 @@ if TYPE_CHECKING: AtComponent, ReplyComponent, ForwardNodeComponent, - StandardMessageComponents, ) @@ -122,6 +121,9 @@ def setup_mocks(monkeypatch): db_mod = _stub_module("src.common.database.database") db_mod.get_db_session = get_db_session db_mod.get_manual_db_session = get_manual_db_session + + db_model_mod = _stub_module("src.common.database.database_model") + db_model_mod.Messages = None # 可以根据需要添加更多的属性或方法 emoji_manager_mod = _stub_module("src.chat.emoji_system.emoji_manager") emoji_manager_mod.emoji_manager = None # 可以根据需要添加更多的属性或方法 diff --git a/src/common/data_models/message_component_data_model.py b/src/common/data_models/message_component_data_model.py index f54dee0a..3b511eab 100644 --- a/src/common/data_models/message_component_data_model.py +++ b/src/common/data_models/message_component_data_model.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from copy import deepcopy from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo -from typing import Optional, List, Union, Dict, Any, Sequence +from typing import Optional, List, Union, Dict, Any import asyncio import hashlib @@ -239,7 +239,7 @@ class ForwardComponent(BaseMessageComponentModel): self, user_nickname: str, message_id: str, - content: Sequence[StandardMessageComponents], + content: List[StandardMessageComponents], user_id: Optional[str] = None, user_cardname: Optional[str] = None, ): @@ -247,7 +247,7 @@ class ForwardComponent(BaseMessageComponentModel): """转发节点的发送者昵称""" self.message_id: str = message_id """转发节点的消息ID""" - self.content: Sequence[StandardMessageComponents] = content + self.content: List[StandardMessageComponents] = content """消息内容""" self.user_id: Optional[str] = user_id """转发节点的发送者ID,可能为 None""" @@ -264,7 +264,7 @@ class ForwardComponent(BaseMessageComponentModel): class MessageSequence: """消息组件序列,包含一个消息中的所有组件,按照顺序排列""" - def __init__(self, components: Sequence[StandardMessageComponents]): + def __init__(self, components: List[StandardMessageComponents]): """ 创建一个消息组件序列 @@ -274,7 +274,42 @@ class MessageSequence: 因此也可以包含多个`ReplyComponent`组件(例如回复多条消息)。 如果需要对组件进行去重或校验,还请在使用时自行处理。 """ - self.components: Sequence[StandardMessageComponents] = components + self.components: List[StandardMessageComponents] = components + + """链式调用的接口,方便在创建消息组件序列时逐步追加组件""" + + def text(self, text: str) -> "MessageSequence": + """在消息组件序列末尾追加一个文本组件""" + self.components.append(TextComponent(text)) + return self + + def image(self, binary_data: bytes, content: Optional[str] = None): + """在消息组件序列末尾追加一个图片组件""" + hash_str = hashlib.sha256(binary_data).hexdigest() + self.components.append(ImageComponent(binary_hash=hash_str, content=content, binary_data=binary_data)) + return self + + def emoji(self, binary_data: bytes, content: Optional[str] = None): + """在消息组件序列末尾追加一个表情组件""" + hash_str = hashlib.sha256(binary_data).hexdigest() + self.components.append(EmojiComponent(binary_hash=hash_str, content=content, binary_data=binary_data)) + return self + + def voice(self, binary_data: bytes, content: Optional[str] = None): + """在消息组件序列末尾追加一个语音组件""" + hash_str = hashlib.sha256(binary_data).hexdigest() + self.components.append(VoiceComponent(binary_hash=hash_str, content=content, binary_data=binary_data)) + return self + + def at(self, target_user_id: str): + """在消息组件序列末尾追加一个@组件""" + self.components.append(AtComponent(target_user_id)) + return self + + def reply(self, target_message_id: str): + """在消息组件序列末尾追加一个回复组件""" + self.components.append(ReplyComponent(target_message_id=target_message_id)) + return self def to_dict(self) -> List[Dict[str, Any]]: """将消息序列转换为字典列表格式,便于存储或传输""" @@ -283,7 +318,7 @@ class MessageSequence: @classmethod def from_dict(cls, data: List[Dict[str, Any]]): """从字典列表格式创建消息序列实例""" - components: Sequence[StandardMessageComponents] = [] + components: List[StandardMessageComponents] = [] components.extend(cls._dict_2_item(item) for item in data) return cls(components=components) diff --git a/src/common/utils/utils_message.py b/src/common/utils/utils_message.py index e663de18..f4c5ff2a 100644 --- a/src/common/utils/utils_message.py +++ b/src/common/utils/utils_message.py @@ -1,5 +1,5 @@ from maim_message import MessageBase, Seg -from typing import List, Tuple, Optional, Sequence, TYPE_CHECKING +from typing import List, Tuple, Optional, TYPE_CHECKING import base64 import hashlib @@ -38,7 +38,7 @@ class MessageUtils: def from_maim_message_segments_to_MaiSeq(message: "MessageBase") -> MessageSequence: """从maim_message.MessageBase.message_segment转换为MessageSequence""" raw_msg_seq = message.message_segment - components: Sequence[StandardMessageComponents] = [] + components: List[StandardMessageComponents] = [] if not raw_msg_seq: return MessageSequence(components) if raw_msg_seq.type == "seglist":