Refactor chat stream handling to use BotChatSession

- Updated imports and references from ChatStream to BotChatSession across multiple files.
- Adjusted method signatures and internal logic to accommodate the new session management.
- Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
DrSmoothl
2026-03-07 00:57:33 +08:00
parent 8712fc0d05
commit 2e3dd44ee9
43 changed files with 706 additions and 563 deletions

View File

@@ -19,10 +19,10 @@ sys.path.insert(0, project_root)
try:
from src.common.database.database_model import ChatStreams
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _script_chat_manager
except ImportError:
ChatStreams = None
get_chat_manager = None
_script_chat_manager = None
def get_chat_name(chat_id: str) -> str:

View File

@@ -23,7 +23,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db
from src.common.database.database_model import LLMUsage
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval")

View File

@@ -12,7 +12,7 @@ from src.chat.utils.chat_message_builder import (
build_anonymous_messages,
)
from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
@@ -42,8 +42,8 @@ class ExpressionLearner:
)
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import weighted_sample
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector")
@@ -50,8 +50,9 @@ class ExpressionSelector:
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
# 统一通过 chat_manager 生成 stream_id避免各处自行实现哈希逻辑
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
return SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
)
except Exception:
return None
@@ -127,8 +128,7 @@ class ExpressionSelector:
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
# 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程
fallback_num = min(3, max_num) if max_num > 0 else 3
fallback_selected = self._random_expressions(chat_id, fallback_num)
if fallback_selected:
if fallback_selected := self._random_expressions(chat_id, fallback_num):
self.update_expressions_last_active_time(fallback_selected)
selected_ids = [expr["id"] for expr in fallback_selected]
logger.info(
@@ -199,12 +199,7 @@ class ExpressionSelector:
]
# 随机抽样
if style_exprs:
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
return selected_style
return weighted_sample(style_exprs, total_num) if style_exprs else []
except Exception as e:
logger.error(f"随机选择表达方式失败: {e}")

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import (
parse_chat_id_list,
@@ -99,9 +99,9 @@ class JargonMiner:
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
self.stream_name = stream_name if stream_name else self.chat_id
chat_manager = _chat_manager
stream_name = chat_manager.get_session_name(self.chat_id)
self.stream_name = stream_name or self.chat_id
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()

View File

@@ -2,7 +2,7 @@ import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression
from src.bw_learner.expression_learner import expression_learner_manager
@@ -18,8 +18,8 @@ class MessageRecorder:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()

View File

@@ -5,20 +5,19 @@ from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager
from src.config.config import model_config
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
if TYPE_CHECKING:
pass
logger = get_logger("reflect_tracker")
class ReflectTracker:
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
def __init__(self, chat_stream: BotChatSession, expression: Expression, created_time: float):
self.chat_stream = chat_stream
self.expression = expression
self.created_time = created_time
@@ -42,7 +41,7 @@ class ReflectTracker:
# Fetch messages since creation
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
timestamp_start=self.created_time,
timestamp_end=time.time(),
)
@@ -90,10 +89,7 @@ class ReflectTracker:
from json_repair import repair_json
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if not matches:
# Try to parse raw response if no code block
matches = [response]
matches = re.findall(json_pattern, response, re.DOTALL) or [response]
json_obj = json.loads(repair_json(matches[0]))
@@ -122,10 +118,7 @@ class ReflectTracker:
self.expression.style = corrected_style
# 如果拒绝但未更新,标记为 rejected=1
if not has_update:
self.expression.rejected = True
else:
self.expression.rejected = False
self.expression.rejected = not has_update
self.expression.save()

View File

@@ -4,10 +4,10 @@ MaiBot模块系统
"""
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager
# 导出主要组件供外部使用
__all__ = [
"get_chat_manager",
"chat_manager",
"emoji_manager",
]

View File

@@ -7,7 +7,7 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
# from src.config.config import global_config
from typing import Dict, Any, Optional
from src.chat.message_receive.message import Message
from src.common.data_models.mai_message_data_model import MaiMessage
from .pfc_types import ConversationState
from .pfc import ChatObserver, GoalAnalyzer
from .message_sender import DirectMessageSender
@@ -16,9 +16,8 @@ from .action_planner import ActionPlanner
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
from .reply_generator import ReplyGenerator
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter
@@ -60,7 +59,7 @@ class Conversation:
self.direct_sender = DirectMessageSender(self.private_name)
# 获取聊天流信息
self.chat_stream = get_chat_manager().get_stream(self.stream_id)
self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id)
self.stop_action_planner = False
except Exception as e:
@@ -265,34 +264,34 @@ class Conversation:
return True
return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage:
"""将消息字典转换为MaiMessage对象"""
from datetime import datetime as dt
from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
try:
# 尝试从 msg_dict 直接获取 chat_stream如果失败则从全局 get_chat_manager 获取
chat_info = msg_dict.get("chat_info")
if chat_info and isinstance(chat_info, dict):
chat_stream = ChatStream.from_dict(chat_info)
elif self.chat_stream: # 使用实例变量中的 chat_stream
chat_stream = self.chat_stream
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
chat_stream = get_chat_manager().get_stream(self.stream_id)
if not chat_stream:
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
chat_stream=chat_stream, # 使用确定的 chat_stream
time=msg_dict.get("time", time.time()), # 提供默认时间
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
user_info_dict = msg_dict.get("user_info", {})
user_info = MaiUserInfo(
user_id=user_info_dict.get("user_id", ""),
user_nickname=user_info_dict.get("user_nickname", ""),
user_cardname=user_info_dict.get("user_cardname"),
)
msg = MaiMessage(
message_id=msg_dict.get("message_id", f"gen_{time.time()}"),
timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())),
)
msg.message_info = MessageInfo(user_info=user_info)
msg.platform = user_info_dict.get("platform", "")
msg.session_id = self.stream_id
msg.processed_plain_text = msg_dict.get("processed_plain_text", "")
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
except Exception as e:
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e
async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
@@ -687,7 +686,7 @@ class Conversation:
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
return
if not self.chat_stream:
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。")
return
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)

View File

@@ -1,10 +1,12 @@
import time
from typing import Optional
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import Message, MessageSending
from maim_message import UserInfo, Seg
from src.chat.message_receive.storage import MessageStorage
from maim_message import Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.config.config import global_config
from rich.traceback import install
@@ -19,18 +21,17 @@ class DirectMessageSender:
def __init__(self, private_name: str):
self.private_name = private_name
self.storage = MessageStorage()
async def send_message(
self,
chat_stream: ChatStream,
chat_stream: BotChatSession,
content: str,
reply_to_message: Optional[Message] = None,
reply_to_message: Optional[MaiMessage] = None,
) -> None:
"""发送消息到聊天流
Args:
chat_stream: 聊天
chat_stream: 聊天会话
content: 消息内容
reply_to_message: 要回复的消息(可选)
"""
@@ -42,18 +43,22 @@ class DirectMessageSender:
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=chat_stream.platform,
)
# 用当前时间作为message_id和之前那套sender一样
message_id = f"dm{round(time.time(), 2)}"
# 构建发送者信息(私聊时为接收者)
sender_info = None
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
sender_info = reply_to_message.message_info.user_info
# 构建消息对象
message = MessageSending(
message_id=message_id,
chat_stream=chat_stream,
session=chat_stream,
bot_user_info=bot_user_info,
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
sender_info=sender_info,
message_segment=segments,
reply=reply_to_message,
is_head=True,
@@ -61,17 +66,11 @@ class DirectMessageSender:
thinking_start_time=time.time(),
)
# 处理消息
await message.process()
# 发送消息(直接调用底层 API
from src.chat.message_receive.uni_message_sender import _send_message
sent = await _send_message(message, show_log=True)
# 发送消息
message_sender = UniversalMessageSender()
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
if sent:
# 存储消息
await self.storage.store_message(message, chat_stream)
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
else:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")

View File

@@ -9,7 +9,7 @@ from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.brain_chat.brain_planner import BrainPlanner
@@ -73,10 +73,10 @@ class BrainChatting:
"""
# 基础属性
self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore
if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]"
self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]"
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)

