254 lines
8.8 KiB
Python
254 lines
8.8 KiB
Python
"""query_memory 内置工具。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from typing import Any, Dict, Optional, Tuple
|
||
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config
|
||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||
from src.person_info.person_info import resolve_person_id_for_memory
|
||
from src.services.memory_service import MemorySearchResult, memory_service
|
||
|
||
from .context import BuiltinToolRuntimeContext
|
||
|
||
logger = get_logger("maisaka_builtin_query_memory")
|
||
|
||
_ALLOWED_QUERY_MODES = {"search", "time", "hybrid", "episode", "aggregate"}
|
||
|
||
|
||
def get_tool_spec(*, enabled: bool = True) -> ToolSpec:
|
||
"""获取 query_memory 工具声明。"""
|
||
|
||
return ToolSpec(
|
||
name="query_memory",
|
||
brief_description="检索 A_memorix 长期记忆并返回可读结果。",
|
||
detailed_description=(
|
||
"参数说明:\n"
|
||
"- query:string,可选。要检索的关键词或问题。\n"
|
||
"- limit:integer,可选。返回条数,默认使用系统配置值。\n"
|
||
"- mode:string,可选。search/time/hybrid/episode/aggregate。\n"
|
||
"- person_name:string,可选。人物名,优先用于解析并过滤 person_id。\n"
|
||
"- time_start:string,可选。起始时间,可填写时间戳或可解析时间文本。\n"
|
||
"- time_end:string,可选。结束时间,可填写时间戳或可解析时间文本。\n"
|
||
"- respect_filter:boolean,可选。是否应用聊天过滤配置,默认 true。"
|
||
),
|
||
parameters_schema={
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "要检索的关键词或问题。",
|
||
},
|
||
"limit": {
|
||
"type": "integer",
|
||
"description": "返回条数,默认使用系统配置值。",
|
||
},
|
||
"mode": {
|
||
"type": "string",
|
||
"description": "检索模式:search/time/hybrid/episode/aggregate。",
|
||
"enum": sorted(_ALLOWED_QUERY_MODES),
|
||
"default": "search",
|
||
},
|
||
"person_name": {
|
||
"type": "string",
|
||
"description": "人物名称,可选。提供后优先按人物过滤。",
|
||
},
|
||
"time_start": {
|
||
"type": "string",
|
||
"description": "起始时间,可填写时间戳或可解析时间文本。",
|
||
},
|
||
"time_end": {
|
||
"type": "string",
|
||
"description": "结束时间,可填写时间戳或可解析时间文本。",
|
||
},
|
||
"respect_filter": {
|
||
"type": "boolean",
|
||
"description": "是否应用聊天过滤配置。",
|
||
"default": True,
|
||
},
|
||
},
|
||
},
|
||
provider_name="maisaka_builtin",
|
||
provider_type="builtin",
|
||
enabled=enabled,
|
||
)
|
||
|
||
|
||
def _normalize_optional_time(raw_value: Any) -> str | float | None:
|
||
"""归一化可选时间参数。"""
|
||
|
||
if raw_value is None:
|
||
return None
|
||
if isinstance(raw_value, str):
|
||
time_text = raw_value.strip()
|
||
if not time_text:
|
||
return None
|
||
return time_text
|
||
if isinstance(raw_value, (float, int)):
|
||
return float(raw_value)
|
||
|
||
time_text = str(raw_value).strip()
|
||
if not time_text:
|
||
return None
|
||
return time_text
|
||
|
||
|
||
def _resolve_person_id(
|
||
*,
|
||
person_name: str,
|
||
platform: str,
|
||
user_id: str,
|
||
group_id: str,
|
||
) -> Tuple[str, str]:
|
||
"""按约定顺序解析长期记忆检索使用的 person_id。"""
|
||
|
||
clean_person_name = str(person_name or "").strip()
|
||
if clean_person_name:
|
||
person_id = resolve_person_id_for_memory(
|
||
person_name=clean_person_name,
|
||
platform=platform,
|
||
user_id=user_id,
|
||
)
|
||
if person_id:
|
||
return person_id, clean_person_name
|
||
|
||
if not group_id and platform and user_id:
|
||
person_id = resolve_person_id_for_memory(
|
||
platform=platform,
|
||
user_id=user_id,
|
||
)
|
||
if person_id:
|
||
return person_id, clean_person_name
|
||
|
||
return "", clean_person_name
|
||
|
||
|
||
def _build_success_content(result: MemorySearchResult, *, limit: int) -> str:
|
||
"""构造工具成功时的可读内容。"""
|
||
|
||
summary = str(result.summary or "").strip()
|
||
snippet = result.to_text(limit=max(1, int(limit)))
|
||
|
||
if result.hits:
|
||
if summary and snippet:
|
||
return f"{summary}\n{snippet}"
|
||
if summary:
|
||
return summary
|
||
if snippet:
|
||
return snippet
|
||
return "已找到匹配的长期记忆。"
|
||
|
||
if result.filtered:
|
||
return "当前请求被聊天过滤策略跳过,未执行长期记忆检索。"
|
||
return "未找到匹配的长期记忆。"
|
||
|
||
|
||
async def handle_tool(
|
||
tool_ctx: BuiltinToolRuntimeContext,
|
||
invocation: ToolInvocation,
|
||
context: Optional[ToolExecutionContext] = None,
|
||
) -> ToolExecutionResult:
|
||
"""执行 query_memory 内置工具。"""
|
||
|
||
del context
|
||
runtime = tool_ctx.runtime
|
||
chat_stream = runtime.chat_stream
|
||
|
||
clean_query = str(invocation.arguments.get("query") or "").strip()
|
||
mode = str(invocation.arguments.get("mode") or "search").strip().lower() or "search"
|
||
if mode not in _ALLOWED_QUERY_MODES:
|
||
return tool_ctx.build_failure_result(
|
||
invocation.tool_name,
|
||
f"不支持的检索模式:{mode}。可选值:search/time/hybrid/episode/aggregate。",
|
||
)
|
||
|
||
default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5))
|
||
try:
|
||
limit = int(invocation.arguments.get("limit", default_limit) or default_limit)
|
||
except (TypeError, ValueError):
|
||
limit = default_limit
|
||
limit = max(1, min(limit, 20))
|
||
|
||
time_start = _normalize_optional_time(invocation.arguments.get("time_start"))
|
||
time_end = _normalize_optional_time(invocation.arguments.get("time_end"))
|
||
if not clean_query and time_start is None and time_end is None:
|
||
return tool_ctx.build_failure_result(
|
||
invocation.tool_name,
|
||
"query_memory 需要提供 query,或至少提供 time_start/time_end 中的一个。",
|
||
)
|
||
|
||
session_id = str(runtime.session_id or "").strip()
|
||
platform = str(chat_stream.platform or "").strip()
|
||
user_id = str(chat_stream.user_id or "").strip()
|
||
group_id = str(chat_stream.group_id or "").strip()
|
||
person_id, person_name = _resolve_person_id(
|
||
person_name=str(invocation.arguments.get("person_name") or ""),
|
||
platform=platform,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
)
|
||
respect_filter = bool(invocation.arguments.get("respect_filter", True))
|
||
|
||
logger.info(
|
||
f"{runtime.log_prefix} 触发长期记忆检索工具: "
|
||
f"mode={mode} query={clean_query!r} person_name={person_name!r} person_id={person_id!r}"
|
||
)
|
||
try:
|
||
result = await memory_service.search(
|
||
clean_query,
|
||
limit=limit,
|
||
mode=mode,
|
||
chat_id=session_id,
|
||
person_id=person_id,
|
||
time_start=time_start,
|
||
time_end=time_end,
|
||
respect_filter=respect_filter,
|
||
user_id=user_id,
|
||
group_id=group_id,
|
||
)
|
||
except Exception as exc:
|
||
logger.exception(f"{runtime.log_prefix} 长期记忆检索执行异常: {exc}")
|
||
return tool_ctx.build_failure_result(
|
||
invocation.tool_name,
|
||
f"长期记忆检索失败:{exc}",
|
||
)
|
||
|
||
structured_content: Dict[str, Any] = result.to_dict()
|
||
structured_content.update(
|
||
{
|
||
"query": clean_query,
|
||
"mode": mode,
|
||
"limit": limit,
|
||
"chat_id": session_id,
|
||
"person_name": person_name,
|
||
"person_id": person_id,
|
||
"time_start": time_start,
|
||
"time_end": time_end,
|
||
"respect_filter": respect_filter,
|
||
"user_id": user_id,
|
||
"group_id": group_id,
|
||
}
|
||
)
|
||
|
||
if not result.success:
|
||
error_message = str(result.error or "").strip() or "长期记忆检索失败。"
|
||
return tool_ctx.build_failure_result(
|
||
invocation.tool_name,
|
||
error_message,
|
||
structured_content=structured_content,
|
||
)
|
||
|
||
content = _build_success_content(result, limit=limit)
|
||
if clean_query:
|
||
display_prompt = f"你查询了长期记忆:{clean_query}"
|
||
else:
|
||
display_prompt = "你按时间范围查询了长期记忆。"
|
||
|
||
return tool_ctx.build_success_result(
|
||
invocation.tool_name,
|
||
content,
|
||
structured_content=structured_content,
|
||
metadata={"record_display_prompt": display_prompt},
|
||
)
|