feat:将A_memorix接入maisaka
This commit is contained in:
241
pytests/test_maisaka_builtin_query_memory.py
Normal file
241
pytests/test_maisaka_builtin_query_memory.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.core.tooling import ToolInvocation
|
||||||
|
from src.maisaka.builtin_tool import query_memory as query_memory_tool
|
||||||
|
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
|
||||||
|
from src.services.memory_service import MemoryHit, MemorySearchResult
|
||||||
|
|
||||||
|
|
||||||
|
def _build_tool_ctx(
|
||||||
|
*,
|
||||||
|
session_id: str = "session-1",
|
||||||
|
platform: str = "qq",
|
||||||
|
user_id: str = "user-1",
|
||||||
|
group_id: str = "",
|
||||||
|
) -> BuiltinToolRuntimeContext:
|
||||||
|
runtime = SimpleNamespace(
|
||||||
|
session_id=session_id,
|
||||||
|
chat_stream=SimpleNamespace(
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
),
|
||||||
|
log_prefix=f"[{session_id}]",
|
||||||
|
)
|
||||||
|
return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation:
|
||||||
|
return ToolInvocation(
|
||||||
|
tool_name="query_memory",
|
||||||
|
arguments=dict(arguments),
|
||||||
|
call_id="call-query-memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setattr(
|
||||||
|
query_memory_tool,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(maisaka=SimpleNamespace(memory_query_default_limit=5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
_ = query
|
||||||
|
_ = kwargs
|
||||||
|
raise AssertionError("参数校验失败时不应调用 memory_service.search")
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(),
|
||||||
|
_build_invocation({"query": "", "time_start": "", "time_end": ""}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "query_memory 需要提供 query" in result.error_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
captured: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_resolve_person_id_for_memory(
|
||||||
|
*,
|
||||||
|
person_name: str = "",
|
||||||
|
platform: str = "",
|
||||||
|
user_id: Any = None,
|
||||||
|
strict_known: bool = False,
|
||||||
|
) -> str:
|
||||||
|
_ = strict_known
|
||||||
|
captured["resolve_args"] = {
|
||||||
|
"person_name": person_name,
|
||||||
|
"platform": platform,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
return "pid-private-auto"
|
||||||
|
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
captured["query"] = query
|
||||||
|
captured["search_kwargs"] = dict(kwargs)
|
||||||
|
return MemorySearchResult(
|
||||||
|
summary="检索摘要",
|
||||||
|
hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
|
||||||
|
_build_invocation({"query": "Alice 的喜好"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert captured["query"] == "Alice 的喜好"
|
||||||
|
assert captured["resolve_args"] == {
|
||||||
|
"person_name": "",
|
||||||
|
"platform": "qq",
|
||||||
|
"user_id": "alice",
|
||||||
|
}
|
||||||
|
assert captured["search_kwargs"]["chat_id"] == "private-session"
|
||||||
|
assert captured["search_kwargs"]["user_id"] == "alice"
|
||||||
|
assert captured["search_kwargs"]["group_id"] == ""
|
||||||
|
assert captured["search_kwargs"]["person_id"] == "pid-private-auto"
|
||||||
|
assert isinstance(result.structured_content, dict)
|
||||||
|
assert result.structured_content["person_id"] == "pid-private-auto"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
call_counter = {"resolve": 0}
|
||||||
|
captured_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_resolve_person_id_for_memory(
|
||||||
|
*,
|
||||||
|
person_name: str = "",
|
||||||
|
platform: str = "",
|
||||||
|
user_id: Any = None,
|
||||||
|
strict_known: bool = False,
|
||||||
|
) -> str:
|
||||||
|
_ = person_name
|
||||||
|
_ = platform
|
||||||
|
_ = user_id
|
||||||
|
_ = strict_known
|
||||||
|
call_counter["resolve"] += 1
|
||||||
|
return "unexpected-person-id"
|
||||||
|
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
_ = query
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
return MemorySearchResult(summary="", hits=[])
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"),
|
||||||
|
_build_invocation({"query": "群聊上下文"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert call_counter["resolve"] == 0
|
||||||
|
assert captured_kwargs["chat_id"] == "group-session"
|
||||||
|
assert captured_kwargs["group_id"] == "group-1"
|
||||||
|
assert captured_kwargs["person_id"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
_ = query
|
||||||
|
_ = kwargs
|
||||||
|
return MemorySearchResult(success=False, error="boom")
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(),
|
||||||
|
_build_invocation({"query": "测试失败透传"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error_message == "boom"
|
||||||
|
assert isinstance(result.structured_content, dict)
|
||||||
|
assert result.structured_content["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
captured: Dict[str, Any] = {"resolve_calls": []}
|
||||||
|
|
||||||
|
def fake_resolve_person_id_for_memory(
|
||||||
|
*,
|
||||||
|
person_name: str = "",
|
||||||
|
platform: str = "",
|
||||||
|
user_id: Any = None,
|
||||||
|
strict_known: bool = False,
|
||||||
|
) -> str:
|
||||||
|
_ = strict_known
|
||||||
|
captured["resolve_calls"].append(
|
||||||
|
{
|
||||||
|
"person_name": person_name,
|
||||||
|
"platform": platform,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if person_name:
|
||||||
|
return "pid-by-name"
|
||||||
|
return "pid-private-auto"
|
||||||
|
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
_ = query
|
||||||
|
captured["search_kwargs"] = dict(kwargs)
|
||||||
|
return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")])
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
|
||||||
|
_build_invocation({"query": "小明资料", "person_name": "小明"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert captured["resolve_calls"][0] == {
|
||||||
|
"person_name": "小明",
|
||||||
|
"platform": "qq",
|
||||||
|
"user_id": "alice",
|
||||||
|
}
|
||||||
|
assert captured["search_kwargs"]["person_id"] == "pid-by-name"
|
||||||
|
assert result.structured_content["person_name"] == "小明"
|
||||||
|
assert result.structured_content["person_id"] == "pid-by-name"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
|
||||||
|
_ = query
|
||||||
|
_ = kwargs
|
||||||
|
return MemorySearchResult(summary="", hits=[])
|
||||||
|
|
||||||
|
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
|
||||||
|
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
|
||||||
|
|
||||||
|
result = await query_memory_tool.handle_tool(
|
||||||
|
_build_tool_ctx(),
|
||||||
|
_build_invocation({"query": "不存在的记忆"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert "未找到匹配的长期记忆" in result.content
|
||||||
|
assert isinstance(result.structured_content, dict)
|
||||||
|
assert result.structured_content["query"] == "不存在的记忆"
|
||||||
@@ -99,6 +99,25 @@ def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
|
|||||||
assert mcp_schema.get("uiParent") == "maisaka"
|
assert mcp_schema.get("uiParent") == "maisaka"
|
||||||
|
|
||||||
|
|
||||||
|
def test_maisaka_memory_query_config_fields_are_exposed():
|
||||||
|
"""MaiSaka 长期记忆检索开关和默认条数应出现在配置 schema 中。"""
|
||||||
|
schema = ConfigSchemaGenerator.generate_schema(Config)
|
||||||
|
maisaka_schema = schema["nested"]["maisaka"]
|
||||||
|
|
||||||
|
enable_field = next(field for field in maisaka_schema["fields"] if field["name"] == "enable_memory_query_tool")
|
||||||
|
limit_field = next(field for field in maisaka_schema["fields"] if field["name"] == "memory_query_default_limit")
|
||||||
|
|
||||||
|
assert enable_field["type"] == "boolean"
|
||||||
|
assert enable_field.get("x-widget") == "switch"
|
||||||
|
assert enable_field.get("x-icon") == "database"
|
||||||
|
|
||||||
|
assert limit_field["type"] == "integer"
|
||||||
|
assert limit_field.get("x-widget") == "input"
|
||||||
|
assert limit_field.get("x-icon") == "hash"
|
||||||
|
assert limit_field.get("minValue") == 1
|
||||||
|
assert limit_field.get("maxValue") == 20
|
||||||
|
|
||||||
|
|
||||||
def test_set_field_is_mapped_as_array():
|
def test_set_field_is_mapped_as_array():
|
||||||
"""set[str] 应映射为前端可识别的 array。"""
|
"""set[str] 应映射为前端可识别的 array。"""
|
||||||
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
|
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)
|
||||||
|
|||||||
@@ -1525,6 +1525,26 @@ class MaiSakaConfig(ConfigBase):
|
|||||||
)
|
)
|
||||||
"""每个入站消息的最大内部规划轮数"""
|
"""每个入站消息的最大内部规划轮数"""
|
||||||
|
|
||||||
|
enable_memory_query_tool: bool = Field(
|
||||||
|
default=True,
|
||||||
|
json_schema_extra={
|
||||||
|
"x-widget": "switch",
|
||||||
|
"x-icon": "database",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
"""是否启用 Maisaka 内置长期记忆检索工具 query_memory"""
|
||||||
|
|
||||||
|
memory_query_default_limit: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=20,
|
||||||
|
json_schema_extra={
|
||||||
|
"x-widget": "input",
|
||||||
|
"x-icon": "hash",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
|
||||||
|
|
||||||
tool_filter_task_name: str = Field(
|
tool_filter_task_name: str = Field(
|
||||||
default="utils",
|
default="utils",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from src.config.config import global_config
|
||||||
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
|
||||||
|
|
||||||
@@ -11,6 +12,8 @@ from .no_reply import get_tool_spec as get_no_reply_tool_spec
|
|||||||
from .no_reply import handle_tool as handle_no_reply_tool
|
from .no_reply import handle_tool as handle_no_reply_tool
|
||||||
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
|
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
|
||||||
from .query_jargon import handle_tool as handle_query_jargon_tool
|
from .query_jargon import handle_tool as handle_query_jargon_tool
|
||||||
|
from .query_memory import get_tool_spec as get_query_memory_tool_spec
|
||||||
|
from .query_memory import handle_tool as handle_query_memory_tool
|
||||||
from .query_person_info import get_tool_spec as get_query_person_info_tool_spec
|
from .query_person_info import get_tool_spec as get_query_person_info_tool_spec
|
||||||
from .query_person_info import handle_tool as handle_query_person_info_tool
|
from .query_person_info import handle_tool as handle_query_person_info_tool
|
||||||
from .reply import get_tool_spec as get_reply_tool_spec
|
from .reply import get_tool_spec as get_reply_tool_spec
|
||||||
@@ -33,6 +36,7 @@ def get_builtin_tool_specs() -> List[ToolSpec]:
|
|||||||
get_reply_tool_spec(),
|
get_reply_tool_spec(),
|
||||||
get_view_complex_message_tool_spec(),
|
get_view_complex_message_tool_spec(),
|
||||||
get_query_jargon_tool_spec(),
|
get_query_jargon_tool_spec(),
|
||||||
|
get_query_memory_tool_spec(enabled=bool(global_config.maisaka.enable_memory_query_tool)),
|
||||||
get_no_reply_tool_spec(),
|
get_no_reply_tool_spec(),
|
||||||
get_send_emoji_tool_spec(),
|
get_send_emoji_tool_spec(),
|
||||||
]
|
]
|
||||||
@@ -46,6 +50,7 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
|
|||||||
get_reply_tool_spec(),
|
get_reply_tool_spec(),
|
||||||
get_view_complex_message_tool_spec(),
|
get_view_complex_message_tool_spec(),
|
||||||
get_query_jargon_tool_spec(),
|
get_query_jargon_tool_spec(),
|
||||||
|
get_query_memory_tool_spec(enabled=True),
|
||||||
get_query_person_info_tool_spec(),
|
get_query_person_info_tool_spec(),
|
||||||
get_no_reply_tool_spec(),
|
get_no_reply_tool_spec(),
|
||||||
get_send_emoji_tool_spec(),
|
get_send_emoji_tool_spec(),
|
||||||
@@ -65,6 +70,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
|
|||||||
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
|
||||||
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
|
||||||
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
|
||||||
|
"query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context),
|
||||||
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
|
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
|
||||||
tool_ctx,
|
tool_ctx,
|
||||||
invocation,
|
invocation,
|
||||||
|
|||||||
253
src/maisaka/builtin_tool/query_memory.py
Normal file
253
src/maisaka/builtin_tool/query_memory.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""query_memory 内置工具。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
|
||||||
|
from src.person_info.person_info import resolve_person_id_for_memory
|
||||||
|
from src.services.memory_service import MemorySearchResult, memory_service
|
||||||
|
|
||||||
|
from .context import BuiltinToolRuntimeContext
|
||||||
|
|
||||||
|
logger = get_logger("maisaka_builtin_query_memory")
|
||||||
|
|
||||||
|
_ALLOWED_QUERY_MODES = {"search", "time", "hybrid", "episode", "aggregate"}
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_spec(*, enabled: bool = True) -> ToolSpec:
|
||||||
|
"""获取 query_memory 工具声明。"""
|
||||||
|
|
||||||
|
return ToolSpec(
|
||||||
|
name="query_memory",
|
||||||
|
brief_description="检索 A_memorix 长期记忆并返回可读结果。",
|
||||||
|
detailed_description=(
|
||||||
|
"参数说明:\n"
|
||||||
|
"- query:string,可选。要检索的关键词或问题。\n"
|
||||||
|
"- limit:integer,可选。返回条数,默认使用系统配置值。\n"
|
||||||
|
"- mode:string,可选。search/time/hybrid/episode/aggregate。\n"
|
||||||
|
"- person_name:string,可选。人物名,优先用于解析并过滤 person_id。\n"
|
||||||
|
"- time_start:string,可选。起始时间,可填写时间戳或可解析时间文本。\n"
|
||||||
|
"- time_end:string,可选。结束时间,可填写时间戳或可解析时间文本。\n"
|
||||||
|
"- respect_filter:boolean,可选。是否应用聊天过滤配置,默认 true。"
|
||||||
|
),
|
||||||
|
parameters_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要检索的关键词或问题。",
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "返回条数,默认使用系统配置值。",
|
||||||
|
},
|
||||||
|
"mode": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索模式:search/time/hybrid/episode/aggregate。",
|
||||||
|
"enum": sorted(_ALLOWED_QUERY_MODES),
|
||||||
|
"default": "search",
|
||||||
|
},
|
||||||
|
"person_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "人物名称,可选。提供后优先按人物过滤。",
|
||||||
|
},
|
||||||
|
"time_start": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "起始时间,可填写时间戳或可解析时间文本。",
|
||||||
|
},
|
||||||
|
"time_end": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "结束时间,可填写时间戳或可解析时间文本。",
|
||||||
|
},
|
||||||
|
"respect_filter": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否应用聊天过滤配置。",
|
||||||
|
"default": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
provider_name="maisaka_builtin",
|
||||||
|
provider_type="builtin",
|
||||||
|
enabled=enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_optional_time(raw_value: Any) -> str | float | None:
|
||||||
|
"""归一化可选时间参数。"""
|
||||||
|
|
||||||
|
if raw_value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw_value, str):
|
||||||
|
time_text = raw_value.strip()
|
||||||
|
if not time_text:
|
||||||
|
return None
|
||||||
|
return time_text
|
||||||
|
if isinstance(raw_value, (float, int)):
|
||||||
|
return float(raw_value)
|
||||||
|
|
||||||
|
time_text = str(raw_value).strip()
|
||||||
|
if not time_text:
|
||||||
|
return None
|
||||||
|
return time_text
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_person_id(
|
||||||
|
*,
|
||||||
|
person_name: str,
|
||||||
|
platform: str,
|
||||||
|
user_id: str,
|
||||||
|
group_id: str,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""按约定顺序解析长期记忆检索使用的 person_id。"""
|
||||||
|
|
||||||
|
clean_person_name = str(person_name or "").strip()
|
||||||
|
if clean_person_name:
|
||||||
|
person_id = resolve_person_id_for_memory(
|
||||||
|
person_name=clean_person_name,
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if person_id:
|
||||||
|
return person_id, clean_person_name
|
||||||
|
|
||||||
|
if not group_id and platform and user_id:
|
||||||
|
person_id = resolve_person_id_for_memory(
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
if person_id:
|
||||||
|
return person_id, clean_person_name
|
||||||
|
|
||||||
|
return "", clean_person_name
|
||||||
|
|
||||||
|
|
||||||
|
def _build_success_content(result: MemorySearchResult, *, limit: int) -> str:
|
||||||
|
"""构造工具成功时的可读内容。"""
|
||||||
|
|
||||||
|
summary = str(result.summary or "").strip()
|
||||||
|
snippet = result.to_text(limit=max(1, int(limit)))
|
||||||
|
|
||||||
|
if result.hits:
|
||||||
|
if summary and snippet:
|
||||||
|
return f"{summary}\n{snippet}"
|
||||||
|
if summary:
|
||||||
|
return summary
|
||||||
|
if snippet:
|
||||||
|
return snippet
|
||||||
|
return "已找到匹配的长期记忆。"
|
||||||
|
|
||||||
|
if result.filtered:
|
||||||
|
return "当前请求被聊天过滤策略跳过,未执行长期记忆检索。"
|
||||||
|
return "未找到匹配的长期记忆。"
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_tool(
|
||||||
|
tool_ctx: BuiltinToolRuntimeContext,
|
||||||
|
invocation: ToolInvocation,
|
||||||
|
context: Optional[ToolExecutionContext] = None,
|
||||||
|
) -> ToolExecutionResult:
|
||||||
|
"""执行 query_memory 内置工具。"""
|
||||||
|
|
||||||
|
del context
|
||||||
|
runtime = tool_ctx.runtime
|
||||||
|
chat_stream = runtime.chat_stream
|
||||||
|
|
||||||
|
clean_query = str(invocation.arguments.get("query") or "").strip()
|
||||||
|
mode = str(invocation.arguments.get("mode") or "search").strip().lower() or "search"
|
||||||
|
if mode not in _ALLOWED_QUERY_MODES:
|
||||||
|
return tool_ctx.build_failure_result(
|
||||||
|
invocation.tool_name,
|
||||||
|
f"不支持的检索模式:{mode}。可选值:search/time/hybrid/episode/aggregate。",
|
||||||
|
)
|
||||||
|
|
||||||
|
default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5))
|
||||||
|
try:
|
||||||
|
limit = int(invocation.arguments.get("limit", default_limit) or default_limit)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit = default_limit
|
||||||
|
limit = max(1, min(limit, 20))
|
||||||
|
|
||||||
|
time_start = _normalize_optional_time(invocation.arguments.get("time_start"))
|
||||||
|
time_end = _normalize_optional_time(invocation.arguments.get("time_end"))
|
||||||
|
if not clean_query and time_start is None and time_end is None:
|
||||||
|
return tool_ctx.build_failure_result(
|
||||||
|
invocation.tool_name,
|
||||||
|
"query_memory 需要提供 query,或至少提供 time_start/time_end 中的一个。",
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = str(runtime.session_id or "").strip()
|
||||||
|
platform = str(chat_stream.platform or "").strip()
|
||||||
|
user_id = str(chat_stream.user_id or "").strip()
|
||||||
|
group_id = str(chat_stream.group_id or "").strip()
|
||||||
|
person_id, person_name = _resolve_person_id(
|
||||||
|
person_name=str(invocation.arguments.get("person_name") or ""),
|
||||||
|
platform=platform,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
respect_filter = bool(invocation.arguments.get("respect_filter", True))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"{runtime.log_prefix} 触发长期记忆检索工具: "
|
||||||
|
f"mode={mode} query={clean_query!r} person_name={person_name!r} person_id={person_id!r}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await memory_service.search(
|
||||||
|
clean_query,
|
||||||
|
limit=limit,
|
||||||
|
mode=mode,
|
||||||
|
chat_id=session_id,
|
||||||
|
person_id=person_id,
|
||||||
|
time_start=time_start,
|
||||||
|
time_end=time_end,
|
||||||
|
respect_filter=respect_filter,
|
||||||
|
user_id=user_id,
|
||||||
|
group_id=group_id,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception(f"{runtime.log_prefix} 长期记忆检索执行异常: {exc}")
|
||||||
|
return tool_ctx.build_failure_result(
|
||||||
|
invocation.tool_name,
|
||||||
|
f"长期记忆检索失败:{exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
structured_content: Dict[str, Any] = result.to_dict()
|
||||||
|
structured_content.update(
|
||||||
|
{
|
||||||
|
"query": clean_query,
|
||||||
|
"mode": mode,
|
||||||
|
"limit": limit,
|
||||||
|
"chat_id": session_id,
|
||||||
|
"person_name": person_name,
|
||||||
|
"person_id": person_id,
|
||||||
|
"time_start": time_start,
|
||||||
|
"time_end": time_end,
|
||||||
|
"respect_filter": respect_filter,
|
||||||
|
"user_id": user_id,
|
||||||
|
"group_id": group_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
error_message = str(result.error or "").strip() or "长期记忆检索失败。"
|
||||||
|
return tool_ctx.build_failure_result(
|
||||||
|
invocation.tool_name,
|
||||||
|
error_message,
|
||||||
|
structured_content=structured_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = _build_success_content(result, limit=limit)
|
||||||
|
if clean_query:
|
||||||
|
display_prompt = f"你查询了长期记忆:{clean_query}"
|
||||||
|
else:
|
||||||
|
display_prompt = "你按时间范围查询了长期记忆。"
|
||||||
|
|
||||||
|
return tool_ctx.build_success_result(
|
||||||
|
invocation.tool_name,
|
||||||
|
content,
|
||||||
|
structured_content=structured_content,
|
||||||
|
metadata={"record_display_prompt": display_prompt},
|
||||||
|
)
|
||||||
@@ -637,6 +637,15 @@ class MaisakaReasoningEngine:
|
|||||||
return f"你查询了人物信息:{person_name}"
|
return f"你查询了人物信息:{person_name}"
|
||||||
return "你查询了一次人物信息。"
|
return "你查询了一次人物信息。"
|
||||||
|
|
||||||
|
if invocation.tool_name == "query_memory":
|
||||||
|
query_text = str(invocation.arguments.get("query") or "").strip()
|
||||||
|
mode = str(invocation.arguments.get("mode") or "search").strip() or "search"
|
||||||
|
hit_items = structured_content.get("hits")
|
||||||
|
hit_count = len(hit_items) if isinstance(hit_items, list) else 0
|
||||||
|
if query_text:
|
||||||
|
return f"你查询了长期记忆:{query_text}(模式:{mode},命中 {hit_count} 条)"
|
||||||
|
return f"你按时间范围查询了一次长期记忆(模式:{mode},命中 {hit_count} 条)。"
|
||||||
|
|
||||||
if invocation.tool_name == "view_complex_message":
|
if invocation.tool_name == "view_complex_message":
|
||||||
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
|
||||||
if target_message_id:
|
if target_message_id:
|
||||||
|
|||||||
Reference in New Issue
Block a user