Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -1,14 +1,23 @@
|
||||
from asyncio import Task
|
||||
from datetime import datetime
|
||||
from maim_message import (
|
||||
MessageBase,
|
||||
UserInfo as MaimUserInfo,
|
||||
GroupInfo as MaimGroupInfo,
|
||||
BaseMessageInfo as MaimBaseMessageInfo,
|
||||
Seg,
|
||||
)
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import List, Dict, Tuple, Sequence
|
||||
from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
@@ -19,6 +28,10 @@ from src.common.data_models.message_component_data_model import (
|
||||
ForwardNodeComponent,
|
||||
StandardMessageComponents,
|
||||
)
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -207,3 +220,166 @@ class SessionMessage(MaiMessage):
|
||||
else:
|
||||
processed_texts.append(result)
|
||||
return " ".join(processed_texts)
|
||||
|
||||
|
||||
class MessageSending(MaiMessage):
|
||||
"""发送状态的消息类,继承 MaiMessage 基类。
|
||||
|
||||
用于构建、处理和发送机器人的回复消息。
|
||||
复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法,
|
||||
额外管理发送专属的会话信息和控制字段。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
session: "BotChatSession",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Seg,
|
||||
sender_info: Optional[UserInfo] = None,
|
||||
reply: Optional[MaiMessage] = None,
|
||||
display_message: str = "",
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
reply_to: Optional[str] = None,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
):
|
||||
# 初始化 MaiMessage 基类
|
||||
super().__init__(message_id=message_id, timestamp=datetime.now())
|
||||
|
||||
# 发送专属字段
|
||||
self.session = session
|
||||
self.sender_info = sender_info
|
||||
self.message_segment = message_segment
|
||||
self.reply = reply
|
||||
self.is_head = is_head
|
||||
self.thinking_start_time = thinking_start_time
|
||||
self.selected_expressions = selected_expressions
|
||||
self.reply_to_message_id: Optional[str] = reply.message_id if reply else None
|
||||
self.interest_value: float = 0.0
|
||||
|
||||
# 填充 MaiMessage 标准字段
|
||||
self.platform = session.platform
|
||||
self.session_id = session.session_id
|
||||
self.is_emoji = is_emoji
|
||||
self.reply_to = reply_to
|
||||
self.display_message = display_message
|
||||
self.processed_plain_text = ""
|
||||
|
||||
# 构建 message_info:DB 存储时 user_info 始终为 bot 信息
|
||||
# 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理
|
||||
group_info = self._resolve_group_info()
|
||||
self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info)
|
||||
|
||||
# bot_user_info 单独保存,to_maim_message 覆写时还需要
|
||||
self.bot_user_info = bot_user_info
|
||||
|
||||
# 将 Seg 转换为 MessageSequence,供基类的 to_db_instance / to_maim_message 使用
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=message_segment)
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def _resolve_group_info(self) -> Optional[GroupInfo]:
|
||||
"""从 session 中解析群信息"""
|
||||
if not self.session.group_id:
|
||||
return None
|
||||
group_name = ""
|
||||
if (
|
||||
self.session.context
|
||||
and self.session.context.message
|
||||
and self.session.context.message.message_info.group_info
|
||||
):
|
||||
group_name = self.session.context.message.message_info.group_info.group_name
|
||||
return GroupInfo(group_id=self.session.group_id, group_name=group_name)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息段,生成 processed_plain_text(使用 SessionMessage 的组件处理能力)"""
|
||||
# 同步 message_segment → raw_message(插件可能修改了 message_segment)
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=self.message_segment)
|
||||
)
|
||||
if self.raw_message and self.raw_message.components:
|
||||
tasks = [self._process_component(c) for c in self.raw_message.components]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
texts = []
|
||||
for r in results:
|
||||
if isinstance(r, BaseException):
|
||||
logger.error(f"处理发送消息组件时发生错误: {r}")
|
||||
elif r:
|
||||
texts.append(r)
|
||||
self.processed_plain_text = " ".join(texts)
|
||||
|
||||
async def _process_component(self, component: StandardMessageComponents) -> str:
|
||||
"""简单处理单个标准组件为纯文本描述"""
|
||||
if isinstance(component, TextComponent):
|
||||
return component.text
|
||||
elif isinstance(component, ImageComponent):
|
||||
return "[图片]"
|
||||
elif isinstance(component, EmojiComponent):
|
||||
return "[表情包]"
|
||||
elif isinstance(component, VoiceComponent):
|
||||
return "[语音]"
|
||||
elif isinstance(component, AtComponent):
|
||||
return f"[@{component.target_user_id}]"
|
||||
elif isinstance(component, ReplyComponent):
|
||||
return ""
|
||||
else:
|
||||
return f"[{type(component).__name__}]"
|
||||
|
||||
def build_reply(self) -> None:
|
||||
"""构建回复消息段,在 message_segment 前插入 reply 段"""
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_id),
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
# 同步更新 raw_message
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=self.message_segment)
|
||||
)
|
||||
|
||||
async def to_maim_message(self) -> MessageBase:
|
||||
"""覆写基类方法:发送消息需要特殊处理 user_info(私聊/群聊差异)"""
|
||||
maim_bot_user_info = MaimUserInfo(
|
||||
user_id=self.bot_user_info.user_id,
|
||||
user_nickname=self.bot_user_info.user_nickname,
|
||||
user_cardname=self.bot_user_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
maim_group_info = None
|
||||
if self.message_info.group_info:
|
||||
maim_group_info = MaimGroupInfo(
|
||||
group_id=self.message_info.group_info.group_id,
|
||||
group_name=self.message_info.group_info.group_name,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
# 私聊时 user_info 填接收者信息(sender_info),群聊时填 bot
|
||||
if maim_group_info is None and self.sender_info:
|
||||
msg_user_info = MaimUserInfo(
|
||||
user_id=self.sender_info.user_id,
|
||||
user_nickname=self.sender_info.user_nickname,
|
||||
user_cardname=self.sender_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
else:
|
||||
msg_user_info = maim_bot_user_info
|
||||
|
||||
maim_msg_info = MaimBaseMessageInfo(
|
||||
platform=self.platform,
|
||||
message_id=self.message_id,
|
||||
time=time.time(),
|
||||
group_info=maim_group_info,
|
||||
user_info=msg_user_info,
|
||||
)
|
||||
|
||||
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
|
||||
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))
|
||||
|
||||
Reference in New Issue
Block a user