Refactor chat stream handling to use BotChatSession

- Updated imports and references from ChatStream to BotChatSession across multiple files.
- Adjusted method signatures and internal logic to accommodate the new session management.
- Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
DrSmoothl
2026-03-07 00:57:33 +08:00
parent 8712fc0d05
commit 2e3dd44ee9
43 changed files with 706 additions and 563 deletions

View File

@@ -19,10 +19,10 @@ sys.path.insert(0, project_root)
try: try:
from src.common.database.database_model import ChatStreams from src.common.database.database_model import ChatStreams
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _script_chat_manager
except ImportError: except ImportError:
ChatStreams = None ChatStreams = None
get_chat_manager = None _script_chat_manager = None
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:

View File

@@ -23,7 +23,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import initialize_logging, get_logger from src.common.logger import initialize_logging, get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import LLMUsage from src.common.database.database_model import LLMUsage
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
from maim_message import UserInfo, GroupInfo from maim_message import UserInfo, GroupInfo
logger = get_logger("test_memory_retrieval") logger = get_logger("test_memory_retrieval")

View File

@@ -12,7 +12,7 @@ from src.chat.utils.chat_message_builder import (
build_anonymous_messages, build_anonymous_messages,
) )
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.learner_utils import ( from src.bw_learner.learner_utils import (
filter_message_content, filter_message_content,
is_bot_message, is_bot_message,
@@ -42,8 +42,8 @@ class ExpressionLearner:
) )
self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化 self.check_model: Optional[LLMRequest] = None # 检查用的 LLM 实例,延迟初始化
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 学习锁,防止并发执行学习任务 # 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock() self._learning_lock = asyncio.Lock()

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import weighted_sample from src.bw_learner.learner_utils import weighted_sample
from src.chat.message_receive.chat_stream import get_chat_manager from src.common.utils.utils_session import SessionUtils
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@@ -50,8 +50,9 @@ class ExpressionSelector:
id_str = parts[1] id_str = parts[1]
stream_type = parts[2] stream_type = parts[2]
is_group = stream_type == "group" is_group = stream_type == "group"
# 统一通过 chat_manager 生成 stream_id避免各处自行实现哈希逻辑 return SessionUtils.calculate_session_id(
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) platform, group_id=str(id_str) if is_group else None, user_id=None if is_group else str(id_str)
)
except Exception: except Exception:
return None return None
@@ -127,8 +128,7 @@ class ExpressionSelector:
logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择") logger.info(f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择")
# 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程 # 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程
fallback_num = min(3, max_num) if max_num > 0 else 3 fallback_num = min(3, max_num) if max_num > 0 else 3
fallback_selected = self._random_expressions(chat_id, fallback_num) if fallback_selected := self._random_expressions(chat_id, fallback_num):
if fallback_selected:
self.update_expressions_last_active_time(fallback_selected) self.update_expressions_last_active_time(fallback_selected)
selected_ids = [expr["id"] for expr in fallback_selected] selected_ids = [expr["id"] for expr in fallback_selected]
logger.info( logger.info(
@@ -199,12 +199,7 @@ class ExpressionSelector:
] ]
# 随机抽样 # 随机抽样
if style_exprs: return weighted_sample(style_exprs, total_num) if style_exprs else []
selected_style = weighted_sample(style_exprs, total_num)
else:
selected_style = []
return selected_style
except Exception as e: except Exception as e:
logger.error(f"随机选择表达方式失败: {e}") logger.error(f"随机选择表达方式失败: {e}")

View File

@@ -10,7 +10,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Jargon from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.bw_learner.learner_utils import ( from src.bw_learner.learner_utils import (
parse_chat_id_list, parse_chat_id_list,
@@ -99,9 +99,9 @@ class JargonMiner:
) )
# 初始化stream_name作为类属性避免重复提取 # 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager() chat_manager = _chat_manager
stream_name = chat_manager.get_stream_name(self.chat_id) stream_name = chat_manager.get_session_name(self.chat_id)
self.stream_name = stream_name if stream_name else self.chat_id self.stream_name = stream_name or self.chat_id
self.cache_limit = 50 self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict() self.cache: OrderedDict[str, None] = OrderedDict()

View File

@@ -2,7 +2,7 @@ import time
import asyncio import asyncio
from typing import List, Any from typing import List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.common_utils import TempMethodsExpression from src.chat.utils.common_utils import TempMethodsExpression
from src.bw_learner.expression_learner import expression_learner_manager from src.bw_learner.expression_learner import expression_learner_manager
@@ -18,8 +18,8 @@ class MessageRecorder:
def __init__(self, chat_id: str) -> None: def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = _chat_manager.get_session_by_session_id(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id self.chat_name = _chat_manager.get_session_name(chat_id) or chat_id
# 维护每个chat的上次提取时间 # 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time() self.last_extraction_time: float = time.time()

View File

@@ -5,20 +5,19 @@ from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.config.config import model_config from src.config.config import model_config
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
build_readable_messages, build_readable_messages,
) )
if TYPE_CHECKING:
pass
logger = get_logger("reflect_tracker") logger = get_logger("reflect_tracker")
class ReflectTracker: class ReflectTracker:
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float): def __init__(self, chat_stream: BotChatSession, expression: Expression, created_time: float):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.expression = expression self.expression = expression
self.created_time = created_time self.created_time = created_time
@@ -42,7 +41,7 @@ class ReflectTracker:
# Fetch messages since creation # Fetch messages since creation
msg_list = get_raw_msg_by_timestamp_with_chat( msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
timestamp_start=self.created_time, timestamp_start=self.created_time,
timestamp_end=time.time(), timestamp_end=time.time(),
) )
@@ -90,10 +89,7 @@ class ReflectTracker:
from json_repair import repair_json from json_repair import repair_json
json_pattern = r"```json\s*(.*?)\s*```" json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL) matches = re.findall(json_pattern, response, re.DOTALL) or [response]
if not matches:
# Try to parse raw response if no code block
matches = [response]
json_obj = json.loads(repair_json(matches[0])) json_obj = json.loads(repair_json(matches[0]))
@@ -122,10 +118,7 @@ class ReflectTracker:
self.expression.style = corrected_style self.expression.style = corrected_style
# 如果拒绝但未更新,标记为 rejected=1 # 如果拒绝但未更新,标记为 rejected=1
if not has_update: self.expression.rejected = not has_update
self.expression.rejected = True
else:
self.expression.rejected = False
self.expression.save() self.expression.save()

View File

@@ -4,10 +4,10 @@ MaiBot模块系统
""" """
from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager
# 导出主要组件供外部使用 # 导出主要组件供外部使用
__all__ = [ __all__ = [
"get_chat_manager", "chat_manager",
"emoji_manager", "emoji_manager",
] ]

View File

@@ -7,7 +7,7 @@ from src.chat.utils.chat_message_builder import build_readable_messages, get_raw
# from src.config.config import global_config # from src.config.config import global_config
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
from src.chat.message_receive.message import Message from src.common.data_models.mai_message_data_model import MaiMessage
from .pfc_types import ConversationState from .pfc_types import ConversationState
from .pfc import ChatObserver, GoalAnalyzer from .pfc import ChatObserver, GoalAnalyzer
from .message_sender import DirectMessageSender from .message_sender import DirectMessageSender
@@ -16,9 +16,8 @@ from .action_planner import ActionPlanner
from .observation_info import ObservationInfo from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
from .reply_generator import ReplyGenerator from .reply_generator import ReplyGenerator
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter from .waiter import Waiter
@@ -60,7 +59,7 @@ class Conversation:
self.direct_sender = DirectMessageSender(self.private_name) self.direct_sender = DirectMessageSender(self.private_name)
# 获取聊天流信息 # 获取聊天流信息
self.chat_stream = get_chat_manager().get_stream(self.stream_id) self.chat_stream = _chat_manager.get_session_by_session_id(self.stream_id)
self.stop_action_planner = False self.stop_action_planner = False
except Exception as e: except Exception as e:
@@ -265,34 +264,34 @@ class Conversation:
return True return True
return False return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: def _convert_to_message(self, msg_dict: Dict[str, Any]) -> MaiMessage:
"""将消息字典转换为Message对象""" """将消息字典转换为MaiMessage对象"""
from datetime import datetime as dt
from src.common.data_models.mai_message_data_model import UserInfo as MaiUserInfo, MessageInfo
from src.common.data_models.message_component_data_model import MessageSequence
try: try:
# 尝试从 msg_dict 直接获取 chat_stream如果失败则从全局 get_chat_manager 获取 user_info_dict = msg_dict.get("user_info", {})
chat_info = msg_dict.get("chat_info") user_info = MaiUserInfo(
if chat_info and isinstance(chat_info, dict): user_id=user_info_dict.get("user_id", ""),
chat_stream = ChatStream.from_dict(chat_info) user_nickname=user_info_dict.get("user_nickname", ""),
elif self.chat_stream: # 使用实例变量中的 chat_stream user_cardname=user_info_dict.get("user_cardname"),
chat_stream = self.chat_stream
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
chat_stream = get_chat_manager().get_stream(self.stream_id)
if not chat_stream:
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
chat_stream=chat_stream, # 使用确定的 chat_stream
time=msg_dict.get("time", time.time()), # 提供默认时间
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
) )
msg = MaiMessage(
message_id=msg_dict.get("message_id", f"gen_{time.time()}"),
timestamp=dt.fromtimestamp(msg_dict.get("time", time.time())),
)
msg.message_info = MessageInfo(user_info=user_info)
msg.platform = user_info_dict.get("platform", "")
msg.session_id = self.stream_id
msg.processed_plain_text = msg_dict.get("processed_plain_text", "")
msg.raw_message = MessageSequence(components=[])
msg.initialized = True
return msg
except Exception as e: except Exception as e:
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}") logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题 raise ValueError(f"无法将字典转换为 MaiMessage 对象: {e}") from e
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
async def _handle_action( async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
@@ -687,7 +686,7 @@ class Conversation:
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。") logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
return return
if not self.chat_stream: if not self.chat_stream:
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。") logger.error(f"[私聊][{self.private_name}]会话未初始化,无法发送回复。")
return return
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content) await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)

View File

@@ -1,10 +1,12 @@
import time import time
from typing import Optional from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream from maim_message import Seg
from src.chat.message_receive.message import Message, MessageSending
from maim_message import UserInfo, Seg from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.config.config import global_config from src.config.config import global_config
from rich.traceback import install from rich.traceback import install
@@ -19,18 +21,17 @@ class DirectMessageSender:
def __init__(self, private_name: str): def __init__(self, private_name: str):
self.private_name = private_name self.private_name = private_name
self.storage = MessageStorage()
async def send_message( async def send_message(
self, self,
chat_stream: ChatStream, chat_stream: BotChatSession,
content: str, content: str,
reply_to_message: Optional[Message] = None, reply_to_message: Optional[MaiMessage] = None,
) -> None: ) -> None:
"""发送消息到聊天流 """发送消息到聊天流
Args: Args:
chat_stream: 聊天 chat_stream: 聊天会话
content: 消息内容 content: 消息内容
reply_to_message: 要回复的消息(可选) reply_to_message: 要回复的消息(可选)
""" """
@@ -42,18 +43,22 @@ class DirectMessageSender:
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.bot.qq_account, user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
platform=chat_stream.platform,
) )
# 用当前时间作为message_id和之前那套sender一样 # 用当前时间作为message_id和之前那套sender一样
message_id = f"dm{round(time.time(), 2)}" message_id = f"dm{round(time.time(), 2)}"
# 构建发送者信息(私聊时为接收者)
sender_info = None
if reply_to_message and reply_to_message.message_info and reply_to_message.message_info.user_info:
sender_info = reply_to_message.message_info.user_info
# 构建消息对象 # 构建消息对象
message = MessageSending( message = MessageSending(
message_id=message_id, message_id=message_id,
chat_stream=chat_stream, session=chat_stream,
bot_user_info=bot_user_info, bot_user_info=bot_user_info,
sender_info=reply_to_message.message_info.user_info if reply_to_message else None, sender_info=sender_info,
message_segment=segments, message_segment=segments,
reply=reply_to_message, reply=reply_to_message,
is_head=True, is_head=True,
@@ -61,17 +66,11 @@ class DirectMessageSender:
thinking_start_time=time.time(), thinking_start_time=time.time(),
) )
# 处理消息 # 发送消息
await message.process() message_sender = UniversalMessageSender()
sent = await message_sender.send_message(message, typing=False, set_reply=False, storage_message=True)
# 发送消息(直接调用底层 API
from src.chat.message_receive.uni_message_sender import _send_message
sent = await _send_message(message, show_log=True)
if sent: if sent:
# 存储消息
await self.storage.store_message(message, chat_stream)
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}") logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
else: else:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败") logger.error(f"[私聊][{self.private_name}]PFC消息发送失败")

