Merge pull request #1580 from A-Dawn/r-dev

feat:将A_memorix接入maisaka
This commit is contained in:
Dawn ARC
2026-04-04 02:52:00 +08:00
committed by GitHub
6 changed files with 548 additions and 0 deletions

View File

@@ -0,0 +1,241 @@
from types import SimpleNamespace
from typing import Any, Dict
import pytest
from src.core.tooling import ToolInvocation
from src.maisaka.builtin_tool import query_memory as query_memory_tool
from src.maisaka.builtin_tool.context import BuiltinToolRuntimeContext
from src.services.memory_service import MemoryHit, MemorySearchResult
def _build_tool_ctx(
*,
session_id: str = "session-1",
platform: str = "qq",
user_id: str = "user-1",
group_id: str = "",
) -> BuiltinToolRuntimeContext:
runtime = SimpleNamespace(
session_id=session_id,
chat_stream=SimpleNamespace(
platform=platform,
user_id=user_id,
group_id=group_id,
),
log_prefix=f"[{session_id}]",
)
return BuiltinToolRuntimeContext(engine=SimpleNamespace(), runtime=runtime)
def _build_invocation(arguments: Dict[str, Any]) -> ToolInvocation:
return ToolInvocation(
tool_name="query_memory",
arguments=dict(arguments),
call_id="call-query-memory",
)
@pytest.fixture(autouse=True)
def _patch_maisaka_config(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
query_memory_tool,
"global_config",
SimpleNamespace(maisaka=SimpleNamespace(memory_query_default_limit=5)),
)
@pytest.mark.asyncio
async def test_query_memory_rejects_empty_query_and_time(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
raise AssertionError("参数校验失败时不应调用 memory_service.search")
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "", "time_start": "", "time_end": ""}),
)
assert result.success is False
assert "query_memory 需要提供 query" in result.error_message
@pytest.mark.asyncio
async def test_query_memory_private_chat_auto_sets_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = strict_known
captured["resolve_args"] = {
"person_name": person_name,
"platform": platform,
"user_id": user_id,
}
return "pid-private-auto"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
captured["query"] = query
captured["search_kwargs"] = dict(kwargs)
return MemorySearchResult(
summary="检索摘要",
hits=[MemoryHit(content="Alice 喜欢咖啡", score=0.91)],
)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
_build_invocation({"query": "Alice 的喜好"}),
)
assert result.success is True
assert captured["query"] == "Alice 的喜好"
assert captured["resolve_args"] == {
"person_name": "",
"platform": "qq",
"user_id": "alice",
}
assert captured["search_kwargs"]["chat_id"] == "private-session"
assert captured["search_kwargs"]["user_id"] == "alice"
assert captured["search_kwargs"]["group_id"] == ""
assert captured["search_kwargs"]["person_id"] == "pid-private-auto"
assert isinstance(result.structured_content, dict)
assert result.structured_content["person_id"] == "pid-private-auto"
@pytest.mark.asyncio
async def test_query_memory_group_chat_does_not_attach_default_person_id(monkeypatch: pytest.MonkeyPatch) -> None:
call_counter = {"resolve": 0}
captured_kwargs: Dict[str, Any] = {}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = person_name
_ = platform
_ = user_id
_ = strict_known
call_counter["resolve"] += 1
return "unexpected-person-id"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
captured_kwargs.update(kwargs)
return MemorySearchResult(summary="", hits=[])
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="group-session", platform="qq", user_id="alice", group_id="group-1"),
_build_invocation({"query": "群聊上下文"}),
)
assert result.success is True
assert call_counter["resolve"] == 0
assert captured_kwargs["chat_id"] == "group-session"
assert captured_kwargs["group_id"] == "group-1"
assert captured_kwargs["person_id"] == ""
@pytest.mark.asyncio
async def test_query_memory_search_failure_is_returned(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
return MemorySearchResult(success=False, error="boom")
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "测试失败透传"}),
)
assert result.success is False
assert result.error_message == "boom"
assert isinstance(result.structured_content, dict)
assert result.structured_content["success"] is False
@pytest.mark.asyncio
async def test_query_memory_prefers_person_name_resolution(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {"resolve_calls": []}
def fake_resolve_person_id_for_memory(
*,
person_name: str = "",
platform: str = "",
user_id: Any = None,
strict_known: bool = False,
) -> str:
_ = strict_known
captured["resolve_calls"].append(
{
"person_name": person_name,
"platform": platform,
"user_id": user_id,
}
)
if person_name:
return "pid-by-name"
return "pid-private-auto"
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
captured["search_kwargs"] = dict(kwargs)
return MemorySearchResult(summary="", hits=[MemoryHit(content="命中1")])
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", fake_resolve_person_id_for_memory)
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
result = await query_memory_tool.handle_tool(
_build_tool_ctx(session_id="private-session", platform="qq", user_id="alice", group_id=""),
_build_invocation({"query": "小明资料", "person_name": "小明"}),
)
assert result.success is True
assert captured["resolve_calls"][0] == {
"person_name": "小明",
"platform": "qq",
"user_id": "alice",
}
assert captured["search_kwargs"]["person_id"] == "pid-by-name"
assert result.structured_content["person_name"] == "小明"
assert result.structured_content["person_id"] == "pid-by-name"
@pytest.mark.asyncio
async def test_query_memory_no_hit_returns_readable_message(monkeypatch: pytest.MonkeyPatch) -> None:
async def fake_search(query: str, **kwargs: Any) -> MemorySearchResult:
_ = query
_ = kwargs
return MemorySearchResult(summary="", hits=[])
monkeypatch.setattr(query_memory_tool.memory_service, "search", fake_search)
monkeypatch.setattr(query_memory_tool, "resolve_person_id_for_memory", lambda **kwargs: "")
result = await query_memory_tool.handle_tool(
_build_tool_ctx(),
_build_invocation({"query": "不存在的记忆"}),
)
assert result.success is True
assert "未找到匹配的长期记忆" in result.content
assert isinstance(result.structured_content, dict)
assert result.structured_content["query"] == "不存在的记忆"

View File

@@ -99,6 +99,25 @@ def test_maisaka_is_host_tab_and_mcp_is_attached_to_it():
assert mcp_schema.get("uiParent") == "maisaka"
def test_maisaka_memory_query_config_fields_are_exposed():
"""MaiSaka 长期记忆检索开关和默认条数应出现在配置 schema 中。"""
schema = ConfigSchemaGenerator.generate_schema(Config)
maisaka_schema = schema["nested"]["maisaka"]
enable_field = next(field for field in maisaka_schema["fields"] if field["name"] == "enable_memory_query_tool")
limit_field = next(field for field in maisaka_schema["fields"] if field["name"] == "memory_query_default_limit")
assert enable_field["type"] == "boolean"
assert enable_field.get("x-widget") == "switch"
assert enable_field.get("x-icon") == "database"
assert limit_field["type"] == "integer"
assert limit_field.get("x-widget") == "input"
assert limit_field.get("x-icon") == "hash"
assert limit_field.get("minValue") == 1
assert limit_field.get("maxValue") == 20
def test_set_field_is_mapped_as_array():
"""set[str] 应映射为前端可识别的 array。"""
schema = ConfigSchemaGenerator.generate_schema(MessageReceiveConfig)

View File

@@ -1522,6 +1522,26 @@ class MaiSakaConfig(ConfigBase):
)
"""每个入站消息的最大内部规划轮数"""
enable_memory_query_tool: bool = Field(
default=True,
json_schema_extra={
"x-widget": "switch",
"x-icon": "database",
},
)
"""是否启用 Maisaka 内置长期记忆检索工具 query_memory"""
memory_query_default_limit: int = Field(
default=5,
ge=1,
le=20,
json_schema_extra={
"x-widget": "input",
"x-icon": "hash",
},
)
"""Maisaka 内置长期记忆检索工具 query_memory 的默认返回条数"""
tool_filter_task_name: str = Field(
default="utils",
json_schema_extra={

View File

@@ -3,6 +3,7 @@
from collections.abc import Awaitable, Callable
from typing import Dict, List, Optional
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.llm_models.payload_content.tool_option import ToolDefinitionInput
@@ -11,6 +12,8 @@ from .no_reply import get_tool_spec as get_no_reply_tool_spec
from .no_reply import handle_tool as handle_no_reply_tool
from .query_jargon import get_tool_spec as get_query_jargon_tool_spec
from .query_jargon import handle_tool as handle_query_jargon_tool
from .query_memory import get_tool_spec as get_query_memory_tool_spec
from .query_memory import handle_tool as handle_query_memory_tool
from .query_person_info import get_tool_spec as get_query_person_info_tool_spec
from .query_person_info import handle_tool as handle_query_person_info_tool
from .reply import get_tool_spec as get_reply_tool_spec
@@ -33,6 +36,7 @@ def get_builtin_tool_specs() -> List[ToolSpec]:
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=bool(global_config.maisaka.enable_memory_query_tool)),
get_no_reply_tool_spec(),
get_send_emoji_tool_spec(),
]
@@ -46,6 +50,7 @@ def get_all_builtin_tool_specs() -> List[ToolSpec]:
get_reply_tool_spec(),
get_view_complex_message_tool_spec(),
get_query_jargon_tool_spec(),
get_query_memory_tool_spec(enabled=True),
get_query_person_info_tool_spec(),
get_no_reply_tool_spec(),
get_send_emoji_tool_spec(),
@@ -65,6 +70,7 @@ def build_builtin_tool_handlers(tool_ctx: BuiltinToolRuntimeContext) -> Dict[str
"reply": lambda invocation, context=None: handle_reply_tool(tool_ctx, invocation, context),
"no_reply": lambda invocation, context=None: handle_no_reply_tool(tool_ctx, invocation, context),
"query_jargon": lambda invocation, context=None: handle_query_jargon_tool(tool_ctx, invocation, context),
"query_memory": lambda invocation, context=None: handle_query_memory_tool(tool_ctx, invocation, context),
"query_person_info": lambda invocation, context=None: handle_query_person_info_tool(
tool_ctx,
invocation,

View File

@@ -0,0 +1,253 @@
"""query_memory 内置工具。"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
from src.common.logger import get_logger
from src.config.config import global_config
from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvocation, ToolSpec
from src.person_info.person_info import resolve_person_id_for_memory
from src.services.memory_service import MemorySearchResult, memory_service
from .context import BuiltinToolRuntimeContext
logger = get_logger("maisaka_builtin_query_memory")
_ALLOWED_QUERY_MODES = {"search", "time", "hybrid", "episode", "aggregate"}
def get_tool_spec(*, enabled: bool = True) -> ToolSpec:
"""获取 query_memory 工具声明。"""
return ToolSpec(
name="query_memory",
brief_description="检索 A_memorix 长期记忆并返回可读结果。",
detailed_description=(
"参数说明:\n"
"- querystring可选。要检索的关键词或问题。\n"
"- limitinteger可选。返回条数默认使用系统配置值。\n"
"- modestring可选。search/time/hybrid/episode/aggregate。\n"
"- person_namestring可选。人物名优先用于解析并过滤 person_id。\n"
"- time_startstring可选。起始时间可填写时间戳或可解析时间文本。\n"
"- time_endstring可选。结束时间可填写时间戳或可解析时间文本。\n"
"- respect_filterboolean可选。是否应用聊天过滤配置默认 true。"
),
parameters_schema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "要检索的关键词或问题。",
},
"limit": {
"type": "integer",
"description": "返回条数,默认使用系统配置值。",
},
"mode": {
"type": "string",
"description": "检索模式search/time/hybrid/episode/aggregate。",
"enum": sorted(_ALLOWED_QUERY_MODES),
"default": "search",
},
"person_name": {
"type": "string",
"description": "人物名称,可选。提供后优先按人物过滤。",
},
"time_start": {
"type": "string",
"description": "起始时间,可填写时间戳或可解析时间文本。",
},
"time_end": {
"type": "string",
"description": "结束时间,可填写时间戳或可解析时间文本。",
},
"respect_filter": {
"type": "boolean",
"description": "是否应用聊天过滤配置。",
"default": True,
},
},
},
provider_name="maisaka_builtin",
provider_type="builtin",
enabled=enabled,
)
def _normalize_optional_time(raw_value: Any) -> str | float | None:
"""归一化可选时间参数。"""
if raw_value is None:
return None
if isinstance(raw_value, str):
time_text = raw_value.strip()
if not time_text:
return None
return time_text
if isinstance(raw_value, (float, int)):
return float(raw_value)
time_text = str(raw_value).strip()
if not time_text:
return None
return time_text
def _resolve_person_id(
*,
person_name: str,
platform: str,
user_id: str,
group_id: str,
) -> Tuple[str, str]:
"""按约定顺序解析长期记忆检索使用的 person_id。"""
clean_person_name = str(person_name or "").strip()
if clean_person_name:
person_id = resolve_person_id_for_memory(
person_name=clean_person_name,
platform=platform,
user_id=user_id,
)
if person_id:
return person_id, clean_person_name
if not group_id and platform and user_id:
person_id = resolve_person_id_for_memory(
platform=platform,
user_id=user_id,
)
if person_id:
return person_id, clean_person_name
return "", clean_person_name
def _build_success_content(result: MemorySearchResult, *, limit: int) -> str:
"""构造工具成功时的可读内容。"""
summary = str(result.summary or "").strip()
snippet = result.to_text(limit=max(1, int(limit)))
if result.hits:
if summary and snippet:
return f"{summary}\n{snippet}"
if summary:
return summary
if snippet:
return snippet
return "已找到匹配的长期记忆。"
if result.filtered:
return "当前请求被聊天过滤策略跳过,未执行长期记忆检索。"
return "未找到匹配的长期记忆。"
async def handle_tool(
tool_ctx: BuiltinToolRuntimeContext,
invocation: ToolInvocation,
context: Optional[ToolExecutionContext] = None,
) -> ToolExecutionResult:
"""执行 query_memory 内置工具。"""
del context
runtime = tool_ctx.runtime
chat_stream = runtime.chat_stream
clean_query = str(invocation.arguments.get("query") or "").strip()
mode = str(invocation.arguments.get("mode") or "search").strip().lower() or "search"
if mode not in _ALLOWED_QUERY_MODES:
return tool_ctx.build_failure_result(
invocation.tool_name,
f"不支持的检索模式:{mode}。可选值search/time/hybrid/episode/aggregate。",
)
default_limit = max(1, int(getattr(global_config.maisaka, "memory_query_default_limit", 5) or 5))
try:
limit = int(invocation.arguments.get("limit", default_limit) or default_limit)
except (TypeError, ValueError):
limit = default_limit
limit = max(1, min(limit, 20))
time_start = _normalize_optional_time(invocation.arguments.get("time_start"))
time_end = _normalize_optional_time(invocation.arguments.get("time_end"))
if not clean_query and time_start is None and time_end is None:
return tool_ctx.build_failure_result(
invocation.tool_name,
"query_memory 需要提供 query或至少提供 time_start/time_end 中的一个。",
)
session_id = str(runtime.session_id or "").strip()
platform = str(chat_stream.platform or "").strip()
user_id = str(chat_stream.user_id or "").strip()
group_id = str(chat_stream.group_id or "").strip()
person_id, person_name = _resolve_person_id(
person_name=str(invocation.arguments.get("person_name") or ""),
platform=platform,
user_id=user_id,
group_id=group_id,
)
respect_filter = bool(invocation.arguments.get("respect_filter", True))
logger.info(
f"{runtime.log_prefix} 触发长期记忆检索工具: "
f"mode={mode} query={clean_query!r} person_name={person_name!r} person_id={person_id!r}"
)
try:
result = await memory_service.search(
clean_query,
limit=limit,
mode=mode,
chat_id=session_id,
person_id=person_id,
time_start=time_start,
time_end=time_end,
respect_filter=respect_filter,
user_id=user_id,
group_id=group_id,
)
except Exception as exc:
logger.exception(f"{runtime.log_prefix} 长期记忆检索执行异常: {exc}")
return tool_ctx.build_failure_result(
invocation.tool_name,
f"长期记忆检索失败:{exc}",
)
structured_content: Dict[str, Any] = result.to_dict()
structured_content.update(
{
"query": clean_query,
"mode": mode,
"limit": limit,
"chat_id": session_id,
"person_name": person_name,
"person_id": person_id,
"time_start": time_start,
"time_end": time_end,
"respect_filter": respect_filter,
"user_id": user_id,
"group_id": group_id,
}
)
if not result.success:
error_message = str(result.error or "").strip() or "长期记忆检索失败。"
return tool_ctx.build_failure_result(
invocation.tool_name,
error_message,
structured_content=structured_content,
)
content = _build_success_content(result, limit=limit)
if clean_query:
display_prompt = f"你查询了长期记忆:{clean_query}"
else:
display_prompt = "你按时间范围查询了长期记忆。"
return tool_ctx.build_success_result(
invocation.tool_name,
content,
structured_content=structured_content,
metadata={"record_display_prompt": display_prompt},
)

View File

@@ -649,6 +649,15 @@ class MaisakaReasoningEngine:
return f"你查询了人物信息:{person_name}"
return "你查询了一次人物信息。"
if invocation.tool_name == "query_memory":
query_text = str(invocation.arguments.get("query") or "").strip()
mode = str(invocation.arguments.get("mode") or "search").strip() or "search"
hit_items = structured_content.get("hits")
hit_count = len(hit_items) if isinstance(hit_items, list) else 0
if query_text:
return f"你查询了长期记忆:{query_text}(模式:{mode},命中 {hit_count} 条)"
return f"你按时间范围查询了一次长期记忆(模式:{mode},命中 {hit_count} 条)。"
if invocation.tool_name == "view_complex_message":
target_message_id = str(invocation.arguments.get("msg_id") or "").strip()
if target_message_id: