Refactor chat stream handling to use BotChatSession
- Updated imports and references from ChatStream to BotChatSession across multiple files. - Adjusted method signatures and internal logic to accommodate the new session management. - Ensured compatibility with existing functionality while improving code clarity and maintainability.
This commit is contained in:
@@ -16,7 +16,7 @@ from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_manager import BotChatSession, chat_manager as _chat_manager
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
@@ -31,7 +31,7 @@ class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
@@ -39,7 +39,7 @@ class ChatManager:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
List[BotChatSession]: 聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
@@ -48,7 +48,7 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
@@ -57,7 +57,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
@@ -65,14 +65,14 @@ class ChatManager:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
List[BotChatSession]: 群聊聊天流列表
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.is_group_session:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
@@ -80,7 +80,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
@@ -88,7 +88,7 @@ class ChatManager:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
List[BotChatSession]: 私聊聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
@@ -97,8 +97,10 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||
) and not stream.is_group_session:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
@@ -108,7 +110,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
@@ -116,7 +118,7 @@ class ChatManager:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
Optional[BotChatSession]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 group_id 为空字符串
|
||||
@@ -129,11 +131,11 @@ class ChatManager:
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
stream.is_group_session
|
||||
and str(stream.group_id) == str(group_id)
|
||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
@@ -145,7 +147,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
) -> Optional[BotChatSession]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
@@ -153,7 +155,7 @@ class ChatManager:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
Optional[BotChatSession]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 user_id 为空字符串
|
||||
@@ -166,11 +168,11 @@ class ChatManager:
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
for _, stream in _chat_manager.sessions.items():
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
not stream.is_group_session
|
||||
and str(stream.user_id) == str(user_id)
|
||||
and (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform)
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
@@ -180,7 +182,7 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
||||
"""获取聊天流类型
|
||||
|
||||
Args:
|
||||
@@ -190,20 +192,18 @@ class ChatManager:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
TypeError: 如果 chat_stream 不是 BotChatSession 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
return "group" if chat_stream.is_group_session else "private"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
@@ -213,36 +213,34 @@ class ChatManager:
|
||||
Dict ({str: Any}): 聊天流信息字典
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
TypeError: 如果 chat_stream 不是 BotChatSession 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not isinstance(chat_stream, BotChatSession):
|
||||
raise TypeError("chat_stream 必须是 BotChatSession 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"session_id": chat_stream.session_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.group_info:
|
||||
info.update(
|
||||
{
|
||||
"group_id": chat_stream.group_info.group_id,
|
||||
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
|
||||
}
|
||||
)
|
||||
|
||||
if chat_stream.user_info:
|
||||
info.update(
|
||||
{
|
||||
"user_id": chat_stream.user_info.user_id,
|
||||
"user_name": chat_stream.user_info.user_nickname,
|
||||
}
|
||||
)
|
||||
if chat_stream.is_group_session:
|
||||
info["group_id"] = chat_stream.group_id
|
||||
# Try to get group name from context
|
||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.group_info:
|
||||
info["group_name"] = chat_stream.context.message.message_info.group_info.group_name or "未知群聊"
|
||||
else:
|
||||
info["group_name"] = "未知群聊"
|
||||
else:
|
||||
info["user_id"] = chat_stream.user_id
|
||||
if chat_stream.context and chat_stream.context.message and chat_stream.context.message.message_info.user_info:
|
||||
info["user_name"] = chat_stream.context.message.message_info.user_info.user_nickname
|
||||
else:
|
||||
info["user_name"] = "未知用户"
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
@@ -285,37 +283,37 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[BotChatSession]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[BotChatSession]:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: BotChatSession) -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: BotChatSession) -> Dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user