diff --git a/pytests/test_maisaka_builtin_query_memory.py b/pytests/test_maisaka_builtin_query_memory.py new file mode 100644 index 00000000..7bc10cf7 --- /dev/null +++ b/pytests/test_maisaka_builtin_query_memory.py @@ -0,0 +1,241 @@ +from types import SimpleNamespace +from typing import Any, Dict + +import pytest + +from src.core.tooling import ToolInvocation +from src.maisaka.builtin_tool import query_memory as query_memory_tool +from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext +from src.services.memory_service import MemoryHit, MemorySearchResult + + +def _build_tool_ctx( + *, + session_id: str = "session-1", + platform: str = "qq", + user_id: str = "user-1", + group_id: str = "", +) -> BuiltinToolRuntimeContext: + runtime = SimpleNamespace( + session_id=session_id, + chat_stream=SimpleNamespace( + platform=platform, + user_id=user_id, + group_id=group_id, + ), + log_prefix=f"[{session_id}]", + ) + return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime) + + +def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation: + return ToolInvocation( + tool_name="query_memory", + arguments=dict(arguments), + call_id="call-query-memory", + ) + + +@pytest.fixture(autouse=True) +def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + query_memory_tool, + "global_config", + SimpleNamespace(maisaka=SimpleNamespace(memory_query_default_limit=5)), + ) + + +@pytest.mark.asyncio +async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None: + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + _ = query + _ = kwargs + raise AssertionError("参数校验失败时不应调用 memory_service.search") + + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(), + _build_invocation({"query": "", "time_start": "", "time_end": ""}), + ) + + assert result.success is False + assert "query_memory 需要提供 query" in result.error_message + + +@pytest.mark.asyncio +async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None: + captured: Dict[str, Any] = {} + + def fake_resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Any = None, + strict_known: bool = False, + ) -> str: + _ = strict_known + captured["resolve_args"] = { + "person_name": person_name, + "platform": platform, + "user_id": user_id, + } + return "pid-private-auto" + + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + captured["query"] = query + captured["search_kwargs"] = dict(kwargs) + return MemorySearchResult( + summary="检索摘要", + hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)], + ) + + monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""), + _build_invocation({"query": "Alice 的喜好"}), + ) + + assert result.success is True + assert captured["query"] == "Alice 的喜好" + assert captured["resolve_args"] == { + "person_name": "", + "platform": "qq", + "user_id": "alice", + } + assert captured["search_kwargs"]["chat_id"] == "private-session" + assert captured["search_kwargs"]["user_id"] == "alice" + assert captured["search_kwargs"]["group_id"] == "" + assert captured["search_kwargs"]["person_id"] == "pid-private-auto" + assert isinstance(result.structured_content, dict) + assert result.structured_content["person_id"] == "pid-private-auto" + + +@pytest.mark.asyncio +async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None: + call_counter = {"resolve": 0} + captured_kwargs: Dict[str, Any] = {} + + def fake_resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Any = None, + strict_known: bool = False, + ) -> str: + _ = person_name + _ = platform + _ = user_id + _ = strict_known + call_counter["resolve"] += 1 + return "unexpected-person-id" + + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + _ = query + captured_kwargs.update(kwargs) + return MemorySearchResult(summary="", hits=[]) + + monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"), + _build_invocation({"query": "群聊上下文"}), + ) + + assert result.success is True + assert call_counter["resolve"] == 0 + assert captured_kwargs["chat_id"] == "group-session" + assert captured_kwargs["group_id"] == "group-1" + assert captured_kwargs["person_id"] == "" + + +@pytest.mark.asyncio +async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None: + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + _ = query + _ = kwargs + return MemorySearchResult(success=False, error="boom") + + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "") + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(), + _build_invocation({"query": "测试失败透传"}), + ) + + assert result.success is False + assert result.error_message == "boom" + assert isinstance(result.structured_content, dict) + assert result.structured_content["success"] is False + + +@pytest.mark.asyncio +async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None: + captured: Dict[str, Any] = {"resolve_calls": []} + + def fake_resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Any = None, + strict_known: bool = False, + ) -> str: + _ = strict_known + captured["resolve_calls"].append( + { + "person_name": person_name, + "platform": platform, + "user_id": user_id, + } + ) + if person_name: + return "pid-by-name" + return "pid-private-auto" + + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + _ = query + captured["search_kwargs"] = dict(kwargs) + return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")]) + + monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory) + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""), + _build_invocation({"query": "小明资料", "person_name": "小明"}), + ) + + assert result.success is True + assert captured["resolve_calls"][0] == { + "person_name": "小明", + "platform": "qq", + "user_id": "alice", + } + assert captured["search_kwargs"]["person_id"] == "pid-by-name" + assert result.structured_content["person_name"] == "小明" + assert result.structured_content["person_id"] == "pid-by-name" + + +@pytest.mark.asyncio +async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None: + async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult: + _ = query + _ = kwargs + return MemorySearchResult(summary="", hits=[]) + + monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search) + monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "") + + result = await query_memory_tool.handle_tool( + _build_tool_ctx(), + _build_invocation({"query": "不存在的记忆"}), + ) + + assert result.success is True + assert "未找到匹配的长期记忆" in result.content + assert isinstance(result.structured_content, dict) + assert result.structured_content["query"] == "不存在的记忆" diff --git a/pytests/webui/test_config_schema.py b/pytests/webui/test_config_schema.py index 7bd5e3fa..47310379 100644 --- a/pytests/webui/test_config_schema.py +++ b/pytests/webui/test_config_schema.py @@ -99,6 +99,25 @@ def test_maisaka_is_host_tab_and_mcp_is_attached_to_it(): assert mcp_schema.get("uiParent") == "maisaka" +def test_maisaka_memory_query_config_fields_are_exposed(): + """MaiSaka 长期记忆检索开关和默认条数应出现在配置 schema 中。""" + schema = ConfigSchemaGenerator.generate_schema(Config) + maisaka_schema = schema["nested"]["maisaka"] + + enable_field = next(field for field in maisaka_schema["fields"] if field["name"] == "enable_memory_query_tool") + limit_field = next(field for field in maisaka_schema["fields"] if field["name"] == "memory_query_default_limit") + + assert enable_field["type"] == "boolean" + assert enable_field.get("x-widget") == "switch" + assert enable_field.get("x-icon") == "database" + + assert limit_field["type"] == "integer" + assert limit_field.get("x-widget") == "input" + assert limit_field.get("x-icon") == "hash" + assert limit_field.get("minValue") == 1 + assert limit_field.get("maxValue") == 20 + + def test_set_field_is_mapped_as_array(): """set[str] 应映射为前端可识别的 array。""" schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index c7c1cfcb..75e533cf 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1525,6 +1525,26 @@ class MaiSakaConfig(ConfigBase): ) """每个入站消息的最大内部规划轮数""" + enable_memory_query_tool: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "database", + }, + ) + """是否启用 Maisaka 内置长期记忆检索工具 query_memory""" + + memory_query_default_limit: int = Field( + default=5, + ge=1, + le=20, + json_schema_extra={ + "x-widget": "input", + "x-icon": "hash", + }, + ) + """Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数""" + tool_filter_task_name: str = Field( default="utils", json_schema_extra={ diff --git a/src/maisaka/builtin_tool/__init__.py b/src/maisaka/builtin_tool/__init__.py index 74d1672f..4ebf150a 100644 --- a/src/maisaka/builtin_tool/__init__.py +++ b/src/maisaka/builtin_tool/__init__.py @@ -3,6 +3,7 @@ from collections.abc import Awaitable, Callable from typing import Dict, List, Optional +from src.config.config import global_config from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec from src.llm_models.payload_content.tool_option import ToolDefinitionInput @@ -11,6 +12,8 @@ from .no_reply import get_tool_spec as get_no_reply_tool_spec from .no_reply import handle_tool as handle_no_reply_tool from .query_jargon import get_tool_spec as get_query_jargon_tool_spec from .query_jargon import handle_tool as handle_query_jargon_tool +from .query_memory import get_tool_spec as get_query_memory_tool_spec +from .query_memory import handle_tool as handle_query_memory_tool from .query_person_info import get_tool_spec as get_query_person_info_tool_spec from .query_person_info import handle_tool as handle_query_person_info_tool from .reply import get_tool_spec as get_reply_tool_spec @@ -33,6 +36,7 @@ def get_builtin_tool_specs() -> List[ToolSpec]: get_reply_tool_spec(), get_view_complex_message_tool_spec(), get_query_jargon_tool_spec(), + get_query_memory_tool_spec(enabled=bool(global_config.maisaka.enable_memory_query_tool)), get_no_reply_tool_spec(), get_send_emoji_tool_spec(), ] @@ -46,6 +50,7 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]: get_reply_tool_spec(), get_view_complex_message_tool_spec(), get_query_jargon_tool_spec(), + get_query_memory_tool_spec(enabled=True), get_query_person_info_tool_spec(), get_no_reply_tool_spec(), get_send_emoji_tool_spec(), @@ -65,6 +70,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str "reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context), "no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context), "query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context), + "query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context), "query_person_info": lambda invocation, context=None: handle_query_person_info_tool( tool_ctx, invocation, diff --git a/src/maisaka/builtin_tool/query_memory.py b/src/maisaka/builtin_tool/query_memory.py new file mode 100644 index 00000000..529eb2a7 --- /dev/null +++ b/src/maisaka/builtin_tool/query_memory.py @@ -0,0 +1,253 @@ +"""query_memory 内置工具。""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +from src.common.logger import get_logger +from src.config.config import global_config +from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec +from src.person_info.person_info import resolve_person_id_for_memory +from src.services.memory_service import MemorySearchResult, memory_service + +from .context import BuiltinToolRuntimeContext + +logger = get_logger("maisaka_builtin_query_memory") + +_ALLOWED_QUERY_MODES = {"search", "time", "hybrid", "episode", "aggregate"} + + +def get_tool_spec(*, enabled: bool = True) -> ToolSpec: + """获取 query_memory 工具声明。""" + + return ToolSpec( + name="query_memory", + brief_description="检索 A_memorix 长期记忆并返回可读结果。", + detailed_description=( + "参数说明:\n" + "- query:string,可选。要检索的关键词或问题。\n" + "- limit:integer,可选。返回条数,默认使用系统配置值。\n" + "- mode:string,可选。search/time/hybrid/episode/aggregate。\n" + "- person_name:string,可选。人物名,优先用于解析并过滤 person_id。\n" + "- time_start:string,可选。起始时间,可填写时间戳或可解析时间文本。\n" + "- time_end:string,可选。结束时间,可填写时间戳或可解析时间文本。\n" + "- respect_filter:boolean,可选。是否应用聊天过滤配置,默认 true。" + ), + parameters_schema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "要检索的关键词或问题。", + }, + "limit": { + "type": "integer", + "description": "返回条数,默认使用系统配置值。", + }, + "mode": { + "type": "string", + "description": "检索模式:search/time/hybrid/episode/aggregate。", + "enum": sorted(_ALLOWED_QUERY_MODES), + "default": "search", + }, + "person_name": { + "type": "string", + "description": "人物名称,可选。提供后优先按人物过滤。", + }, + "time_start": { + "type": "string", + "description": "起始时间,可填写时间戳或可解析时间文本。", + }, + "time_end": { + "type": "string", + "description": "结束时间,可填写时间戳或可解析时间文本。", + }, + "respect_filter": { + "type": "boolean", + "description": "是否应用聊天过滤配置。", + "default": True, + }, + }, + }, + provider_name="maisaka_builtin", + provider_type="builtin", + enabled=enabled, + ) + + +def _normalize_optional_time(raw_value: Any) -> str | float | None: + """归一化可选时间参数。""" + + if raw_value is None: + return None + if isinstance(raw_value, str): + time_text = raw_value.strip() + if not time_text: + return None + return time_text + if isinstance(raw_value, (float, int)): + return float(raw_value) + + time_text = str(raw_value).strip() + if not time_text: + return None + return time_text + + +def _resolve_person_id( + *, + person_name: str, + platform: str, + user_id: str, + group_id: str, +) -> Tuple[str, str]: + """按约定顺序解析长期记忆检索使用的 person_id。""" + + clean_person_name = str(person_name or "").strip() + if clean_person_name: + person_id = resolve_person_id_for_memory( + person_name=clean_person_name, + platform=platform, + user_id=user_id, + ) + if person_id: + return person_id, clean_person_name + + if not group_id and platform and user_id: + person_id = resolve_person_id_for_memory( + platform=platform, + user_id=user_id, + ) + if person_id: + return person_id, clean_person_name + + return "", clean_person_name + + +def _build_success_content(result: MemorySearchResult, *, limit: int) -> str: + """构造工具成功时的可读内容。""" + + summary = str(result.summary or "").strip() + snippet = result.to_text(limit=max(1, int(limit))) + + if result.hits: + if summary and snippet: + return f"{summary}\n{snippet}" + if summary: + return summary + if snippet: + return snippet + return "已找到匹配的长期记忆。" + + if result.filtered: + return "当前请求被聊天过滤策略跳过,未执行长期记忆检索。" + return "未找到匹配的长期记忆。" + + +async def handle_tool( + tool_ctx: BuiltinToolRuntimeContext, + invocation: ToolInvocation, + context: Optional[ToolExecutionContext] = None, +) -> ToolExecutionResult: + """执行 query_memory 内置工具。""" + + del context + runtime = tool_ctx.runtime + chat_stream = runtime.chat_stream + + clean_query = str(invocation.arguments.get("query") or "").strip() + mode = str(invocation.arguments.get("mode") or "search").strip().lower() or "search" + if mode not in _ALLOWED_QUERY_MODES: + return tool_ctx.build_failure_result( + invocation.tool_name, + f"不支持的检索模式:{mode}。可选值:search/time/hybrid/episode/aggregate。", + ) + + default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5)) + try: + limit = int(invocation.arguments.get("limit", default_limit) or default_limit) + except (TypeError, ValueError): + limit = default_limit + limit = max(1, min(limit, 20)) + + time_start = _normalize_optional_time(invocation.arguments.get("time_start")) + time_end = _normalize_optional_time(invocation.arguments.get("time_end")) + if not clean_query and time_start is None and time_end is None: + return tool_ctx.build_failure_result( + invocation.tool_name, + "query_memory 需要提供 query,或至少提供 time_start/time_end 中的一个。", + ) + + session_id = str(runtime.session_id or "").strip() + platform = str(chat_stream.platform or "").strip() + user_id = str(chat_stream.user_id or "").strip() + group_id = str(chat_stream.group_id or "").strip() + person_id, person_name = _resolve_person_id( + person_name=str(invocation.arguments.get("person_name") or ""), + platform=platform, + user_id=user_id, + group_id=group_id, + ) + respect_filter = bool(invocation.arguments.get("respect_filter", True)) + + logger.info( + f"{runtime.log_prefix} 触发长期记忆检索工具: " + f"mode={mode} query={clean_query!r} person_name={person_name!r} person_id={person_id!r}" + ) + try: + result = await memory_service.search( + clean_query, + limit=limit, + mode=mode, + chat_id=session_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + respect_filter=respect_filter, + user_id=user_id, + group_id=group_id, + ) + except Exception as exc: + logger.exception(f"{runtime.log_prefix} 长期记忆检索执行异常: {exc}") + return tool_ctx.build_failure_result( + invocation.tool_name, + f"长期记忆检索失败:{exc}", + ) + + structured_content: Dict[str, Any] = result.to_dict() + structured_content.update( + { + "query": clean_query, + "mode": mode, + "limit": limit, + "chat_id": session_id, + "person_name": person_name, + "person_id": person_id, + "time_start": time_start, + "time_end": time_end, + "respect_filter": respect_filter, + "user_id": user_id, + "group_id": group_id, + } + ) + + if not result.success: + error_message = str(result.error or "").strip() or "长期记忆检索失败。" + return tool_ctx.build_failure_result( + invocation.tool_name, + error_message, + structured_content=structured_content, + ) + + content = _build_success_content(result, limit=limit) + if clean_query: + display_prompt = f"你查询了长期记忆:{clean_query}" + else: + display_prompt = "你按时间范围查询了长期记忆。" + + return tool_ctx.build_success_result( + invocation.tool_name, + content, + structured_content=structured_content, + metadata={"record_display_prompt": display_prompt}, + ) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 26436455..afccc611 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -637,6 +637,15 @@ class MaisakaReasoningEngine: return f"你查询了人物信息:{person_name}" return "你查询了一次人物信息。" + if invocation.tool_name == "query_memory": + query_text = str(invocation.arguments.get("query") or "").strip() + mode = str(invocation.arguments.get("mode") or "search").strip() or "search" + hit_items = structured_content.get("hits") + hit_count = len(hit_items) if isinstance(hit_items, list) else 0 + if query_text: + return f"你查询了长期记忆:{query_text}(模式:{mode},命中 {hit_count} 条)" + return f"你按时间范围查询了一次长期记忆(模式:{mode},命中 {hit_count} 条)。" + if invocation.tool_name == "view_complex_message": target_message_id = str(invocation.arguments.get("msg_id") or "").strip() if target_message_id: