Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
@@ -45,17 +48,17 @@ logger = get_logger("replyer")
|
||||
class DefaultReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
@@ -132,7 +135,7 @@ class DefaultReplyer:
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt="",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
@@ -202,7 +205,7 @@ class DefaultReplyer:
|
||||
try:
|
||||
if log_reply:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt=prompt,
|
||||
output=content,
|
||||
processed_output=None,
|
||||
@@ -259,7 +262,7 @@ class DefaultReplyer:
|
||||
if log_reply:
|
||||
try:
|
||||
PlanReplyLogger.log_reply(
|
||||
chat_id=self.chat_stream.stream_id,
|
||||
chat_id=self.chat_stream.session_id,
|
||||
prompt=prompt or "",
|
||||
output=None,
|
||||
processed_output=None,
|
||||
@@ -353,14 +356,14 @@ class DefaultReplyer:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id,
|
||||
self.chat_stream.session_id,
|
||||
chat_history,
|
||||
max_num=8,
|
||||
target_message=target,
|
||||
@@ -702,10 +705,11 @@ class DefaultReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
chat_id = SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
|
||||
)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -778,7 +782,7 @@ class DefaultReplyer:
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
_is_group_chat = bool(chat_stream.group_info)
|
||||
platform = chat_stream.platform
|
||||
|
||||
@@ -1005,7 +1009,7 @@ class DefaultReplyer:
|
||||
reply_to: str,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
|
||||
@@ -1105,29 +1109,27 @@ class DefaultReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: Optional[MaiMessage] = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
message_id=message_id,
|
||||
session=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
reply=anchor_message,
|
||||
is_head=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
|
||||
thinking_start_time=thinking_start_time,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,8 +12,11 @@ from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from maim_message import Seg
|
||||
|
||||
from src.common.data_models.mai_message_data_model import MaiMessage, UserInfo
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.chat.message_receive.uni_message_sender import UniversalMessageSender
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info, is_bot_self
|
||||
@@ -43,18 +46,18 @@ logger = get_logger("replyer")
|
||||
class PrivateReplyer:
|
||||
def __init__(
|
||||
self,
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: BotChatSession,
|
||||
request_type: str = "replyer",
|
||||
):
|
||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||
self.chat_stream = chat_stream
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.session_id)
|
||||
self.heart_fc_sender = UniversalMessageSender()
|
||||
# self.memory_activator = MemoryActivator()
|
||||
|
||||
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
|
||||
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
|
||||
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.session_id, enable_cache=True, cache_ttl=3)
|
||||
|
||||
async def generate_reply_with_context(
|
||||
self,
|
||||
@@ -253,14 +256,14 @@ class PrivateReplyer:
|
||||
str: 表达习惯信息字符串
|
||||
"""
|
||||
# 检查是否允许在此聊天流中使用表达
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.stream_id)
|
||||
use_expression, _, _ = TempMethodsExpression.get_expression_config_for_chat(self.chat_stream.session_id)
|
||||
if not use_expression:
|
||||
return "", []
|
||||
style_habits = []
|
||||
# 使用从处理器传来的选中表达方式
|
||||
# 使用模型预测选择表达方式
|
||||
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
|
||||
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
self.chat_stream.session_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
@@ -550,10 +553,11 @@ class PrivateReplyer:
|
||||
# 判断是否为群聊
|
||||
is_group = stream_type == "group"
|
||||
|
||||
# 使用 ChatManager 提供的接口生成 chat_id,避免在此重复实现逻辑
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.utils.utils_session import SessionUtils
|
||||
|
||||
chat_id = get_chat_manager().get_stream_id(platform, str(id_str), is_group=is_group)
|
||||
chat_id = SessionUtils.calculate_session_id(
|
||||
platform, group_id=str(id_str) if is_group else None, user_id=str(id_str) if not is_group else None
|
||||
)
|
||||
return chat_id, prompt_content
|
||||
|
||||
except (ValueError, IndexError):
|
||||
@@ -624,7 +628,7 @@ class PrivateReplyer:
|
||||
if available_actions is None:
|
||||
available_actions = {}
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
platform = chat_stream.platform
|
||||
|
||||
user_id = "用户ID"
|
||||
@@ -843,7 +847,7 @@ class PrivateReplyer:
|
||||
reply_to: str,
|
||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
chat_id = chat_stream.session_id
|
||||
|
||||
sender, target = self._parse_reply_target(reply_to)
|
||||
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
|
||||
@@ -948,29 +952,27 @@ class PrivateReplyer:
|
||||
is_emoji: bool,
|
||||
thinking_start_time: float,
|
||||
display_message: str,
|
||||
anchor_message: Optional[MessageRecv] = None,
|
||||
anchor_message: Optional[MaiMessage] = None,
|
||||
) -> MessageSending:
|
||||
"""构建单个发送消息"""
|
||||
|
||||
bot_user_info = UserInfo(
|
||||
user_id=str(global_config.bot.qq_account),
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.chat_stream.platform,
|
||||
)
|
||||
|
||||
# await anchor_message.process()
|
||||
sender_info = anchor_message.message_info.user_info if anchor_message else None
|
||||
|
||||
return MessageSending(
|
||||
message_id=message_id, # 使用片段的唯一ID
|
||||
chat_stream=self.chat_stream,
|
||||
message_id=message_id,
|
||||
session=self.chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=sender_info,
|
||||
message_segment=message_segment,
|
||||
reply=anchor_message, # 回复原始锚点
|
||||
reply=anchor_message,
|
||||
is_head=reply_to,
|
||||
is_emoji=is_emoji,
|
||||
thinking_start_time=thinking_start_time, # 传递原始思考开始时间
|
||||
thinking_start_time=thinking_start_time,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
from src.chat.replyer.group_generator import DefaultReplyer
|
||||
from src.chat.replyer.private_generator import PrivateReplyer
|
||||
|
||||
@@ -14,7 +14,7 @@ class ReplyerManager:
|
||||
|
||||
def get_replyer(
|
||||
self,
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_stream: Optional[BotChatSession] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer | PrivateReplyer]:
|
||||
@@ -24,7 +24,7 @@ class ReplyerManager:
|
||||
model_configs 仅在首次为某个 chat_id/stream_id 创建实例时有效。
|
||||
后续调用将返回已缓存的实例,忽略 model_configs 参数。
|
||||
"""
|
||||
stream_id = chat_stream.stream_id if chat_stream else chat_id
|
||||
stream_id = chat_stream.session_id if chat_stream else chat_id
|
||||
if not stream_id:
|
||||
logger.warning("[ReplyerManager] 缺少 stream_id,无法获取回复器。")
|
||||
return None
|
||||
@@ -39,15 +39,14 @@ class ReplyerManager:
|
||||
|
||||
target_stream = chat_stream
|
||||
if not target_stream:
|
||||
if chat_manager := get_chat_manager():
|
||||
target_stream = chat_manager.get_stream(stream_id)
|
||||
target_stream = _chat_manager.get_session_by_session_id(stream_id)
|
||||
|
||||
if not target_stream:
|
||||
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")
|
||||
return None
|
||||
|
||||
# model_configs 只在此时(初始化时)生效
|
||||
if target_stream.group_info:
|
||||
if target_stream.is_group_session:
|
||||
replyer = DefaultReplyer(
|
||||
chat_stream=target_stream,
|
||||
request_type=request_type,
|
||||
|
||||
Reference in New Issue
Block a user