View File

@@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
@@ -38,7 +38,7 @@ install(extra_lines=3)
class BrainPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(

View File

@@ -5,9 +5,8 @@ import time
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.apis import send_api
from maim_message.message_base import GroupInfo
from src.common.message_repository import count_messages
@@ -121,28 +120,24 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None)
async def send_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
chat = await _chat_manager.get_or_create_session(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
user_id="114514",
group_id="114514",
)
await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False
message_type="state", content="typing", stream_id=chat.session_id, storage_message=False
)
async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心")
chat = await get_chat_manager().get_or_create_stream(
chat = await _chat_manager.get_or_create_session(
platform="amaidesu_default",
user_info=None,
group_info=group_info,
user_id="114514",
group_id="114514",
)
await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False
message_type="state", content="stop_typing", stream_id=chat.session_id, storage_message=False
)

View File

@@ -8,7 +8,6 @@ from typing import Dict, Any
from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils
from src.chat.message_receive.message_old import MessageRecv
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@@ -41,14 +40,14 @@ class ChatBot:
self._started = True
async def _create_pfc_chat(self, message: MessageRecv):
async def _create_pfc_chat(self, message: SessionMessage):
"""创建或获取PFC对话实例
Args:
message: 消息对象
"""
try:
chat_id = str(message.chat_stream.stream_id)
chat_id = message.session_id
private_name = str(message.message_info.user_info.user_nickname)
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
@@ -177,12 +176,12 @@ class ChatBot:
logger.error(f"[新运行时] 执行命令 {matched.full_name} 异常: {e}", exc_info=True)
return True, str(e), True
async def handle_notice_message(self, message: MessageRecv):
if message.message_info.message_id == "notice":
async def handle_notice_message(self, message: SessionMessage):
if message.message_id == "notice":
message.is_notify = True
logger.debug("notice消息")
try:
seg = message.message_segment
seg = getattr(message, "message_segment", None) # SessionMessage 没有 message_segment
mi = message.message_info
sub_type = None
scene = None
@@ -246,10 +245,8 @@ class ChatBot:
return
mmc_message_id = message_data.get("echo")
actual_message_id = message_data.get("actual_id")
if MessageStorage.update_message(mmc_message_id, actual_message_id):
logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}")
else:
logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}")
# TODO: Implement message ID update in new architecture
logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息

View File

@@ -1,14 +1,23 @@
from asyncio import Task
from datetime import datetime
from maim_message import (
MessageBase,
UserInfo as MaimUserInfo,
GroupInfo as MaimGroupInfo,
BaseMessageInfo as MaimBaseMessageInfo,
Seg,
)
from rich.traceback import install
from sqlmodel import select
from typing import List, Dict, Tuple, Sequence
from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING
import asyncio
import time
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import (
TextComponent,
ImageComponent,
@@ -19,6 +28,10 @@ from src.common.data_models.message_component_data_model import (
ForwardNodeComponent,
StandardMessageComponents,
)
from src.common.utils.utils_message import MessageUtils
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3)
@@ -207,3 +220,166 @@ class SessionMessage(MaiMessage):
else:
processed_texts.append(result)
return " ".join(processed_texts)
class MessageSending(MaiMessage):
"""发送状态的消息类,继承 MaiMessage 基类。
用于构建、处理和发送机器人的回复消息。
复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法,
额外管理发送专属的会话信息和控制字段。
"""
def __init__(
self,
message_id: str,
session: "BotChatSession",
bot_user_info: UserInfo,
message_segment: Seg,
sender_info: Optional[UserInfo] = None,
reply: Optional[MaiMessage] = None,
display_message: str = "",
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
reply_to: Optional[str] = None,
selected_expressions: Optional[List[int]] = None,
):
# 初始化 MaiMessage 基类
super().__init__(message_id=message_id, timestamp=datetime.now())
# 发送专属字段
self.session = session
self.sender_info = sender_info
self.message_segment = message_segment
self.reply = reply
self.is_head = is_head
self.thinking_start_time = thinking_start_time
self.selected_expressions = selected_expressions
self.reply_to_message_id: Optional[str] = reply.message_id if reply else None
self.interest_value: float = 0.0
# 填充 MaiMessage 标准字段
self.platform = session.platform
self.session_id = session.session_id
self.is_emoji = is_emoji
self.reply_to = reply_to
self.display_message = display_message
self.processed_plain_text = ""
# 构建 message_infoDB 存储时 user_info 始终为 bot 信息
# 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理
group_info = self._resolve_group_info()
self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info)
# bot_user_info 单独保存to_maim_message 覆写时还需要
self.bot_user_info = bot_user_info
# 将 Seg 转换为 MessageSequence供基类的 to_db_instance / to_maim_message 使用
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=message_segment)
)
self.initialized = True
def _resolve_group_info(self) -> Optional[GroupInfo]:
"""从 session 中解析群信息"""
if not self.session.group_id:
return None
group_name = ""
if (
self.session.context
and self.session.context.message
and self.session.context.message.message_info.group_info
):
group_name = self.session.context.message.message_info.group_info.group_name
return GroupInfo(group_id=self.session.group_id, group_name=group_name)
async def process(self) -> None:
"""处理消息段,生成 processed_plain_text使用 SessionMessage 的组件处理能力)"""
# 同步 message_segment → raw_message插件可能修改了 message_segment
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
if self.raw_message and self.raw_message.components:
tasks = [self._process_component(c) for c in self.raw_message.components]
results = await asyncio.gather(*tasks, return_exceptions=True)
texts = []
for r in results:
if isinstance(r, BaseException):
logger.error(f"处理发送消息组件时发生错误: {r}")
elif r:
texts.append(r)
self.processed_plain_text = " ".join(texts)
async def _process_component(self, component: StandardMessageComponents) -> str:
"""简单处理单个标准组件为纯文本描述"""
if isinstance(component, TextComponent):
return component.text
elif isinstance(component, ImageComponent):
return "[图片]"
elif isinstance(component, EmojiComponent):
return "[表情包]"
elif isinstance(component, VoiceComponent):
return "[语音]"
elif isinstance(component, AtComponent):
return f"[@{component.target_user_id}]"
elif isinstance(component, ReplyComponent):
return ""
else:
return f"[{type(component).__name__}]"
def build_reply(self) -> None:
"""构建回复消息段,在 message_segment 前插入 reply 段"""
if self.reply:
self.reply_to_message_id = self.reply.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_id),
self.message_segment,
],
)
# 同步更新 raw_message
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
async def to_maim_message(self) -> MessageBase:
"""覆写基类方法:发送消息需要特殊处理 user_info私聊/群聊差异)"""
maim_bot_user_info = MaimUserInfo(
user_id=self.bot_user_info.user_id,
user_nickname=self.bot_user_info.user_nickname,
user_cardname=self.bot_user_info.user_cardname,
platform=self.platform,
)
maim_group_info = None
if self.message_info.group_info:
maim_group_info = MaimGroupInfo(
group_id=self.message_info.group_info.group_id,
group_name=self.message_info.group_info.group_name,
platform=self.platform,
)
# 私聊时 user_info 填接收者信息sender_info群聊时填 bot
if maim_group_info is None and self.sender_info:
msg_user_info = MaimUserInfo(
user_id=self.sender_info.user_id,
user_nickname=self.sender_info.user_nickname,
user_cardname=self.sender_info.user_cardname,
platform=self.platform,
)
else:
msg_user_info = maim_bot_user_info
maim_msg_info = MaimBaseMessageInfo(
platform=self.platform,
message_id=self.message_id,
time=time.time(),
group_info=maim_group_info,
user_info=msg_user_info,
)
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))

