炸 service 层 x 2,把能归类为现有重构好的模块的都归类过去
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.chat.message_receive.chat_manager import BotChatSession
|
||||
from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import ActionRecord
|
||||
from src.common.logger import get_logger
|
||||
@@ -23,12 +24,10 @@ def _to_dict(record: Any) -> dict[str, Any]:
|
||||
return record
|
||||
if hasattr(record, "model_dump"):
|
||||
return record.model_dump()
|
||||
if hasattr(record, "__dict__"):
|
||||
return dict(record.__dict__)
|
||||
return {}
|
||||
return dict(record.__dict__) if hasattr(record, "__dict__") else {}
|
||||
|
||||
|
||||
def _get_model_field(model_class: type[SQLModel], field_name: str):
|
||||
def _get_model_field(model_class: type[SQLModel], field_name: str) -> Any:
|
||||
field = getattr(model_class, field_name, None)
|
||||
if field is None:
|
||||
raise ValueError(f"{model_class.__name__} 不存在字段 {field_name}")
|
||||
@@ -41,7 +40,7 @@ def _build_filters(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
return [_get_model_field(model_class, field_name) == value for field_name, value in filters.items()]
|
||||
|
||||
|
||||
def _apply_order_by(statement, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None):
|
||||
def _apply_order_by(statement: Any, model_class: type[SQLModel], order_by: Optional[str | list[str]] = None) -> Any:
|
||||
if not order_by:
|
||||
return statement
|
||||
|
||||
@@ -60,7 +59,7 @@ async def db_save(
|
||||
data: dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None,
|
||||
):
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
record = None
|
||||
@@ -91,12 +90,11 @@ async def db_get(
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str | list[str]] = None,
|
||||
single_result: bool = False,
|
||||
):
|
||||
) -> Optional[dict[str, Any]] | list[dict[str, Any]]:
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
statement = _apply_order_by(statement, model_class, order_by)
|
||||
if limit:
|
||||
@@ -116,8 +114,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = select(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
records = session.exec(statement).all()
|
||||
for record in records:
|
||||
@@ -136,8 +133,7 @@ async def db_delete(model_class: type[SQLModel], filters: Optional[dict[str, Any
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
statement = delete(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement)
|
||||
return result.rowcount or 0
|
||||
@@ -151,8 +147,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
try:
|
||||
with get_db_session(auto_commit=False) as session:
|
||||
statement = select(func.count()).select_from(model_class)
|
||||
conditions = _build_filters(model_class, filters)
|
||||
if conditions:
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement).one()
|
||||
return int(result or 0)
|
||||
@@ -163,18 +158,15 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
chat_stream: BotChatSession,
|
||||
builtin_prompt: Optional[str] = None,
|
||||
display_prompt: str = "",
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: Optional[dict[str, Any]] = None,
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
):
|
||||
) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
if chat_stream is None:
|
||||
raise ValueError("store_action_info 需要 chat_stream")
|
||||
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"timestamp": datetime.now(),
|
||||
|
||||
Reference in New Issue
Block a user