feat: 增强数据库服务,添加类型转换以支持更灵活的查询

This commit is contained in:
DrSmoothl
2026-03-31 08:21:53 +08:00
parent 5ac088ded8
commit ea4cea39f2
2 changed files with 247 additions and 26 deletions

View File

@@ -20,7 +20,7 @@ from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.know_u.knowledge_store import get_knowledge_store
from src.learners.jargon_explainer import search_jargon
from src.llm_models.exceptions import ReqAbortException
@@ -417,6 +417,230 @@ class MaisakaReasoningEngine:
metadata={"anchor_message": anchor_message},
)
@staticmethod
def _normalize_tool_record_value(value: Any) -> Any:
"""将工具记录中的任意值规范化为可序列化结构。
Args:
value: 原始值。
Returns:
Any: 适合写入 JSON 的规范化结果。
"""
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, datetime):
return value.isoformat()
if isinstance(value, dict):
normalized_dict: dict[str, Any] = {}
for key, item in value.items():
normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item)
return normalized_dict
if isinstance(value, (list, tuple, set)):
return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value]
if isinstance(value, bytes):
return f"<bytes:{len(value)}>"
if hasattr(value, "model_dump"):
try:
return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump())
except Exception:
return str(value)
if hasattr(value, "__dict__"):
try:
return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__))
except Exception:
return str(value)
return str(value)
@staticmethod
def _truncate_tool_record_text(text: str, max_length: int = 180) -> str:
"""截断工具记录中的展示文本。
Args:
text: 原始文本。
max_length: 最长保留字符数。
Returns:
str: 截断后的文本。
"""
normalized_text = text.strip()
if len(normalized_text) <= max_length:
return normalized_text
return f"{normalized_text[: max_length - 1]}"
def _build_tool_record_payload(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> dict[str, Any]:
"""构造统一工具落库数据。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
Returns:
dict[str, Any]: 可直接写入数据库的工具记录数据。
"""
payload: dict[str, Any] = {
"call_id": invocation.call_id,
"session_id": invocation.session_id,
"stream_id": invocation.stream_id,
"arguments": self._normalize_tool_record_value(invocation.arguments),
"success": result.success,
"content": result.content,
"error_message": result.error_message,
"history_content": result.get_history_content(),
"structured_content": self._normalize_tool_record_value(result.structured_content),
"metadata": self._normalize_tool_record_value(result.metadata),
}
if tool_spec is not None:
payload["provider_name"] = tool_spec.provider_name
payload["provider_type"] = tool_spec.provider_type
payload["brief_description"] = tool_spec.brief_description
payload["detailed_description"] = tool_spec.detailed_description
payload["title"] = tool_spec.title
return payload
def _build_tool_display_prompt(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> str:
"""构造展示给历史回放与 UI 的工具摘要。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
Returns:
str: 用于展示的工具摘要文本。
"""
custom_display_prompt = result.metadata.get("record_display_prompt")
if isinstance(custom_display_prompt, str) and custom_display_prompt.strip():
return custom_display_prompt.strip()
structured_content = (
result.structured_content
if isinstance(result.structured_content, dict)
else {}
)
history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200)
normalized_args = self._normalize_tool_record_value(invocation.arguments)
if invocation.tool_name == "reply":
target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方"
reply_text = str(structured_content.get("reply_text") or "").strip()
if result.success and reply_text:
return f"你对{target_user_name}进行了回复:{reply_text}"
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120)
return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}"
if invocation.tool_name == "send_emoji":
description = str(structured_content.get("description") or "").strip()
emotion_list = structured_content.get("emotion")
if isinstance(emotion_list, list):
emotion_text = "".join(str(item).strip() for item in emotion_list if str(item).strip())
else:
emotion_text = ""
if result.success and description:
if emotion_text:
return f"你发送了表情包:{description}(情绪:{emotion_text}"
return f"你发送了表情包:{description}"
return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}"
if invocation.tool_name == "wait":
wait_seconds = invocation.arguments.get("seconds", 30)
return f"你让当前对话先等待 {wait_seconds} 秒。"
if invocation.tool_name == "stop":
return "你暂停了当前对话循环,等待新的外部消息。"
if invocation.tool_name == "query_jargon":
words = invocation.arguments.get("words", [])
if isinstance(words, list):
words_text = "".join(str(item).strip() for item in words if str(item).strip())
else:
words_text = ""
if words_text:
return f"你查询了这些黑话或词条:{words_text}"
return "你查询了一次黑话或词条信息。"
if invocation.tool_name == "query_person_info":
person_name = str(invocation.arguments.get("person_name") or "").strip()
if person_name:
return f"你查询了人物信息:{person_name}"
return "你查询了一次人物信息。"
brief_description = ""
if tool_spec is not None:
brief_description = tool_spec.brief_description.strip()
if normalized_args:
arguments_text = self._truncate_tool_record_text(
json.dumps(normalized_args, ensure_ascii=False),
max_length=160,
)
else:
arguments_text = "{}"
if result.success:
if brief_description:
return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}"
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}"
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160)
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}"
async def _store_tool_execution_record(
self,
invocation: ToolInvocation,
result: ToolExecutionResult,
tool_spec: Optional[ToolSpec],
) -> None:
"""将工具执行结果落库到统一工具记录表。
Args:
invocation: 工具调用对象。
result: 工具执行结果。
tool_spec: 对应的工具声明。
"""
if self._runtime.chat_stream is None:
logger.debug(
f"{self._runtime.log_prefix} 当前没有 chat_stream跳过工具记录存储: "
f"工具={invocation.tool_name}"
)
return
builtin_prompt = ""
if tool_spec is not None:
builtin_prompt = tool_spec.build_llm_description()
try:
await database_api.store_tool_info(
chat_stream=self._runtime.chat_stream,
builtin_prompt=builtin_prompt,
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
tool_id=invocation.call_id,
tool_data=self._build_tool_record_payload(invocation, result, tool_spec),
tool_name=invocation.tool_name,
tool_reasoning=invocation.reasoning,
)
except Exception:
logger.exception(
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
)
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
"""将统一工具执行结果写回 Maisaka 历史。
@@ -435,6 +659,7 @@ class MaisakaReasoningEngine:
timestamp=datetime.now(),
tool_call_id=tool_call.call_id,
tool_name=tool_call.func_name,
success=result.success,
)
)
@@ -614,20 +839,29 @@ class MaisakaReasoningEngine:
if self._runtime._tool_registry is None:
for tool_call in tool_calls:
self._append_tool_execution_result(
tool_call,
ToolExecutionResult(
tool_name=tool_call.func_name,
success=False,
error_message="统一工具注册表尚未初始化。",
),
invocation = self._build_tool_invocation(tool_call, latest_thought)
result = ToolExecutionResult(
tool_name=tool_call.func_name,
success=False,
error_message="统一工具注册表尚未初始化。",
)
await self._store_tool_execution_record(invocation, result, None)
self._append_tool_execution_result(tool_call, result)
return False
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
tool_spec_map = {
tool_spec.name: tool_spec
for tool_spec in await self._runtime._tool_registry.list_tools()
}
for tool_call in tool_calls:
invocation = self._build_tool_invocation(tool_call, latest_thought)
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
await self._store_tool_execution_record(
invocation,
result,
tool_spec_map.get(invocation.tool_name),
)
self._append_tool_execution_result(tool_call, result)
if not result.success and tool_call.func_name == "reply":
@@ -1015,19 +1249,6 @@ class MaisakaReasoningEngine:
or target_user_info.user_nickname
or target_user_info.user_id
)
if self._runtime.chat_stream is not None:
await database_api.store_tool_info(
chat_stream=self._runtime.chat_stream,
display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}",
tool_data={
"msg_id": target_message_id,
"quote": quote_reply,
"reply_text": combined_reply_text,
"reply_segments": reply_segments,
},
tool_name="reply",
tool_reasoning=latest_thought,
)
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
reply_timestamp = datetime.now()