重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -2,13 +2,15 @@ import asyncio
import hashlib
import time
import copy
from datetime import datetime
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
@@ -76,7 +78,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
self.context: Optional[ChatMessageContext] = None
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -95,10 +97,13 @@ class ChatStream:
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
if user_info is None:
raise ValueError("user_info is required to build ChatStream")
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
user_info=user_info,
group_info=group_info,
data=data,
)
@@ -128,12 +133,7 @@ class ChatManager:
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
get_db_session()
self._initialized = True
# 在事件循环中启动初始化
@@ -161,8 +161,13 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
platform = message.message_info.platform or ""
if not platform:
raise ValueError("platform is required for ChatStream")
if message.message_info.user_info is None and message.message_info.group_info is None:
raise ValueError("user_info or group_info is required for ChatStream")
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
platform,
message.message_info.user_info,
message.message_info.group_info,
)
@@ -176,12 +181,18 @@ class ChatManager:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info is None and user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
if user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if user_info.user_id is None:
raise ValueError("user_id is required for private stream")
components = [platform, str(user_info.user_id), "private"]
# 使用MD5生成唯一ID
key = "_".join(components)
@@ -231,33 +242,35 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
return session.exec(statement).first()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"platform": model_instance.platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
@@ -329,20 +342,26 @@ class ChatManager:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
record = session.exec(statement).first()
if record is None:
record = ChatSession(
session_id=s_data_dict["stream_id"],
platform=s_data_dict["platform"],
user_id=user_info_d["user_id"] if user_info_d else None,
group_id=group_info_d["group_id"] if group_info_d else None,
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
)
session.add(record)
return
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
record.platform = s_data_dict["platform"]
record.user_id = user_info_d["user_id"] if user_info_d else None
record.group_id = group_info_d["group_id"] if group_info_d else None
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@@ -361,30 +380,32 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
with get_db_session() as session:
statement = select(ChatSession)
for model_instance in session.exec(statement).all():
user_info_data = {
"platform": model_instance.platform,
"user_id": model_instance.user_id or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
data_for_from_dict = {
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try: