feat: 增强数据库服务,添加类型转换以支持更灵活的查询

This commit is contained in:
DrSmoothl
2026-03-31 08:21:53 +08:00
parent 5ac088ded8
commit ea4cea39f2
2 changed files with 247 additions and 26 deletions

View File

@@ -20,7 +20,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.know_u.knowledge_store import get_knowledge_store
from src.learners.jargon_explainer import search_jargon
from src.llm_models.exceptions import ReqAbortException
@@ -417,6 +417,230 @@ class MaisakaReasoningEngine:
metadata={"anchor_message": anchor_message},
)
@staticmethod
def _normalize_tool_record_value(value: Any) -> Any:
"""将工具记录中的任意值规范化为可序列化结构。
Args:
value: 原始值。
Returns:
Any: 适合写入 JSON 的规范化结果。
"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, dict):
normalized_dict: dict[str, Any] = {}
for key, item in value.items():
normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item)
return normalized_dict
if isinstance(value, (list, tuple, set)):
return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value]
if isinstance(value, bytes):
return f"<bytes:{len(value)}>"
if hasattr(value, "model_dump"):
try:
return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump())
except Exception:
return str(value)
if hasattr(value, "__dict__"):
try:
return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__))
except Exception:
return str(value)
return str(value)
@staticmethod
def _truncate_tool_record_text(text: str, max_length: int = 180) -> str:
"""截断工具记录中的展示文本。
Args:
text: 原始文本。
max_length: 最长保留字符数。
Returns:
str: 截断后的文本。
"""
normalized_text = text.strip()
if len(normalized_text) <= max_length:
return normalized_text
return f"{normalized_text[: max_length - 1]}"
def _build_tool_record_payload(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> dict[str, Any]:
"""构造统一工具落库数据。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
Returns:
dict[str, Any]: 可直接写入数据库的工具记录数据。
"""
payload: dict[str, Any] = {
"call_id": invocation.call_id,
"session_id": invocation.session_id,
"stream_id": invocation.stream_id,
"arguments": self._normalize_tool_record_value(invocation.arguments),
"success": result.success,
"content": result.content,
"error_message": result.error_message,
"history_content": result.get_history_content(),
"structured_content": self._normalize_tool_record_value(result.structured_content),
"metadata": self._normalize_tool_record_value(result.metadata),
}
if tool_spec is not None:
payload["provider_name"] = tool_spec.provider_name
payload["provider_type"] = tool_spec.provider_type
payload["brief_description"] = tool_spec.brief_description
payload["detailed_description"] = tool_spec.detailed_description
payload["title"] = tool_spec.title
return payload
def _build_tool_display_prompt(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> str:
"""构造展示给历史回放与 UI 的工具摘要。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
Returns:
str: 用于展示的工具摘要文本。
"""
custom_display_prompt = result.metadata.get("record_display_prompt")
if isinstance(custom_display_prompt, str) and custom_display_prompt.strip():
return custom_display_prompt.strip()
structured_content = (
result.structured_content
if isinstance(result.structured_content, dict)
else {}
)
history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200)
normalized_args = self._normalize_tool_record_value(invocation.arguments)
if invocation.tool_name == "reply":
target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方"
reply_text = str(structured_content.get("reply_text") or "").strip()
if result.success and reply_text:
return f"你对{target_user_name}进行了回复:{reply_text}"
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120)
return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}"
if invocation.tool_name == "send_emoji":
description = str(structured_content.get("description") or "").strip()
emotion_list = structured_content.get("emotion")
if isinstance(emotion_list, list):
emotion_text = "".join(str(item).strip() for item in emotion_list if str(item).strip())
else:
emotion_text = ""
if result.success and description:
if emotion_text:
return f"你发送了表情包:{description}(情绪:{emotion_text}"
return f"你发送了表情包:{description}"
return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}"
if invocation.tool_name == "wait":
wait_seconds = invocation.arguments.get("seconds", 30)
return f"你让当前对话先等待 {wait_seconds} 秒。"
if invocation.tool_name == "stop":
return "你暂停了当前对话循环,等待新的外部消息。"
if invocation.tool_name == "query_jargon":
words = invocation.arguments.get("words", [])
if isinstance(words, list):
words_text = "".join(str(item).strip() for item in words if str(item).strip())
else:
words_text = ""
if words_text:
return f"你查询了这些黑话或词条:{words_text}"
return "你查询了一次黑话或词条信息。"
if invocation.tool_name == "query_person_info":
person_name = str(invocation.arguments.get("person_name") or "").strip()
if person_name:
return f"你查询了人物信息:{person_name}"
return "你查询了一次人物信息。"
brief_description = ""
if tool_spec is not None:
brief_description = tool_spec.brief_description.strip()
if normalized_args:
arguments_text = self._truncate_tool_record_text(
json.dumps(normalized_args, ensure_ascii=False),
max_length=160,
)
else:
arguments_text = "{}"
if result.success:
if brief_description:
return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}"
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}"
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160)
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}"
async def _store_tool_execution_record(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> None:
"""将工具执行结果落库到统一工具记录表。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
"""
if self._runtime.chat_stream is None:
logger.debug(
f"{self._runtime.log_prefix} 当前没有 chat_stream跳过工具记录存储: "
f"工具={invocation.tool_name}"
)
return
builtin_prompt = ""
if tool_spec is not None:
builtin_prompt = tool_spec.build_llm_description()
try:
await database_api.store_tool_info(
chat_stream=self._runtime.chat_stream,
builtin_prompt=builtin_prompt,
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
tool_id=invocation.call_id,
tool_data=self._build_tool_record_payload(invocation, result, tool_spec),
tool_name=invocation.tool_name,
tool_reasoning=invocation.reasoning,
)
except Exception:
logger.exception(
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
)
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
"""将统一工具执行结果写回 Maisaka 历史。
@@ -435,6 +659,7 @@ class MaisakaReasoningEngine:
timestamp=datetime.now(),
tool_call_id=tool_call.call_id,
tool_name=tool_call.func_name,
success=result.success,
)
)
@@ -614,20 +839,29 @@ class MaisakaReasoningEngine:
if self._runtime._tool_registry is None:
for tool_call in tool_calls:
self._append_tool_execution_result(
tool_call,
ToolExecutionResult(
invocation = self._build_tool_invocation(tool_call, latest_thought)
result = ToolExecutionResult(
tool_name=tool_call.func_name,
success=False,
error_message="统一工具注册表尚未初始化。",
),
)
await self._store_tool_execution_record(invocation, result, None)
self._append_tool_execution_result(tool_call, result)
return False
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
tool_spec_map = {
tool_spec.name: tool_spec
for tool_spec in await self._runtime._tool_registry.list_tools()
}
for tool_call in tool_calls:
invocation = self._build_tool_invocation(tool_call, latest_thought)
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
await self._store_tool_execution_record(
invocation,
result,
tool_spec_map.get(invocation.tool_name),
)
self._append_tool_execution_result(tool_call, result)
if not result.success and tool_call.func_name == "reply":
@@ -1015,19 +1249,6 @@ class MaisakaReasoningEngine:
or target_user_info.user_nickname
or target_user_info.user_id
)
if self._runtime.chat_stream is not None:
await database_api.store_tool_info(
chat_stream=self._runtime.chat_stream,
display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}",
tool_data={
"msg_id": target_message_id,
"quote": quote_reply,
"reply_text": combined_reply_text,
"reply_segments": reply_segments,
},
tool_name="reply",
tool_reasoning=latest_thought,
)
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
reply_timestamp = datetime.now()

View File

@@ -4,7 +4,7 @@ import json
import time
import traceback
from datetime import datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from sqlalchemy import delete, func, select
from sqlmodel import SQLModel
@@ -65,7 +65,7 @@ async def db_save(
record = None
if key_field and key_value is not None:
key_column = _get_model_field(model_class, key_field)
record = session.exec(select(model_class).where(key_column == key_value)).first()
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
if record is None:
record = model_class(**data)
@@ -99,7 +99,7 @@ async def db_get(
statement = _apply_order_by(statement, model_class, order_by)
if limit:
statement = statement.limit(limit)
results = session.exec(statement).all()
results = session.exec(cast(Any, statement)).all()
data = [_to_dict(item) for item in results]
if single_result:
return data[0] if data else None
@@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
statement = select(model_class)
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
records = session.exec(statement).all()
records = session.exec(cast(Any, statement)).all()
for record in records:
for field_name, value in data.items():
_get_model_field(model_class, field_name)
@@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
statement = select(func.count()).select_from(model_class)
if conditions := _build_filters(model_class, filters):
statement = statement.where(*conditions)
result = session.exec(statement).one()
result = session.exec(cast(Any, statement)).one()
return int(result or 0)
except Exception as e:
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")