重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

This commit is contained in:
DrSmoothl
2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
29 changed files with 2459 additions and 1737 deletions

View File

@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Self, TypeVar, Generic, TYPE_CHECKING
from dataclasses import is_dataclass
from typing import Any, Dict, Self, TypeVar, Generic, TYPE_CHECKING
import copy
@@ -15,9 +16,23 @@ class BaseDataModel:
return copy.deepcopy(self)
def transform_class_to_dict(obj: Any) -> Dict[str, Any]:
if obj is None:
return {}
if is_dataclass(obj):
return obj.__dict__
if hasattr(obj, "dict"):
return obj.dict()
if hasattr(obj, "model_dump"):
return obj.model_dump()
if hasattr(obj, "__dict__"):
return obj.__dict__
return {"value": obj}
class BaseDatabaseDataModel(ABC, Generic[T]):
@abstractmethod
@classmethod
@abstractmethod
def from_db_instance(cls, db_record: T) -> Self:
"""从数据库实例创建数据模型对象"""
raise NotImplementedError

View File

@@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union
from . import BaseDataModel
class ReplyContentType(Enum):
TEXT = "text"
IMAGE = "image"
EMOJI = "emoji"
COMMAND = "command"
VOICE = "voice"
HYBRID = "hybrid"
FORWARD = "forward"
def __str__(self) -> str:
return self.value
@dataclass
class ReplyContent:
content_type: ReplyContentType | str
content: Any
@dataclass
class ForwardNode:
user_id: Optional[str] = None
user_nickname: Optional[str] = None
content: Union[str, List[ReplyContent], None] = None
@classmethod
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
return cls(content=message_id)
@classmethod
def construct_as_created_node(
cls,
user_id: str,
user_nickname: str,
content: List[ReplyContent],
) -> "ForwardNode":
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
class ReplySetModel(BaseDataModel):
def __init__(self) -> None:
self.reply_data: List[ReplyContent] = []
def __len__(self) -> int:
return len(self.reply_data)
def add_text_content(self, text: str) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
def add_voice_content(self, voice_base64: str) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None:
hybrid_contents: List[ReplyContent] = []
for content_type, content in message_tuple_list:
hybrid_contents.append(
ReplyContent(content_type=self._normalize_content_type(content_type), content=content)
)
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents))
def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes))
@staticmethod
def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str:
if isinstance(content_type, ReplyContentType):
return content_type
if isinstance(content_type, str):
for item in ReplyContentType:
if item.value == content_type:
return item
return content_type

View File

@@ -1,11 +1,12 @@
from rich.traceback import install
from pathlib import Path
from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker
from pathlib import Path
from typing import Generator, TYPE_CHECKING
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel import create_engine, Session
from typing import TYPE_CHECKING, Generator
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, Session, create_engine
if TYPE_CHECKING:
from sqlite3 import Connection as SQLite3Connection
@@ -53,6 +54,19 @@ SessionLocal = sessionmaker(
class_=Session,
)
_db_initialized = False
def initialize_database() -> None:
global _db_initialized
if _db_initialized:
return
_DB_DIR.mkdir(parents=True, exist_ok=True)
import src.common.database.database_model # noqa: F401
SQLModel.metadata.create_all(engine)
_db_initialized = True
@contextmanager
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
@@ -87,6 +101,7 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
- auto_commit=True 时,成功执行完会自动提交
- auto_commit=False 时,需要手动调用 session.commit()
"""
initialize_database()
session = SessionLocal()
try:
yield session
@@ -120,6 +135,7 @@ def get_db() -> Generator[Session, None, None]:
Yields:
Session: SQLAlchemy 数据库会话
"""
initialize_database()
session = SessionLocal()
try:
yield session

View File