View File

@@ -6,8 +6,8 @@ from maim_message import Seg
from src.common.message_server.api import get_global_api
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
@@ -130,8 +130,8 @@ def parse_message_segments(segment) -> list:
async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.message_info.platform
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
platform = message.platform
group_id = message.session.group_id
try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
@@ -221,33 +221,14 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
raise legacy_exception
return False
# 使用 MessageConverter 转换 Legacy MessageBase 到 APIMessageBase
# 发送场景MaiMBot 发送回复消息给外部用户
# group_info/user_info 是消息接收者信息,放入 receiver_info
# 使用 MessageConverter 转换为 API 消息
from maim_message import MessageConverter
# 修复 API Server Fallback 模式下的 user_info 问题
# 在 Legacy 模式下MessageSending.to_dict() 的第 454 行会将 user_info 替换为 chat_stream.user_info
# 但在 API Server Fallback 模式下MessageConverter.to_api_send() 直接访问 message 对象,不调用 to_dict()
# 需要手动应用相同的变通方案在私聊场景下user_info 应该是接收者sender_info
message_for_conversion = message
if hasattr(message, "message_info") and message.message_info.group_info is None:
# 私聊场景group_info 为 None
# user_info 应该是接收者,从 chat_stream.user_info 或 sender_info 获取
temp_dict = message.to_dict()
if (
hasattr(message, "chat_stream")
and message.chat_stream
and hasattr(message.chat_stream, "user_info")
):
temp_dict["message_info"]["user_info"] = message.chat_stream.user_info.to_dict()
# 重新构建 MessageBase 对象(不保留 sender_info 等扩展属性)
from maim_message import MessageBase
message_for_conversion = MessageBase.from_dict(temp_dict)
# 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
message_base = await message.to_maim_message()
api_message = MessageConverter.to_api_send(
message=message_for_conversion,
message=message_base,
api_key=target_api_key,
platform=platform,
)
@@ -278,10 +259,11 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return False
try:
send_result = await get_global_api().send_message(message)
message_base = await message.to_maim_message()
send_result = await get_global_api().send_message(message_base)
if send_result:
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
return True
else:
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
@@ -297,7 +279,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return await send_with_new_api(legacy_exception=legacy_e)
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
traceback.print_exc()
raise e # 重新抛出其他异常
@@ -306,7 +288,7 @@ class UniversalMessageSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self):
self.storage = MessageStorage()
pass
async def send_message(
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
@@ -321,15 +303,15 @@ class UniversalMessageSender:
用法:
- typing=True 时,发送前会有打字等待。
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream,无法发送")
raise ValueError("消息缺少 chat_stream,无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送")
if not message.session:
logger.error("消息缺少 session,无法发送")
raise ValueError("消息缺少 session,无法发送")
if not message.message_id:
logger.error("消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id
chat_id = message.session_id
message_id = message.message_id
try:
if set_reply:
@@ -391,7 +373,8 @@ class UniversalMessageSender:
message.processed_plain_text = modified_message.plain_text
if storage_message:
await self.storage.store_message(message, message.chat_stream)
with get_db_session() as db_session:
db_session.add(message.to_db_instance())
return sent_msg

View File

@@ -1,6 +1,6 @@
from typing import Dict, Optional, Type
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry
@@ -35,7 +35,7 @@ class ActionManager:
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
chat_stream: BotChatSession,
log_prefix: str,
shutting_down: bool = False,
action_message: Optional[DatabaseMessages] = None,

View File

@@ -4,15 +4,12 @@ from typing import List, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager")
@@ -27,8 +24,8 @@ class ActionModifier:
def __init__(self, action_manager: ActionManager, chat_id: str):
"""初始化动作处理器"""
self.chat_id = chat_id
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.action_manager = action_manager
@@ -121,7 +118,7 @@ class ActionModifier:
available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext):
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession):
type_mismatched_actions: List[Tuple[str, str]] = []
for action_name, action_info in all_actions.items():
if action_info.associated_types and not chat_context.check_types(action_info.associated_types):

View File

