- Updated `calculate_session_id` method in `SessionUtils` to include optional `account_id` and `scope` parameters for more granular session ID generation. - Added new environment variables in `plugin_runtime` for external plugin dependencies and global configuration snapshots. - Introduced methods in `RuntimeComponentManagerProtocol` for loading and reloading plugins globally, accommodating external dependencies. - Enhanced `PluginRunnerSupervisor` to manage external available plugin IDs during plugin reloads. - Implemented dependency extraction and management in `PluginRuntimeManager` to handle cross-supervisor dependencies. - Added tests for session ID calculation and message registration in `ChatManager` to ensure correct behavior with new parameters.
280 lines
10 KiB
Python
280 lines
10 KiB
Python
import asyncio
|
||
from datetime import datetime
|
||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||
|
||
from rich.traceback import install
|
||
from sqlmodel import select
|
||
|
||
from src.common.data_models.chat_session_data_model import MaiChatSession
|
||
from src.common.database.database import get_db_session
|
||
from src.common.database.database_model import ChatSession
|
||
from src.common.logger import get_logger
|
||
from src.common.utils.utils_session import SessionUtils
|
||
from src.platform_io.route_key_factory import RouteKeyFactory
|
||
|
||
if TYPE_CHECKING:
|
||
from .message import SessionMessage
|
||
|
||
install(extra_lines=3)
|
||
|
||
logger = get_logger("chat_manager")
|
||
|
||
|
||
class SessionContext:
|
||
"""会话上下文"""
|
||
|
||
def __init__(self, message: "SessionMessage"):
|
||
self.message = message
|
||
self.template_name: Optional[str] = None
|
||
|
||
def update_template(self, template_name: str):
|
||
"""更新当前使用的回复模板"""
|
||
self.template_name = template_name
|
||
|
||
|
||
class BotChatSession(MaiChatSession):
|
||
def __init__(
|
||
self,
|
||
session_id: str,
|
||
platform: str,
|
||
user_id: Optional[str] = None,
|
||
group_id: Optional[str] = None,
|
||
created_timestamp: Optional[datetime] = None,
|
||
last_active_timestamp: Optional[datetime] = None,
|
||
):
|
||
self.context: Optional[SessionContext] = None
|
||
self.accept_format: List[str] = []
|
||
|
||
super().__init__(
|
||
session_id=session_id,
|
||
platform=platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
created_timestamp=created_timestamp,
|
||
last_active_timestamp=last_active_timestamp,
|
||
)
|
||
|
||
def check_types(self, types: List[str]) -> bool:
|
||
"""检查消息是否符合可接受类型列表"""
|
||
return all(t in self.accept_format for t in types)
|
||
|
||
def update_active_time(self):
|
||
"""更新最后活跃时间"""
|
||
self.last_active_timestamp = datetime.now()
|
||
|
||
def set_context(self, message: "SessionMessage"):
|
||
"""设置会话上下文"""
|
||
self.context = SessionContext(message=message)
|
||
|
||
|
||
class ChatManager:
|
||
"""聊天管理器,负责管理所有聊天会话"""
|
||
|
||
def __init__(self) -> None:
|
||
self.sessions: Dict[str, BotChatSession] = {} # session_id -> BotChatSession
|
||
self.last_messages: Dict[str, "SessionMessage"] = {} # session_id -> SessionMessage
|
||
|
||
async def initialize(self):
|
||
"""初始化聊天管理器"""
|
||
try:
|
||
await self.load_all_sessions_from_db()
|
||
logger.info(f"已加载 {len(self.sessions)} 个会话记录到内存中")
|
||
except Exception as e:
|
||
logger.error(f"初始化聊天管理器出现错误: {e}")
|
||
|
||
async def get_or_create_session(
|
||
self,
|
||
platform: str,
|
||
user_id: str,
|
||
group_id: Optional[str] = None,
|
||
account_id: Optional[str] = None,
|
||
scope: Optional[str] = None,
|
||
) -> BotChatSession:
|
||
"""获取会话,如果不存在则创建一个新会话;一个封装方法。
|
||
|
||
Args:
|
||
platform: 平台
|
||
user_id: 用户ID
|
||
group_id: 群ID(如果是群聊)
|
||
account_id: 平台账号 ID
|
||
scope: 路由作用域
|
||
Returns:
|
||
return (BotChatSession) 会话对象
|
||
Raises:
|
||
Exception: 获取或创建会话时发生错误
|
||
"""
|
||
session_id = SessionUtils.calculate_session_id(
|
||
platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
account_id=account_id,
|
||
scope=scope,
|
||
)
|
||
if session := self.get_session_by_session_id(session_id):
|
||
session.update_active_time()
|
||
return session
|
||
|
||
# 内存没有就找db
|
||
try:
|
||
with get_db_session() as db_session:
|
||
statement = select(ChatSession).filter_by(session_id=session_id).limit(1)
|
||
if result := db_session.exec(statement).first():
|
||
session = BotChatSession.from_db_instance(result)
|
||
self.sessions[session.session_id] = session
|
||
return session
|
||
except Exception as e:
|
||
logger.error(f"从数据库获取会话时发生错误: {e}")
|
||
raise e
|
||
|
||
# 都没有就创建新的
|
||
new_session = BotChatSession(
|
||
session_id=session_id,
|
||
platform=platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
)
|
||
self.sessions[new_session.session_id] = new_session
|
||
if new_session.session_id in self.last_messages:
|
||
new_session.set_context(self.last_messages[new_session.session_id])
|
||
self._save_session(new_session)
|
||
return new_session
|
||
|
||
def register_message(self, message: "SessionMessage"):
|
||
platform = message.platform
|
||
if not platform:
|
||
raise ValueError("消息缺少平台信息")
|
||
user_id = message.message_info.user_info.user_id
|
||
group_id = message.message_info.group_info.group_id if message.message_info.group_info else None
|
||
account_id = None
|
||
scope = None
|
||
additional_config = message.message_info.additional_config
|
||
if isinstance(additional_config, dict):
|
||
account_id, scope = RouteKeyFactory.extract_components(additional_config)
|
||
session_id = SessionUtils.calculate_session_id(
|
||
platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
account_id=account_id,
|
||
scope=scope,
|
||
)
|
||
message.session_id = session_id # 确保消息的session_id正确设置
|
||
self.last_messages[session_id] = message
|
||
|
||
async def load_all_sessions_from_db(self):
|
||
"""从数据库加载全部会话记录到内存中"""
|
||
self.sessions.clear()
|
||
try:
|
||
await asyncio.to_thread(self._load_sessions_from_db)
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载会话记录时发生错误: {e}")
|
||
self.sessions.clear()
|
||
raise e
|
||
|
||
async def regularly_save_sessions(self, interval_seconds: int = 300):
|
||
"""定期将会话记录保存到数据库中
|
||
|
||
Args:
|
||
interval_seconds: 保存间隔时间,单位为秒,默认为300秒(5分钟)
|
||
"""
|
||
while True:
|
||
await asyncio.sleep(interval_seconds)
|
||
try:
|
||
await asyncio.to_thread(self.save_all_sessions)
|
||
except Exception as e:
|
||
logger.error(f"定期保存会话记录时发生错误: {e}")
|
||
|
||
def save_all_sessions(self):
|
||
"""将内存中的全部会话记录保存到数据库"""
|
||
try:
|
||
for session in self.sessions.values():
|
||
self._save_session(session)
|
||
logger.info(f"共 {len(self.sessions)} 个会话已经保存到数据库中")
|
||
except Exception as e:
|
||
logger.error(f"保存会话记录到数据库时发生错误: {e}")
|
||
raise e
|
||
|
||
def get_session_name(self, session_id: str) -> Optional[str]:
|
||
"""根据会话ID获取会话名称
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
Returns:
|
||
Optional[str]: 会话名称,如果无法获取则返回None
|
||
"""
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
return None
|
||
if session.is_group_session:
|
||
if session.context and session.context.message and session.context.message.message_info.group_info:
|
||
return session.context.message.message_info.group_info.group_name
|
||
elif session.context and session.context.message and session.context.message.message_info.user_info:
|
||
nickname = session.context.message.message_info.user_info.user_nickname
|
||
return f"{nickname}的私聊"
|
||
return None
|
||
|
||
def get_session_by_info(
|
||
self,
|
||
platform: str,
|
||
user_id: Optional[str] = None,
|
||
group_id: Optional[str] = None,
|
||
account_id: Optional[str] = None,
|
||
scope: Optional[str] = None,
|
||
) -> Optional[BotChatSession]:
|
||
"""根据平台、用户ID和群ID获取对应的会话
|
||
|
||
Args:
|
||
platform: 平台
|
||
user_id: 用户ID
|
||
group_id: 群ID(如果是群聊)
|
||
account_id: 平台账号 ID
|
||
scope: 路由作用域
|
||
Returns:
|
||
return (Optional[BotChatSession]): 会话对象,如果不存在则返回None
|
||
"""
|
||
session_id = SessionUtils.calculate_session_id(
|
||
platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
account_id=account_id,
|
||
scope=scope,
|
||
)
|
||
return self.get_session_by_session_id(session_id)
|
||
|
||
def get_session_by_session_id(self, session_id: str) -> Optional[BotChatSession]:
|
||
"""根据会话ID获取对应的会话
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
Returns:
|
||
Optional[BotChatSession]: 会话对象,如果不存在则返回None
|
||
"""
|
||
session = self.sessions.get(session_id)
|
||
if session and session_id in self.last_messages:
|
||
session.set_context(self.last_messages[session_id])
|
||
return session
|
||
|
||
def _load_sessions_from_db(self):
|
||
"""从数据库加载单个会话记录"""
|
||
with get_db_session() as session:
|
||
statements = select(ChatSession)
|
||
for model_instance in session.exec(statements).all():
|
||
bot_chat_session = BotChatSession.from_db_instance(model_instance)
|
||
self.sessions[bot_chat_session.session_id] = bot_chat_session
|
||
if bot_chat_session.session_id in self.last_messages:
|
||
bot_chat_session.set_context(self.last_messages[bot_chat_session.session_id])
|
||
|
||
def _save_session(self, session: BotChatSession):
|
||
"""将会话记录保存到数据库"""
|
||
with get_db_session() as db_session:
|
||
db_instance = session.to_db_instance()
|
||
statement = select(ChatSession).filter_by(session_id=db_instance.session_id).limit(1)
|
||
if result := db_session.exec(statement).first():
|
||
result.created_timestamp = db_instance.created_timestamp
|
||
result.last_active_timestamp = db_instance.last_active_timestamp
|
||
db_session.add(result)
|
||
else:
|
||
db_session.add(db_instance)
|
||
|
||
|
||
chat_manager = ChatManager()
|