feat: 增强数据库服务,添加类型转换以支持更灵活的查询
This commit is contained in:
@@ -20,7 +20,7 @@ from src.common.database.database import get_db_session
|
|||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation
|
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
from src.know_u.knowledge_store import get_knowledge_store
|
from src.know_u.knowledge_store import get_knowledge_store
|
||||||
from src.learners.jargon_explainer import search_jargon
|
from src.learners.jargon_explainer import search_jargon
|
||||||
from src.llm_models.exceptions import ReqAbortException
|
from src.llm_models.exceptions import ReqAbortException
|
||||||
@@ -417,6 +417,230 @@ class MaisakaReasoningEngine:
|
|||||||
metadata={"anchor_message": anchor_message},
|
metadata={"anchor_message": anchor_message},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_tool_record_value(value: Any) -> Any:
|
||||||
|
"""将工具记录中的任意值规范化为可序列化结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 原始值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 适合写入 JSON 的规范化结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if value is None or isinstance(value, (str, int, float, bool)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value.isoformat()
|
||||||
|
if isinstance(value, dict):
|
||||||
|
normalized_dict: dict[str, Any] = {}
|
||||||
|
for key, item in value.items():
|
||||||
|
normalized_dict[str(key)] = MaisakaReasoningEngine._normalize_tool_record_value(item)
|
||||||
|
return normalized_dict
|
||||||
|
if isinstance(value, (list, tuple, set)):
|
||||||
|
return [MaisakaReasoningEngine._normalize_tool_record_value(item) for item in value]
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return f"<bytes:{len(value)}>"
|
||||||
|
if hasattr(value, "model_dump"):
|
||||||
|
try:
|
||||||
|
return MaisakaReasoningEngine._normalize_tool_record_value(value.model_dump())
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
if hasattr(value, "__dict__"):
|
||||||
|
try:
|
||||||
|
return MaisakaReasoningEngine._normalize_tool_record_value(dict(value.__dict__))
|
||||||
|
except Exception:
|
||||||
|
return str(value)
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _truncate_tool_record_text(text: str, max_length: int = 180) -> str:
|
||||||
|
"""截断工具记录中的展示文本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本。
|
||||||
|
max_length: 最长保留字符数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 截断后的文本。
|
||||||
|
"""
|
||||||
|
|
||||||
|
normalized_text = text.strip()
|
||||||
|
if len(normalized_text) <= max_length:
|
||||||
|
return normalized_text
|
||||||
|
return f"{normalized_text[: max_length - 1]}…"
|
||||||
|
|
||||||
|
def _build_tool_record_payload(
|
||||||
|
self,
|
||||||
|
invocation: ToolInvocation,
|
||||||
|
result: ToolExecutionResult,
|
||||||
|
tool_spec: Optional[ToolSpec],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""构造统一工具落库数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: 工具调用对象。
|
||||||
|
result: 工具执行结果。
|
||||||
|
tool_spec: 对应的工具声明。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: 可直接写入数据库的工具记录数据。
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"call_id": invocation.call_id,
|
||||||
|
"session_id": invocation.session_id,
|
||||||
|
"stream_id": invocation.stream_id,
|
||||||
|
"arguments": self._normalize_tool_record_value(invocation.arguments),
|
||||||
|
"success": result.success,
|
||||||
|
"content": result.content,
|
||||||
|
"error_message": result.error_message,
|
||||||
|
"history_content": result.get_history_content(),
|
||||||
|
"structured_content": self._normalize_tool_record_value(result.structured_content),
|
||||||
|
"metadata": self._normalize_tool_record_value(result.metadata),
|
||||||
|
}
|
||||||
|
if tool_spec is not None:
|
||||||
|
payload["provider_name"] = tool_spec.provider_name
|
||||||
|
payload["provider_type"] = tool_spec.provider_type
|
||||||
|
payload["brief_description"] = tool_spec.brief_description
|
||||||
|
payload["detailed_description"] = tool_spec.detailed_description
|
||||||
|
payload["title"] = tool_spec.title
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _build_tool_display_prompt(
|
||||||
|
self,
|
||||||
|
invocation: ToolInvocation,
|
||||||
|
result: ToolExecutionResult,
|
||||||
|
tool_spec: Optional[ToolSpec],
|
||||||
|
) -> str:
|
||||||
|
"""构造展示给历史回放与 UI 的工具摘要。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: 工具调用对象。
|
||||||
|
result: 工具执行结果。
|
||||||
|
tool_spec: 对应的工具声明。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 用于展示的工具摘要文本。
|
||||||
|
"""
|
||||||
|
|
||||||
|
custom_display_prompt = result.metadata.get("record_display_prompt")
|
||||||
|
if isinstance(custom_display_prompt, str) and custom_display_prompt.strip():
|
||||||
|
return custom_display_prompt.strip()
|
||||||
|
|
||||||
|
structured_content = (
|
||||||
|
result.structured_content
|
||||||
|
if isinstance(result.structured_content, dict)
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
history_content = self._truncate_tool_record_text(result.get_history_content(), max_length=200)
|
||||||
|
normalized_args = self._normalize_tool_record_value(invocation.arguments)
|
||||||
|
|
||||||
|
if invocation.tool_name == "reply":
|
||||||
|
target_user_name = str(structured_content.get("target_user_name") or "对方").strip() or "对方"
|
||||||
|
reply_text = str(structured_content.get("reply_text") or "").strip()
|
||||||
|
if result.success and reply_text:
|
||||||
|
return f"你对{target_user_name}进行了回复:{reply_text}"
|
||||||
|
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||||
|
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=120)
|
||||||
|
return f"你尝试回复消息 {target_message_id or 'unknown'},但失败了:{error_text}"
|
||||||
|
|
||||||
|
if invocation.tool_name == "send_emoji":
|
||||||
|
description = str(structured_content.get("description") or "").strip()
|
||||||
|
emotion_list = structured_content.get("emotion")
|
||||||
|
if isinstance(emotion_list, list):
|
||||||
|
emotion_text = "、".join(str(item).strip() for item in emotion_list if str(item).strip())
|
||||||
|
else:
|
||||||
|
emotion_text = ""
|
||||||
|
if result.success and description:
|
||||||
|
if emotion_text:
|
||||||
|
return f"你发送了表情包:{description}(情绪:{emotion_text})"
|
||||||
|
return f"你发送了表情包:{description}"
|
||||||
|
return f"你尝试发送表情包,但失败了:{self._truncate_tool_record_text(result.error_message or history_content, 120)}"
|
||||||
|
|
||||||
|
if invocation.tool_name == "wait":
|
||||||
|
wait_seconds = invocation.arguments.get("seconds", 30)
|
||||||
|
return f"你让当前对话先等待 {wait_seconds} 秒。"
|
||||||
|
|
||||||
|
if invocation.tool_name == "stop":
|
||||||
|
return "你暂停了当前对话循环,等待新的外部消息。"
|
||||||
|
|
||||||
|
if invocation.tool_name == "query_jargon":
|
||||||
|
words = invocation.arguments.get("words", [])
|
||||||
|
if isinstance(words, list):
|
||||||
|
words_text = "、".join(str(item).strip() for item in words if str(item).strip())
|
||||||
|
else:
|
||||||
|
words_text = ""
|
||||||
|
if words_text:
|
||||||
|
return f"你查询了这些黑话或词条:{words_text}"
|
||||||
|
return "你查询了一次黑话或词条信息。"
|
||||||
|
|
||||||
|
if invocation.tool_name == "query_person_info":
|
||||||
|
person_name = str(invocation.arguments.get("person_name") or "").strip()
|
||||||
|
if person_name:
|
||||||
|
return f"你查询了人物信息:{person_name}"
|
||||||
|
return "你查询了一次人物信息。"
|
||||||
|
|
||||||
|
brief_description = ""
|
||||||
|
if tool_spec is not None:
|
||||||
|
brief_description = tool_spec.brief_description.strip()
|
||||||
|
|
||||||
|
if normalized_args:
|
||||||
|
arguments_text = self._truncate_tool_record_text(
|
||||||
|
json.dumps(normalized_args, ensure_ascii=False),
|
||||||
|
max_length=160,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
arguments_text = "{}"
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
if brief_description:
|
||||||
|
return f"{brief_description} 参数={arguments_text};结果:{history_content or '执行成功'}"
|
||||||
|
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};结果:{history_content or '执行成功'}"
|
||||||
|
|
||||||
|
error_text = self._truncate_tool_record_text(result.error_message or history_content, max_length=160)
|
||||||
|
return f"你调用了工具 {invocation.tool_name},参数={arguments_text};执行失败:{error_text}"
|
||||||
|
|
||||||
|
async def _store_tool_execution_record(
|
||||||
|
self,
|
||||||
|
invocation: ToolInvocation,
|
||||||
|
result: ToolExecutionResult,
|
||||||
|
tool_spec: Optional[ToolSpec],
|
||||||
|
) -> None:
|
||||||
|
"""将工具执行结果落库到统一工具记录表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: 工具调用对象。
|
||||||
|
result: 工具执行结果。
|
||||||
|
tool_spec: 对应的工具声明。
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._runtime.chat_stream is None:
|
||||||
|
logger.debug(
|
||||||
|
f"{self._runtime.log_prefix} 当前没有 chat_stream,跳过工具记录存储: "
|
||||||
|
f"工具={invocation.tool_name}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
builtin_prompt = ""
|
||||||
|
if tool_spec is not None:
|
||||||
|
builtin_prompt = tool_spec.build_llm_description()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await database_api.store_tool_info(
|
||||||
|
chat_stream=self._runtime.chat_stream,
|
||||||
|
builtin_prompt=builtin_prompt,
|
||||||
|
display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec),
|
||||||
|
tool_id=invocation.call_id,
|
||||||
|
tool_data=self._build_tool_record_payload(invocation, result, tool_spec),
|
||||||
|
tool_name=invocation.tool_name,
|
||||||
|
tool_reasoning=invocation.reasoning,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
|
||||||
|
)
|
||||||
|
|
||||||
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
||||||
"""将统一工具执行结果写回 Maisaka 历史。
|
"""将统一工具执行结果写回 Maisaka 历史。
|
||||||
|
|
||||||
@@ -435,6 +659,7 @@ class MaisakaReasoningEngine:
|
|||||||
timestamp=datetime.now(),
|
timestamp=datetime.now(),
|
||||||
tool_call_id=tool_call.call_id,
|
tool_call_id=tool_call.call_id,
|
||||||
tool_name=tool_call.func_name,
|
tool_name=tool_call.func_name,
|
||||||
|
success=result.success,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -614,20 +839,29 @@ class MaisakaReasoningEngine:
|
|||||||
|
|
||||||
if self._runtime._tool_registry is None:
|
if self._runtime._tool_registry is None:
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
self._append_tool_execution_result(
|
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||||
tool_call,
|
result = ToolExecutionResult(
|
||||||
ToolExecutionResult(
|
tool_name=tool_call.func_name,
|
||||||
tool_name=tool_call.func_name,
|
success=False,
|
||||||
success=False,
|
error_message="统一工具注册表尚未初始化。",
|
||||||
error_message="统一工具注册表尚未初始化。",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
await self._store_tool_execution_record(invocation, result, None)
|
||||||
|
self._append_tool_execution_result(tool_call, result)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
|
execution_context = self._build_tool_execution_context(latest_thought, anchor_message)
|
||||||
|
tool_spec_map = {
|
||||||
|
tool_spec.name: tool_spec
|
||||||
|
for tool_spec in await self._runtime._tool_registry.list_tools()
|
||||||
|
}
|
||||||
for tool_call in tool_calls:
|
for tool_call in tool_calls:
|
||||||
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
invocation = self._build_tool_invocation(tool_call, latest_thought)
|
||||||
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
result = await self._runtime._tool_registry.invoke(invocation, execution_context)
|
||||||
|
await self._store_tool_execution_record(
|
||||||
|
invocation,
|
||||||
|
result,
|
||||||
|
tool_spec_map.get(invocation.tool_name),
|
||||||
|
)
|
||||||
self._append_tool_execution_result(tool_call, result)
|
self._append_tool_execution_result(tool_call, result)
|
||||||
|
|
||||||
if not result.success and tool_call.func_name == "reply":
|
if not result.success and tool_call.func_name == "reply":
|
||||||
@@ -1015,19 +1249,6 @@ class MaisakaReasoningEngine:
|
|||||||
or target_user_info.user_nickname
|
or target_user_info.user_nickname
|
||||||
or target_user_info.user_id
|
or target_user_info.user_id
|
||||||
)
|
)
|
||||||
if self._runtime.chat_stream is not None:
|
|
||||||
await database_api.store_tool_info(
|
|
||||||
chat_stream=self._runtime.chat_stream,
|
|
||||||
display_prompt=f"你对{target_user_name}进行了回复:{combined_reply_text}",
|
|
||||||
tool_data={
|
|
||||||
"msg_id": target_message_id,
|
|
||||||
"quote": quote_reply,
|
|
||||||
"reply_text": combined_reply_text,
|
|
||||||
"reply_segments": reply_segments,
|
|
||||||
},
|
|
||||||
tool_name="reply",
|
|
||||||
tool_reasoning=latest_thought,
|
|
||||||
)
|
|
||||||
|
|
||||||
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
|
bot_name = global_config.bot.nickname.strip() or "MaiSaka"
|
||||||
reply_timestamp = datetime.now()
|
reply_timestamp = datetime.now()
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
@@ -65,7 +65,7 @@ async def db_save(
|
|||||||
record = None
|
record = None
|
||||||
if key_field and key_value is not None:
|
if key_field and key_value is not None:
|
||||||
key_column = _get_model_field(model_class, key_field)
|
key_column = _get_model_field(model_class, key_field)
|
||||||
record = session.exec(select(model_class).where(key_column == key_value)).first()
|
record = session.exec(cast(Any, select(model_class).where(key_column == key_value))).first()
|
||||||
|
|
||||||
if record is None:
|
if record is None:
|
||||||
record = model_class(**data)
|
record = model_class(**data)
|
||||||
@@ -99,7 +99,7 @@ async def db_get(
|
|||||||
statement = _apply_order_by(statement, model_class, order_by)
|
statement = _apply_order_by(statement, model_class, order_by)
|
||||||
if limit:
|
if limit:
|
||||||
statement = statement.limit(limit)
|
statement = statement.limit(limit)
|
||||||
results = session.exec(statement).all()
|
results = session.exec(cast(Any, statement)).all()
|
||||||
data = [_to_dict(item) for item in results]
|
data = [_to_dict(item) for item in results]
|
||||||
if single_result:
|
if single_result:
|
||||||
return data[0] if data else None
|
return data[0] if data else None
|
||||||
@@ -116,7 +116,7 @@ async def db_update(model_class: type[SQLModel], data: dict[str, Any], filters:
|
|||||||
statement = select(model_class)
|
statement = select(model_class)
|
||||||
if conditions := _build_filters(model_class, filters):
|
if conditions := _build_filters(model_class, filters):
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
records = session.exec(statement).all()
|
records = session.exec(cast(Any, statement)).all()
|
||||||
for record in records:
|
for record in records:
|
||||||
for field_name, value in data.items():
|
for field_name, value in data.items():
|
||||||
_get_model_field(model_class, field_name)
|
_get_model_field(model_class, field_name)
|
||||||
@@ -149,7 +149,7 @@ async def db_count(model_class: type[SQLModel], filters: Optional[dict[str, Any]
|
|||||||
statement = select(func.count()).select_from(model_class)
|
statement = select(func.count()).select_from(model_class)
|
||||||
if conditions := _build_filters(model_class, filters):
|
if conditions := _build_filters(model_class, filters):
|
||||||
statement = statement.where(*conditions)
|
statement = statement.where(*conditions)
|
||||||
result = session.exec(statement).one()
|
result = session.exec(cast(Any, statement)).one()
|
||||||
return int(result or 0)
|
return int(result or 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
|
logger.error(f"[DatabaseService] 统计数据库记录出错: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user