View File

@@ -9,7 +9,7 @@ from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.info_data_model import ActionPlannerInfo from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.message_data_model import ReplyContentType from src.common.data_models.message_data_model import ReplyContentType
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.brain_chat.brain_planner import BrainPlanner from src.chat.brain_chat.brain_planner import BrainPlanner
@@ -73,10 +73,10 @@ class BrainChatting:
""" """
# 基础属性 # 基础属性
self.stream_id: str = chat_id # 聊天流ID self.stream_id: str = chat_id # 聊天流ID
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.stream_id) # type: ignore self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.stream_id) # type: ignore
if not self.chat_stream: if not self.chat_stream:
raise ValueError(f"无法找到聊天流: {self.stream_id}") raise ValueError(f"无法找到聊天流: {self.stream_id}")
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.stream_id) or self.stream_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(self.stream_id) or self.stream_id}]"
self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id) self.expression_learner = expression_learner_manager.get_expression_learner(self.stream_id)

View File

@@ -22,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
) )
from src.chat.utils.utils import get_chat_type_and_target_info from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
@@ -38,7 +38,7 @@ install(extra_lines=3)
class BrainPlanner: class BrainPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager): def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id self.chat_id = chat_id
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager
# LLM规划器配置 # LLM规划器配置
self.planner_llm = LLMRequest( self.planner_llm = LLMRequest(

View File

@@ -5,9 +5,8 @@ import time
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
from maim_message.message_base import GroupInfo
from src.common.message_repository import count_messages from src.common.message_repository import count_messages
@@ -121,28 +120,24 @@ def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None)
async def send_typing(): async def send_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") chat = await _chat_manager.get_or_create_session(
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", platform="amaidesu_default",
user_info=None, user_id="114514",
group_info=group_info, group_id="114514",
) )
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="typing", stream_id=chat.session_id, storage_message=False
) )
async def stop_typing(): async def stop_typing():
group_info = GroupInfo(platform="amaidesu_default", group_id="114514", group_name="内心") chat = await _chat_manager.get_or_create_session(
chat = await get_chat_manager().get_or_create_stream(
platform="amaidesu_default", platform="amaidesu_default",
user_info=None, user_id="114514",
group_info=group_info, group_id="114514",
) )
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="state", content="stop_typing", stream_id=chat.stream_id, storage_message=False message_type="state", content="stop_typing", stream_id=chat.session_id, storage_message=False
) )

View File

