重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
This commit is contained in:
@@ -2,13 +2,15 @@ import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from sqlmodel import select, col
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import ChatStreams # 新增导入
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ChatSession
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
if TYPE_CHECKING:
|
||||
@@ -76,7 +78,7 @@ class ChatStream:
|
||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
self.saved = False
|
||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||||
self.context: Optional[ChatMessageContext] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -95,10 +97,13 @@ class ChatStream:
|
||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||
|
||||
if user_info is None:
|
||||
raise ValueError("user_info is required to build ChatStream")
|
||||
|
||||
return cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info, # type: ignore
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
@@ -128,12 +133,7 @@ class ChatManager:
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
get_db_session()
|
||||
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
@@ -161,8 +161,13 @@ class ChatManager:
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
platform = message.message_info.platform or ""
|
||||
if not platform:
|
||||
raise ValueError("platform is required for ChatStream")
|
||||
if message.message_info.user_info is None and message.message_info.group_info is None:
|
||||
raise ValueError("user_info or group_info is required for ChatStream")
|
||||
stream_id = self._generate_stream_id(
|
||||
message.message_info.platform, # type: ignore
|
||||
platform,
|
||||
message.message_info.user_info,
|
||||
message.message_info.group_info,
|
||||
)
|
||||
@@ -176,12 +181,18 @@ class ChatManager:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
if group_info is None and user_info is None:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [platform, str(group_info.group_id)]
|
||||
else:
|
||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||||
if user_info is None:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
if user_info.user_id is None:
|
||||
raise ValueError("user_id is required for private stream")
|
||||
components = [platform, str(user_info.user_id), "private"]
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
@@ -231,33 +242,35 @@ class ChatManager:
|
||||
|
||||
# 检查数据库中是否存在
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
|
||||
return session.exec(statement).first()
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"platform": model_instance.platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
"user_nickname": "",
|
||||
"user_cardname": "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"platform": model_instance.platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
"group_name": "",
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"stream_id": model_instance.session_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
"create_time": model_instance.created_timestamp.timestamp(),
|
||||
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
@@ -329,20 +342,26 @@ class ChatManager:
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
|
||||
record = session.exec(statement).first()
|
||||
if record is None:
|
||||
record = ChatSession(
|
||||
session_id=s_data_dict["stream_id"],
|
||||
platform=s_data_dict["platform"],
|
||||
user_id=user_info_d["user_id"] if user_info_d else None,
|
||||
group_id=group_info_d["group_id"] if group_info_d else None,
|
||||
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
|
||||
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
|
||||
)
|
||||
session.add(record)
|
||||
return
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
record.platform = s_data_dict["platform"]
|
||||
record.user_id = user_info_d["user_id"] if user_info_d else None
|
||||
record.group_id = group_info_d["group_id"] if group_info_d else None
|
||||
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
|
||||
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
@@ -361,30 +380,32 @@ class ChatManager:
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
with get_db_session() as session:
|
||||
statement = select(ChatSession)
|
||||
for model_instance in session.exec(statement).all():
|
||||
user_info_data = {
|
||||
"platform": model_instance.platform,
|
||||
"user_id": model_instance.user_id or "",
|
||||
"user_nickname": "",
|
||||
"user_cardname": "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": "",
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.session_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.created_timestamp.timestamp(),
|
||||
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user