fix typing, api change
This commit is contained in:
@@ -13,23 +13,29 @@
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from enum import Enum
|
||||
|
||||
# 导入依赖
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
@@ -37,7 +43,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform:
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
@@ -45,11 +51,11 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
@@ -57,7 +63,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and stream.group_info:
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
@@ -65,11 +71,11 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
@@ -77,7 +83,7 @@ class ChatManager:
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if stream.platform == platform and not stream.group_info:
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
@@ -85,12 +91,14 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
platform: 平台,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
@@ -110,12 +118,14 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台,默认为"qq"
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
@@ -145,7 +155,7 @@ class ChatManager:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
"""
|
||||
if not chat_stream:
|
||||
return "unknown"
|
||||
raise ValueError("chat_stream cannot be None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
@@ -165,7 +175,7 @@ class ChatManager:
|
||||
return {}
|
||||
|
||||
try:
|
||||
info = {
|
||||
info: Dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
@@ -200,9 +210,9 @@ class ChatManager:
|
||||
Dict[str, int]: 包含各种统计信息的字典
|
||||
"""
|
||||
try:
|
||||
all_streams = ChatManager.get_all_streams()
|
||||
group_streams = ChatManager.get_group_streams()
|
||||
private_streams = ChatManager.get_private_streams()
|
||||
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
|
||||
summary = {
|
||||
"total_streams": len(all_streams),
|
||||
@@ -215,7 +225,12 @@ class ChatManager:
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
||||
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
|
||||
return {
|
||||
"total_streams": 0,
|
||||
"group_streams": 0,
|
||||
"private_streams": 0,
|
||||
"qq_streams": 0,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -223,41 +238,41 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_stream_by_group_id(group_id, platform)
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_stream_by_user_id(user_id, platform)
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: ChatStream):
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
def get_stream_info(chat_stream: ChatStream):
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
def get_streams_summary():
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
|
||||
Reference in New Issue
Block a user