Files
mai-bot/src/maisaka/builtin_tool/query_memory.py
2026-04-04 02:50:08 +08:00

254 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""query_memory 内置工具。"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.person_info.person_info import resolve_person_id_for_memory
from src.services.memory_service import MemorySearchResult, memory_service
from .context import BuiltinToolRuntimeContext
logger = get_logger("maisaka_builtin_query_memory")
_ALLOWED_QUERY_MODES = {"search", "time", "hybrid", "episode", "aggregate"}
def get_tool_spec(*, enabled: bool = True) -> ToolSpec:
"""获取 query_memory 工具声明。"""
return ToolSpec(
name="query_memory",
brief_description="检索 A_memorix 长期记忆并返回可读结果。",
detailed_description=(
"参数说明:\n"
"- querystring可选。要检索的关键词或问题。\n"
"- limitinteger可选。返回条数默认使用系统配置值。\n"
"- modestring可选。search/time/hybrid/episode/aggregate。\n"
"- person_namestring可选。人物名优先用于解析并过滤 person_id。\n"
"- time_startstring可选。起始时间可填写时间戳或可解析时间文本。\n"
"- time_endstring可选。结束时间可填写时间戳或可解析时间文本。\n"
"- respect_filterboolean可选。是否应用聊天过滤配置默认 true。"
),
parameters_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "要检索的关键词或问题。",
},
"limit": {
"type": "integer",
"description": "返回条数,默认使用系统配置值。",
},
"mode": {
"type": "string",
"description": "检索模式search/time/hybrid/episode/aggregate。",
"enum": sorted(_ALLOWED_QUERY_MODES),
"default": "search",
},
"person_name": {
"type": "string",
"description": "人物名称,可选。提供后优先按人物过滤。",
},
"time_start": {
"type": "string",
"description": "起始时间,可填写时间戳或可解析时间文本。",
},
"time_end": {
"type": "string",
"description": "结束时间,可填写时间戳或可解析时间文本。",
},
"respect_filter": {
"type": "boolean",
"description": "是否应用聊天过滤配置。",
"default": True,
},
},
},
provider_name="maisaka_builtin",
provider_type="builtin",
enabled=enabled,
)
def _normalize_optional_time(raw_value: Any) -> str | float | None:
"""归一化可选时间参数。"""
if raw_value is None:
return None
if isinstance(raw_value, str):
time_text = raw_value.strip()
if not time_text:
return None
return time_text
if isinstance(raw_value, (float, int)):
return float(raw_value)
time_text = str(raw_value).strip()
if not time_text:
return None
return time_text
def _resolve_person_id(
*,
person_name: str,
platform: str,
user_id: str,
group_id: str,
) -> Tuple[str, str]:
"""按约定顺序解析长期记忆检索使用的 person_id。"""
clean_person_name = str(person_name or "").strip()
if clean_person_name:
person_id = resolve_person_id_for_memory(
person_name=clean_person_name,
platform=platform,
user_id=user_id,
)
if person_id:
return person_id, clean_person_name
if not group_id and platform and user_id:
person_id = resolve_person_id_for_memory(
platform=platform,
user_id=user_id,
)
if person_id:
return person_id, clean_person_name
return "", clean_person_name
def _build_success_content(result: MemorySearchResult, *, limit: int) -> str:
"""构造工具成功时的可读内容。"""
summary = str(result.summary or "").strip()
snippet = result.to_text(limit=max(1, int(limit)))
if result.hits:
if summary and snippet:
return f"{summary}\n{snippet}"
if summary:
return summary
if snippet:
return snippet
return "已找到匹配的长期记忆。"
if result.filtered:
return "当前请求被聊天过滤策略跳过,未执行长期记忆检索。"
return "未找到匹配的长期记忆。"
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 query_memory 内置工具。"""
del context
runtime = tool_ctx.runtime
chat_stream = runtime.chat_stream
clean_query = str(invocation.arguments.get("query") or "").strip()
mode = str(invocation.arguments.get("mode") or "search").strip().lower() or "search"
if mode not in _ALLOWED_QUERY_MODES:
return tool_ctx.build_failure_result(
invocation.tool_name,
f"不支持的检索模式:{mode}。可选值search/time/hybrid/episode/aggregate。",
)
default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5))
try:
limit = int(invocation.arguments.get("limit", default_limit) or default_limit)
except (TypeError, ValueError):
limit = default_limit
limit = max(1, min(limit, 20))
time_start = _normalize_optional_time(invocation.arguments.get("time_start"))
time_end = _normalize_optional_time(invocation.arguments.get("time_end"))
if not clean_query and time_start is None and time_end is None:
return tool_ctx.build_failure_result(
invocation.tool_name,
"query_memory 需要提供 query或至少提供 time_start/time_end 中的一个。",
)
session_id = str(runtime.session_id or "").strip()
platform = str(chat_stream.platform or "").strip()
user_id = str(chat_stream.user_id or "").strip()
group_id = str(chat_stream.group_id or "").strip()
person_id, person_name = _resolve_person_id(
person_name=str(invocation.arguments.get("person_name") or ""),
platform=platform,
user_id=user_id,
group_id=group_id,
)
respect_filter = bool(invocation.arguments.get("respect_filter", True))
logger.info(
f"{runtime.log_prefix} 触发长期记忆检索工具: "
f"mode={mode} query={clean_query!r} person_name={person_name!r} person_id={person_id!r}"
)
try:
result = await memory_service.search(
clean_query,
limit=limit,
mode=mode,
chat_id=session_id,
person_id=person_id,
time_start=time_start,
time_end=time_end,
respect_filter=respect_filter,
user_id=user_id,
group_id=group_id,
)
except Exception as exc:
logger.exception(f"{runtime.log_prefix} 长期记忆检索执行异常: {exc}")
return tool_ctx.build_failure_result(
invocation.tool_name,
f"长期记忆检索失败:{exc}",
)
structured_content: Dict[str, Any] = result.to_dict()
structured_content.update(
{
"query": clean_query,
"mode": mode,
"limit": limit,
"chat_id": session_id,
"person_name": person_name,
"person_id": person_id,
"time_start": time_start,
"time_end": time_end,
"respect_filter": respect_filter,
"user_id": user_id,
"group_id": group_id,
}
)
if not result.success:
error_message = str(result.error or "").strip() or "长期记忆检索失败。"
return tool_ctx.build_failure_result(
invocation.tool_name,
error_message,
structured_content=structured_content,
)
content = _build_success_content(result, limit=limit)
if clean_query:
display_prompt = f"你查询了长期记忆:{clean_query}"
else:
display_prompt = "你按时间范围查询了长期记忆。"
return tool_ctx.build_success_result(
invocation.tool_name,
content,
structured_content=structured_content,
metadata={"record_display_prompt": display_prompt},
)