添加对 peewee 的旧数据库的兼容层,初步重构插件的 database API
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from rich.traceback import install
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing import TYPE_CHECKING, Generator
|
||||
|
||||
@@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _AtomicContext:
|
||||
def __init__(self) -> None:
|
||||
self._session: Session | None = None
|
||||
|
||||
def __enter__(self) -> Session:
|
||||
self._session = SessionLocal()
|
||||
self._session.begin()
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if exc_type is None:
|
||||
self._session.commit()
|
||||
else:
|
||||
self._session.rollback()
|
||||
finally:
|
||||
self._session.close()
|
||||
|
||||
|
||||
class DatabaseCompat:
|
||||
"""兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。"""
|
||||
|
||||
def connect(self, reuse_if_open: bool = True) -> None:
|
||||
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
|
||||
_ = reuse_if_open
|
||||
|
||||
def create_tables(self, models: list[type], safe: bool = True) -> None:
|
||||
_ = safe
|
||||
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
|
||||
if not tables:
|
||||
return
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
SQLModel.metadata.create_all(engine, tables=tables)
|
||||
|
||||
def atomic(self) -> _AtomicContext:
|
||||
return _AtomicContext()
|
||||
|
||||
def execute_sql(self, sql: str):
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
conn.commit()
|
||||
return result
|
||||
|
||||
def table_exists(self, model: type) -> bool:
|
||||
if not hasattr(model, "__tablename__"):
|
||||
return False
|
||||
inspector = sqlalchemy_inspect(engine)
|
||||
return inspector.has_table(model.__tablename__)
|
||||
|
||||
|
||||
db = DatabaseCompat()
|
||||
|
||||
@@ -304,7 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"websockets",
|
||||
"httpcore",
|
||||
"requests",
|
||||
"peewee",
|
||||
"sqlalchemy",
|
||||
"openai",
|
||||
"uvicorn",
|
||||
"jieba",
|
||||
@@ -876,19 +876,19 @@ def initialize_logging(verbose: bool = True):
|
||||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||
|
||||
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
|
||||
|
||||
|
||||
Args:
|
||||
verbose: 是否输出详细的初始化信息。默认为 True。
|
||||
在 Runner 进程中可以设置为 False 以避免重复的初始化日志。
|
||||
"""
|
||||
global LOG_CONFIG, _logging_initialized
|
||||
|
||||
|
||||
# 防止重复初始化(在同一进程内)
|
||||
if _logging_initialized:
|
||||
return
|
||||
|
||||
|
||||
_logging_initialized = True
|
||||
|
||||
|
||||
LOG_CONFIG = load_log_config()
|
||||
# print(LOG_CONFIG)
|
||||
configure_third_party_loggers()
|
||||
@@ -941,16 +941,16 @@ def cleanup_old_logs():
|
||||
|
||||
def start_log_cleanup_task(verbose: bool = True):
|
||||
"""启动日志清理任务
|
||||
|
||||
|
||||
Args:
|
||||
verbose: 是否输出启动信息。默认为 True。
|
||||
"""
|
||||
global _cleanup_task_started
|
||||
|
||||
|
||||
# 防止重复启动清理任务
|
||||
if _cleanup_task_started:
|
||||
return
|
||||
|
||||
|
||||
_cleanup_task_started = True
|
||||
|
||||
def cleanup_task():
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -11,11 +10,15 @@ from src.common.logger import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
|
||||
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
"""
|
||||
return DatabaseMessages(**model_instance.__data__)
|
||||
if isinstance(model_instance, dict):
|
||||
return DatabaseMessages(**model_instance)
|
||||
if hasattr(model_instance, "model_dump"):
|
||||
return DatabaseMessages(**model_instance.model_dump())
|
||||
return DatabaseMessages(**model_instance.__dict__)
|
||||
|
||||
|
||||
def find_messages(
|
||||
@@ -92,14 +95,17 @@ def find_messages(
|
||||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
query = query.order_by("time").limit(limit)
|
||||
peewee_results = list(query)
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
query = query.order_by("-time").limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
peewee_results = sorted(
|
||||
latest_results_peewee,
|
||||
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
|
||||
)
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
@@ -108,9 +114,9 @@ def find_messages(
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
peewee_sort_terms.append(field_name)
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
peewee_sort_terms.append(f"-{field_name}")
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user