数据库的信息重构为dataclass
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from typing import Optional, Union, Dict, Any, Tuple, List
|
||||
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseUserInfo:
|
||||
user_platform: str = field(default_factory=str)
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
user_cardname: Optional[str] = None
|
||||
@@ -84,17 +81,21 @@ class DatabaseMessages:
|
||||
user_id=self.user_id,
|
||||
user_nickname=self.user_nickname,
|
||||
user_cardname=self.user_cardname,
|
||||
user_platform=self.user_platform,
|
||||
platform=self.user_platform,
|
||||
)
|
||||
|
||||
if not (self.chat_info_group_id and self.chat_info_group_name):
|
||||
self.group_info = None
|
||||
if self.chat_info_group_id and self.chat_info_group_name:
|
||||
self.group_info = DatabaseGroupInfo(
|
||||
group_id=self.chat_info_group_id,
|
||||
group_name=self.chat_info_group_name,
|
||||
group_platform=self.chat_info_group_platform,
|
||||
)
|
||||
|
||||
chat_user_info = DatabaseUserInfo(
|
||||
user_id=self.chat_info_user_id,
|
||||
user_nickname=self.chat_info_user_nickname,
|
||||
user_cardname=self.chat_info_user_cardname,
|
||||
user_platform=self.chat_info_user_platform,
|
||||
platform=self.chat_info_user_platform,
|
||||
)
|
||||
self.chat_info = DatabaseChatInfo(
|
||||
stream_id=self.chat_info_stream_id,
|
||||
|
||||
10
src/common/data_models/info_data_model.py
Normal file
10
src/common/data_models/info_data_model.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass
|
||||
class TargetPersonInfo:
|
||||
platform: str = field(default_factory=str)
|
||||
user_id: str = field(default_factory=str)
|
||||
user_nickname: str = field(default_factory=str)
|
||||
person_id: Optional[str] = None
|
||||
person_name: Optional[str] = None
|
||||
@@ -2,19 +2,20 @@ import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.database.database_model import Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
"""
|
||||
return model_instance.__data__
|
||||
return DatabaseMessages(**model_instance.__data__)
|
||||
|
||||
|
||||
def find_messages(
|
||||
@@ -24,7 +25,7 @@ def find_messages(
|
||||
limit_mode: str = "latest",
|
||||
filter_bot=False,
|
||||
filter_command=False,
|
||||
) -> List[dict[str, Any]]:
|
||||
) -> List[DatabaseMessages]:
|
||||
"""
|
||||
根据提供的过滤器、排序和限制条件查找消息。
|
||||
|
||||
@@ -112,7 +113,7 @@ def find_messages(
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
|
||||
return [_model_to_dict(msg) for msg in peewee_results]
|
||||
return [_model_to_instance(msg) for msg in peewee_results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
|
||||
Reference in New Issue
Block a user