214 lines
7.7 KiB
Python
214 lines
7.7 KiB
Python
"""数据库服务模块。"""
|
|
|
|
import json
|
|
import time
|
|
import traceback
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
from sqlalchemy import delete, func
|
|
from sqlmodel import SQLModel, select
|
|
|
|
from src.common.database.database import get_db_session
|
|
from src.common.database.database_model import ToolRecord
|
|
from src.common.logger import get_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from src.chat.message_receive.chat_manager import BotChatSession
|
|
|
|
logger = get_logger("database_service")
|
|
|
|
|
|
def _to_dict(record: Any) -> dict[str, Any]:
|
|
if record is None:
|
|
return {}
|
|
if isinstance(record, dict):
|
|
return record
|
|
if hasattr(record, "model_dump"):
|
|
return record.model_dump()
|
|
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
|
|
|
|
|
|
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
|
|
field = getattr(model_class, field_name, None)
|
|
if field is None:
|
|
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
|
|
return field
|
|
|
|
|
|
def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> list[Any]:
|
|
if not filters:
|
|
return []
|
|
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
|
|
|
|
|
|
def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any:
|
|
if not order_by:
|
|
return statement
|
|
|
|
order_fields = [order_by] if isinstance(order_by, str) else order_by
|
|
clauses = []
|
|
for item in order_fields:
|
|
descending = item.startswith("-")
|
|
field_name = item[1:] if descending else item
|
|
field = _get_model_field(model_class, field_name)
|
|
clauses.append(field.desc() if descending else field.asc())
|
|
return statement.order_by(*clauses)
|
|
|
|
|
|
async def db_save(
|
|
model_class: type[SQLModel],
|
|
data: dict[str, Any],
|
|
key_field: Optional[str] = None,
|
|
key_value: Optional[Any] = None,
|
|
) -> Optional[dict[str, Any]]:
|
|
try:
|
|
with get_db_session() as session:
|
|
record = None
|
|
if key_field and key_value is not None:
|
|
key_column = _get_model_field(model_class, key_field)
|
|
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
|
|
|
|
if record is None:
|
|
record = model_class(**data)
|
|
else:
|
|
for field_name, value in data.items():
|
|
_get_model_field(model_class, field_name)
|
|
setattr(record, field_name, value)
|
|
|
|
session.add(record)
|
|
session.flush()
|
|
session.refresh(record)
|
|
return _to_dict(record)
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 保存数据库记录出错: {e}")
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
|
|
async def db_get(
|
|
model_class: type[SQLModel],
|
|
filters: Optional[dict[str, Any]] = None,
|
|
limit: Optional[int] = None,
|
|
order_by: Optional[str | list[str]] = None,
|
|
single_result: bool = False,
|
|
) -> Optional[dict[str, Any]] | list[dict[str, Any]]:
|
|
try:
|
|
with get_db_session(auto_commit=False) as session:
|
|
statement = select(model_class)
|
|
if conditions := _build_filters(model_class, filters):
|
|
statement = statement.where(*conditions)
|
|
statement = _apply_order_by(statement, model_class, order_by)
|
|
if limit:
|
|
statement = statement.limit(limit)
|
|
results = session.exec(cast(Any, statement)).all()
|
|
data = [_to_dict(item) for item in results]
|
|
if single_result:
|
|
return data[0] if data else None
|
|
return data
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 获取数据库记录出错: {e}")
|
|
traceback.print_exc()
|
|
return None if single_result else []
|
|
|
|
|
|
async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters: Optional[dict[str, Any]] = None) -> int:
|
|
try:
|
|
with get_db_session() as session:
|
|
statement = select(model_class)
|
|
if conditions := _build_filters(model_class, filters):
|
|
statement = statement.where(*conditions)
|
|
records = session.exec(cast(Any, statement)).all()
|
|
for record in records:
|
|
for field_name, value in data.items():
|
|
_get_model_field(model_class, field_name)
|
|
setattr(record, field_name, value)
|
|
session.add(record)
|
|
return len(records)
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 更新数据库记录出错: {e}")
|
|
traceback.print_exc()
|
|
return 0
|
|
|
|
|
|
async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
|
|
try:
|
|
with get_db_session() as session:
|
|
statement = delete(model_class)
|
|
if conditions := _build_filters(model_class, filters):
|
|
statement = statement.where(*conditions)
|
|
result = session.exec(statement)
|
|
return result.rowcount or 0
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 删除数据库记录出错: {e}")
|
|
traceback.print_exc()
|
|
return 0
|
|
|
|
|
|
async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]] = None) -> int:
|
|
try:
|
|
with get_db_session(auto_commit=False) as session:
|
|
statement = select(func.count()).select_from(model_class)
|
|
if conditions := _build_filters(model_class, filters):
|
|
statement = statement.where(*conditions)
|
|
result = session.exec(cast(Any, statement)).one()
|
|
return int(result or 0)
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
|
|
traceback.print_exc()
|
|
return 0
|
|
|
|
|
|
async def store_tool_info(
|
|
chat_stream: "BotChatSession",
|
|
builtin_prompt: Optional[str] = None,
|
|
display_prompt: str = "",
|
|
tool_id: str = "",
|
|
tool_data: Optional[dict[str, Any]] = None,
|
|
tool_name: str = "",
|
|
tool_reasoning: str = "",
|
|
) -> Optional[dict[str, Any]]:
|
|
try:
|
|
record_data = {
|
|
"tool_id": tool_id or str(int(time.time() * 1000000)),
|
|
"timestamp": datetime.now(),
|
|
"session_id": chat_stream.session_id,
|
|
"tool_name": tool_name,
|
|
"tool_data": json.dumps(tool_data or {}, ensure_ascii=False),
|
|
"tool_reasoning": tool_reasoning,
|
|
"tool_builtin_prompt": builtin_prompt,
|
|
"tool_display_prompt": display_prompt,
|
|
}
|
|
|
|
saved_record = await db_save(ToolRecord, data=record_data, key_field="tool_id", key_value=record_data["tool_id"])
|
|
if saved_record:
|
|
logger.debug(f"[DatabaseService] 成功存储工具信息: {tool_name} (ID: {record_data['tool_id']})")
|
|
else:
|
|
logger.error(f"[DatabaseService] 存储工具信息失败: {tool_name}")
|
|
return saved_record
|
|
except Exception as e:
|
|
logger.error(f"[DatabaseService] 存储工具信息时发生错误: {e}")
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
|
|
async def store_action_info(
|
|
chat_stream: "BotChatSession",
|
|
builtin_prompt: Optional[str] = None,
|
|
display_prompt: str = "",
|
|
thinking_id: str = "",
|
|
action_data: Optional[dict[str, Any]] = None,
|
|
action_name: str = "",
|
|
action_reasoning: str = "",
|
|
) -> Optional[dict[str, Any]]:
|
|
"""兼容旧接口,内部转发到 ``store_tool_info``。"""
|
|
return await store_tool_info(
|
|
chat_stream=chat_stream,
|
|
builtin_prompt=builtin_prompt,
|
|
display_prompt=display_prompt,
|
|
tool_id=thinking_id,
|
|
tool_data=action_data,
|
|
tool_name=action_name,
|
|
tool_reasoning=action_reasoning,
|
|
)
|