feat: 增强数据库服务,添加类型转换以支持更灵活的查询
This commit is contained in:
@@ -20,7 +20,7 @@ from src.common.database.database import get_db_session
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation
|
||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||
from src.know_u.knowledge_store import get_knowledge_store
|
||||
from src.learners.jargon_explainer import search_jargon
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
@@ -417,6 +417,230 @@ class MaisakaReasoningEngine:
|
||||
metadata={"anchor_message": anchor_message},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_record_value(value: Any) -> Any:
|
||||
"""将工具记录中的任意值规范化为可序列化结构。
|
||||
|
||||
Args:
|
||||
value: 原始值。
|
||||
|
||||
Returns:
|
||||
Any: 适合写入 JSON 的规范化结果。
|
||||
"""
|
||||
|
||||
if value is None or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
if isinstance(value, dict):
|
||||
normalized_dict: dict[str, Any] = {}
|
||||
for key, item in value.items():
|
||||
normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item)
|
||||
return normalized_dict
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value]
|
||||
if isinstance(value, bytes):
|
||||
return f"<bytes:{len(value)}>"
|
||||
if hasattr(value, "model_dump"):
|
||||
try:
|
||||
return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump())
|
||||
except Exception:
|
||||
return str(value)
|
||||
if hasattr(value, "__dict__"):
|
||||
try:
|
||||
return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__))
|
||||
except Exception:
|
||||
return str(value)
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _truncate_tool_record_text(text: str, max_length: int = 180) -> str:
|
||||
"""截断工具记录中的展示文本。
|
||||
|
||||
Args:
|
||||
text: 原始文本。
|
||||
max_length: 最长保留字符数。
|
||||
|
||||
Returns:
|
||||
str: 截断后的文本。
|
||||
"""
|
||||
|
||||
normalized_text = text.strip()
|
||||
if len(normalized_text) <= max_length:
|
||||
return normalized_text
|
||||
return f"{normalized_text[: max_length - 1]}…"
|
||||
|
||||
def _build_tool_record_payload(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
result: ToolExecutionResult,
|
||||
tool_spec: Optional[ToolSpec],
|
||||
) -> dict[str, Any]:
|
||||
"""构造统一工具落库数据。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用对象。
|
||||
result: 工具执行结果。
|
||||
tool_spec: 对应的工具声明。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 可直接写入数据库的工具记录数据。
|
||||
"""
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"call_id": invocation.call_id,
|
||||
"session_id": invocation.session_id,
|
||||
"stream_id": invocation.stream_id,
|
||||
"arguments": self._normalize_tool_record_value(invocation.arguments),
|
||||
"success": result.success,
|
||||
"content": result.content,
|
||||
"error_message": result.error_message,
|
||||
"history_content": result.get_history_content(),
|
||||
"structured_content": self._normalize_tool_record_value(result.structured_content),
|
||||
"metadata": self._normalize_tool_record_value(result.metadata),
|
||||
}
|
||||
if tool_spec is not None:
|
||||
payload["provider_name"] = tool_spec.provider_name
|
||||
payload["provider_type"] = tool_spec.provider_type
|
||||
payload["brief_description"] = tool_spec.brief_description
|
||||
payload["detailed_description"] = tool_spec.detailed_description
|
||||
payload["title"] = tool_spec.title
|
||||
return payload
|
||||
|
||||
def _build_tool_display_prompt(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
result: ToolExecutionResult,
|
||||
tool_spec: Optional[ToolSpec],
|
||||
) -> str:
|
||||
"""构造展示给历史回放与 UI 的工具摘要。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用对象。
|
||||
result: 工具执行结果。
|
||||
tool_spec: 对应的工具声明。
|
||||
|
||||
Returns:
|
||||
str: 用于展示的工具摘要文本。
|
||||
"""
|
||||
|
||||
custom_display_prompt = result.metadata.get("record_display_prompt")
|
||||
if isinstance(custom_display_prompt, str) and custom_display_prompt.strip():
|
||||
return custom_display_prompt.strip()
|
||||
|
||||
structured_content = (
|
||||
result.structured_content
|
||||
if isinstance(result.structured_content, dict)
|
||||
else {}
|
||||
)
|
||||
history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200)
|
||||
normalized_args = self._normalize_tool_record_value(invocation.arguments)
|
||||
|
||||
if invocation.tool_name == "reply":
|
||||
target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方"
|
||||
reply_text = str(structured_content.get("reply_text") or "").strip()
|
||||
if result.success and reply_text:
|
||||
return f"你对{target_user_name}进行了回复:{reply_text}"
|
||||
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120)
|
||||
return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}"
|
||||
|
||||
if invocation.tool_name == "send_emoji":
|
||||
description = str(structured_content.get("description") or "").strip()
|
||||
emotion_list = structured_content.get("emotion")
|
||||
if isinstance(emotion_list, list):
|
||||
emotion_text = "、".join(str(item).strip() for item in emotion_list if str(item).strip())
|
||||
else:
|
||||
emotion_text = ""
|
||||
if result.success and description:
|
||||
if emotion_text:
|
||||
return f"你发送了表情包:{description}(情绪:{emotion_text})"
|
||||
return f"你发送了表情包:{description}"
|
||||
return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}"
|
||||
|
||||
if invocation.tool_name == "wait":
|
||||
wait_seconds = invocation.arguments.get("seconds", 30)
|
||||
return f"你让当前对话先等待 {wait_seconds} 秒。"
|
||||
|
||||
if invocation.tool_name == "stop":
|
||||
return "你暂停了当前对话循环,等待新的外部消息。"
|
||||
|
||||
if invocation.tool_name == "query_jargon":
|
||||
words = invocation.arguments.get("words", [])
|
||||
if isinstance(words, list):
|
||||
words_text = "、".join(str(item).strip() for item in words if str(item).strip())
|
||||
else:
|
||||
words_text = ""
|
||||
if words_text:
|
||||
return f"你查询了这些黑话或词条:{words_text}"
|
||||
return "你查询了一次黑话或词条信息。"
|
||||
|
||||
if invocation.tool_name == "query_person_info":
|
||||
person_name = str(invocation.arguments.get("person_name") or "").strip()
|
||||
if person_name:
|
||||
return f"你查询了人物信息:{person_name}"
|
||||
return "你查询了一次人物信息。"
|
||||
|
||||
brief_description = ""
|
||||
if tool_spec is not None:
|
||||
brief_description = tool_spec.brief_description.strip()
|
||||
|
||||
if normalized_args:
|
||||
arguments_text = self._truncate_tool_record_text(
|
||||
json.dumps(normalized_args, ensure_ascii=False),
|
||||
max_length=160,
|
||||
)
|
||||
else:
|
||||
arguments_text = "{}"
|
||||
|
||||
if result.success:
|
||||
if brief_description:
|
||||
return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}"
|
||||
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}"
|
||||
|
||||
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160)
|
||||
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}"
|
||||
|
||||
async def _store_tool_execution_record(
|
||||
self,
|
||||
invocation: ToolInvocation,
|
||||
result: ToolExecutionResult,
|
||||
tool_spec: Optional[ToolSpec],
|
||||
) -> None:
|
||||
"""将工具执行结果落库到统一工具记录表。
|
||||
|
||||
Args:
|
||||
invocation: 工具调用对象。
|
||||
result: 工具执行结果。
|
||||
tool_spec: 对应的工具声明。
|
||||
"""
|
||||
|
||||
if self._runtime.chat_stream is None:
|
||||
logger.debug(
|
||||
f"{self._runtime.log_prefix} 当前没有 chat_stream,跳过工具记录存储: "
|
||||
f"工具={invocation.tool_name}"
|
||||
)
|
||||
return
|
||||
|
||||
builtin_prompt = ""
|
||||
if tool_spec is not None:
|
||||
builtin_prompt = tool_spec.build_llm_description()
|
||||
|
||||
try:
|
||||
await database_api.store_tool_info(
|
||||
chat_stream=self._runtime.chat_stream,
|
||||
builtin_prompt=builtin_prompt,
|
||||
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
|
||||
tool_id=invocation.call_id,
|
||||
tool_data=self._build_tool_record_payload(invocation, result, tool_spec),
|
||||
tool_name=invocation.tool_name,
|
||||
tool_reasoning=invocation.reasoning,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
|
||||
)
|
||||
|
||||
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
||||
"""将统一工具执行结果写回 Maisaka 历史。
|
||||
|
||||
@@ -435,6 +659,7 @@ class MaisakaReasoningEngine:
|
||||
timestamp=datetime.now(),
|
||||
tool_call_id=tool_call.call_id,
|
||||
tool_name=tool_call.func_name,
|
||||
success=result.success,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -614,20 +839,29 @@ class MaisakaReasoningEngine:
|
||||
|
||||
if self._runtime._tool_registry is None:
|
||||
for tool_call in tool_calls:
|
||||
self._append_tool_execution_result(
|
||||
tool_call,
|
||||
ToolExecutionResult(
|
||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||
result = ToolExecutionResult(
|
||||
tool_name=tool_call.func_name,
|
||||
success=False,
|
||||
error_message="统一工具注册表尚未初始化。",
|
||||
),
|
||||
)
|
||||
await self._store_tool_execution_record(invocation, result, None)
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
return False
|
||||
|
||||
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
|
||||
tool_spec_map = {
|
||||
tool_spec.name: tool_spec
|
||||
for tool_spec in await self._runtime._tool_registry.list_tools()
|
||||
}
|
||||
for tool_call in tool_calls:
|
||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||
await self._store_tool_execution_record(
|
||||
invocation,
|
||||
result,
|
||||
tool_spec_map.get(invocation.tool_name),
|
||||
)
|
||||
self._append_tool_execution_result(tool_call, result)
|
||||
|
||||
if not result.success and tool_call.func_name == "reply":
|
||||
@@ -1015,19 +1249,6 @@ class MaisakaReasoningEngine:
|
||||
or target_user_info.user_nickname
|
||||
or target_user_info.user_id
|
||||
)
|
||||
if self._runtime.chat_stream is not None:
|
||||
await database_api.store_tool_info(
|
||||
chat_stream=self._runtime.chat_stream,
|
||||
display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}",
|
||||
tool_data={
|
||||
"msg_id": target_message_id,
|
||||
"quote": quote_reply,
|
||||
"reply_text": combined_reply_text,
|
||||
"reply_segments": reply_segments,
|
||||
},
|
||||
tool_name="reply",
|
||||
tool_reasoning=latest_thought,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
|
||||
reply_timestamp = datetime.now()
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlmodel import SQLModel
|
||||
@@ -65,7 +65,7 @@ async def db_save(
|
||||
record = None
|
||||
if key_field and key_value is not None:
|
||||
key_column = _get_model_field(model_class, key_field)
|
||||
record = session.exec(select(model_class).where(key_column == key_value)).first()
|
||||
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
|
||||
|
||||
if record is None:
|
||||
record = model_class(**data)
|
||||
@@ -99,7 +99,7 @@ async def db_get(
|
||||
statement = _apply_order_by(statement, model_class, order_by)
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
results = session.exec(statement).all()
|
||||
results = session.exec(cast(Any, statement)).all()
|
||||
data = [_to_dict(item) for item in results]
|
||||
if single_result:
|
||||
return data[0] if data else None
|
||||
@@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
||||
statement = select(model_class)
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
records = session.exec(statement).all()
|
||||
records = session.exec(cast(Any, statement)).all()
|
||||
for record in records:
|
||||
for field_name, value in data.items():
|
||||
_get_model_field(model_class, field_name)
|
||||
@@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
||||
statement = select(func.count()).select_from(model_class)
|
||||
if conditions := _build_filters(model_class, filters):
|
||||
statement = statement.where(*conditions)
|
||||
result = session.exec(statement).one()
|
||||
result = session.exec(cast(Any, statement)).one()
|
||||
return int(result or 0)
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
|
||||
|
||||
Reference in New Issue
Block a user