Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -4,10 +4,10 @@ MaiBot模块系统
|
||||
"""
|
||||
|
||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager
|
||||
|
||||
# 导出主要组件供外部使用
|
||||
__all__ = [
|
||||
"get_chat_manager",
|
||||
"chat_manager",
|
||||
"emoji_manager",
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
|
||||
|
||||
# from src.config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
from src.chat.message_receive.message import Message
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage
|
||||
from .pfc_types import ConversationState
|
||||
from .pfc import ChatObserver, GoalAnalyzer
|
||||
from .message_sender import DirectMessageSender
|
||||
@@ -16,9 +16,8 @@ from .action_planner import ActionPlanner
|
||||
from .observation_info import ObservationInfo
|
||||
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
|
||||
from .reply_generator import ReplyGenerator
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from maim_message import UserInfo
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
||||
from .waiter import Waiter
|
||||
|
||||
@@ -60,7 +59,7 @@ class Conversation:
|
||||
self.direct_sender = DirectMessageSender(self.private_name)
|
||||
|
||||
# 获取聊天流信息
|
||||
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id)
|
||||
|
||||
self.stop_action_planner = False
|
||||
except Exception as e:
|
||||
@@ -265,34 +264,34 @@ class Conversation:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||
"""将消息字典转换为Message对象"""
|
||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage:
|
||||
"""将消息字典转换为MaiMessage对象"""
|
||||
from datetime import datetime as dt
|
||||
from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import MessageSequence
|
||||
|
||||
try:
|
||||
# 尝试从 msg_dict 直接获取 chat_stream,如果失败则从全局 get_chat_manager 获取
|
||||
chat_info = msg_dict.get("chat_info")
|
||||
if chat_info and isinstance(chat_info, dict):
|
||||
chat_stream = ChatStream.from_dict(chat_info)
|
||||
elif self.chat_stream: # 使用实例变量中的 chat_stream
|
||||
chat_stream = self.chat_stream
|
||||
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
|
||||
chat_stream = get_chat_manager().get_stream(self.stream_id)
|
||||
if not chat_stream:
|
||||
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
|
||||
|
||||
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
|
||||
|
||||
return Message(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
|
||||
chat_stream=chat_stream, # 使用确定的 chat_stream
|
||||
time=msg_dict.get("time", time.time()), # 提供默认时间
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
|
||||
user_info_dict = msg_dict.get("user_info", {})
|
||||
user_info = MaiUserInfo(
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
)
|
||||
|
||||
msg = MaiMessage(
|
||||
message_id=msg_dict.get("message_id", f"gen_{time.time()}"),
|
||||
timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())),
|
||||
)
|
||||
msg.message_info = MessageInfo(user_info=user_info)
|
||||
msg.platform = user_info_dict.get("platform", "")
|
||||
msg.session_id = self.stream_id
|
||||
msg.processed_plain_text = msg_dict.get("processed_plain_text", "")
|
||||
msg.raw_message = MessageSequence(components=[])
|
||||
msg.initialized = True
|
||||
return msg
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
|
||||
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
|
||||
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
|
||||
raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e
|
||||
|
||||
async def _handle_action(
|
||||
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
|
||||
@@ -687,7 +686,7 @@ class Conversation:
|
||||
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
|
||||
return
|
||||
if not self.chat_stream:
|
||||
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
|
||||
logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。")
|
||||
return
|
||||
|
||||
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import Message, MessageSending
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -19,18 +21,17 @@ class DirectMessageSender:
|
||||
|
||||
def __init__(self, private_name: str):
|
||||
self.private_name = private_name
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
content: str,
|
||||
reply_to_message: Optional[Message] = None,
|
||||
reply_to_message: Optional[MaiMessage] = None,
|
||||
) -> None:
|
||||
"""发送消息到聊天流
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流
|
||||
chat_stream: 聊天会话
|
||||
content: 消息内容
|
||||
reply_to_message: 要回复的消息(可选)
|
||||
"""
|
||||
@@ -42,18 +43,22 @@ class DirectMessageSender:
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
# 用当前时间作为message_id,和之前那套sender一样
|
||||
message_id = f"dm{round(time.time(), 2)}"
|
||||
|
||||
# 构建发送者信息(私聊时为接收者)
|
||||
sender_info = None
|
||||
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
|
||||
sender_info = reply_to_message.message_info.user_info
|
||||
|
||||
# 构建消息对象
|
||||
message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
session=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
|
||||
sender_info=sender_info,
|
||||
message_segment=segments,
|
||||
reply=reply_to_message,
|
||||
is_head=True,
|
||||
@@ -61,17 +66,11 @@ class DirectMessageSender:
|
||||
thinking_start_time=time.time(),
|
||||
)
|
||||
|
||||
# 处理消息
|
||||
await message.process()
|
||||
|
||||
# 发送消息(直接调用底层 API)
|
||||
from src.chat.message_receive.uni_message_sender import _send_message
|
||||
|
||||
sent = await _send_message(message, show_log=True)
|
||||
# 发送消息
|
||||
message_sender = UniversalMessageSender()
|
||||
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
|
||||
|
||||
if sent:
|
||||
# 存储消息
|
||||
await self.storage.store_message(message, chat_stream)
|
||||
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
|
||||
else:
|
||||
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.message_data_model import ReplyContentType
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.brain_chat.brain_planner import BrainPlanner
|
||||
@@ -73,10 +73,10 @@ class BrainChatting:
|
||||
"""
|
||||
# 基础属性
|
||||
self.stream_id: str = chat_id # 聊天流ID
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore
|
||||
if not self.chat_stream:
|
||||
raise ValueError(f"无法找到聊天流: {self.stream_id}")
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]"
|
||||
|
||||
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
@@ -38,7 +38,7 @@ install(extra_lines=3)
|
||||
class BrainPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
|
||||
@@ -5,9 +5,8 @@ import time
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.common.message_repository import count_messages
|
||||
|
||||
@@ -121,28 +120,24 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None)
|
||||
|
||||
|
||||
async def send_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
chat = await _chat_manager.get_or_create_session(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
user_id="114514",
|
||||
group_id="114514",
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
|
||||
message_type="state", content="typing", stream_id=chat.session_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
async def stop_typing():
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
chat = await _chat_manager.get_or_create_session(
|
||||
platform="amaidesu_default",
|
||||
user_info=None,
|
||||
group_info=group_info,
|
||||
user_id="114514",
|
||||
group_id="114514",
|
||||
)
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
|
||||
message_type="state", content="stop_typing", stream_id=chat.session_id, storage_message=False
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
from src.chat.message_receive.message_old import MessageRecv
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
@@ -41,14 +40,14 @@ class ChatBot:
|
||||
|
||||
self._started = True
|
||||
|
||||
async def _create_pfc_chat(self, message: MessageRecv):
|
||||
async def _create_pfc_chat(self, message: SessionMessage):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = str(message.chat_stream.stream_id)
|
||||
chat_id = message.session_id
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
@@ -177,12 +176,12 @@ class ChatBot:
|
||||
logger.error(f"[新运行时] 执行命令 {matched.full_name} 异常: {e}", exc_info=True)
|
||||
return True, str(e), True
|
||||
|
||||
async def handle_notice_message(self, message: MessageRecv):
|
||||
if message.message_info.message_id == "notice":
|
||||
async def handle_notice_message(self, message: SessionMessage):
|
||||
if message.message_id == "notice":
|
||||
message.is_notify = True
|
||||
logger.debug("notice消息")
|
||||
try:
|
||||
seg = message.message_segment
|
||||
seg = getattr(message, "message_segment", None) # SessionMessage 没有 message_segment
|
||||
mi = message.message_info
|
||||
sub_type = None
|
||||
scene = None
|
||||
@@ -246,10 +245,8 @@ class ChatBot:
|
||||
return
|
||||
mmc_message_id = message_data.get("echo")
|
||||
actual_message_id = message_data.get("actual_id")
|
||||
if MessageStorage.update_message(mmc_message_id, actual_message_id):
|
||||
logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}")
|
||||
else:
|
||||
logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}")
|
||||
# TODO: Implement message ID update in new architecture
|
||||
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
|
||||
@@ -1,14 +1,23 @@
|
||||
from asyncio import Task
|
||||
from datetime import datetime
|
||||
from maim_message import (
|
||||
MessageBase,
|
||||
UserInfo as MaimUserInfo,
|
||||
GroupInfo as MaimGroupInfo,
|
||||
BaseMessageInfo as MaimBaseMessageInfo,
|
||||
Seg,
|
||||
)
|
||||
from rich.traceback import install
|
||||
from sqlmodel import select
|
||||
from typing import List, Dict, Tuple, Sequence
|
||||
from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo
|
||||
from src.common.data_models.message_component_data_model import (
|
||||
TextComponent,
|
||||
ImageComponent,
|
||||
@@ -19,6 +28,10 @@ from src.common.data_models.message_component_data_model import (
|
||||
ForwardNodeComponent,
|
||||
StandardMessageComponents,
|
||||
)
|
||||
from src.common.utils.utils_message import MessageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
@@ -207,3 +220,166 @@ class SessionMessage(MaiMessage):
|
||||
else:
|
||||
processed_texts.append(result)
|
||||
return " ".join(processed_texts)
|
||||
|
||||
|
||||
class MessageSending(MaiMessage):
|
||||
"""发送状态的消息类,继承 MaiMessage 基类。
|
||||
|
||||
用于构建、处理和发送机器人的回复消息。
|
||||
复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法,
|
||||
额外管理发送专属的会话信息和控制字段。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
session: "BotChatSession",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Seg,
|
||||
sender_info: Optional[UserInfo] = None,
|
||||
reply: Optional[MaiMessage] = None,
|
||||
display_message: str = "",
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
reply_to: Optional[str] = None,
|
||||
selected_expressions: Optional[List[int]] = None,
|
||||
):
|
||||
# 初始化 MaiMessage 基类
|
||||
super().__init__(message_id=message_id, timestamp=datetime.now())
|
||||
|
||||
# 发送专属字段
|
||||
self.session = session
|
||||
self.sender_info = sender_info
|
||||
self.message_segment = message_segment
|
||||
self.reply = reply
|
||||
self.is_head = is_head
|
||||
self.thinking_start_time = thinking_start_time
|
||||
self.selected_expressions = selected_expressions
|
||||
self.reply_to_message_id: Optional[str] = reply.message_id if reply else None
|
||||
self.interest_value: float = 0.0
|
||||
|
||||
# 填充 MaiMessage 标准字段
|
||||
self.platform = session.platform
|
||||
self.session_id = session.session_id
|
||||
self.is_emoji = is_emoji
|
||||
self.reply_to = reply_to
|
||||
self.display_message = display_message
|
||||
self.processed_plain_text = ""
|
||||
|
||||
# 构建 message_info:DB 存储时 user_info 始终为 bot 信息
|
||||
# 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理
|
||||
group_info = self._resolve_group_info()
|
||||
self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info)
|
||||
|
||||
# bot_user_info 单独保存,to_maim_message 覆写时还需要
|
||||
self.bot_user_info = bot_user_info
|
||||
|
||||
# 将 Seg 转换为 MessageSequence,供基类的 to_db_instance / to_maim_message 使用
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=message_segment)
|
||||
)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def _resolve_group_info(self) -> Optional[GroupInfo]:
|
||||
"""从 session 中解析群信息"""
|
||||
if not self.session.group_id:
|
||||
return None
|
||||
group_name = ""
|
||||
if (
|
||||
self.session.context
|
||||
and self.session.context.message
|
||||
and self.session.context.message.message_info.group_info
|
||||
):
|
||||
group_name = self.session.context.message.message_info.group_info.group_name
|
||||
return GroupInfo(group_id=self.session.group_id, group_name=group_name)
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息段,生成 processed_plain_text(使用 SessionMessage 的组件处理能力)"""
|
||||
# 同步 message_segment → raw_message(插件可能修改了 message_segment)
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=self.message_segment)
|
||||
)
|
||||
if self.raw_message and self.raw_message.components:
|
||||
tasks = [self._process_component(c) for c in self.raw_message.components]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
texts = []
|
||||
for r in results:
|
||||
if isinstance(r, BaseException):
|
||||
logger.error(f"处理发送消息组件时发生错误: {r}")
|
||||
elif r:
|
||||
texts.append(r)
|
||||
self.processed_plain_text = " ".join(texts)
|
||||
|
||||
async def _process_component(self, component: StandardMessageComponents) -> str:
|
||||
"""简单处理单个标准组件为纯文本描述"""
|
||||
if isinstance(component, TextComponent):
|
||||
return component.text
|
||||
elif isinstance(component, ImageComponent):
|
||||
return "[图片]"
|
||||
elif isinstance(component, EmojiComponent):
|
||||
return "[表情包]"
|
||||
elif isinstance(component, VoiceComponent):
|
||||
return "[语音]"
|
||||
elif isinstance(component, AtComponent):
|
||||
return f"[@{component.target_user_id}]"
|
||||
elif isinstance(component, ReplyComponent):
|
||||
return ""
|
||||
else:
|
||||
return f"[{type(component).__name__}]"
|
||||
|
||||
def build_reply(self) -> None:
|
||||
"""构建回复消息段,在 message_segment 前插入 reply 段"""
|
||||
if self.reply:
|
||||
self.reply_to_message_id = self.reply.message_id
|
||||
self.message_segment = Seg(
|
||||
type="seglist",
|
||||
data=[
|
||||
Seg(type="reply", data=self.reply.message_id),
|
||||
self.message_segment,
|
||||
],
|
||||
)
|
||||
# 同步更新 raw_message
|
||||
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
|
||||
MessageBase(message_info=None, message_segment=self.message_segment)
|
||||
)
|
||||
|
||||
async def to_maim_message(self) -> MessageBase:
|
||||
"""覆写基类方法:发送消息需要特殊处理 user_info(私聊/群聊差异)"""
|
||||
maim_bot_user_info = MaimUserInfo(
|
||||
user_id=self.bot_user_info.user_id,
|
||||
user_nickname=self.bot_user_info.user_nickname,
|
||||
user_cardname=self.bot_user_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
maim_group_info = None
|
||||
if self.message_info.group_info:
|
||||
maim_group_info = MaimGroupInfo(
|
||||
group_id=self.message_info.group_info.group_id,
|
||||
group_name=self.message_info.group_info.group_name,
|
||||
platform=self.platform,
|
||||
)
|
||||
|
||||
# 私聊时 user_info 填接收者信息(sender_info),群聊时填 bot
|
||||
if maim_group_info is None and self.sender_info:
|
||||
msg_user_info = MaimUserInfo(
|
||||
user_id=self.sender_info.user_id,
|
||||
user_nickname=self.sender_info.user_nickname,
|
||||
user_cardname=self.sender_info.user_cardname,
|
||||
platform=self.platform,
|
||||
)
|
||||
else:
|
||||
msg_user_info = maim_bot_user_info
|
||||
|
||||
maim_msg_info = MaimBaseMessageInfo(
|
||||
platform=self.platform,
|
||||
message_id=self.message_id,
|
||||
time=time.time(),
|
||||
group_info=maim_group_info,
|
||||
user_info=msg_user_info,
|
||||
)
|
||||
|
||||
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
|
||||
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))
|
||||
|
||||
@@ -6,8 +6,8 @@ from maim_message import Seg
|
||||
|
||||
from src.common.message_server.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import get_db_session
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
|
||||
@@ -130,8 +130,8 @@ def parse_message_segments(segment) -> list:
|
||||
async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
"""合并后的消息发送函数,包含WS发送和日志记录"""
|
||||
message_preview = truncate_message(message.processed_plain_text, max_length=200)
|
||||
platform = message.message_info.platform
|
||||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||||
platform = message.platform
|
||||
group_id = message.session.group_id
|
||||
|
||||
try:
|
||||
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
|
||||
@@ -221,33 +221,14 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
raise legacy_exception
|
||||
return False
|
||||
|
||||
# 使用 MessageConverter 转换 Legacy MessageBase 到 APIMessageBase
|
||||
# 发送场景:MaiMBot 发送回复消息给外部用户
|
||||
# group_info/user_info 是消息接收者信息,放入 receiver_info
|
||||
# 使用 MessageConverter 转换为 API 消息
|
||||
from maim_message import MessageConverter
|
||||
|
||||
# 修复 API Server Fallback 模式下的 user_info 问题
|
||||
# 在 Legacy 模式下,MessageSending.to_dict() 的第 454 行会将 user_info 替换为 chat_stream.user_info
|
||||
# 但在 API Server Fallback 模式下,MessageConverter.to_api_send() 直接访问 message 对象,不调用 to_dict()
|
||||
# 需要手动应用相同的变通方案:在私聊场景下,user_info 应该是接收者(sender_info)
|
||||
message_for_conversion = message
|
||||
if hasattr(message, "message_info") and message.message_info.group_info is None:
|
||||
# 私聊场景:group_info 为 None
|
||||
# user_info 应该是接收者,从 chat_stream.user_info 或 sender_info 获取
|
||||
temp_dict = message.to_dict()
|
||||
if (
|
||||
hasattr(message, "chat_stream")
|
||||
and message.chat_stream
|
||||
and hasattr(message.chat_stream, "user_info")
|
||||
):
|
||||
temp_dict["message_info"]["user_info"] = message.chat_stream.user_info.to_dict()
|
||||
# 重新构建 MessageBase 对象(不保留 sender_info 等扩展属性)
|
||||
from maim_message import MessageBase
|
||||
|
||||
message_for_conversion = MessageBase.from_dict(temp_dict)
|
||||
# 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
|
||||
message_base = await message.to_maim_message()
|
||||
|
||||
api_message = MessageConverter.to_api_send(
|
||||
message=message_for_conversion,
|
||||
message=message_base,
|
||||
api_key=target_api_key,
|
||||
platform=platform,
|
||||
)
|
||||
@@ -278,10 +259,11 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return False
|
||||
|
||||
try:
|
||||
send_result = await get_global_api().send_message(message)
|
||||
message_base = await message.to_maim_message()
|
||||
send_result = await get_global_api().send_message(message_base)
|
||||
if send_result:
|
||||
if show_log:
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
|
||||
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
|
||||
return True
|
||||
else:
|
||||
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
|
||||
@@ -297,7 +279,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return await send_with_new_api(legacy_exception=legacy_e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
@@ -306,7 +288,7 @@ class UniversalMessageSender:
|
||||
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage = MessageStorage()
|
||||
pass
|
||||
|
||||
async def send_message(
|
||||
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
|
||||
@@ -321,15 +303,15 @@ class UniversalMessageSender:
|
||||
用法:
|
||||
- typing=True 时,发送前会有打字等待。
|
||||
"""
|
||||
if not message.chat_stream:
|
||||
logger.error("消息缺少 chat_stream,无法发送")
|
||||
raise ValueError("消息缺少 chat_stream,无法发送")
|
||||
if not message.message_info or not message.message_info.message_id:
|
||||
logger.error("消息缺少 message_info 或 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_info 或 message_id,无法发送")
|
||||
if not message.session:
|
||||
logger.error("消息缺少 session,无法发送")
|
||||
raise ValueError("消息缺少 session,无法发送")
|
||||
if not message.message_id:
|
||||
logger.error("消息缺少 message_id,无法发送")
|
||||
raise ValueError("消息缺少 message_id,无法发送")
|
||||
|
||||
chat_id = message.chat_stream.stream_id
|
||||
message_id = message.message_info.message_id
|
||||
chat_id = message.session_id
|
||||
message_id = message.message_id
|
||||
|
||||
try:
|
||||
if set_reply:
|
||||
@@ -391,7 +373,8 @@ class UniversalMessageSender:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
if storage_message:
|
||||
await self.storage.store_message(message, message.chat_stream)
|
||||
with get_db_session() as db_session:
|
||||
db_session.add(message.to_db_instance())
|
||||
|
||||
return sent_msg
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
@@ -35,7 +35,7 @@ class ActionManager:
|
||||
action_reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
log_prefix: str,
|
||||
shutting_down: bool = False,
|
||||
action_message: Optional[DatabaseMessages] = None,
|
||||
|
||||
@@ -4,15 +4,12 @@ from typing import List, Dict, TYPE_CHECKING, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("action_manager")
|
||||
|
||||
|
||||
@@ -27,8 +24,8 @@ class ActionModifier:
|
||||
def __init__(self, action_manager: ActionManager, chat_id: str):
|
||||
"""初始化动作处理器"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.action_manager = action_manager
|
||||
|
||||
@@ -121,7 +118,7 @@ class ActionModifier:
|
||||
available_actions_text = "、".join(available_actions) if available_actions else "无"
|
||||
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
|
||||
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
|
||||
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession):
|
||||
type_mismatched_actions: List[Tuple[str, str]] = []
|
||||
for action_name, action_info in all_actions.items():
|
||||
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
import traceback
|
||||
import random
|
||||
import re
|
||||
import contextlib
|
||||
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
|
||||
from collections import OrderedDict
|
||||
from rich.traceback import install
|
||||
@@ -21,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
|
||||
)
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.apis.message_api import translate_pid_to_description
|
||||
@@ -39,7 +40,7 @@ install(extra_lines=3)
|
||||
class ActionPlanner:
|
||||
def __init__(self, chat_id: str, action_manager: ActionManager):
|
||||
self.chat_id = chat_id
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
|
||||
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
|
||||
self.action_manager = action_manager
|
||||
# LLM规划器配置
|
||||
self.planner_llm = LLMRequest(
|
||||
@@ -80,7 +81,7 @@ class ActionPlanner:
|
||||
if not text:
|
||||
return text
|
||||
|
||||
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
|
||||
id_to_message = dict(message_id_list)
|
||||
|
||||
# 匹配m后带2-4位数字,前后不是字母数字下划线
|
||||
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
|
||||
@@ -223,7 +224,7 @@ class ActionPlanner:
|
||||
action_data=action_data,
|
||||
action_message=target_message,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning if extracted_reasoning else None,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -238,7 +239,7 @@ class ActionPlanner:
|
||||
action_data={},
|
||||
action_message=None,
|
||||
available_actions=available_actions_dict,
|
||||
action_reasoning=extracted_reasoning if extracted_reasoning else None,
|
||||
action_reasoning=extracted_reasoning or None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -292,8 +293,7 @@ class ActionPlanner:
|
||||
if new_words:
|
||||
for word in new_words:
|
||||
if isinstance(word, str):
|
||||
word = word.strip()
|
||||
if word:
|
||||
if word := word.strip():
|
||||
cleaned_new_words.append(word)
|
||||
|
||||
# 获取缓存中的黑话列表
|
||||
@@ -351,10 +351,9 @@ class ActionPlanner:
|
||||
break
|
||||
|
||||
# 如果当前 plan 的 reply 没有提取,移除最老的1个
|
||||
if not has_extracted_unknown_words:
|
||||
if len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0:
|
||||
self.unknown_words_cache.popitem(last=False)
|
||||
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话,移除最老的1个缓存")
|
||||
|
||||
# 对于每个 reply action,合并缓存和新提取的黑话
|
||||
for action in actions:
|
||||
@@ -363,10 +362,7 @@ class ActionPlanner:
|
||||
new_words = action_data.get("unknown_words")
|
||||
|
||||
# 合并新提取的和缓存的黑话列表
|
||||
merged_words = self._merge_unknown_words_with_cache(new_words)
|
||||
|
||||
# 更新 action_data
|
||||
if merged_words:
|
||||
if merged_words := self._merge_unknown_words_with_cache(new_words):
|
||||
action_data["unknown_words"] = merged_words
|
||||
logger.debug(
|
||||
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
|
||||
@@ -449,15 +445,12 @@ class ActionPlanner:
|
||||
# 如果有强制回复消息,确保回复该消息
|
||||
if force_reply_message:
|
||||
# 检查是否已经有回复该消息的 action
|
||||
has_reply_to_force_message = False
|
||||
for action in actions:
|
||||
if (
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
):
|
||||
has_reply_to_force_message = True
|
||||
break
|
||||
has_reply_to_force_message = any(
|
||||
action.action_type == "reply"
|
||||
and action.action_message
|
||||
and action.action_message.message_id == force_reply_message.message_id
|
||||
for action in actions
|
||||
)
|
||||
|
||||
# 如果没有回复该消息,强制添加回复 action
|
||||
if not has_reply_to_force_message:
|
||||
@@ -532,13 +525,10 @@ class ActionPlanner:
|
||||
# 从后往前遍历,收集最新的记录
|
||||
for reasoning, timestamp, content in reversed(self.plan_log):
|
||||
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
|
||||
# 这是action记录
|
||||
if len(action_records) < max_action_records:
|
||||
action_records.append((reasoning, timestamp, content, "action"))
|
||||
else:
|
||||
# 这是执行结果记录
|
||||
if len(execution_records) < max_execution_records:
|
||||
execution_records.append((reasoning, timestamp, content, "execution"))
|
||||
elif len(execution_records) < max_execution_records:
|
||||
execution_records.append((reasoning, timestamp, content, "execution"))
|
||||
|
||||
# 合并所有记录并按时间戳排序
|
||||
all_records = action_records + execution_records
|
||||
@@ -700,15 +690,9 @@ class ActionPlanner:
|
||||
param_text = param_text.rstrip("\n")
|
||||
|
||||
# 构建要求文本
|
||||
require_text = ""
|
||||
for require_item in action_info.action_require:
|
||||
require_text += f"- {require_item}\n"
|
||||
require_text = require_text.rstrip("\n")
|
||||
require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require)
|
||||
|
||||
if not action_info.parallel_action:
|
||||
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
|
||||
else:
|
||||
parallel_text = ""
|
||||
parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)"
|
||||
|
||||
# 获取动作提示模板并填充
|
||||
using_action_prompt = prompt_manager.get_prompt("action")
|
||||
@@ -864,20 +848,15 @@ class ActionPlanner:
|
||||
# 尝试按行分割,每行可能是一个JSON对象
|
||||
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
|
||||
for line in lines:
|
||||
try:
|
||||
# 尝试解析每一行作为独立的JSON对象
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
json_obj = json.loads(repair_json(line))
|
||||
if isinstance(json_obj, dict):
|
||||
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
|
||||
if json_obj:
|
||||
json_objects.append(json_obj)
|
||||
elif isinstance(json_obj, list):
|
||||
for item in json_obj:
|
||||
if isinstance(item, dict) and item:
|
||||
json_objects.append(item)
|
||||
except json.JSONDecodeError:
|
||||
# 如果单行解析失败,尝试将整个块作为一个JSON对象或数组
|
||||
pass
|
||||
|
||||
# 如果按行解析没有成功(或只得到空字典),尝试将整个块作为一个JSON对象或数组
|
||||
if not json_objects:
|
||||
|
||||
@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
@@ -45,17 +48,17 @@ logger = get_logger("replyer")
|
||||
class DefaultReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
@@ -132,7 +135,7 @@ class DefaultReplyer:
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt="",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
@@ -202,7 +205,7 @@ class DefaultReplyer:
|
||||
try:
|
||||
if log_reply:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt=prompt,
|
||||
output=content,
|
||||
processed_output=None,
|
||||
@@ -259,7 +262,7 @@ class DefaultReplyer:
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt=prompt or "",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
@@ -353,14 +356,14 @@ class DefaultReplyer:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id,
|
||||
self.chat_stream.session_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
@@ -702,10 +705,11 @@ class DefaultReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
chat_id = SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
|
||||
)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -778,7 +782,7 @@ class DefaultReplyer:
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
_is_group_chat = bool(chat_stream.group_info)
|
||||
platform = chat_stream.platform
|
||||
|
||||
@@ -1005,7 +1009,7 @@ class DefaultReplyer:
|
||||
reply_to: str,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -1105,29 +1109,27 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: Optional[MaiMessage] = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
message_id=message_id,
|
||||
session=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
reply=anchor_message,
|
||||
is_head=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
|
||||
thinking_start_time=thinking_start_time,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
@@ -43,18 +46,18 @@ logger = get_logger("replyer")
|
||||
class PrivateReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
@@ -253,14 +256,14 @@ class PrivateReplyer:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
self.chat_stream.session_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -550,10 +553,11 @@ class PrivateReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
chat_id = SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
|
||||
)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -624,7 +628,7 @@ class PrivateReplyer:
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
platform = chat_stream.platform
|
||||
|
||||
user_id = "用户ID"
|
||||
@@ -843,7 +847,7 @@ class PrivateReplyer:
|
||||
reply_to: str,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
@@ -948,29 +952,27 @@ class PrivateReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: Optional[MaiMessage] = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
message_id=message_id,
|
||||
session=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
reply=anchor_message,
|
||||
is_head=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
|
||||
thinking_start_time=thinking_start_time,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
@@ -14,7 +14,7 @@ class ReplyerManager:
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
@@ -24,7 +24,7 @@ class ReplyerManager:
|
||||
model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。
|
||||
后续调用将返回已缓存的实例,忽略 model_configs 参数。
|
||||
"""
|
||||
stream_id = chat_stream.stream_id if chat_stream else chat_id
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取回复器。")
|
||||
return None
|
||||
@@ -39,15 +39,14 @@ class ReplyerManager:
|
||||
|
||||
target_stream = chat_stream
|
||||
if not target_stream:
|
||||
if chat_manager := get_chat_manager():
|
||||
target_stream = chat_manager.get_stream(stream_id)
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
|
||||
if not target_stream:
|
||||
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")
|
||||
return None
|
||||
|
||||
# model_configs 只在此时(初始化时)生效
|
||||
if target_stream.group_info:
|
||||
if target_stream.is_group_session:
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
|
||||
@@ -61,9 +61,12 @@ class TempMethodsExpression:
|
||||
str: 生成的聊天流ID(哈希值)
|
||||
"""
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
if is_group:
|
||||
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
|
||||
else:
|
||||
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
|
||||
except Exception as e:
|
||||
logger.error(f"生成聊天流ID失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -1051,18 +1051,13 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""从chat_id获取显示名称"""
|
||||
try:
|
||||
# 首先尝试从chat_stream获取真实群组名称
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
if chat_id in chat_manager.streams:
|
||||
stream = chat_manager.streams[chat_id]
|
||||
if stream.group_info and hasattr(stream.group_info, "group_name"):
|
||||
group_name = stream.group_info.group_name
|
||||
if group_name and group_name.strip():
|
||||
return group_name.strip()
|
||||
elif stream.user_info and hasattr(stream.user_info, "user_nickname"):
|
||||
user_name = stream.user_info.user_nickname
|
||||
if chat_id in _stat_chat_manager.sessions:
|
||||
session = _stat_chat_manager.sessions[chat_id]
|
||||
name = _stat_chat_manager.get_session_name(chat_id)
|
||||
if name and name.strip():
|
||||
return name.strip()
|
||||
if user_name and user_name.strip():
|
||||
return user_name.strip()
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ from typing import Optional, Tuple, List, TYPE_CHECKING
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import SessionMessage
|
||||
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import Person
|
||||
from .typo_generator import ChineseTypoGenerator
|
||||
@@ -114,10 +114,10 @@ def is_bot_self(platform: str, user_id: str) -> bool:
|
||||
return user_id_str == qq_account
|
||||
|
||||
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
|
||||
def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]:
|
||||
"""检查消息是否提到了机器人(统一多平台实现)"""
|
||||
text = message.processed_plain_text or ""
|
||||
platform = getattr(message.message_info, "platform", "") or ""
|
||||
platform = message.platform or ""
|
||||
|
||||
# 获取各平台账号
|
||||
platforms_list = getattr(global_config.bot, "platforms", []) or []
|
||||
@@ -696,15 +696,23 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
if chat_stream := get_chat_manager().get_stream(chat_id):
|
||||
if chat_stream.group_info:
|
||||
if chat_stream := _chat_manager.get_session_by_session_id(chat_id):
|
||||
if chat_stream.is_group_session:
|
||||
is_group_chat = True
|
||||
chat_target_info = None # Explicitly None for group chat
|
||||
elif chat_stream.user_info: # It's a private chat
|
||||
elif chat_stream.user_id: # It's a private chat
|
||||
is_group_chat = False
|
||||
user_info = chat_stream.user_info
|
||||
platform: str = chat_stream.platform
|
||||
user_id: str = user_info.user_id # type: ignore
|
||||
user_id: str = chat_stream.user_id
|
||||
|
||||
# Try to get nickname from context
|
||||
user_nickname = None
|
||||
if (
|
||||
chat_stream.context
|
||||
and chat_stream.context.message
|
||||
and chat_stream.context.message.message_info.user_info
|
||||
):
|
||||
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
|
||||
|
||||
@@ -712,7 +720,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
target_info = TargetPersonInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=user_info.user_nickname, # type: ignore
|
||||
user_nickname=user_nickname, # type: ignore
|
||||
person_id=None,
|
||||
person_name=None,
|
||||
)
|
||||
@@ -721,7 +729,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
|
||||
try:
|
||||
person = Person(platform=platform, user_id=user_id)
|
||||
if not person.is_known:
|
||||
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
|
||||
logger.warning(f"用户 {user_nickname} 尚未认识")
|
||||
# 如果用户尚未认识,则返回False和None
|
||||
return False, None
|
||||
if person.person_id:
|
||||
|
||||
Reference in New Issue
Block a user