数据库的信息重构为dataclass

This commit is contained in:
UnCLAS-Prommer
2025-08-17 17:11:32 +08:00
parent d74beef4b4
commit 3481234d2b
18 changed files with 243 additions and 206 deletions

View File

@@ -1,13 +1,10 @@
from enum import Enum
from typing import Optional, Union, Dict, Any, Tuple, List
from typing import Optional
from dataclasses import dataclass, field
@dataclass
class DatabaseUserInfo:
user_platform: str = field(default_factory=str)
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
user_cardname: Optional[str] = None
@@ -84,17 +81,21 @@ class DatabaseMessages:
user_id=self.user_id,
user_nickname=self.user_nickname,
user_cardname=self.user_cardname,
user_platform=self.user_platform,
platform=self.user_platform,
)
if not (self.chat_info_group_id and self.chat_info_group_name):
self.group_info = None
if self.chat_info_group_id and self.chat_info_group_name:
self.group_info = DatabaseGroupInfo(
group_id=self.chat_info_group_id,
group_name=self.chat_info_group_name,
group_platform=self.chat_info_group_platform,
)
chat_user_info = DatabaseUserInfo(
user_id=self.chat_info_user_id,
user_nickname=self.chat_info_user_nickname,
user_cardname=self.chat_info_user_cardname,
user_platform=self.chat_info_user_platform,
platform=self.chat_info_user_platform,
)
self.chat_info = DatabaseChatInfo(
stream_id=self.chat_info_stream_id,

View File

@@ -0,0 +1,10 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class TargetPersonInfo:
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None
person_name: Optional[str] = None

View File

@@ -2,19 +2,20 @@ import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from src.config.config import global_config
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import Messages
from src.common.logger import get_logger
logger = get_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
return DatabaseMessages(**model_instance.__data__)
def find_messages(
@@ -24,7 +25,7 @@ def find_messages(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
) -> List[dict[str, Any]]:
) -> List[DatabaseMessages]:
"""
根据提供的过滤器、排序和限制条件查找消息。
@@ -112,7 +113,7 @@ def find_messages(
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
return [_model_to_dict(msg) for msg in peewee_results]
return [_model_to_instance(msg) for msg in peewee_results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"