@@ -8,7 +8,6 @@ from typing import Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.utils.utils_message import MessageUtils from src.common.utils.utils_message import MessageUtils
from src.common.utils.utils_session import SessionUtils from src.common.utils.utils_session import SessionUtils
from src.chat.message_receive.message_old import MessageRecv
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.brain_chat.PFC.pfc_manager import PFCManager from src.chat.brain_chat.PFC.pfc_manager import PFCManager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@@ -41,14 +40,14 @@ class ChatBot:
self._started = True self._started = True
async def _create_pfc_chat(self, message: MessageRecv): async def _create_pfc_chat(self, message: SessionMessage):
"""创建或获取PFC对话实例 """创建或获取PFC对话实例
Args: Args:
message: 消息对象 message: 消息对象
""" """
try: try:
chat_id = str(message.chat_stream.stream_id) chat_id = message.session_id
private_name = str(message.message_info.user_info.user_nickname) private_name = str(message.message_info.user_info.user_nickname)
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}") logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
@@ -177,12 +176,12 @@ class ChatBot:
logger.error(f"[新运行时] 执行命令 {matched.full_name} 异常: {e}", exc_info=True) logger.error(f"[新运行时] 执行命令 {matched.full_name} 异常: {e}", exc_info=True)
return True, str(e), True return True, str(e), True
async def handle_notice_message(self, message: MessageRecv): async def handle_notice_message(self, message: SessionMessage):
if message.message_info.message_id == "notice": if message.message_id == "notice":
message.is_notify = True message.is_notify = True
logger.debug("notice消息") logger.debug("notice消息")
try: try:
seg = message.message_segment seg = getattr(message, "message_segment", None) # SessionMessage 没有 message_segment
mi = message.message_info mi = message.message_info
sub_type = None sub_type = None
scene = None scene = None
@@ -246,10 +245,8 @@ class ChatBot:
return return
mmc_message_id = message_data.get("echo") mmc_message_id = message_data.get("echo")
actual_message_id = message_data.get("actual_id") actual_message_id = message_data.get("actual_id")
if MessageStorage.update_message(mmc_message_id, actual_message_id): # TODO: Implement message ID update in new architecture
logger.debug(f"更新消息ID成功: {mmc_message_id} -> {actual_message_id}") logger.debug(f"收到回送消息ID: {mmc_message_id} -> {actual_message_id}")
else:
logger.warning(f"更新消息ID失败: {mmc_message_id} -> {actual_message_id}")
async def message_process(self, message_data: Dict[str, Any]) -> None: async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息 """处理转化后的统一格式消息

View File

@@ -1,14 +1,23 @@
from asyncio import Task from asyncio import Task
from datetime import datetime
from maim_message import (
MessageBase,
UserInfo as MaimUserInfo,
GroupInfo as MaimGroupInfo,
BaseMessageInfo as MaimBaseMessageInfo,
Seg,
)
from rich.traceback import install from rich.traceback import install
from sqlmodel import select from sqlmodel import select
from typing import List, Dict, Tuple, Sequence from typing import List, Dict, Optional, Tuple, Sequence, TYPE_CHECKING
import asyncio import asyncio
import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import Messages from src.common.database.database_model import Messages
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo, GroupInfo, MessageInfo
from src.common.data_models.message_component_data_model import ( from src.common.data_models.message_component_data_model import (
TextComponent, TextComponent,
ImageComponent, ImageComponent,
@@ -19,6 +28,10 @@ from src.common.data_models.message_component_data_model import (
ForwardNodeComponent, ForwardNodeComponent,
StandardMessageComponents, StandardMessageComponents,
) )
from src.common.utils.utils_message import MessageUtils
if TYPE_CHECKING:
from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3) install(extra_lines=3)
@@ -207,3 +220,166 @@ class SessionMessage(MaiMessage):
else: else:
processed_texts.append(result) processed_texts.append(result)
return " ".join(processed_texts) return " ".join(processed_texts)
class MessageSending(MaiMessage):
"""发送状态的消息类,继承 MaiMessage 基类。
用于构建、处理和发送机器人的回复消息。
复用 MaiMessage 的 to_maim_message() 和 to_db_instance() 方法,
额外管理发送专属的会话信息和控制字段。
"""
def __init__(
self,
message_id: str,
session: "BotChatSession",
bot_user_info: UserInfo,
message_segment: Seg,
sender_info: Optional[UserInfo] = None,
reply: Optional[MaiMessage] = None,
display_message: str = "",
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
reply_to: Optional[str] = None,
selected_expressions: Optional[List[int]] = None,
):
# 初始化 MaiMessage 基类
super().__init__(message_id=message_id, timestamp=datetime.now())
# 发送专属字段
self.session = session
self.sender_info = sender_info
self.message_segment = message_segment
self.reply = reply
self.is_head = is_head
self.thinking_start_time = thinking_start_time
self.selected_expressions = selected_expressions
self.reply_to_message_id: Optional[str] = reply.message_id if reply else None
self.interest_value: float = 0.0
# 填充 MaiMessage 标准字段
self.platform = session.platform
self.session_id = session.session_id
self.is_emoji = is_emoji
self.reply_to = reply_to
self.display_message = display_message
self.processed_plain_text = ""
# 构建 message_infoDB 存储时 user_info 始终为 bot 信息
# 私聊/群聊的 user_info 差异仅在 to_maim_message() 覆写中处理
group_info = self._resolve_group_info()
self.message_info = MessageInfo(user_info=bot_user_info, group_info=group_info)
# bot_user_info 单独保存to_maim_message 覆写时还需要
self.bot_user_info = bot_user_info
# 将 Seg 转换为 MessageSequence供基类的 to_db_instance / to_maim_message 使用
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=message_segment)
)
self.initialized = True
def _resolve_group_info(self) -> Optional[GroupInfo]:
"""从 session 中解析群信息"""
if not self.session.group_id:
return None
group_name = ""
if (
self.session.context
and self.session.context.message
and self.session.context.message.message_info.group_info
):
group_name = self.session.context.message.message_info.group_info.group_name
return GroupInfo(group_id=self.session.group_id, group_name=group_name)
async def process(self) -> None:
"""处理消息段,生成 processed_plain_text使用 SessionMessage 的组件处理能力)"""
# 同步 message_segment → raw_message插件可能修改了 message_segment
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
if self.raw_message and self.raw_message.components:
tasks = [self._process_component(c) for c in self.raw_message.components]
results = await asyncio.gather(*tasks, return_exceptions=True)
texts = []
for r in results:
if isinstance(r, BaseException):
logger.error(f"处理发送消息组件时发生错误: {r}")
elif r:
texts.append(r)
self.processed_plain_text = " ".join(texts)
async def _process_component(self, component: StandardMessageComponents) -> str:
"""简单处理单个标准组件为纯文本描述"""
if isinstance(component, TextComponent):
return component.text
elif isinstance(component, ImageComponent):
return "[图片]"
elif isinstance(component, EmojiComponent):
return "[表情包]"
elif isinstance(component, VoiceComponent):
return "[语音]"
elif isinstance(component, AtComponent):
return f"[@{component.target_user_id}]"
elif isinstance(component, ReplyComponent):
return ""
else:
return f"[{type(component).__name__}]"
def build_reply(self) -> None:
"""构建回复消息段,在 message_segment 前插入 reply 段"""
if self.reply:
self.reply_to_message_id = self.reply.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_id),
self.message_segment,
],
)
# 同步更新 raw_message
self.raw_message = MessageUtils.from_maim_message_segments_to_MaiSeq(
MessageBase(message_info=None, message_segment=self.message_segment)
)
async def to_maim_message(self) -> MessageBase:
"""覆写基类方法:发送消息需要特殊处理 user_info私聊/群聊差异)"""
maim_bot_user_info = MaimUserInfo(
user_id=self.bot_user_info.user_id,
user_nickname=self.bot_user_info.user_nickname,
user_cardname=self.bot_user_info.user_cardname,
platform=self.platform,
)
maim_group_info = None
if self.message_info.group_info:
maim_group_info = MaimGroupInfo(
group_id=self.message_info.group_info.group_id,
group_name=self.message_info.group_info.group_name,
platform=self.platform,
)
# 私聊时 user_info 填接收者信息sender_info群聊时填 bot
if maim_group_info is None and self.sender_info:
msg_user_info = MaimUserInfo(
user_id=self.sender_info.user_id,
user_nickname=self.sender_info.user_nickname,
user_cardname=self.sender_info.user_cardname,
platform=self.platform,
)
else:
msg_user_info = maim_bot_user_info
maim_msg_info = MaimBaseMessageInfo(
platform=self.platform,
message_id=self.message_id,
time=time.time(),
group_info=maim_group_info,
user_info=msg_user_info,
)
msg_segments = await MessageUtils.from_MaiSeq_to_maim_message_segments(self.raw_message)
return MessageBase(message_info=maim_msg_info, message_segment=Seg(type="seglist", data=msg_segments))

View File

@@ -6,8 +6,8 @@ from maim_message import Seg
from src.common.message_server.api import get_global_api from src.common.message_server.api import get_global_api
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.chat.message_receive.message import MessageSending from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time from src.chat.utils.utils import calculate_typing_time
@@ -130,8 +130,8 @@ def parse_message_segments(segment) -> list:
async def _send_message(message: MessageSending, show_log=True) -> bool: async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录""" """合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=200) message_preview = truncate_message(message.processed_plain_text, max_length=200)
platform = message.message_info.platform platform = message.platform
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None group_id = message.session.group_id
try: try:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息 # 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
@@ -221,33 +221,14 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
raise legacy_exception raise legacy_exception
return False return False
# 使用 MessageConverter 转换 Legacy MessageBase 到 APIMessageBase # 使用 MessageConverter 转换为 API 消息
# 发送场景MaiMBot 发送回复消息给外部用户
# group_info/user_info 是消息接收者信息,放入 receiver_info
from maim_message import MessageConverter from maim_message import MessageConverter
# 修复 API Server Fallback 模式下的 user_info 问题 # 新架构:通过 to_maim_message() 转换,内部已处理私聊/群聊的 user_info 差异
# 在 Legacy 模式下MessageSending.to_dict() 的第 454 行会将 user_info 替换为 chat_stream.user_info message_base = await message.to_maim_message()
# 但在 API Server Fallback 模式下MessageConverter.to_api_send() 直接访问 message 对象,不调用 to_dict()
# 需要手动应用相同的变通方案在私聊场景下user_info 应该是接收者sender_info
message_for_conversion = message
if hasattr(message, "message_info") and message.message_info.group_info is None:
# 私聊场景group_info 为 None
# user_info 应该是接收者,从 chat_stream.user_info 或 sender_info 获取
temp_dict = message.to_dict()
if (
hasattr(message, "chat_stream")
and message.chat_stream
and hasattr(message.chat_stream, "user_info")
):
temp_dict["message_info"]["user_info"] = message.chat_stream.user_info.to_dict()
# 重新构建 MessageBase 对象(不保留 sender_info 等扩展属性)
from maim_message import MessageBase
message_for_conversion = MessageBase.from_dict(temp_dict)
api_message = MessageConverter.to_api_send( api_message = MessageConverter.to_api_send(
message=message_for_conversion, message=message_base,
api_key=target_api_key, api_key=target_api_key,
platform=platform, platform=platform,
) )
@@ -278,10 +259,11 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return False return False
try: try:
send_result = await get_global_api().send_message(message) message_base = await message.to_maim_message()
send_result = await get_global_api().send_message(message_base)
if send_result: if send_result:
if show_log: if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'") logger.info(f"已将消息 '{message_preview}' 发往平台'{message.platform}'")
return True return True
else: else:
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback # Legacy API 返回 False (发送失败但未报错),尝试 Fallback
@@ -297,7 +279,7 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
return await send_with_new_api(legacy_exception=legacy_e) return await send_with_new_api(legacy_exception=legacy_e)
except Exception as e: except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}") logger.error(f"发送消息 '{message_preview}' 发往平台'{message.platform}' 失败: {str(e)}")
traceback.print_exc() traceback.print_exc()
raise e # 重新抛出其他异常 raise e # 重新抛出其他异常
@@ -306,7 +288,7 @@ class UniversalMessageSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。""" """管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self): def __init__(self):
self.storage = MessageStorage() pass
async def send_message( async def send_message(
self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True self, message: MessageSending, typing=False, set_reply=False, storage_message=True, show_log=True
@@ -321,15 +303,15 @@ class UniversalMessageSender:
用法: 用法:
- typing=True 时,发送前会有打字等待。 - typing=True 时,发送前会有打字等待。
""" """
if not message.chat_stream: if not message.session:
logger.error("消息缺少 chat_stream,无法发送") logger.error("消息缺少 session,无法发送")
raise ValueError("消息缺少 chat_stream,无法发送") raise ValueError("消息缺少 session,无法发送")
if not message.message_info or not message.message_info.message_id: if not message.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送") logger.error("消息缺少 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送") raise ValueError("消息缺少 message_id无法发送")
chat_id = message.chat_stream.stream_id chat_id = message.session_id
message_id = message.message_info.message_id message_id = message.message_id
try: try:
if set_reply: if set_reply:
@@ -391,7 +373,8 @@ class UniversalMessageSender:
message.processed_plain_text = modified_message.plain_text message.processed_plain_text = modified_message.plain_text
if storage_message: if storage_message:
await self.storage.store_message(message, message.chat_stream) with get_db_session() as db_session:
db_session.add(message.to_db_instance())
return sent_msg return sent_msg

View File

@@ -1,6 +1,6 @@
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
@@ -35,7 +35,7 @@ class ActionManager:
action_reasoning: str, action_reasoning: str,
cycle_timers: dict, cycle_timers: dict,
thinking_id: str, thinking_id: str,
chat_stream: ChatStream, chat_stream: BotChatSession,
log_prefix: str, log_prefix: str,
shutting_down: bool = False, shutting_down: bool = False,
action_message: Optional[DatabaseMessages] = None, action_message: Optional[DatabaseMessages] = None,

View File

@@ -4,15 +4,12 @@ from typing import List, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
from src.plugin_system.base.component_types import ActionInfo, ActionActivationType from src.plugin_system.base.component_types import ActionInfo, ActionActivationType
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager") logger = get_logger("action_manager")
@@ -27,8 +24,8 @@ class ActionModifier:
def __init__(self, action_manager: ActionManager, chat_id: str): def __init__(self, action_manager: ActionManager, chat_id: str):
"""初始化动作处理器""" """初始化动作处理器"""
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream: ChatStream = get_chat_manager().get_stream(self.chat_id) # type: ignore self.chat_stream: BotChatSession = _chat_manager.get_session_by_session_id(self.chat_id) # type: ignore
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager
@@ -121,7 +118,7 @@ class ActionModifier:
available_actions_text = "".join(available_actions) if available_actions else "" available_actions_text = "".join(available_actions) if available_actions else ""
logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}") logger.debug(f"{self.log_prefix} 当前可用动作: {available_actions_text}||移除: {removals_summary}")
def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: ChatMessageContext): def _check_action_associated_types(self, all_actions: Dict[str, ActionInfo], chat_context: BotChatSession):
type_mismatched_actions: List[Tuple[str, str]] = [] type_mismatched_actions: List[Tuple[str, str]] = []
for action_name, action_info in all_actions.items(): for action_name, action_info in all_actions.items():
if action_info.associated_types and not chat_context.check_types(action_info.associated_types): if action_info.associated_types and not chat_context.check_types(action_info.associated_types):

View File

@@ -3,6 +3,7 @@ import time
import traceback import traceback
import random import random
import re import re
import contextlib
from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union from typing import Dict, Optional, Tuple, List, TYPE_CHECKING, Union
from collections import OrderedDict from collections import OrderedDict
from rich.traceback import install from rich.traceback import install
@@ -21,7 +22,7 @@ from src.chat.utils.chat_message_builder import (
) )
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.apis.message_api import translate_pid_to_description from src.plugin_system.apis.message_api import translate_pid_to_description
@@ -39,7 +40,7 @@ install(extra_lines=3)
class ActionPlanner: class ActionPlanner:
def __init__(self, chat_id: str, action_manager: ActionManager): def __init__(self, chat_id: str, action_manager: ActionManager):
self.chat_id = chat_id self.chat_id = chat_id
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(chat_id) or chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager
# LLM规划器配置 # LLM规划器配置
self.planner_llm = LLMRequest( self.planner_llm = LLMRequest(
@@ -80,7 +81,7 @@ class ActionPlanner:
if not text: if not text:
return text return text
id_to_message = {msg_id: msg for msg_id, msg in message_id_list} id_to_message = dict(message_id_list)
# 匹配m后带2-4位数字前后不是字母数字下划线 # 匹配m后带2-4位数字前后不是字母数字下划线
pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])" pattern = r"(?<![A-Za-z0-9_])m\d{2,4}(?![A-Za-z0-9_])"
@@ -223,7 +224,7 @@ class ActionPlanner:
action_data=action_data, action_data=action_data,
action_message=target_message, action_message=target_message,
available_actions=available_actions_dict, available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None, action_reasoning=extracted_reasoning or None,
) )
) )
@@ -238,7 +239,7 @@ class ActionPlanner:
action_data={}, action_data={},
action_message=None, action_message=None,
available_actions=available_actions_dict, available_actions=available_actions_dict,
action_reasoning=extracted_reasoning if extracted_reasoning else None, action_reasoning=extracted_reasoning or None,
) )
) )
@@ -292,8 +293,7 @@ class ActionPlanner:
if new_words: if new_words:
for word in new_words: for word in new_words:
if isinstance(word, str): if isinstance(word, str):
word = word.strip() if word := word.strip():
if word:
cleaned_new_words.append(word) cleaned_new_words.append(word)
# 获取缓存中的黑话列表 # 获取缓存中的黑话列表
@@ -351,10 +351,9 @@ class ActionPlanner:
break break
# 如果当前 plan 的 reply 没有提取移除最老的1个 # 如果当前 plan 的 reply 没有提取移除最老的1个
if not has_extracted_unknown_words: if not has_extracted_unknown_words and len(self.unknown_words_cache) > 0:
if len(self.unknown_words_cache) > 0: self.unknown_words_cache.popitem(last=False)
self.unknown_words_cache.popitem(last=False) logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
logger.debug(f"{self.log_prefix}当前 plan 的 reply 没有提取黑话移除最老的1个缓存")
# 对于每个 reply action合并缓存和新提取的黑话 # 对于每个 reply action合并缓存和新提取的黑话
for action in actions: for action in actions:
@@ -363,10 +362,7 @@ class ActionPlanner:
new_words = action_data.get("unknown_words") new_words = action_data.get("unknown_words")
# 合并新提取的和缓存的黑话列表 # 合并新提取的和缓存的黑话列表
merged_words = self._merge_unknown_words_with_cache(new_words) if merged_words := self._merge_unknown_words_with_cache(new_words):
# 更新 action_data
if merged_words:
action_data["unknown_words"] = merged_words action_data["unknown_words"] = merged_words
logger.debug( logger.debug(
f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个," f"{self.log_prefix}合并黑话:新提取 {len(new_words) if new_words else 0} 个,"
@@ -449,15 +445,12 @@ class ActionPlanner:
# 如果有强制回复消息,确保回复该消息 # 如果有强制回复消息,确保回复该消息
if force_reply_message: if force_reply_message:
# 检查是否已经有回复该消息的 action # 检查是否已经有回复该消息的 action
has_reply_to_force_message = False has_reply_to_force_message = any(
for action in actions: action.action_type == "reply"
if ( and action.action_message
action.action_type == "reply" and action.action_message.message_id == force_reply_message.message_id
and action.action_message for action in actions
and action.action_message.message_id == force_reply_message.message_id )
):
has_reply_to_force_message = True
break
# 如果没有回复该消息,强制添加回复 action # 如果没有回复该消息,强制添加回复 action
if not has_reply_to_force_message: if not has_reply_to_force_message:
@@ -532,13 +525,10 @@ class ActionPlanner:
# 从后往前遍历,收集最新的记录 # 从后往前遍历,收集最新的记录
for reasoning, timestamp, content in reversed(self.plan_log): for reasoning, timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content): if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
# 这是action记录
if len(action_records) < max_action_records: if len(action_records) < max_action_records:
action_records.append((reasoning, timestamp, content, "action")) action_records.append((reasoning, timestamp, content, "action"))
else: elif len(execution_records) < max_execution_records:
# 这是执行结果记录 execution_records.append((reasoning, timestamp, content, "execution"))
if len(execution_records) < max_execution_records:
execution_records.append((reasoning, timestamp, content, "execution"))
# 合并所有记录并按时间戳排序 # 合并所有记录并按时间戳排序
all_records = action_records + execution_records all_records = action_records + execution_records
@@ -700,15 +690,9 @@ class ActionPlanner:
param_text = param_text.rstrip("\n") param_text = param_text.rstrip("\n")
# 构建要求文本 # 构建要求文本
require_text = "" require_text = "\n".join(f"- {require_item}" for require_item in action_info.action_require)
for require_item in action_info.action_require:
require_text += f"- {require_item}\n"
require_text = require_text.rstrip("\n")
if not action_info.parallel_action: parallel_text = "" if action_info.parallel_action else "(当选择这个动作时,请不要选择其他动作)"
parallel_text = "(当选择这个动作时,请不要选择其他动作)"
else:
parallel_text = ""
# 获取动作提示模板并填充 # 获取动作提示模板并填充
using_action_prompt = prompt_manager.get_prompt("action") using_action_prompt = prompt_manager.get_prompt("action")
@@ -864,20 +848,15 @@ class ActionPlanner:
# 尝试按行分割每行可能是一个JSON对象 # 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split("\n") if line.strip()] lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines: for line in lines:
try: with contextlib.suppress(json.JSONDecodeError):
# 尝试解析每一行作为独立的JSON对象
json_obj = json.loads(repair_json(line)) json_obj = json.loads(repair_json(line))
if isinstance(json_obj, dict): if isinstance(json_obj, dict):
# 过滤掉空字典,避免单个 { 字符被错误修复为 {} 的情况
if json_obj: if json_obj:
json_objects.append(json_obj) json_objects.append(json_obj)
elif isinstance(json_obj, list): elif isinstance(json_obj, list):
for item in json_obj: for item in json_obj:
if isinstance(item, dict) and item: if isinstance(item, dict) and item:
json_objects.append(item) json_objects.append(item)
except json.JSONDecodeError:
# 如果单行解析失败尝试将整个块作为一个JSON对象或数组
pass
# 如果按行解析没有成功或只得到空字典尝试将整个块作为一个JSON对象或数组 # 如果按行解析没有成功或只得到空字典尝试将整个块作为一个JSON对象或数组
if not json_objects: if not json_objects:

View File

@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending from maim_message import Seg
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
@@ -45,17 +48,17 @@ logger = get_logger("replyer")
class DefaultReplyer: class DefaultReplyer:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: BotChatSession,
request_type: str = "replyer", request_type: str = "replyer",
): ):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender() self.heart_fc_sender = UniversalMessageSender()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖 from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context( async def generate_reply_with_context(
self, self,
@@ -132,7 +135,7 @@ class DefaultReplyer:
if log_reply: if log_reply:
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt="", prompt="",
output=None, output=None,
processed_output=None, processed_output=None,
@@ -202,7 +205,7 @@ class DefaultReplyer:
try: try:
if log_reply: if log_reply:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt=prompt, prompt=prompt,
output=content, output=content,
processed_output=None, processed_output=None,
@@ -259,7 +262,7 @@ class DefaultReplyer:
if log_reply: if log_reply:
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.session_id,
prompt=prompt or "", prompt=prompt or "",
output=None, output=None,
processed_output=None, processed_output=None,
@@ -353,14 +356,14 @@ class DefaultReplyer:
str: 表达习惯信息字符串 str: 表达习惯信息字符串
""" """
# 检查是否允许在此聊天流中使用表达 # 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression: if not use_expression:
return "", [] return "", []
style_habits = [] style_habits = []
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式 # 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, self.chat_stream.session_id,
chat_history, chat_history,
max_num=8, max_num=8,
target_message=target, target_message=target,
@@ -702,10 +705,11 @@ class DefaultReplyer:
# 判断是否为群聊 # 判断是否为群聊
is_group = stream_type == "group" is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑 from src.common.utils.utils_session import SessionUtils
from src.chat.message_receive.chat_stream import get_chat_manager
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) chat_id = SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
)
return chat_id, prompt_content return chat_id, prompt_content
except (ValueError, IndexError): except (ValueError, IndexError):
@@ -778,7 +782,7 @@ class DefaultReplyer:
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
_is_group_chat = bool(chat_stream.group_info) _is_group_chat = bool(chat_stream.group_info)
platform = chat_stream.platform platform = chat_stream.platform
@@ -1005,7 +1009,7 @@ class DefaultReplyer:
reply_to: str, reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -1105,29 +1109,27 @@ class DefaultReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: Optional[MessageRecv] = None, anchor_message: Optional[MaiMessage] = None,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account), user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
) )
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending( return MessageSending(
message_id=message_id, # 使用片段的唯一ID message_id=message_id,
chat_stream=self.chat_stream, session=self.chat_stream,
bot_user_info=bot_user_info, bot_user_info=bot_user_info,
sender_info=sender_info, sender_info=sender_info,
message_segment=message_segment, message_segment=message_segment,
reply=anchor_message, # 回复原始锚点 reply=anchor_message,
is_head=reply_to, is_head=reply_to,
is_emoji=is_emoji, is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间 thinking_start_time=thinking_start_time,
display_message=display_message, display_message=display_message,
) )

View File

@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.llm_data_model import LLMGenerationDataModel from src.common.data_models.llm_data_model import LLMGenerationDataModel
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending from maim_message import Seg
from src.chat.message_receive.chat_stream import ChatStream
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
@@ -43,18 +46,18 @@ logger = get_logger("replyer")
class PrivateReplyer: class PrivateReplyer:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: BotChatSession,
request_type: str = "replyer", request_type: str = "replyer",
): ):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
self.heart_fc_sender = UniversalMessageSender() self.heart_fc_sender = UniversalMessageSender()
# self.memory_activator = MemoryActivator() # self.memory_activator = MemoryActivator()
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖 from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
async def generate_reply_with_context( async def generate_reply_with_context(
self, self,
@@ -253,14 +256,14 @@ class PrivateReplyer:
str: 表达习惯信息字符串 str: 表达习惯信息字符串
""" """
# 检查是否允许在此聊天流中使用表达 # 检查是否允许在此聊天流中使用表达
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id) use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
if not use_expression: if not use_expression:
return "", [] return "", []
style_habits = [] style_habits = []
# 使用从处理器传来的选中表达方式 # 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式 # 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions( selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason self.chat_stream.session_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
) )
if selected_expressions: if selected_expressions:
@@ -550,10 +553,11 @@ class PrivateReplyer:
# 判断是否为群聊 # 判断是否为群聊
is_group = stream_type == "group" is_group = stream_type == "group"
# 使用 ChatManager 提供的接口生成 chat_id避免在此重复实现逻辑 from src.common.utils.utils_session import SessionUtils
from src.chat.message_receive.chat_stream import get_chat_manager
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) chat_id = SessionUtils.calculate_session_id(
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
)
return chat_id, prompt_content return chat_id, prompt_content
except (ValueError, IndexError): except (ValueError, IndexError):
@@ -624,7 +628,7 @@ class PrivateReplyer:
if available_actions is None: if available_actions is None:
available_actions = {} available_actions = {}
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
platform = chat_stream.platform platform = chat_stream.platform
user_id = "用户ID" user_id = "用户ID"
@@ -843,7 +847,7 @@ class PrivateReplyer:
reply_to: str, reply_to: str,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.session_id
sender, target = self._parse_reply_target(reply_to) sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True) target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@@ -948,29 +952,27 @@ class PrivateReplyer:
is_emoji: bool, is_emoji: bool,
thinking_start_time: float, thinking_start_time: float,
display_message: str, display_message: str,
anchor_message: Optional[MessageRecv] = None, anchor_message: Optional[MaiMessage] = None,
) -> MessageSending: ) -> MessageSending:
"""构建单个发送消息""" """构建单个发送消息"""
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=str(global_config.bot.qq_account), user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
platform=self.chat_stream.platform,
) )
# await anchor_message.process()
sender_info = anchor_message.message_info.user_info if anchor_message else None sender_info = anchor_message.message_info.user_info if anchor_message else None
return MessageSending( return MessageSending(
message_id=message_id, # 使用片段的唯一ID message_id=message_id,
chat_stream=self.chat_stream, session=self.chat_stream,
bot_user_info=bot_user_info, bot_user_info=bot_user_info,
sender_info=sender_info, sender_info=sender_info,
message_segment=message_segment, message_segment=message_segment,
reply=anchor_message, # 回复原始锚点 reply=anchor_message,
is_head=reply_to, is_head=reply_to,
is_emoji=is_emoji, is_emoji=is_emoji,
thinking_start_time=thinking_start_time, # 传递原始思考开始时间 thinking_start_time=thinking_start_time,
display_message=display_message, display_message=display_message,
) )

View File

@@ -1,7 +1,7 @@
from typing import Dict, Optional from typing import Dict, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer from src.chat.replyer.private_generator import PrivateReplyer
@@ -14,7 +14,7 @@ class ReplyerManager:
def get_replyer( def get_replyer(
self, self,
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]: ) -> Optional[DefaultReplyer | PrivateReplyer]:
@@ -24,7 +24,7 @@ class ReplyerManager:
model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。 model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。
后续调用将返回已缓存的实例,忽略 model_configs 参数。 后续调用将返回已缓存的实例,忽略 model_configs 参数。
""" """
stream_id = chat_stream.stream_id if chat_stream else chat_id stream_id = chat_stream.session_id if chat_stream else chat_id
if not stream_id: if not stream_id:
logger.warning("[ReplyerManager] 缺少 stream_id无法获取回复器。") logger.warning("[ReplyerManager] 缺少 stream_id无法获取回复器。")
return None return None
@@ -39,15 +39,14 @@ class ReplyerManager:
target_stream = chat_stream target_stream = chat_stream
if not target_stream: if not target_stream:
if chat_manager := get_chat_manager(): target_stream = _chat_manager.get_session_by_session_id(stream_id)
target_stream = chat_manager.get_stream(stream_id)
if not target_stream: if not target_stream:
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。") logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")
return None return None
# model_configs 只在此时(初始化时)生效 # model_configs 只在此时(初始化时)生效
if target_stream.group_info: if target_stream.is_group_session:
replyer = DefaultReplyer( replyer = DefaultReplyer(
chat_stream=target_stream, chat_stream=target_stream,
request_type=request_type, request_type=request_type,

View File

@@ -61,9 +61,12 @@ class TempMethodsExpression:
str: 生成的聊天流ID哈希值 str: 生成的聊天流ID哈希值
""" """
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.common.utils.utils_session import SessionUtils
return get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group) if is_group:
return SessionUtils.calculate_session_id(platform, group_id=str(id_str))
else:
return SessionUtils.calculate_session_id(platform, user_id=str(id_str))
except Exception as e: except Exception as e:
logger.error(f"生成聊天流ID失败: {e}") logger.error(f"生成聊天流ID失败: {e}")
return None return None

View File

@@ -1051,18 +1051,13 @@ class StatisticOutputTask(AsyncTask):
"""从chat_id获取显示名称""" """从chat_id获取显示名称"""
try: try:
# 首先尝试从chat_stream获取真实群组名称 # 首先尝试从chat_stream获取真实群组名称
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _stat_chat_manager
chat_manager = get_chat_manager() if chat_id in _stat_chat_manager.sessions:
session = _stat_chat_manager.sessions[chat_id]
if chat_id in chat_manager.streams: name = _stat_chat_manager.get_session_name(chat_id)
stream = chat_manager.streams[chat_id] if name and name.strip():
if stream.group_info and hasattr(stream.group_info, "group_name"): return name.strip()
group_name = stream.group_info.group_name
if group_name and group_name.strip():
return group_name.strip()
elif stream.user_info and hasattr(stream.user_info, "user_nickname"):
user_name = stream.user_info.user_nickname
if user_name and user_name.strip(): if user_name and user_name.strip():
return user_name.strip() return user_name.strip()

View File

@@ -12,8 +12,8 @@ from typing import Optional, Tuple, List, TYPE_CHECKING
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import SessionMessage
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import Person from src.person_info.person_info import Person
from .typo_generator import ChineseTypoGenerator from .typo_generator import ChineseTypoGenerator
@@ -114,10 +114,10 @@ def is_bot_self(platform: str, user_id: str) -> bool:
return user_id_str == qq_account return user_id_str == qq_account
def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, bool, float]: def is_mentioned_bot_in_message(message: SessionMessage) -> tuple[bool, bool, float]:
"""检查消息是否提到了机器人(统一多平台实现)""" """检查消息是否提到了机器人(统一多平台实现)"""
text = message.processed_plain_text or "" text = message.processed_plain_text or ""
platform = getattr(message.message_info, "platform", "") or "" platform = message.platform or ""
# 获取各平台账号 # 获取各平台账号
platforms_list = getattr(global_config.bot, "platforms", []) or [] platforms_list = getattr(global_config.bot, "platforms", []) or []
@@ -696,15 +696,23 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
chat_target_info = None chat_target_info = None
try: try:
if chat_stream := get_chat_manager().get_stream(chat_id): if chat_stream := _chat_manager.get_session_by_session_id(chat_id):
if chat_stream.group_info: if chat_stream.is_group_session:
is_group_chat = True is_group_chat = True
chat_target_info = None # Explicitly None for group chat chat_target_info = None # Explicitly None for group chat
elif chat_stream.user_info: # It's a private chat elif chat_stream.user_id: # It's a private chat
is_group_chat = False is_group_chat = False
user_info = chat_stream.user_info
platform: str = chat_stream.platform platform: str = chat_stream.platform
user_id: str = user_info.user_id # type: ignore user_id: str = chat_stream.user_id
# Try to get nickname from context
user_nickname = None
if (
chat_stream.context
and chat_stream.context.message
and chat_stream.context.message.message_info.user_info
):
user_nickname = chat_stream.context.message.message_info.user_info.user_nickname
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题 from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
@@ -712,7 +720,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
target_info = TargetPersonInfo( target_info = TargetPersonInfo(
platform=platform, platform=platform,
user_id=user_id, user_id=user_id,
user_nickname=user_info.user_nickname, # type: ignore user_nickname=user_nickname, # type: ignore
person_id=None, person_id=None,
person_name=None, person_name=None,
) )
@@ -721,7 +729,7 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
try: try:
person = Person(platform=platform, user_id=user_id) person = Person(platform=platform, user_id=user_id)
if not person.is_known: if not person.is_known:
logger.warning(f"用户 {user_info.user_nickname} 尚未认识") logger.warning(f"用户 {user_nickname} 尚未认识")
# 如果用户尚未认识则返回False和None # 如果用户尚未认识则返回False和None
return False, None return False, None
if person.person_id: if person.person_id:

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from PIL import Image as PILImage, ImageSequence from PIL import Image as PILImage, ImageSequence
from typing import Optional, Union
import base64 import base64
import io import io
@@ -102,3 +104,30 @@ class ImageUtils:
logger.error("输入的图片字节数据无效") logger.error("输入的图片字节数据无效")
raise ValueError("输入的图片字节数据无效") raise ValueError("输入的图片字节数据无效")
return base64.b64encode(image_bytes).decode("utf-8") return base64.b64encode(image_bytes).decode("utf-8")
@staticmethod
def image_path_to_base64(image_path: Union[str, Path]) -> Optional[str]:
"""读取图片文件并转换为 Base64 编码字符串"""
try:
path = Path(image_path)
if not path.exists():
logger.error(f"图片文件不存在: {path}")
return None
image_bytes = path.read_bytes()
return base64.b64encode(image_bytes).decode("utf-8")
except Exception as e:
logger.error(f"读取图片文件失败: {e}")
return None
@staticmethod
def base64_to_image(base64_str: str, save_path: Union[str, Path]) -> bool:
"""将 Base64 编码字符串解码并保存为图片文件"""
try:
image_bytes = base64.b64decode(base64_str)
path = Path(save_path)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(image_bytes)
return True
except Exception as e:
logger.error(f"保存图片文件失败: {e}")
return False

View File

@@ -7,7 +7,7 @@ from src.config.config import global_config, model_config
from src.llm_models.payload_content.message import RoleType, Message from src.llm_models.payload_content.message import RoleType, Message
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager from src.common.utils.utils_session import SessionUtils
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
logger = get_logger("dream_generator") logger = get_logger("dream_generator")
@@ -178,10 +178,9 @@ async def generate_dream_summary(
logger.warning(f"[dream][梦境总结] dream_send 平台或用户ID为空当前值: {dream_send_raw!r}") logger.warning(f"[dream][梦境总结] dream_send 平台或用户ID为空当前值: {dream_send_raw!r}")
else: else:
# 默认为私聊会话 # 默认为私聊会话
stream_id = get_chat_manager().get_stream_id( stream_id = SessionUtils.calculate_session_id(
platform=platform, platform=platform,
id=str(user_id), user_id=str(user_id),
is_group=False,
) )
if not stream_id: if not stream_id:
logger.error( logger.error(

View File

@@ -8,7 +8,7 @@ from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask # from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager
from src.config.config import config_manager, global_config from src.config.config import config_manager, global_config
from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.bot import chat_bot
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -119,8 +119,8 @@ class MainSystem:
logger.info("表情包管理器初始化成功") logger.info("表情包管理器初始化成功")
# 初始化聊天管理器 # 初始化聊天管理器
await get_chat_manager()._initialize() await chat_manager.initialize()
asyncio.create_task(get_chat_manager()._auto_save_task()) asyncio.create_task(chat_manager.regularly_save_sessions())
logger.info("聊天管理器初始化成功") logger.info("聊天管理器初始化成功")

View File

@@ -21,7 +21,7 @@ from src.plugin_system.apis import message_api
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.utils import is_bot_self from src.chat.utils.utils import is_bot_self
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
logger = get_logger("chat_history_summarizer") logger = get_logger("chat_history_summarizer")
@@ -100,7 +100,7 @@ class ChatHistorySummarizer:
def _get_chat_display_name(self) -> str: def _get_chat_display_name(self) -> str:
"""获取聊天显示名称""" """获取聊天显示名称"""
try: try:
chat_name = get_chat_manager().get_stream_name(self.chat_id) chat_name = _chat_manager.get_session_name(self.chat_id)
if chat_name: if chat_name:
return chat_name return chat_name
# 如果获取失败使用简化的chat_id显示 # 如果获取失败使用简化的chat_id显示

View File

@@ -1,3 +1,4 @@
import contextlib
import time import time
import json import json
import asyncio import asyncio
@@ -12,7 +13,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon from src.bw_learner.jargon_explainer import retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval") logger = get_logger("memory_retrieval")
@@ -133,10 +134,10 @@ async def _react_agent_solve_question(
Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时) Tuple[bool, str, List[Dict[str, Any]], bool]: (是否找到答案, 答案内容, 思考步骤列表, 是否超时)
""" """
start_time = time.time() start_time = time.time()
collected_info = initial_info if initial_info else "" collected_info = initial_info or ""
# 构造日志前缀:[聊天流名称],用于在日志中标识聊天流 # 构造日志前缀:[聊天流名称],用于在日志中标识聊天流
try: try:
chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id chat_name = _chat_manager.get_session_name(chat_id) or chat_id
except Exception: except Exception:
chat_name = chat_id chat_name = chat_id
react_log_prefix = f"[{chat_name}] " react_log_prefix = f"[{chat_name}] "
@@ -235,7 +236,7 @@ async def _react_agent_solve_question(
# head_prompt应该只构建一次使用初始的collected_info后续迭代都复用同一个 # head_prompt应该只构建一次使用初始的collected_info后续迭代都复用同一个
if first_head_prompt is None: if first_head_prompt is None:
# 第一次构建使用初始的collected_info即initial_info # 第一次构建使用初始的collected_info即initial_info
initial_collected_info = initial_info if initial_info else "" initial_collected_info = initial_info or ""
# 根据配置选择使用哪个 prompt # 根据配置选择使用哪个 prompt
prompt_name = ( prompt_name = (
"memory_retrieval_react_prompt_head_lpmm" "memory_retrieval_react_prompt_head_lpmm"
@@ -362,7 +363,7 @@ async def _react_agent_solve_question(
return information return information
except (json.JSONDecodeError, ValueError, TypeError): except (json.JSONDecodeError, ValueError, TypeError):
# 如果JSON解析失败尝试在文本中查找JSON对象 # 如果JSON解析失败尝试在文本中查找JSON对象
try: with contextlib.suppress(json.JSONDecodeError, ValueError, TypeError):
# 查找第一个 { 和最后一个 } 之间的内容更健壮的JSON提取 # 查找第一个 { 和最后一个 } 之间的内容更健壮的JSON提取
first_brace = text.find("{") first_brace = text.find("{")
if first_brace != -1: if first_brace != -1:
@@ -384,8 +385,6 @@ async def _react_agent_solve_question(
if isinstance(data, dict) and "return_information" in data: if isinstance(data, dict) and "return_information" in data:
information = data.get("information", "") information = data.get("information", "")
return information return information
except (json.JSONDecodeError, ValueError, TypeError):
pass
return None return None
@@ -679,7 +678,7 @@ async def _react_agent_solve_question(
evaluation_prompt_template.add_context("bot_name", bot_name) evaluation_prompt_template.add_context("bot_name", bot_name)
evaluation_prompt_template.add_context("time_now", time_now) evaluation_prompt_template.add_context("time_now", time_now)
evaluation_prompt_template.add_context("chat_history", chat_history) evaluation_prompt_template.add_context("chat_history", chat_history)
evaluation_prompt_template.add_context("collected_info", collected_info if collected_info else "暂无信息") evaluation_prompt_template.add_context("collected_info", collected_info or "暂无信息")
evaluation_prompt_template.add_context("current_iteration", str(current_iteration)) evaluation_prompt_template.add_context("current_iteration", str(current_iteration))
evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations)) evaluation_prompt_template.add_context("remaining_iterations", str(remaining_iterations))
evaluation_prompt_template.add_context("max_iterations", str(max_iterations)) evaluation_prompt_template.add_context("max_iterations", str(max_iterations))
@@ -800,8 +799,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
if not records: if not records:
return "" return ""
history_lines = [] history_lines = ["最近已查询的问题和结果:"]
history_lines.append("最近已查询的问题和结果:")
for record in records: for record in records:
status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案" status = "✓ 已找到答案" if record.found_answer else "✗ 未找到答案"
@@ -813,8 +811,7 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
if len(record.answer) > 100: if len(record.answer) > 100:
answer_preview += "..." answer_preview += "..."
history_lines.append(f"- 问题:{record.question}") history_lines.extend([f"- 问题:{record.question}", f" 状态:{status}"])
history_lines.append(f" 状态:{status}")
if answer_preview: if answer_preview:
history_lines.append(f" 答案:{answer_preview}") history_lines.append(f" 答案:{answer_preview}")
history_lines.append("") # 空行分隔 history_lines.append("") # 空行分隔
@@ -855,12 +852,11 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
if not records: if not records:
return [] return []
found_answers = [] return [
for record in records: f"问题:{record.question}\n答案:{record.answer}"
if record.answer: for record in records
found_answers.append(f"问题:{record.question}\n答案:{record.answer}") if record.answer
]
return found_answers
except Exception as e: except Exception as e:
logger.error(f"获取最近已找到答案的记录失败: {e}") logger.error(f"获取最近已找到答案的记录失败: {e}")
@@ -892,8 +888,7 @@ def _store_thinking_back(
.order_by(col(ThinkingQuestion.updated_timestamp).desc()) .order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(1) .limit(1)
) )
record = session.exec(statement).first() if record := session.exec(statement).first():
if record:
record.context = context record.context = context
record.found_answer = found_answer record.found_answer = found_answer
record.answer = answer record.answer = answer
@@ -957,10 +952,7 @@ async def _process_memory_retrieval(
if is_timeout: if is_timeout:
logger.info("ReAct Agent超时不返回结果") logger.info("ReAct Agent超时不返回结果")
if found_answer and answer: return answer if found_answer and answer else None
return answer
return None
async def build_memory_retrieval_prompt( async def build_memory_retrieval_prompt(
@@ -1013,8 +1005,7 @@ async def build_memory_retrieval_prompt(
cleaned_concepts = [] cleaned_concepts = []
for word in unknown_words: for word in unknown_words:
if isinstance(word, str): if isinstance(word, str):
cleaned = word.strip() if cleaned := word.strip():
if cleaned:
cleaned_concepts.append(cleaned) cleaned_concepts.append(cleaned)
if cleaned_concepts: if cleaned_concepts:
# 对匹配到的概念进行jargon检索作为初始信息 # 对匹配到的概念进行jargon检索作为初始信息

View File

@@ -30,9 +30,8 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
return chat_ids return chat_ids
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.common.utils.utils_session import SessionUtils
chat_manager = get_chat_manager()
for blacklist_item in blacklist: for blacklist_item in blacklist:
if not isinstance(blacklist_item, str): if not isinstance(blacklist_item, str):
continue continue
@@ -51,7 +50,10 @@ def _parse_blacklist_to_chat_ids(blacklist: list[str]) -> Set[str]:
is_group = stream_type == "group" is_group = stream_type == "group"
# 转换为chat_id # 转换为chat_id
chat_id = chat_manager.get_stream_id(platform, str(id_str), is_group=is_group) if is_group:
chat_id = SessionUtils.calculate_session_id(platform, group_id=str(id_str))
else:
chat_id = SessionUtils.calculate_session_id(platform, user_id=str(id_str))
if chat_id: if chat_id:
chat_ids.add(chat_id) chat_ids.add(chat_id)
else: else:
@@ -225,9 +227,9 @@ async def search_chat_history(
if keyword: if keyword:
keyword_matched = False keyword_matched = False
# 解析多个关键词(支持空格、逗号等分隔符) # 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword) keywords_list = parse_keywords_string(keyword) or (
if not keywords_list: [keyword.strip()] if keyword.strip() else []
keywords_list = [keyword.strip()] if keyword.strip() else [] )
# 转换为小写以便匹配 # 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()] keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]

View File

@@ -16,7 +16,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
logger = get_logger("person_info") logger = get_logger("person_info")
@@ -818,22 +818,22 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
chat_id: 聊天ID chat_id: 聊天ID
""" """
try: try:
# 从chat_id获取chat_stream # 从 chat_id 获取 session
chat_stream = get_chat_manager().get_stream(chat_id) session = _chat_manager.get_session_by_session_id(chat_id)
if not chat_stream: if not session:
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}") logger.warning(f"无法获取session for chat_id: {chat_id}")
return return
platform = chat_stream.platform platform = session.platform
# 尝试从person_name查找person_id # 尝试从person_name查找person_id
# 首先尝试通过person_name查找 # 首先尝试通过person_name查找
person_id = get_person_id_by_person_name(person_name) person_id = get_person_id_by_person_name(person_name)
if not person_id: if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info # 如果通过person_name找不到尝试从 session 获取 user_id
if platform and chat_stream.user_info and chat_stream.user_info.user_id: if platform and session.user_id:
user_id = chat_stream.user_info.user_id user_id = session.user_id
person_id = get_person_id(platform, user_id) person_id = get_person_id(platform, user_id)
else: else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")

View File

@@ -16,7 +16,7 @@ from typing import List, Dict, Any, Optional
from enum import Enum from enum import Enum
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
logger = get_logger("chat_api") logger = get_logger("chat_api")
@@ -31,7 +31,7 @@ class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理""" """聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod @staticmethod
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有聊天流 """获取所有聊天流
@@ -39,7 +39,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns: Returns:
List[ChatStream]: 聊天流列表 List[BotChatSession]: 聊天流列表
Raises: Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
@@ -48,7 +48,7 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in _chat_manager.sessions.items():
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform: if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream) streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流") logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
@@ -57,7 +57,7 @@ class ChatManager:
return streams return streams
@staticmethod @staticmethod
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有群聊聊天流 """获取所有群聊聊天流
@@ -65,14 +65,14 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns: Returns:
List[ChatStream]: 群聊聊天流列表 List[BotChatSession]: 群聊聊天流列表
""" """
if not isinstance(platform, (str, SpecialTypes)): if not isinstance(platform, (str, SpecialTypes)):
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info: if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
streams.append(stream) streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流") logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e: except Exception as e:
@@ -80,7 +80,7 @@ class ChatManager:
return streams return streams
@staticmethod @staticmethod
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有私聊聊天流 """获取所有私聊聊天流
@@ -88,7 +88,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns: Returns:
List[ChatStream]: 私聊聊天流列表 List[BotChatSession]: 私聊聊天流列表
Raises: Raises:
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型 TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
@@ -97,8 +97,10 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in _chat_manager.sessions.items():
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info: if (
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
) and not stream.is_group_session:
streams.append(stream) streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流") logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e: except Exception as e:
@@ -108,7 +110,7 @@ class ChatManager:
@staticmethod @staticmethod
def get_group_stream_by_group_id( def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq" group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流 """根据群ID获取聊天流
Args: Args:
@@ -116,7 +118,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns: Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises: Raises:
ValueError: 如果 group_id 为空字符串 ValueError: 如果 group_id 为空字符串
@@ -129,11 +131,11 @@ class ChatManager:
if not group_id: if not group_id:
raise ValueError("group_id 不能为空") raise ValueError("group_id 不能为空")
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in _chat_manager.sessions.items():
if ( if (
stream.group_info stream.is_group_session
and str(stream.group_info.group_id) == str(group_id) and str(stream.group_id) == str(group_id)
and stream.platform == platform and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
): ):
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流") logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
return stream return stream
@@ -145,7 +147,7 @@ class ChatManager:
@staticmethod @staticmethod
def get_private_stream_by_user_id( def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq" user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast ) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流 """根据用户ID获取私聊流
Args: Args:
@@ -153,7 +155,7 @@ class ChatManager:
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流 platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns: Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None Optional[BotChatSession]: 聊天流对象如果未找到返回None
Raises: Raises:
ValueError: 如果 user_id 为空字符串 ValueError: 如果 user_id 为空字符串
@@ -166,11 +168,11 @@ class ChatManager:
if not user_id: if not user_id:
raise ValueError("user_id 不能为空") raise ValueError("user_id 不能为空")
try: try:
for _, stream in get_chat_manager().streams.items(): for _, stream in _chat_manager.sessions.items():
if ( if (
not stream.group_info not stream.is_group_session
and str(stream.user_info.user_id) == str(user_id) and str(stream.user_id) == str(user_id)
and stream.platform == platform and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
): ):
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流") logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
return stream return stream
@@ -180,7 +182,7 @@ class ChatManager:
return None return None
@staticmethod @staticmethod
def get_stream_type(chat_stream: ChatStream) -> str: def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型 """获取聊天流类型
Args: Args:
@@ -190,20 +192,18 @@ class ChatManager:
str: 聊天类型 ("group", "private", "unknown") str: 聊天类型 ("group", "private", "unknown")
Raises: Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型 TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空 ValueError: 如果 chat_stream 为空
""" """
if not isinstance(chat_stream, ChatStream): if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 ChatStream 类型") raise TypeError("chat_stream 必须是 BotChatSession 类型")
if not chat_stream: if not chat_stream:
raise ValueError("chat_stream 不能为 None") raise ValueError("chat_stream 不能为 None")
if hasattr(chat_stream, "group_info"): return "group" if chat_stream.is_group_session else "private"
return "group" if chat_stream.group_info else "private"
return "unknown"
@staticmethod @staticmethod
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流详细信息 """获取聊天流详细信息
Args: Args:
@@ -213,36 +213,34 @@ class ChatManager:
Dict ({str: Any}): 聊天流信息字典 Dict ({str: Any}): 聊天流信息字典
Raises: Raises:
TypeError: 如果 chat_stream 不是 ChatStream 类型 TypeError: 如果 chat_stream 不是 BotChatSession 类型
ValueError: 如果 chat_stream 为空 ValueError: 如果 chat_stream 为空
""" """
if not chat_stream: if not chat_stream:
raise ValueError("chat_stream 不能为 None") raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, ChatStream): if not isinstance(chat_stream, BotChatSession):
raise TypeError("chat_stream 必须是 ChatStream 类型") raise TypeError("chat_stream 必须是 BotChatSession 类型")
try: try:
info: Dict[str, Any] = { info: Dict[str, Any] = {
"stream_id": chat_stream.stream_id, "session_id": chat_stream.session_id,
"platform": chat_stream.platform, "platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream), "type": ChatManager.get_stream_type(chat_stream),
} }
if chat_stream.group_info: if chat_stream.is_group_session:
info.update( info["group_id"] = chat_stream.group_id
{ # Try to get group name from context
"group_id": chat_stream.group_info.group_id, if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"), info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
} else:
) info["group_name"] = "未知群聊"
else:
if chat_stream.user_info: info["user_id"] = chat_stream.user_id
info.update( if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
{ info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
"user_id": chat_stream.user_info.user_id, else:
"user_name": chat_stream.user_info.user_nickname, info["user_name"] = "未知用户"
}
)
return info return info
except Exception as e: except Exception as e:
@@ -285,37 +283,37 @@ class ChatManager:
# ============================================================================= # =============================================================================
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取所有聊天流的便捷函数""" """获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform) return ChatManager.get_all_streams(platform)
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取群聊聊天流的便捷函数""" """获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform) return ChatManager.get_group_streams(platform)
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]: def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
"""获取私聊聊天流的便捷函数""" """获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform) return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据群ID获取聊天流的便捷函数""" """根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform) return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]: def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
"""根据用户ID获取私聊流的便捷函数""" """根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform) return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str: def get_stream_type(chat_stream: BotChatSession) -> str:
"""获取聊天流类型的便捷函数""" """获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream) return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]: def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
"""获取聊天流信息的便捷函数""" """获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream) return ChatManager.get_stream_info(chat_stream)

