feat: 增强数据库服务,添加类型转换以支持更灵活的查询

This commit is contained in:
DrSmoothl
2026-03-31 08:21:53 +08:00
parent 5ac088ded8
commit ea4cea39f2
2 changed files with 247 additions and 26 deletions

View File

@@ -4,7 +4,7 @@ import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
@@ -65,7 +65,7 @@ async def db_save(
record = None
if key_field and key_value is not None:
key_column = _get_model_field(model_class, key_field)
record = session.exec(select(model_class).where(key_column == key_value)).first()
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
if record is None:
record = model_class(**data)
@@ -99,7 +99,7 @@ async def db_get(
statement = _apply_order_by(statement, model_class, order_by)
if limit:
statement = statement.limit(limit)
results = session.exec(statement).all()
results = session.exec(cast(Any, statement)).all()
data = [_to_dict(item) for item in results]
if single_result:
return data[0] if data else None
@@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
statement = select(model_class)
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
records = session.exec(statement).all()
records = session.exec(cast(Any, statement)).all()
for record in records:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
@@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
statement = select(func.count()).select_from(model_class)
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
result = session.exec(statement).one()
result = session.exec(cast(Any, statement)).one()
return int(result or 0)
except Exception as e:
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")