炸 service 层

This commit is contained in:
DrSmoothl
2026-03-14 00:13:35 +08:00
parent 898fab6de9
commit 43c5b34623
13 changed files with 1408 additions and 1736 deletions

View File

@@ -1,13 +1,16 @@
"""数据库服务模块
提供数据库操作相关的核心功能。
"""
"""数据库服务模块"""
import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord
from src.common.logger import get_logger
logger = get_logger("database_service")
@@ -25,73 +28,57 @@ def _to_dict(record: Any) -> dict[str, Any]:
return {}
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: str = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: bool = False,
def _get_model_field(model_class: type[SQLModel], field_name: str):
field = getattr(model_class, field_name, None)
if field is None:
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
return field
def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> list[Any]:
if not filters:
return []
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
if not order_by:
return statement
order_fields = [order_by] if isinstance(order_by, str) else order_by
clauses = []
for item in order_fields:
descending = item.startswith("-")
field_name = item[1:] if descending else item
field = _get_model_field(model_class, field_name)
clauses.append(field.desc() if descending else field.asc())
return statement.order_by(*clauses)
async def db_save(
model_class: type[SQLModel],
data: dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None,
):
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
with get_db_session() as session:
record = None
if key_field and key_value is not None:
key_column = _get_model_field(model_class, key_field)
record = session.exec(select(model_class).where(key_column == key_value)).first()
if query_type == "get":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(*order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
if record is None:
record = model_class(**data)
else:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
if query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
record = model_class.create(**data)
session.add(record)
session.flush()
session.refresh(record)
return _to_dict(record)
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
if query_type == "delete":
return model_class.delete().where(*query.stmt._where_criteria).execute()
return query.count()
except Exception as e:
logger.error(f"[DatabaseService] 数据库操作出错: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
try:
if key_field and key_value is not None:
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
if record is not None:
for field, value in data.items():
setattr(record, field, value)
record.save()
return _to_dict(record)
new_record = model_class.create(**data)
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
traceback.print_exc()
@@ -99,68 +86,108 @@ async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] =
async def db_get(
model_class,
model_class: type[SQLModel],
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
order_by: Optional[str | list[str]] = None,
single_result: bool = False,
):
try:
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if order_by:
query = query.order_by(order_by)
if limit:
query = query.limit(limit)
results = list(query.dicts())
if single_result:
return results[0] if results else None
return results
with get_db_session(auto_commit=False) as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
statement = _apply_order_by(statement, model_class, order_by)
if limit:
statement = statement.limit(limit)
results = session.exec(statement).all()
data = [_to_dict(item) for item in results]
if single_result:
return data[0] if data else None
return data
except Exception as e:
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
traceback.print_exc()
return None if single_result else []
async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = select(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
records = session.exec(statement).all()
for record in records:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
setattr(record, field_name, value)
session.add(record)
return len(records)
except Exception as e:
logger.error(f"[DatabaseService] 更新数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session() as session:
statement = delete(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement)
return result.rowcount or 0
except Exception as e:
logger.error(f"[DatabaseService] 删除数据库记录出错: {e}")
traceback.print_exc()
return 0
async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
try:
with get_db_session(auto_commit=False) as session:
statement = select(func.count()).select_from(model_class)
conditions = _build_filters(model_class, filters)
if conditions:
statement = statement.where(*conditions)
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
traceback.print_exc()
return 0
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
builtin_prompt: Optional[str] = None,
display_prompt: str = "",
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
):
try:
from src.common.database.database_model import ActionRecords
if chat_stream is None:
raise ValueError("store_action_info 需要 chat_stream")
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"timestamp": datetime.now(),
"session_id": chat_stream.session_id,
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_reasoning": action_reasoning,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
"action_builtin_prompt": builtin_prompt,
"action_display_prompt": display_prompt,
}
if chat_stream:
record_data.update(
{
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
}
)
else:
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
ActionRecord, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseService] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")