View File

@@ -16,7 +16,7 @@ import uuid
from typing import Optional, Tuple, List, Dict, Any from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image from src.common.utils.utils_image import ImageUtils
from src.config.config import global_config from src.config.config import global_config
logger = get_logger("emoji_api") logger = get_logger("emoji_api")
@@ -56,7 +56,7 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
emoji_path = str(emoji_obj.full_path) emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = image_path_to_base64(emoji_path) emoji_base64 = ImageUtils.image_path_to_base64(emoji_path)
if not emoji_base64: if not emoji_base64:
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}") logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
@@ -115,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
results = [] results = []
for selected_emoji in selected_emojis: for selected_emoji in selected_emojis:
emoji_base64 = image_path_to_base64(str(selected_emoji.full_path)) emoji_base64 = ImageUtils.image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64: if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
@@ -174,7 +174,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 随机选择匹配的表情包 # 随机选择匹配的表情包
selected_emoji = random.choice(matching_emojis) selected_emoji = random.choice(matching_emojis)
emoji_base64 = image_path_to_base64(selected_emoji.full_path) emoji_base64 = ImageUtils.image_path_to_base64(selected_emoji.full_path)
if not emoji_base64: if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
@@ -263,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
if emoji_obj.is_deleted: if emoji_obj.is_deleted:
continue continue
emoji_base64 = image_path_to_base64(str(emoji_obj.full_path)) emoji_base64 = ImageUtils.image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64: if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}") logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
@@ -429,7 +429,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
try: try:
# 解码base64并保存图片 # 解码base64并保存图片
if not base64_to_image(image_base64, temp_file_path): if not ImageUtils.base64_to_image(image_base64, temp_file_path):
logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}") logger.error(f"[EmojiAPI] 无法保存base64图片到文件: {temp_file_path}")
return { return {
"success": False, "success": False,

View File

@@ -16,7 +16,7 @@ from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplySetModel from src.common.data_models.message_data_model import ReplySetModel
from src.chat.replyer.group_generator import DefaultReplyer from src.chat.replyer.group_generator import DefaultReplyer
from src.chat.replyer.private_generator import PrivateReplyer from src.chat.replyer.private_generator import PrivateReplyer
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
@@ -38,7 +38,7 @@ logger = get_logger("generator_api")
def get_replyer( def get_replyer(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer | PrivateReplyer]: ) -> Optional[DefaultReplyer | PrivateReplyer]:
@@ -79,7 +79,7 @@ def get_replyer(
async def generate_reply( async def generate_reply(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None, action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None, reply_message: Optional["DatabaseMessages"] = None,
@@ -161,7 +161,7 @@ async def generate_reply(
unknown_words=unknown_words, unknown_words=unknown_words,
think_level=think_level, think_level=think_level,
from_plugin=from_plugin, from_plugin=from_plugin,
stream_id=chat_stream.stream_id if chat_stream else chat_id, stream_id=chat_stream.session_id if chat_stream else chat_id,
reply_time_point=reply_time_point, reply_time_point=reply_time_point,
log_reply=False, log_reply=False,
) )
@@ -181,7 +181,7 @@ async def generate_reply(
# 统一在这里记录最终回复日志(包含分割后的 processed_output # 统一在这里记录最终回复日志(包含分割后的 processed_output
try: try:
PlanReplyLogger.log_reply( PlanReplyLogger.log_reply(
chat_id=chat_stream.stream_id if chat_stream else (chat_id or ""), chat_id=chat_stream.session_id if chat_stream else (chat_id or ""),
prompt=llm_response.prompt or "", prompt=llm_response.prompt or "",
output=llm_response.content, output=llm_response.content,
processed_output=llm_response.processed_output, processed_output=llm_response.processed_output,
@@ -210,7 +210,7 @@ async def generate_reply(
async def rewrite_reply( async def rewrite_reply(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[BotChatSession] = None,
reply_data: Optional[Dict[str, Any]] = None, reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
enable_splitter: bool = True, enable_splitter: bool = True,
@@ -302,7 +302,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
async def generate_response_custom( async def generate_response_custom(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[BotChatSession] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
request_type: str = "generator_api", request_type: str = "generator_api",
prompt: str = "", prompt: str = "",

View File

@@ -26,10 +26,12 @@ from typing import Optional, Union, Dict, List, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType from src.common.data_models.message_data_model import ReplyContentType
from src.config.config import global_config from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.chat.message_receive.uni_message_sender import UniversalMessageSender from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.message_receive.message import MessageSending, MessageRecv from maim_message import Seg
from maim_message import Seg, UserInfo, MessageBase, BaseMessageInfo
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
from src.chat.message_receive.message import MessageSending
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
@@ -77,7 +79,7 @@ async def _send_to_target(
logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}") logger.debug(f"[SendAPI] 发送{message_segment.type}消息到 {stream_id}")
# 查找目标聊天流 # 查找目标聊天流
target_stream = get_chat_manager().get_stream(stream_id) target_stream = _chat_manager.get_session_by_session_id(stream_id)
if not target_stream: if not target_stream:
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}") logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
return False return False
@@ -93,27 +95,29 @@ async def _send_to_target(
bot_user_info = UserInfo( bot_user_info = UserInfo(
user_id=global_config.bot.qq_account, user_id=global_config.bot.qq_account,
user_nickname=global_config.bot.nickname, user_nickname=global_config.bot.nickname,
platform=target_stream.platform,
) )
reply_to_platform_id = "" reply_to_platform_id = ""
anchor_message: Union["MessageRecv", None] = None anchor_message: Optional[MaiMessage] = None
if reply_message: if reply_message:
anchor_message = db_message_to_message_recv(reply_message) anchor_message = db_message_to_mai_message(reply_message)
logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}") # type: ignore
if anchor_message: if anchor_message:
anchor_message.update_chat_stream(target_stream) logger.debug(f"[SendAPI] 找到匹配的回复消息,发送者: {anchor_message.message_info.user_info.user_id}")
assert anchor_message.message_info.user_info, "用户信息缺失"
reply_to_platform_id = ( reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" f"{anchor_message.platform}:{anchor_message.message_info.user_info.user_id}"
) )
# 构建 sender_info私聊时为接收者信息
sender_info = None
if target_stream.context and target_stream.context.message:
sender_info = target_stream.context.message.message_info.user_info
# 构建发送消息对象 # 构建发送消息对象
bot_message = MessageSending( bot_message = MessageSending(
message_id=message_id, message_id=message_id,
chat_stream=target_stream, session=target_stream,
bot_user_info=bot_user_info, bot_user_info=bot_user_info,
sender_info=target_stream.user_info, sender_info=sender_info,
message_segment=message_segment, message_segment=message_segment,
display_message=display_message, display_message=display_message,
reply=anchor_message, reply=anchor_message,
@@ -146,51 +150,43 @@ async def _send_to_target(
return False return False
def db_message_to_message_recv(message_obj: "DatabaseMessages") -> MessageRecv: def db_message_to_mai_message(message_obj: "DatabaseMessages") -> Optional[MaiMessage]:
"""将数据库dict重建为MessageRecv对象 """将数据库消息重建为 MaiMessage 对象,用于回复引用。
Args: Args:
message_dict: 消息字典 message_obj: 插件系统的 DatabaseMessages 数据对象
Returns: Returns:
Optional[MessageRecv]: 找到的消息,如果没找到则返回None Optional[MaiMessage]: 构建的消息对象,如果信息不足则返回 None
""" """
# 构建MessageRecv对象 from datetime import datetime
user_info = { from src.common.data_models.mai_message_data_model import GroupInfo, MessageInfo
"platform": message_obj.user_info.platform or "", from src.common.data_models.message_component_data_model import MessageSequence
"user_id": message_obj.user_info.user_id or "",
"user_nickname": message_obj.user_info.user_nickname or "",
"user_cardname": message_obj.user_info.user_cardname or "",
}
group_info = {} user_info = UserInfo(
user_id=message_obj.user_info.user_id or "",
user_nickname=message_obj.user_info.user_nickname or "",
user_cardname=message_obj.user_info.user_cardname,
)
group_info = None
if message_obj.chat_info.group_info: if message_obj.chat_info.group_info:
group_info = { group_info = GroupInfo(
"platform": message_obj.chat_info.group_info.group_platform or "", group_id=message_obj.chat_info.group_info.group_id or "",
"group_id": message_obj.chat_info.group_info.group_id or "", group_name=message_obj.chat_info.group_info.group_name or "",
"group_name": message_obj.chat_info.group_info.group_name or "", )
}
format_info = {"content_format": "", "accept_format": ""} msg = MaiMessage(
template_info = {"template_items": {}} message_id=message_obj.message_id,
timestamp=datetime.fromtimestamp(message_obj.time) if message_obj.time else datetime.now(),
message_info = { )
"platform": message_obj.chat_info.platform or "", msg.message_info = MessageInfo(user_info=user_info, group_info=group_info)
"message_id": message_obj.message_id, msg.platform = message_obj.chat_info.platform or ""
"time": message_obj.time, msg.session_id = message_obj.chat_info.stream_id or ""
"group_info": group_info, msg.processed_plain_text = message_obj.processed_plain_text
"user_info": user_info, msg.raw_message = MessageSequence(components=[])
"additional_config": message_obj.additional_config, msg.initialized = True
"format_info": format_info, return msg
"template_info": template_info,
}
message_dict_recv = {
"message_info": message_info,
"raw_message": message_obj.processed_plain_text,
"processed_plain_text": message_obj.processed_plain_text,
}
return MessageRecv(message_dict_recv)
# ============================================================================= # =============================================================================

View File

@@ -5,12 +5,12 @@ from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger from src.common.logger import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
logger = get_logger("tool_api") logger = get_logger("tool_api")
def get_tool_instance(tool_name: str, chat_stream: Optional["ChatStream"] = None) -> Optional[BaseTool]: def get_tool_instance(tool_name: str, chat_stream: Optional["BotChatSession"] = None) -> Optional[BaseTool]:
"""获取公开工具实例 """获取公开工具实例
Args: Args:

View File

@@ -6,7 +6,7 @@ from typing import Tuple, Optional, TYPE_CHECKING, Dict, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api from src.plugin_system.apis import send_api, database_api, message_api
@@ -36,7 +36,7 @@ class BaseAction(ABC):
action_reasoning: str, action_reasoning: str,
cycle_timers: dict, cycle_timers: dict,
thinking_id: str, thinking_id: str,
chat_stream: ChatStream, chat_stream: BotChatSession,
plugin_config: Optional[dict] = None, plugin_config: Optional[dict] = None,
action_message: Optional["DatabaseMessages"] = None, action_message: Optional["DatabaseMessages"] = None,
**kwargs, **kwargs,
@@ -92,7 +92,7 @@ class BaseAction(ABC):
# 获取聊天流对象 # 获取聊天流对象
self.chat_stream = chat_stream or kwargs.get("chat_stream") self.chat_stream = chat_stream or kwargs.get("chat_stream")
self.chat_id = self.chat_stream.stream_id self.chat_id = self.chat_stream.session_id
self.platform = getattr(self.chat_stream, "platform", None) self.platform = getattr(self.chat_stream, "platform", None)
# 初始化基础信息(带类型注解) # 初始化基础信息(带类型注解)

View File

@@ -3,7 +3,7 @@ from typing import Dict, Tuple, Optional, TYPE_CHECKING, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode from src.common.data_models.message_data_model import ReplyContentType, ReplyContent, ReplySetModel, ForwardNode
from src.plugin_system.base.component_types import CommandInfo, ComponentType from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import SessionMessage
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -31,7 +31,7 @@ class BaseCommand(ABC):
command_pattern: str = r"" command_pattern: str = r""
"""命令匹配的正则表达式""" """命令匹配的正则表达式"""
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None): def __init__(self, message: SessionMessage, plugin_config: Optional[dict] = None):
"""初始化Command组件 """初始化Command组件
Args: Args:
@@ -107,14 +107,14 @@ class BaseCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
return await send_api.text_to_stream( return await send_api.text_to_stream(
text=content, text=content,
stream_id=chat_stream.stream_id, stream_id=session_id,
set_reply=set_reply, set_reply=set_reply,
reply_message=reply_message, reply_message=reply_message,
storage_message=storage_message, storage_message=storage_message,
@@ -135,14 +135,14 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
return await send_api.image_to_stream( return await send_api.image_to_stream(
image_base64, image_base64,
chat_stream.stream_id, session_id,
set_reply=set_reply, set_reply=set_reply,
reply_message=reply_message, reply_message=reply_message,
storage_message=storage_message, storage_message=storage_message,
@@ -166,13 +166,13 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
return await send_api.emoji_to_stream( return await send_api.emoji_to_stream(
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message emoji_base64, session_id, set_reply=set_reply, reply_message=reply_message
) )
async def send_command( async def send_command(
@@ -195,9 +195,9 @@ class BaseCommand(ABC):
""" """
try: try:
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
# 构造命令数据 # 构造命令数据
@@ -205,7 +205,7 @@ class BaseCommand(ABC):
success = await send_api.command_to_stream( success = await send_api.command_to_stream(
command=command_data, command=command_data,
stream_id=chat_stream.stream_id, stream_id=session_id,
storage_message=storage_message, storage_message=storage_message,
display_message=display_message, display_message=display_message,
) )
@@ -229,15 +229,15 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
return await send_api.custom_to_stream( return await send_api.custom_to_stream(
message_type="voice", message_type="voice",
content=voice_base64, content=voice_base64,
stream_id=chat_stream.stream_id, stream_id=session_id,
typing=False, typing=False,
set_reply=False, set_reply=False,
reply_message=None, reply_message=None,
@@ -262,15 +262,15 @@ class BaseCommand(ABC):
reply_message: 回复的消息对象 reply_message: 回复的消息对象
storage_message: 是否存储消息到数据库 storage_message: 是否存储消息到数据库
""" """
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
reply_set = ReplySetModel() reply_set = ReplySetModel()
reply_set.add_hybrid_content_by_raw(message_tuple_list) reply_set.add_hybrid_content_by_raw(message_tuple_list)
return await send_api.custom_reply_set_to_stream( return await send_api.custom_reply_set_to_stream(
reply_set=reply_set, reply_set=reply_set,
stream_id=chat_stream.stream_id, stream_id=session_id,
typing=typing, typing=typing,
set_reply=set_reply, set_reply=set_reply,
reply_message=reply_message, reply_message=reply_message,
@@ -293,9 +293,9 @@ class BaseCommand(ABC):
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
reply_set = ReplySetModel() reply_set = ReplySetModel()
forward_message_nodes: List[ForwardNode] = [] forward_message_nodes: List[ForwardNode] = []
@@ -318,7 +318,7 @@ class BaseCommand(ABC):
reply_set.add_forward_content(forward_message_nodes) reply_set.add_forward_content(forward_message_nodes)
return await send_api.custom_reply_set_to_stream( return await send_api.custom_reply_set_to_stream(
reply_set=reply_set, reply_set=reply_set,
stream_id=chat_stream.stream_id, stream_id=session_id,
storage_message=storage_message, storage_message=storage_message,
set_reply=False, set_reply=False,
reply_message=None, reply_message=None,
@@ -349,15 +349,15 @@ class BaseCommand(ABC):
bool: 是否发送成功 bool: 是否发送成功
""" """
# 获取聊天流信息 # 获取聊天流信息
chat_stream = self.message.chat_stream session_id = self.message.session_id
if not chat_stream or not hasattr(chat_stream, "stream_id"): if not session_id:
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id") logger.error(f"{self.log_prefix} 缺少session_id")
return False return False
return await send_api.custom_to_stream( return await send_api.custom_to_stream(
message_type=message_type, message_type=message_type,
content=content, content=content,
stream_id=chat_stream.stream_id, stream_id=session_id,
display_message=display_message, display_message=display_message,
typing=typing, typing=typing,
set_reply=set_reply, set_reply=set_reply,

View File

@@ -6,7 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_manager import BotChatSession
install(extra_lines=3) install(extra_lines=3)
@@ -32,7 +32,7 @@ class BaseTool(ABC):
available_for_llm: bool = False available_for_llm: bool = False
"""是否可供LLM使用""" """是否可供LLM使用"""
def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["ChatStream"] = None): def __init__(self, plugin_config: Optional[dict] = None, chat_stream: Optional["BotChatSession"] = None):
"""初始化工具基类 """初始化工具基类
Args: Args:
@@ -47,7 +47,7 @@ class BaseTool(ABC):
# 获取聊天流对象 # 获取聊天流对象
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.chat_id = self.chat_stream.stream_id if self.chat_stream else None self.chat_id = self.chat_stream.session_id if self.chat_stream else None
self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None self.platform = getattr(self.chat_stream, "platform", None) if self.chat_stream else None
@classmethod @classmethod

View File

@@ -2,8 +2,8 @@ import asyncio
import contextlib import contextlib
from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING from typing import List, Dict, Optional, Type, Tuple, TYPE_CHECKING
from src.chat.message_receive.message import MessageRecv, MessageSending from src.chat.message_receive.message import MessageSending, SessionMessage
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages, CustomEventHandlerResult
from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_events_handler import BaseEventHandler
@@ -72,7 +72,7 @@ class EventsManager:
async def handle_mai_events( async def handle_mai_events(
self, self,
event_type: EventType | str, event_type: EventType | str,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None, message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None, llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
@@ -87,7 +87,7 @@ class EventsManager:
# 1. 准备消息 # 1. 准备消息
transformed_message = self._prepare_message( transformed_message = self._prepare_message(
event_type, message, llm_prompt, llm_response, stream_id, action_usage event_type, message, llm_prompt, llm_response, stream_id, action_usage # type: ignore[arg-type]
) )
if transformed_message: if transformed_message:
transformed_message = transformed_message.deepcopy() transformed_message = transformed_message.deepcopy()
@@ -134,7 +134,7 @@ class EventsManager:
async def handle_workflow_message( async def handle_workflow_message(
self, self,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None, message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None, action_usage: Optional[List[str]] = None,
context: Optional[WorkflowContext] = None, context: Optional[WorkflowContext] = None,
@@ -248,11 +248,13 @@ class EventsManager:
def _transform_event_message( def _transform_event_message(
self, self,
message: MessageRecv | MessageSending, message: SessionMessage | MessageSending,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None, llm_response: Optional["LLMGenerationDataModel"] = None,
) -> MaiMessages: ) -> MaiMessages:
"""转换事件消息格式""" """转换事件消息格式"""
from maim_message import Seg
# 直接赋值部分内容 # 直接赋值部分内容
transformed_message = MaiMessages( transformed_message = MaiMessages(
llm_prompt=llm_prompt, llm_prompt=llm_prompt,
@@ -260,45 +262,62 @@ class EventsManager:
llm_response_reasoning=llm_response.reasoning if llm_response else None, llm_response_reasoning=llm_response.reasoning if llm_response else None,
llm_response_model=llm_response.model if llm_response else None, llm_response_model=llm_response.model if llm_response else None,
llm_response_tool_call=llm_response.tool_calls if llm_response else None, llm_response_tool_call=llm_response.tool_calls if llm_response else None,
raw_message=message.raw_message, raw_message=message.processed_plain_text or "",
additional_data=message.message_info.additional_config or {}, additional_data={},
) )
# 消息段处理 # 消息段处理
if message.message_segment.type == "seglist": if isinstance(message, MessageSending):
transformed_message.message_segments = list(message.message_segment.data) # type: ignore if message.message_segment.type == "seglist":
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
else:
transformed_message.message_segments = [message.message_segment]
else: else:
transformed_message.message_segments = [message.message_segment] # SessionMessage: 使用 processed_plain_text 构造简单段
transformed_message.message_segments = [Seg(type="text", data=message.processed_plain_text or "")]
# stream_id 处理 # stream_id 处理
if hasattr(message, "chat_stream") and message.chat_stream: transformed_message.stream_id = message.session_id if hasattr(message, "session_id") else ""
transformed_message.stream_id = message.chat_stream.stream_id
# 处理后文本 # 处理后文本
transformed_message.plain_text = message.processed_plain_text transformed_message.plain_text = message.processed_plain_text
# 基本信息 # 基本信息
if hasattr(message, "message_info") and message.message_info: if isinstance(message, MessageSending):
if message.message_info.platform: transformed_message.message_base_info["platform"] = message.platform
transformed_message.message_base_info["platform"] = message.message_info.platform if message.session.group_id:
transformed_message.is_group_message = True
group_name = ""
if message.session.context and message.session.context.message and message.session.context.message.message_info.group_info:
group_name = message.session.context.message.message_info.group_info.group_name
transformed_message.message_base_info.update({
"group_id": message.session.group_id,
"group_name": group_name,
})
transformed_message.message_base_info.update({
"user_id": message.bot_user_info.user_id,
"user_cardname": message.bot_user_info.user_cardname,
"user_nickname": message.bot_user_info.user_nickname,
})
if not transformed_message.is_group_message:
transformed_message.is_private_message = True
elif hasattr(message, "message_info") and message.message_info:
if message.platform:
transformed_message.message_base_info["platform"] = message.platform
if message.message_info.group_info: if message.message_info.group_info:
transformed_message.is_group_message = True transformed_message.is_group_message = True
transformed_message.message_base_info.update( transformed_message.message_base_info.update({
{ "group_id": message.message_info.group_info.group_id,
"group_id": message.message_info.group_info.group_id, "group_name": message.message_info.group_info.group_name,
"group_name": message.message_info.group_info.group_name, })
}
)
if message.message_info.user_info: if message.message_info.user_info:
if not transformed_message.is_group_message: if not transformed_message.is_group_message:
transformed_message.is_private_message = True transformed_message.is_private_message = True
transformed_message.message_base_info.update( transformed_message.message_base_info.update({
{ "user_id": message.message_info.user_info.user_id,
"user_id": message.message_info.user_info.user_id, "user_cardname": message.message_info.user_info.user_cardname,
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称 "user_nickname": message.message_info.user_info.user_nickname,
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名) })
}
)
return transformed_message return transformed_message
@@ -306,9 +325,9 @@ class EventsManager:
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
) -> MaiMessages: ) -> MaiMessages:
"""从流ID构建消息""" """从流ID构建消息"""
chat_stream = get_chat_manager().get_stream(stream_id) session = _chat_manager.get_session_by_session_id(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id}聊天流" assert session, f"未找到流ID为 {stream_id}会话"
message = chat_stream.context.get_last_message() message = session.context.message
return self._transform_event_message(message, llm_prompt, llm_response) return self._transform_event_message(message, llm_prompt, llm_response)
def _transform_event_without_message( def _transform_event_without_message(
@@ -319,8 +338,8 @@ class EventsManager:
action_usage: Optional[List[str]] = None, action_usage: Optional[List[str]] = None,
) -> MaiMessages: ) -> MaiMessages:
"""没有message对象时进行转换""" """没有message对象时进行转换"""
chat_stream = get_chat_manager().get_stream(stream_id) session = _chat_manager.get_session_by_session_id(stream_id)
assert chat_stream, f"未找到流ID为 {stream_id}聊天流" assert session, f"未找到流ID为 {stream_id}会话"
return MaiMessages( return MaiMessages(
stream_id=stream_id, stream_id=stream_id,
llm_prompt=llm_prompt, llm_prompt=llm_prompt,
@@ -328,8 +347,8 @@ class EventsManager:
llm_response_reasoning=(llm_response.reasoning if llm_response else None), llm_response_reasoning=(llm_response.reasoning if llm_response else None),
llm_response_model=(llm_response.model if llm_response else None), llm_response_model=(llm_response.model if llm_response else None),
llm_response_tool_call=(llm_response.tool_calls if llm_response else None), llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
is_group_message=(not (not chat_stream.group_info)), is_group_message=session.is_group_session,
is_private_message=(not chat_stream.group_info), is_private_message=not session.is_group_session,
action_usage=action_usage, action_usage=action_usage,
additional_data={"response_is_processed": True}, additional_data={"response_is_processed": True},
) )
@@ -373,7 +392,7 @@ class EventsManager:
def _prepare_message( def _prepare_message(
self, self,
event_type: EventType | str, event_type: EventType | str,
message: Optional[MessageRecv | MessageSending | MaiMessages] = None, message: Optional[SessionMessage | MessageSending | MaiMessages] = None,
llm_prompt: Optional[str] = None, llm_prompt: Optional[str] = None,
llm_response: Optional["LLMGenerationDataModel"] = None, llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None, stream_id: Optional[str] = None,

View File

@@ -7,7 +7,7 @@ from src.llm_models.utils_model import LLMRequest
from src.llm_models.payload_content import ToolCall from src.llm_models.payload_content import ToolCall
from src.config.config import global_config, model_config from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager from src.prompt.prompt_manager import prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -28,8 +28,8 @@ class ToolExecutor:
cache_ttl: 缓存生存时间(周期数) cache_ttl: 缓存生存时间(周期数)
""" """
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(self.chat_id) self.chat_stream = _chat_manager.get_session_by_session_id(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" self.log_prefix = f"[{_chat_manager.get_session_name(self.chat_id) or self.chat_id}]"
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")

View File

@@ -11,7 +11,7 @@ from sqlmodel import col, select, delete
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import get_db_session from src.common.database.database import get_db_session
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_manager import chat_manager as _chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.expression") logger = get_logger("webui.expression")
@@ -118,14 +118,11 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
def get_chat_name(chat_id: str) -> str: def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称""" """根据 chat_id 获取聊天名称"""
try: try:
chat_stream = get_chat_manager().get_stream(chat_id) session = _chat_manager.get_session_by_session_id(chat_id)
if not chat_stream: if not session:
return chat_id return chat_id
if chat_stream.group_info and chat_stream.group_info.group_name: name = _chat_manager.get_session_name(chat_id)
return chat_stream.group_info.group_name return name or chat_id
if chat_stream.user_info and chat_stream.user_info.user_nickname:
return chat_stream.user_info.user_nickname
return chat_id
except Exception: except Exception:
return chat_id return chat_id
@@ -134,15 +131,9 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称""" """批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try: try:
chat_manager = get_chat_manager()
for chat_id in chat_ids: for chat_id in chat_ids:
chat_stream = chat_manager.get_stream(chat_id) if name := _chat_manager.get_session_name(chat_id):
if not chat_stream: result[chat_id] = name
continue
if chat_stream.group_info and chat_stream.group_info.group_name:
result[chat_id] = chat_stream.group_info.group_name
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
result[chat_id] = chat_stream.user_info.user_nickname
except Exception as e: except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}") logger.warning(f"批量获取聊天名称失败: {e}")
return result return result
@@ -179,17 +170,14 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
verify_auth_token(maibot_session, authorization) verify_auth_token(maibot_session, authorization)
chat_list = [] chat_list = []
for stream_id, stream in get_chat_manager().streams.items(): for session_id, session in _chat_manager.sessions.items():
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None chat_name = _chat_manager.get_session_name(session_id) or session_id
if not chat_name and stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname
chat_name = chat_name or stream_id
chat_list.append( chat_list.append(
ChatInfo( ChatInfo(
chat_id=stream_id, chat_id=session_id,
chat_name=chat_name, chat_name=chat_name,
platform=stream.platform, platform=session.platform,
is_group=bool(stream.group_info and stream.group_info.group_id), is_group=session.is_group_session,
) )
) )
@@ -495,11 +483,10 @@ async def batch_delete_expressions(
# 查找所有要删除的表达方式 # 查找所有要删除的表达方式
with get_db_session() as session: with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids)) statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = [expr_id for expr_id in session.exec(statements).all()] found_ids = list(session.exec(statements).all())
# 检查是否有未找到的ID # 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids) if not_found_ids := set(request.ids) - set(found_ids):
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}") logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除 # 执行批量删除
@@ -800,7 +787,7 @@ async def batch_review_expressions(
session.add(db_expression) session.add(db_expression)
results.append( results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝") BatchReviewResultItem(id=item.id, success=True, message="拒绝" if item.rejected else "通过")
) )
succeeded += 1 succeeded += 1