fix typing, api change

This commit is contained in:
UnCLASPrommer
2025-07-15 00:57:43 +08:00
parent eae399fb95
commit af02f2ab57
5 changed files with 110 additions and 77 deletions

View File

@@ -13,23 +13,29 @@
"""
from typing import List, Dict, Any, Optional
from src.common.logger import get_logger
from enum import Enum
# 导入依赖
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
logger = get_logger("chat_api")
class SpecialTypes(Enum):
"""特殊枚举类型"""
ALL_PLATFORMS = "all_platforms"
class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 聊天流列表
@@ -37,7 +43,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform:
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的聊天流")
except Exception as e:
@@ -45,11 +51,11 @@ class ChatManager:
return streams
@staticmethod
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有群聊聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 群聊聊天流列表
@@ -57,7 +63,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and stream.group_info:
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的群聊流")
except Exception as e:
@@ -65,11 +71,11 @@ class ChatManager:
return streams
@staticmethod
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
"""获取所有私聊聊天流
Args:
platform: 平台筛选,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
List[ChatStream]: 私聊聊天流列表
@@ -77,7 +83,7 @@ class ChatManager:
streams = []
try:
for _, stream in get_chat_manager().streams.items():
if stream.platform == platform and not stream.group_info:
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
streams.append(stream)
logger.debug(f"[ChatAPI] 获取到 {len(streams)}{platform} 平台的私聊流")
except Exception as e:
@@ -85,12 +91,14 @@ class ChatManager:
return streams
@staticmethod
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_group_stream_by_group_id(
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
"""根据群ID获取聊天流
Args:
group_id: 群聊ID
platform: 平台,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
@@ -110,12 +118,14 @@ class ChatManager:
return None
@staticmethod
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_private_stream_by_user_id(
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
) -> Optional[ChatStream]:
"""根据用户ID获取私聊流
Args:
user_id: 用户ID
platform: 平台,默认为"qq"
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
Returns:
Optional[ChatStream]: 聊天流对象如果未找到返回None
@@ -145,7 +155,7 @@ class ChatManager:
str: 聊天类型 ("group", "private", "unknown")
"""
if not chat_stream:
return "unknown"
raise ValueError("chat_stream cannot be None")
if hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
@@ -165,7 +175,7 @@ class ChatManager:
return {}
try:
info = {
info: Dict[str, Any] = {
"stream_id": chat_stream.stream_id,
"platform": chat_stream.platform,
"type": ChatManager.get_stream_type(chat_stream),
@@ -200,9 +210,9 @@ class ChatManager:
Dict[str, int]: 包含各种统计信息的字典
"""
try:
all_streams = ChatManager.get_all_streams()
group_streams = ChatManager.get_group_streams()
private_streams = ChatManager.get_private_streams()
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
summary = {
"total_streams": len(all_streams),
@@ -215,7 +225,12 @@ class ChatManager:
return summary
except Exception as e:
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
return {"total_streams": 0, "group_streams": 0, "private_streams": 0, "qq_streams": 0}
return {
"total_streams": 0,
"group_streams": 0,
"private_streams": 0,
"qq_streams": 0,
}
# =============================================================================
@@ -223,41 +238,41 @@ class ChatManager:
# =============================================================================
def get_all_streams(platform: str = "qq") -> List[ChatStream]:
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform)
def get_group_streams(platform: str = "qq") -> List[ChatStream]:
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform)
def get_private_streams(platform: str = "qq") -> List[ChatStream]:
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq"):
"""获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq"):
"""根据群ID获取聊天流的便捷函数"""
return ChatManager.get_stream_by_group_id(group_id, platform)
return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: str = "qq") -> Optional[ChatStream]:
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq"):
"""根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_stream_by_user_id(user_id, platform)
return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str:
def get_stream_type(chat_stream: ChatStream):
"""获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
def get_stream_info(chat_stream: ChatStream):
"""获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream)
def get_streams_summary() -> Dict[str, int]:
def get_streams_summary():
"""获取聊天流统计摘要的便捷函数"""
return ChatManager.get_streams_summary()