@@ -1,35 +1,153 @@
import traceback
from datetime import datetime
from typing import Any
from typing import List, Any, Optional
import json
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
from sqlalchemy import func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
"""
将 Peewee 模型实例转换为字典。
"""
if isinstance(model_instance, dict):
return DatabaseMessages(**model_instance)
if hasattr(model_instance, "model_dump"):
return DatabaseMessages(**model_instance.model_dump())
return DatabaseMessages(**model_instance.__dict__)
FIELD_MAP: dict[str, Any] = {
"time": Messages.timestamp,
"timestamp": Messages.timestamp,
"chat_id": Messages.session_id,
"session_id": Messages.session_id,
"user_id": Messages.user_id,
"message_id": Messages.message_id,
"group_id": Messages.group_id,
"platform": Messages.platform,
"is_command": Messages.is_command,
"is_mentioned": Messages.is_mentioned,
"is_at": Messages.is_at,
"is_emoji": Messages.is_emoji,
"is_picid": Messages.is_picture,
"is_picture": Messages.is_picture,
"reply_to": Messages.reply_to,
}
def _parse_additional_config(message: Messages) -> dict[str, Any]:
if not message.additional_config:
return {}
try:
parsed = json.loads(message.additional_config)
except (json.JSONDecodeError, TypeError):
return {}
if isinstance(parsed, dict):
return parsed
return {}
def _normalize_optional_str(value: object) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def _message_to_instance(message: Messages) -> DatabaseMessages:
config = _parse_additional_config(message)
timestamp_value = message.timestamp
if isinstance(timestamp_value, datetime):
time_value = timestamp_value.timestamp()
else:
time_value = float(timestamp_value)
selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
priority_info = _normalize_optional_str(config.get("priority_info"))
return DatabaseMessages(
message_id=message.message_id,
time=time_value,
chat_id=message.session_id,
reply_to=message.reply_to,
interest_value=config.get("interest_value"),
key_words=_normalize_optional_str(config.get("key_words")),
key_words_lite=_normalize_optional_str(config.get("key_words_lite")),
is_mentioned=message.is_mentioned,
is_at=message.is_at,
reply_probability_boost=config.get("reply_probability_boost"),
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
priority_mode=_normalize_optional_str(config.get("priority_mode")),
priority_info=priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji,
is_picid=message.is_picture,
is_command=message.is_command,
intercept_message_level=config.get("intercept_message_level", 0),
is_notify=message.is_notify,
selected_expressions=selected_expressions,
user_id=message.user_id,
user_nickname=message.user_nickname,
user_cardname=message.user_cardname,
user_platform=message.platform,
chat_info_group_id=message.group_id,
chat_info_group_name=message.group_name,
chat_info_group_platform=message.platform,
chat_info_user_id=message.user_id,
chat_info_user_nickname=message.user_nickname,
chat_info_user_cardname=message.user_cardname,
chat_info_user_platform=message.platform,
chat_info_stream_id=message.session_id,
chat_info_platform=message.platform,
chat_info_create_time=0.0,
chat_info_last_active_time=0.0,
)
def _coerce_datetime(value: Any) -> Any:
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value)
return value
def _cast_value_for_field(field: Any, value: Any) -> Any:
if field is Messages.timestamp:
return _coerce_datetime(value)
return value
def _ensure_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
if isinstance(value, set):
return list(value)
return [value]
def _resolve_field(field_name: str) -> Any | None:
if field_name in FIELD_MAP:
return FIELD_MAP[field_name]
if hasattr(Messages, field_name):
return getattr(Messages, field_name)
return None
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
sort: list[tuple[str, int]] | None = None,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: int | None = None,
) -> list[DatabaseMessages]:
"""
根据提供的过滤器、排序和限制条件查找消息。
@@ -43,92 +161,79 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
query = Messages.select()
# 应用过滤器
conditions: list[Any] = []
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
field = _resolve_field(key)
if field is None:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
if filter_bot:
query = query.where(Messages.user_id != global_config.bot.qq_account)
conditions.append(Messages.user_id != global_config.bot.qq_account)
if filter_command:
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
query = query.where(~Messages.is_command)
if filter_intercept_message_level is not None:
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
conditions.append(Messages.is_command == False) # noqa: E712
statement = select(Messages).where(*conditions)
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by("time").limit(limit)
peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by("-time").limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(
latest_results_peewee,
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
)
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
with get_db_session() as session:
results = list(session.exec(statement).all())
else:
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
with get_db_session() as session:
results = list(session.exec(statement).all())
results = list(reversed(results))
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
peewee_sort_terms = []
order_terms: list[Any] = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field_name)
elif direction == -1: # DESC
peewee_sort_terms.append(f"-{field_name}")
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
sort_field = _resolve_field(field_name)
if sort_field is None:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
continue
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
if order_terms:
statement = statement.order_by(*order_terms)
with get_db_session() as session:
results = list(session.exec(statement).all())
return [_model_to_instance(msg) for msg in peewee_results]
if filter_intercept_message_level is not None:
filtered_results = []
for msg in results:
config = _parse_additional_config(msg)
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
filtered_results.append(msg)
results = filtered_results
return [_message_to_instance(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -146,54 +251,42 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
query = Messages.select()
# 应用过滤器
conditions: list[Any] = []
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
field = _resolve_field(key)
if field is None:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
count = query.count()
return count
conditions.append(Messages.message_id != "notice")
statement = select(func.count()).select_from(Messages).where(*conditions)
with get_db_session() as session:
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。