@@ -3,6 +3,7 @@ import time
import traceback
import random
import re
import contextlib
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from collections import OrderedDict
from rich.traceback import install
@@ -21,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
)
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.apis.message_api import translate_pid_to_description
@@ -39,7 +40,7 @@ install(extra_lines=3)
class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
self.planner_llm = LLMRequest(
@@ -80,7 +81,7 @@ class ActionPlanner:
if not text:
return text
id_to_message = {msg_id: msg for msg_id, msg in message_id_list}
id_to_message = dict(message_id_list)
# 匹配m后带2-4位数字前后不是字母数字下划线
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
@@ -223,7 +224,7 @@ class ActionPlanner:
action_data=action_data,
action_message=target_message,
available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None,
action_reasoning=extracted_reasoning or None,
)
)
@@ -238,7 +239,7 @@ class ActionPlanner:
action_data={},
action_message=None,
available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None,
action_reasoning=extracted_reasoning or None,
)
)
@@ -292,8 +293,7 @@ class ActionPlanner:
if new_words:
for word in new_words:
if isinstance(word, str):
word = word.strip()
if word:
if word := word.strip():
cleaned_new_words.append(word)
# 获取缓存中的黑话列表
@@ -351,10 +351,9 @@ class ActionPlanner:
break
# 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words:
if len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0:
self.unknown_words_cache.popitem(last=False)
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话
for action in actions:
@@ -363,10 +362,7 @@ class ActionPlanner:
new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words)
# 更新 action_data
if merged_words:
if merged_words := self._merge_unknown_words_with_cache(new_words):
action_data["unknown_words"] = merged_words
logger.debug(
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
@@ -449,15 +445,12 @@ class ActionPlanner:
# 如果有强制回复消息,确保回复该消息
if force_reply_message:
# 检查是否已经有回复该消息的 action
has_reply_to_force_message = False
for action in actions:
if (
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
):
has_reply_to_force_message = True
break
has_reply_to_force_message = any(
action.action_type == "reply"
and action.action_message
and action.action_message.message_id == force_reply_message.message_id
for action in actions
)
# 如果没有回复该消息,强制添加回复 action
if not has_reply_to_force_message:
@@ -532,13 +525,10 @@ class ActionPlanner:
# 从后往前遍历,收集最新的记录
for reasoning, timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
# 这是action记录
if len(action_records) < max_action_records:
action_records.append((reasoning, timestamp, content, "action"))
else:
# 这是执行结果记录
if len(execution_records) < max_execution_records:
execution_records.append((reasoning, timestamp, content, "execution"))
elif len(execution_records) < max_execution_records:
execution_records.append((reasoning, timestamp, content, "execution"))
# 合并所有记录并按时间戳排序
all_records = action_records + execution_records
@@ -700,15 +690,9 @@ class ActionPlanner:
param_text = param_text.rstrip("\n")
# 构建要求文本
require_text = ""
for require_item in action_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require)
if not action_info.parallel_action:
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
else:
parallel_text = ""
parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)"
# 获取动作提示模板并填充
using_action_prompt = prompt_manager.get_prompt("action")
@@ -864,20 +848,15 @@ class ActionPlanner:
# 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines:
try:
# 尝试解析每一行作为独立的JSON对象
with contextlib.suppress(json.JSONDecodeError):
json_obj = json.loads(repair_json(line))
if isinstance(json_obj, dict):
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
if json_obj:
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict) and item:
json_objects.append(item)
except json.JSONDecodeError:
# 如果单行解析失败尝试将整个块作为一个JSON对象或数组
pass
# 如果按行解析没有成功或只得到空字典尝试将整个块作为一个JSON对象或数组
if not json_objects:

View File

@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
@@ -45,17 +48,17 @@ logger = get_logger("replyer")
class DefaultReplyer:
def __init__(
self,
chat_stream: ChatStream,
chat_stream: BotChatSession,
request_type: str = "replyer",
):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context(
self,
@@ -132,7 +135,7 @@ class DefaultReplyer:
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt="",
output=None,
processed_output=None,
@@ -202,7 +205,7 @@ class DefaultReplyer:
try:
if log_reply:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt=prompt,
output=content,
processed_output=None,
@@ -259,7 +262,7 @@ class DefaultReplyer:
if log_reply:
try:
PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id,
chat_id=self.chat_stream.session_id,
prompt=prompt or "",
output=None,
processed_output=None,
@@ -353,14 +356,14 @@ class DefaultReplyer:
str: 表达习惯信息字符串
"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id,
self.chat_stream.session_id,
chat_history,
max_num=8,
target_message=target,
@@ -702,10 +705,11 @@ class DefaultReplyer:
# 判断是否为群聊
is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
chat_id = SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
)
return chat_id, prompt_content
except (ValueError, IndexError):
@@ -778,7 +782,7 @@ class DefaultReplyer:
if available_actions is None:
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
_is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform
@@ -1005,7 +1009,7 @@ class DefaultReplyer:
reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -1105,29 +1109,27 @@ class DefaultReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: Optional[MessageRecv] = None,
anchor_message: Optional[MaiMessage] = None,
) -> MessageSending:
"""构建单个发送消息"""
bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream,
message_id=message_id,
session=self.chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
message_segment=message_segment,
reply=anchor_message, # 回复原始锚点
reply=anchor_message,
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
thinking_start_time=thinking_start_time,
display_message=display_message,
)

View File

@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
from src.chat.message_receive.chat_stream import ChatStream
from maim_message import Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
@@ -43,18 +46,18 @@ logger = get_logger("replyer")
class PrivateReplyer:
def __init__(
self,
chat_stream: ChatStream,
chat_stream: BotChatSession,
request_type: str = "replyer",
):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context(
self,
@@ -253,14 +256,14 @@ class PrivateReplyer:
str: 表达习惯信息字符串
"""
# 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression:
return "", []
style_habits = []
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
self.chat_stream.session_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
)
if selected_expressions:
@@ -550,10 +553,11 @@ class PrivateReplyer:
# 判断是否为群聊
is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
chat_id = SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
)
return chat_id, prompt_content
except (ValueError, IndexError):
@@ -624,7 +628,7 @@ class PrivateReplyer:
if available_actions is None:
available_actions = {}
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
platform = chat_stream.platform
user_id = "用户ID"
@@ -843,7 +847,7 @@ class PrivateReplyer:
reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -948,29 +952,27 @@ class PrivateReplyer:
is_emoji: bool,
thinking_start_time: float,
display_message: str,
anchor_message: Optional[MessageRecv] = None,
anchor_message: Optional[MaiMessage] = None,
) -> MessageSending:
"""构建单个发送消息"""
bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
)
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending(
message_id=message_id, # 使用片段的唯一ID
chat_stream=self.chat_stream,
message_id=message_id,
session=self.chat_stream,
bot_user_info=bot_user_info,
sender_info=sender_info,
message_segment=message_segment,
reply=anchor_message, # 回复原始锚点
reply=anchor_message,
is_head=reply_to,
is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
thinking_start_time=thinking_start_time,
display_message=display_message,
)

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
@@ -14,7 +14,7 @@ class ReplyerManager:
def get_replyer(
self,
chat_stream: Optional[ChatStream] = None,
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]:
@@ -24,7 +24,7 @@ class ReplyerManager:
model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。
后续调用将返回已缓存的实例,忽略 model_configs 参数。
"""
stream_id = chat_stream.stream_id if chat_stream else chat_id
stream_id = chat_stream.session_id if chat_stream else chat_id
if not stream_id:
logger.warning("[ReplyerManager] 缺少 stream_id无法获取回复器。")
return None
@@ -39,15 +39,14 @@ class ReplyerManager:
target_stream = chat_stream
if not target_stream:
if chat_manager := get_chat_manager():
target_stream = chat_manager.get_stream(stream_id)
target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream:
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")
return None
# model_configs 只在此时(初始化时)生效
if target_stream.group_info:
if target_stream.is_group_session:
replyer = DefaultReplyer(
chat_stream=target_stream,
request_type=request_type,

View File

@@ -61,9 +61,12 @@ class TempMethodsExpression:
str: 生成的聊天流ID哈希值
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
if is_group:
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
else:
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
except Exception as e:
logger.error(f"生成聊天流ID失败: {e}")
return None

View File

@@ -1051,18 +1051,13 @@ class StatisticOutputTask(AsyncTask):
"""从chat_id获取显示名称"""
try:
# 首先尝试从chat_stream获取真实群组名称
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
chat_manager = get_chat_manager()
if chat_id in chat_manager.streams:
stream = chat_manager.streams[chat_id]
if stream.group_info and hasattr(stream.group_info, "group_name"):
group_name = stream.group_info.group_name
if group_name and group_name.strip():
return group_name.strip()
elif stream.user_info and hasattr(stream.user_info, "user_nickname"):
user_name = stream.user_info.user_nickname
if chat_id in _stat_chat_manager.sessions:
session = _stat_chat_manager.sessions[chat_id]
name = _stat_chat_manager.get_session_name(chat_id)
if name and name.strip():
return name.strip()
if user_name and user_name.strip():
return user_name.strip()

View File

@@ -12,8 +12,8 @@ from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import Person
from .typo_generator import ChineseTypoGenerator
@@ -114,10 +114,10 @@ def is_bot_self(platform: str, user_id: str) -> bool:
return user_id_str == qq_account
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]:
def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or ""
platform = message.platform or ""
# 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or []
@@ -696,15 +696,23 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
chat_target_info = None
try:
if chat_stream := get_chat_manager().get_stream(chat_id):
if chat_stream.group_info:
if chat_stream := _chat_manager.get_session_by_session_id(chat_id):
if chat_stream.is_group_session:
is_group_chat = True
chat_target_info = None # Explicitly None for group chat
elif chat_stream.user_info: # It's a private chat
elif chat_stream.user_id: # It's a private chat
is_group_chat = False
user_info = chat_stream.user_info
platform: str = chat_stream.platform
user_id: str = user_info.user_id # type: ignore
user_id: str = chat_stream.user_id
# Try to get nickname from context
user_nickname = None
if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.user_info
):
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
@@ -712,7 +720,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
target_info = TargetPersonInfo(
platform=platform,
user_id=user_id,
user_nickname=user_info.user_nickname, # type: ignore
user_nickname=user_nickname, # type: ignore
person_id=None,
person_name=None,
)
@@ -721,7 +729,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
try:
person = Person(platform=platform, user_id=user_id)
if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识")
logger.warning(f"用户 {user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None
return False, None
if person.person_id:

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from PIL import Image as PILImage, ImageSequence
from typing import Optional, Union
import base64
import io
@@ -102,3 +104,30 @@ class ImageUtils:
logger.error("输入的图片字节数据无效")
raise ValueError("输入的图片字节数据无效")
return base64.b64encode(image_bytes).decode("utf-8")
@staticmethod
def image_path_to_base64(image_path: Union[str, Path]) -> Optional[str]:
"""读取图片文件并转换为 Base64 编码字符串"""
try:
path = Path(image_path)
if not path.exists():
logger.error(f"图片文件不存在: {path}")
return None
image_bytes = path.read_bytes()
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"读取图片文件失败: {e}")
return None
@staticmethod
def base64_to_image(base64_str: str, save_path: Union[str, Path]) -> bool:
"""将 Base64 编码字符串解码并保存为图片文件"""
try:
image_bytes = base64.b64decode(base64_str)
path = Path(save_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
return True
except Exception as e:
logger.error(f"保存图片文件失败: {e}")
return False

View File

@@ -7,7 +7,7 @@ from src.config.config import global_config, model_config
from src.llm_models.payload_content.message import RoleType, Message
from src.prompt.prompt_manager import prompt_manager
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
from src.plugin_system.apis import send_api
logger = get_logger("dream_generator")
@@ -178,10 +178,9 @@ async def generate_dream_summary(
logger.warning(f"[dream][梦境总结] dream_send 平台或用户ID为空当前值: {dream_send_raw!r}")
else:
# 默认为私聊会话
stream_id = get_chat_manager().get_stream_id(
stream_id = SessionUtils.calculate_session_id(
platform=platform,
id=str(user_id),
is_group=False,
user_id=str(user_id),
)
if not stream_id:
logger.error(

View File

@@ -8,7 +8,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager
from src.config.config import config_manager, global_config
from src.chat.message_receive.bot import chat_bot
from src.common.logger import get_logger
@@ -119,8 +119,8 @@ class MainSystem:
logger.info("表情包管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())
await chat_manager.initialize()
asyncio.create_task(chat_manager.regularly_save_sessions())
logger.info("聊天管理器初始化成功")

View File

@@ -21,7 +21,7 @@ from src.plugin_system.apis import message_api
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.prompt.prompt_manager import prompt_manager
logger = get_logger("chat_history_summarizer")
@@ -100,7 +100,7 @@ class ChatHistorySummarizer:
def _get_chat_display_name(self) -> str:
"""获取聊天显示名称"""
try:
chat_name = get_chat_manager().get_stream_name(self.chat_id)
chat_name = _chat_manager.get_session_name(self.chat_id)
if chat_name:
return chat_name
# 如果获取失败使用简化的chat_id显示

View File

@@ -1,3 +1,4 @@
import contextlib
import time
import json
import asyncio
@@ -12,7 +13,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
@@ -133,10 +134,10 @@ async def _react_agent_solve_question(
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
"""
start_time = time.time()
collected_info = initial_info if initial_info else ""
collected_info = initial_info or ""
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
try:
chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
chat_name = _chat_manager.get_session_name(chat_id) or chat_id
except Exception:
chat_name = chat_id
react_log_prefix = f"[{chat_name}] "
@@ -235,7 +236,7 @@ async def _react_agent_solve_question(
# head_prompt应该只构建一次使用初始的collected_info后续迭代都复用同一个
if first_head_prompt is None:
# 第一次构建使用初始的collected_info即initial_info
initial_collected_info = initial_info if initial_info else ""
initial_collected_info = initial_info or ""
# 根据配置选择使用哪个 prompt
prompt_name = (
"memory_retrieval_react_prompt_head_lpmm"
@@ -362,7 +363,7 @@ async def _react_agent_solve_question(
return information
except (json.JSONDecodeError, ValueError, TypeError):
# 如果JSON解析失败尝试在文本中查找JSON对象
try:
with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError):
# 查找第一个 { 和最后一个 } 之间的内容更健壮的JSON提取
first_brace = text.find("{")
if first_brace != -1:
@@ -384,8 +385,6 @@ async def _react_agent_solve_question(
if isinstance(data, dict) and "return_information" in data:
information = data.get("information", "")
return information
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None
@@ -679,7 +678,7 @@ async def _react_agent_solve_question(
evaluation_prompt_template.add_context("bot_name", bot_name)
evaluation_prompt_template.add_context("time_now", time_now)
evaluation_prompt_template.add_context("chat_history", chat_history)
evaluation_prompt_template.add_context("collected_info", collected_info if collected_info else "暂无信息")
evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息")
evaluation_prompt_template.add_context("current_iteration", str(current_iteration))
evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations))
evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
@@ -800,8 +799,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
if not records:
return ""
history_lines = []
history_lines.append("最近已查询的问题和结果:")
history_lines = ["最近已查询的问题和结果:"]
for record in records:
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
@@ -813,8 +811,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
if len(record.answer) > 100:
answer_preview += "..."
history_lines.append(f"- 问题:{record.question}")
history_lines.append(f" 状态:{status}")
history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
if answer_preview:
history_lines.append(f" 答案:{answer_preview}")
history_lines.append("") # 空行分隔
@@ -855,12 +852,11 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
if not records:
return []
found_answers = []
for record in records:
if record.answer:
found_answers.append(f"问题:{record.question}\n答案:{record.answer}")
return found_answers
return [
f"问题:{record.question}\n答案:{record.answer}"
for record in records
if record.answer
]
except Exception as e:
logger.error(f"获取最近已找到答案的记录失败: {e}")
@@ -892,8 +888,7 @@ def _store_thinking_back(
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(1)
)
record = session.exec(statement).first()
if record:
if record := session.exec(statement).first():
record.context = context
record.found_answer = found_answer
record.answer = answer
@@ -957,10 +952,7 @@ async def _process_memory_retrieval(
if is_timeout:
logger.info("ReAct Agent超时不返回结果")
if found_answer and answer:
return answer
return None
return answer if found_answer and answer else None
async def build_memory_retrieval_prompt(
@@ -1013,8 +1005,7 @@ async def build_memory_retrieval_prompt(
cleaned_concepts = []
for word in unknown_words:
if isinstance(word, str):
cleaned = word.strip()
if cleaned:
if cleaned := word.strip():
cleaned_concepts.append(cleaned)
if cleaned_concepts:
# 对匹配到的概念进行jargon检索作为初始信息

View File

@@ -30,9 +30,8 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
return chat_ids
try:
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.utils.utils_session import SessionUtils
chat_manager = get_chat_manager()
for blacklist_item in blacklist:
if not isinstance(blacklist_item, str):
continue
@@ -51,7 +50,10 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
is_group = stream_type == "group"
# 转换为chat_id
chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group)
if is_group:
chat_id = SessionUtils.calculate_session_id(platform, group_id=str(id_str))
else:
chat_id = SessionUtils.calculate_session_id(platform, user_id=str(id_str))
if chat_id:
chat_ids.add(chat_id)
else:
@@ -225,9 +227,9 @@ async def search_chat_history(
if keyword:
keyword_matched = False
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
keywords_list = parse_keywords_string(keyword) or (
[keyword.strip()] if keyword.strip() else []
)
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]

View File

@@ -16,7 +16,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
logger = get_logger("person_info")
@@ -818,22 +818,22 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
chat_id: 聊天ID
"""
try:
# 从chat_id获取chat_stream
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
# 从 chat_id 获取 session
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
logger.warning(f"无法获取session for chat_id: {chat_id}")
return
platform = chat_stream.platform
platform = session.platform
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
person_id = get_person_id_by_person_name(person_name)
if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info
if platform and chat_stream.user_info and chat_stream.user_info.user_id:
user_id = chat_stream.user_info.user_id
# 如果通过person_name找不到尝试从 session 获取 user_id
if platform and session.user_id:
user_id = session.user_id
person_id = get_person_id(platform, user_id)
else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")

View File

@@ -16,7 +16,7 @@ from typing import List, Dict, Any, Optional
from enum import Enum
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
logger = get_logger("chat_api")
@@ -31,7 +31,7 @@ class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有聊天流
@@ -39,7 +39,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 聊天流列表
List[BotChatSession]: 聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
@@ -48,7 +48,7 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
@@ -57,7 +57,7 @@ class ChatManager:
return streams
@staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有群聊聊天流
@@ -65,14 +65,14 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 群聊聊天流列表
List[BotChatSession]: 群聊聊天流列表
"""
if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
@@ -80,7 +80,7 @@ class ChatManager:
return streams
@staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend
"""获取所有私聊聊天流
@@ -88,7 +88,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 私聊聊天流列表
List[BotChatSession]: 私聊聊天流列表
Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
@@ -97,8 +97,10 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
for _, stream in _chat_manager.sessions.items():
if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
@@ -108,7 +110,7 @@ class ChatManager:
@staticmethod
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流
Args:
@@ -116,7 +118,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 group_id 为空字符串
@@ -129,11 +131,11 @@ class ChatManager:
if not group_id:
raise ValueError("group_id 不能为空")
try:
for _, stream in get_chat_manager().streams.items():
for _, stream in _chat_manager.sessions.items():
if (
stream.group_info
and str(stream.group_info.group_id) == str(group_id)
and stream.platform == platform
stream.is_group_session
and str(stream.group_id) == str(group_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream
@@ -145,7 +147,7 @@ class ChatManager:
@staticmethod
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流
Args:
@@ -153,7 +155,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises:
ValueError: 如果 user_id 为空字符串
@@ -166,11 +168,11 @@ class ChatManager:
if not user_id:
raise ValueError("user_id 不能为空")
try:
for _, stream in get_chat_manager().streams.items():
for _, stream in _chat_manager.sessions.items():
if (
not stream.group_info
and str(stream.user_info.user_id) == str(user_id)
and stream.platform == platform
not stream.is_group_session
and str(stream.user_id) == str(user_id)
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream
@@ -180,7 +182,7 @@ class ChatManager:
return None
@staticmethod
def get_stream_type(chat_stream: ChatStream) -> str:
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型
Args:
@@ -190,20 +192,18 @@ class ChatManager:
str: 聊天类型 ("group", "private", "unknown")
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
return "group" if chat_stream.is_group_session else "private"
@staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流详细信息
Args:
@@ -213,36 +213,34 @@ class ChatManager:
Dict ({str: Any}): 聊天流信息字典
Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型
TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空
"""
if not chat_stream:
raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型")
if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 BotChatSession 类型")
try:
info: Dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"session_id": chat_stream.session_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
}
if chat_stream.group_info:
info.update(
{
"group_id": chat_stream.group_info.group_id,
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
}
)
if chat_stream.user_info:
info.update(
{
"user_id": chat_stream.user_info.user_id,
"user_name": chat_stream.user_info.user_nickname,
}
)
if chat_stream.is_group_session:
info["group_id"] = chat_stream.group_id
# Try to get group name from context
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
else:
info["group_name"] = "未知群聊"
else:
info["user_id"] = chat_stream.user_id
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
else:
info["user_name"] = "未知用户"
return info
except Exception as e:
@@ -285,37 +283,37 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str:
def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)

View File

@@ -16,7 +16,7 @@ import uuid
from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config
logger = get_logger("emoji_api")
@@ -56,7 +56,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = image_path_to_base64(emoji_path)
emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
@@ -115,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
results = []
for selected_emoji in selected_emojis:
emoji_base64 = image_path_to_base64(str(selected_emoji.full_path))
emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
@@ -174,7 +174,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 随机选择匹配的表情包
selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
@@ -263,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
if emoji_obj.is_deleted:
continue
emoji_base64 = image_path_to_base64(str(emoji_obj.full_path))
emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
@@ -429,7 +429,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
try:
# 解码base64并保存图片
if not base64_to_image(image_base64, temp_file_path):
if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
return {
"success": False,

View File

@@ -16,7 +16,7 @@ from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplySetModel
from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
@@ -38,7 +38,7 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]:
@@ -79,7 +79,7 @@ def get_replyer(
async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
@@ -161,7 +161,7 @@ async def generate_reply(
unknown_words=unknown_words,
think_level=think_level,
from_plugin=from_plugin,
stream_id=chat_stream.stream_id if chat_stream else chat_id,
stream_id=chat_stream.session_id if chat_stream else chat_id,
reply_time_point=reply_time_point,
log_reply=False,
)
@@ -181,7 +181,7 @@ async def generate_reply(
# 统一在这里记录最终回复日志(包含分割后的 processed_output
try:
PlanReplyLogger.log_reply(
chat_id=chat_stream.stream_id if chat_stream else (chat_id or ""),
chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
prompt=llm_response.prompt or "",
output=llm_response.content,
processed_output=llm_response.processed_output,
@@ -210,7 +210,7 @@ async def generate_reply(
async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
chat_stream: Optional[BotChatSession] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
@@ -302,7 +302,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None,
request_type: str = "generator_api",
prompt: str = "",

View File

@@ -26,10 +26,12 @@ from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
from maim_message import Seg
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
@@ -77,7 +79,7 @@ async def _send_to_target(
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}")
# 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id)
target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
return False
@@ -93,27 +95,29 @@ async def _send_to_target(
bot_user_info = UserInfo(
user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
)
reply_to_platform_id = ""
anchor_message: Union["MessageRecv", None] = None
anchor_message: Optional[MaiMessage] = None
if reply_message:
anchor_message = db_message_to_message_recv(reply_message)
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") # type: ignore
anchor_message = db_message_to_mai_message(reply_message)
if anchor_message:
anchor_message.update_chat_stream(target_stream)
assert anchor_message.message_info.user_info, "用户信息缺失"
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
)
# 构建 sender_info私聊时为接收者信息
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
# 构建发送消息对象
bot_message = MessageSending(
message_id=message_id,
chat_stream=target_stream,
session=target_stream,
bot_user_info=bot_user_info,
sender_info=target_stream.user_info,
sender_info=sender_info,
message_segment=message_segment,
display_message=display_message,
reply=anchor_message,
@@ -146,51 +150,43 @@ async def _send_to_target(
return False
def db_message_to_message_recv(message_obj: "DatabaseMessages") -> MessageRecv:
"""将数据库dict重建为MessageRecv对象
def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
"""将数据库消息重建为 MaiMessage 对象,用于回复引用。
Args:
message_dict: 消息字典
message_obj: 插件系统的 DatabaseMessages 数据对象
Returns:
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
Optional[MaiMessage]: 构建的消息对象,如果信息不足则返回 None
"""
# 构建MessageRecv对象
user_info = {
"platform": message_obj.user_info.platform or "",
"user_id": message_obj.user_info.user_id or "",
"user_nickname": message_obj.user_info.user_nickname or "",
"user_cardname": message_obj.user_info.user_cardname or "",
}
from datetime import datetime
from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
group_info = {}
user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info:
group_info = {
"platform": message_obj.chat_info.group_info.group_platform or "",
"group_id": message_obj.chat_info.group_info.group_id or "",
"group_name": message_obj.chat_info.group_info.group_name or "",
}
group_info = GroupInfo(
group_id=message_obj.chat_info.group_info.group_id or "",
group_name=message_obj.chat_info.group_info.group_name or "",
)
format_info = {"content_format": "", "accept_format": ""}
template_info = {"template_items": {}}
message_info = {
"platform": message_obj.chat_info.platform or "",
"message_id": message_obj.message_id,
"time": message_obj.time,
"group_info": group_info,
"user_info": user_info,
"additional_config": message_obj.additional_config,
"format_info": format_info,
"template_info": template_info,
}
message_dict_recv = {
"message_info": message_info,
"raw_message": message_obj.processed_plain_text,
"processed_plain_text": message_obj.processed_plain_text,
}
return MessageRecv(message_dict_recv)
msg = MaiMessage(
message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
)
msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
msg.platform = message_obj.chat_info.platform or ""
msg.session_id = message_obj.chat_info.stream_id or ""
msg.processed_plain_text = message_obj.processed_plain_text
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
# =============================================================================

View File

@@ -5,12 +5,12 @@ from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]:
def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]:
"""获取公开工具实例
Args:

View File

@@ -6,7 +6,7 @@ from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
@@ -36,7 +36,7 @@ class BaseAction(ABC):
action_reasoning: str,
cycle_timers: dict,
thinking_id: str,
chat_stream: ChatStream,
chat_stream: BotChatSession,
plugin_config: Optional[dict] = None,
action_message: Optional["DatabaseMessages"] = None,
**kwargs,
@@ -92,7 +92,7 @@ class BaseAction(ABC):
# 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id
self.chat_id = self.chat_stream.session_id
self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解)

View File

@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Optional, TYPE_CHECKING, List
from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.message import SessionMessage
from src.plugin_system.apis import send_api
if TYPE_CHECKING:
@@ -31,7 +31,7 @@ class BaseCommand(ABC):
command_pattern: str = r""
"""命令匹配的正则表达式"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None):
"""初始化Command组件
Args:
@@ -107,14 +107,14 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.text_to_stream(
text=content,
stream_id=chat_stream.stream_id,
stream_id=session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
@@ -135,14 +135,14 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.image_to_stream(
image_base64,
chat_stream.stream_id,
session_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
@@ -166,13 +166,13 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.emoji_to_stream(
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message
emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message
)
async def send_command(
@@ -195,9 +195,9 @@ class BaseCommand(ABC):
"""
try:
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
# 构造命令数据
@@ -205,7 +205,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream(
command=command_data,
stream_id=chat_stream.stream_id,
stream_id=session_id,
storage_message=storage_message,
display_message=display_message,
)
@@ -229,15 +229,15 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type="voice",
content=voice_base64,
stream_id=chat_stream.stream_id,
stream_id=session_id,
typing=False,
set_reply=False,
reply_message=None,
@@ -262,15 +262,15 @@ class BaseCommand(ABC):
reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=chat_stream.stream_id,
stream_id=session_id,
typing=typing,
set_reply=set_reply,
reply_message=reply_message,
@@ -293,9 +293,9 @@ class BaseCommand(ABC):
Returns:
bool: 是否发送成功
"""
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = []
@@ -318,7 +318,7 @@ class BaseCommand(ABC):
reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream(
reply_set=reply_set,
stream_id=chat_stream.stream_id,
stream_id=session_id,
storage_message=storage_message,
set_reply=False,
reply_message=None,
@@ -349,15 +349,15 @@ class BaseCommand(ABC):
bool: 是否发送成功
"""
# 获取聊天流信息
chat_stream = self.message.chat_stream
if not chat_stream or not hasattr(chat_stream, "stream_id"):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
session_id = self.message.session_id
if not session_id:
logger.error(f"{self.log_prefix} 缺少session_id")
return False
return await send_api.custom_to_stream(
message_type=message_type,
content=content,
stream_id=chat_stream.stream_id,
stream_id=session_id,
display_message=display_message,
typing=typing,
set_reply=set_reply,

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3)
@@ -32,7 +32,7 @@ class BaseTool(ABC):
available_for_llm: bool = False
"""是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None):
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None):
"""初始化工具基类
Args:
@@ -47,7 +47,7 @@ class BaseTool(ABC):
# 获取聊天流对象
self.chat_stream = chat_stream
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None
self.chat_id = self.chat_stream.session_id if self.chat_stream else None
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
@classmethod

View File

@@ -2,8 +2,8 @@ import asyncio
import contextlib
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv, MessageSending
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageSending, SessionMessage
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
from src.plugin_system.base.base_events_handler import BaseEventHandler
@@ -72,7 +72,7 @@ class EventsManager:
async def handle_mai_events(
self,
event_type: EventType | str,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
@@ -87,7 +87,7 @@ class EventsManager:
# 1. 准备消息
transformed_message = self._prepare_message(
event_type, message, llm_prompt, llm_response, stream_id, action_usage
event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type]
)
if transformed_message:
transformed_message = transformed_message.deepcopy()
@@ -134,7 +134,7 @@ class EventsManager:
async def handle_workflow_message(
self,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None,
@@ -248,11 +248,13 @@ class EventsManager:
def _transform_event_message(
self,
message: MessageRecv | MessageSending,
message: SessionMessage | MessageSending,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages:
"""转换事件消息格式"""
from maim_message import Seg
# 直接赋值部分内容
transformed_message = MaiMessages(
llm_prompt=llm_prompt,
@@ -260,45 +262,62 @@ class EventsManager:
llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.raw_message,
additional_data=message.message_info.additional_config or {},
raw_message=message.processed_plain_text or "",
additional_data={},
)
# 消息段处理
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
if isinstance(message, MessageSending):
if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
else:
transformed_message.message_segments = [message.message_segment]
# SessionMessage: 使用 processed_plain_text 构造简单段
transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id 处理
if hasattr(message, "chat_stream") and message.chat_stream:
transformed_message.stream_id = message.chat_stream.stream_id
transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else ""
# 处理后文本
transformed_message.plain_text = message.processed_plain_text
# 基本信息
if hasattr(message, "message_info") and message.message_info:
if message.message_info.platform:
transformed_message.message_base_info["platform"] = message.message_info.platform
if isinstance(message, MessageSending):
transformed_message.message_base_info["platform"] = message.platform
if message.session.group_id:
transformed_message.is_group_message = True
group_name = ""
if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info:
group_name = message.session.context.message.message_info.group_info.group_name
transformed_message.message_base_info.update({
"group_id": message.session.group_id,
"group_name": group_name,
})
transformed_message.message_base_info.update({
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
})
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed_message.message_base_info["platform"] = message.platform
if message.message_info.group_info:
transformed_message.is_group_message = True
transformed_message.message_base_info.update(
{
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
}
)
transformed_message.message_base_info.update({
"group_id": message.message_info.group_info.group_id,
"group_name": message.message_info.group_info.group_name,
})
if message.message_info.user_info:
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
transformed_message.message_base_info.update(
{
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
}
)
transformed_message.message_base_info.update({
"user_id": message.message_info.user_info.user_id,
"user_cardname": message.message_info.user_info.user_cardname,
"user_nickname": message.message_info.user_info.user_nickname,
})
return transformed_message
@@ -306,9 +325,9 @@ class EventsManager:
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
) -> MaiMessages:
"""从流ID构建消息"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id}聊天流"
message = chat_stream.context.get_last_message()
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id}会话"
message = session.context.message
return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message(
@@ -319,8 +338,8 @@ class EventsManager:
action_usage: Optional[List[str]] = None,
) -> MaiMessages:
"""没有message对象时进行转换"""
chat_stream = get_chat_manager().get_stream(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id}聊天流"
session = _chat_manager.get_session_by_session_id(stream_id)
assert session, f"未找到流ID为 {stream_id}会话"
return MaiMessages(
stream_id=stream_id,
llm_prompt=llm_prompt,
@@ -328,8 +347,8 @@ class EventsManager:
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
llm_response_model=(llm_response.model if llm_response else None),
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
is_group_message=(not (not chat_stream.group_info)),
is_private_message=(not chat_stream.group_info),
is_group_message=session.is_group_session,
is_private_message=not session.is_group_session,
action_usage=action_usage,
additional_data={"response_is_processed": True},
)
@@ -373,7 +392,7 @@ class EventsManager:
def _prepare_message(
self,
event_type: EventType | str,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None,
message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger
logger = get_logger("tool_use")
@@ -28,8 +28,8 @@ class ToolExecutor:
cache_ttl: 缓存生存时间(周期数)
"""
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")

View File

@@ -11,7 +11,7 @@ from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.expression")
@@ -118,14 +118,11 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
session = _chat_manager.get_session_by_session_id(chat_id)
if not session:
return chat_id
if chat_stream.group_info and chat_stream.group_info.group_name:
return chat_stream.group_info.group_name
if chat_stream.user_info and chat_stream.user_info.user_nickname:
return chat_stream.user_info.user_nickname
return chat_id
name = _chat_manager.get_session_name(chat_id)
return name or chat_id
except Exception:
return chat_id
@@ -134,15 +131,9 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_manager = get_chat_manager()
for chat_id in chat_ids:
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
continue
if chat_stream.group_info and chat_stream.group_info.group_name:
result[chat_id] = chat_stream.group_info.group_name
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
result[chat_id] = chat_stream.user_info.user_nickname
if name := _chat_manager.get_session_name(chat_id):
result[chat_id] = name
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@@ -179,17 +170,14 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
verify_auth_token(maibot_session, authorization)
chat_list = []
for stream_id, stream in get_chat_manager().streams.items():
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
if not chat_name and stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname
chat_name = chat_name or stream_id
for session_id, session in _chat_manager.sessions.items():
chat_name = _chat_manager.get_session_name(session_id) or session_id
chat_list.append(
ChatInfo(
chat_id=stream_id,
chat_id=session_id,
chat_name=chat_name,
platform=stream.platform,
is_group=bool(stream.group_info and stream.group_info.group_id),
platform=session.platform,
is_group=session.is_group_session,
)
)
@@ -495,11 +483,10 @@ async def batch_delete_expressions(
# 查找所有要删除的表达方式
with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = [expr_id for expr_id in session.exec(statements).all()]
found_ids = list(session.exec(statements).all())
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
if not_found_ids := set(request.ids) - set(found_ids):
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
@@ -800,7 +787,7 @@ async def batch_review_expressions(
session.add(db_expression)
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
BatchReviewResultItem(id=item.id, success=True, message="拒绝" if item.rejected else "通过")
)
succeeded += 1