feat:同步本地非算法改动到上游基线

保留反馈纠错、WebUI 与运行时增强。\n移除不应提交的 algorithm_redesign 设计目录及其专项测试。
This commit is contained in:
A-Dawn
2026-04-16 13:57:07 +08:00
parent 6c22fdfdf9
commit 21b642d07d
10 changed files with 2244 additions and 34 deletions

View File

@@ -0,0 +1,740 @@
from __future__ import annotations
import asyncio
import inspect
import json
import time
import uuid
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Callable, Dict
import numpy as np
import pytest
import pytest_asyncio
from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine, select
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
from src.chat.heart_flow.heartflow_manager import heartflow_manager
from src.chat.message_receive import bot as bot_module
from src.chat.message_receive.chat_manager import chat_manager
from src.chat.message_receive.bot import chat_bot
from src.common.database import database as database_module
from src.common.database.database_model import PersonInfo, ToolRecord
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.utils.utils_session import SessionUtils
from src.llm_models.payload_content.tool_option import ToolCall
from src.maisaka import reasoning_engine as reasoning_engine_module
from src.maisaka import runtime as runtime_module
from src.maisaka.chat_loop_service import ChatResponse
from src.maisaka.context_messages import AssistantMessage
from src.plugin_runtime import component_query as component_query_module
from src.services import memory_flow_service as memory_flow_service_module
from src.services import memory_service as memory_service_module
from src.services.memory_service import memory_service
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
kernel_module = None # type: ignore[assignment]
KernelSearchRequest = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
heartflow_manager = None # type: ignore[assignment]
bot_module = None # type: ignore[assignment]
chat_manager = None # type: ignore[assignment]
chat_bot = None # type: ignore[assignment]
database_module = None # type: ignore[assignment]
ToolRecord = None # type: ignore[assignment]
PersonInfo = None # type: ignore[assignment]
create_database_migration_bootstrapper = None # type: ignore[assignment]
SessionUtils = None # type: ignore[assignment]
ToolCall = None # type: ignore[assignment]
reasoning_engine_module = None # type: ignore[assignment]
runtime_module = None # type: ignore[assignment]
ChatResponse = None # type: ignore[assignment]
AssistantMessage = None # type: ignore[assignment]
component_query_module = None # type: ignore[assignment]
memory_flow_service_module = None # type: ignore[assignment]
memory_service_module = None # type: ignore[assignment]
memory_service = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
RELATION_QUERY = "测试用户 和 最喜欢的颜色 有什么关系"
class _FakeEmbeddingManager:
def __init__(self, dimension: int = 8) -> None:
self.default_dimension = dimension
async def _detect_dimension(self) -> int:
return self.default_dimension
async def encode(self, text: Any) -> np.ndarray:
def _encode_one(raw: Any) -> np.ndarray:
content = str(raw or "")
vector = np.zeros(self.default_dimension, dtype=np.float32)
for index, byte in enumerate(content.encode("utf-8")):
vector[index % self.default_dimension] += float((byte % 17) + 1)
norm = float(np.linalg.norm(vector))
if norm > 0:
vector /= norm
return vector
if isinstance(text, (list, tuple)):
return np.stack([_encode_one(item) for item in text]).astype(np.float32)
return _encode_one(text).astype(np.float32)
class _KernelBackedRuntimeManager:
def __init__(self, kernel: SDKMemoryKernel) -> None:
self.kernel = kernel
async def invoke(self, component_name: str, args: Dict[str, Any] | None, *, timeout_ms: int = 30000):
del timeout_ms
payload = args or {}
if component_name == "search_memory":
return await self.kernel.search_memory(
KernelSearchRequest(
query=str(payload.get("query", "") or ""),
limit=int(payload.get("limit", 5) or 5),
mode=str(payload.get("mode", "hybrid") or "hybrid"),
chat_id=str(payload.get("chat_id", "") or ""),
person_id=str(payload.get("person_id", "") or ""),
time_start=payload.get("time_start"),
time_end=payload.get("time_end"),
respect_filter=bool(payload.get("respect_filter", True)),
user_id=str(payload.get("user_id", "") or ""),
group_id=str(payload.get("group_id", "") or ""),
)
)
handler = getattr(self.kernel, component_name)
result = handler(**payload)
return await result if inspect.isawaitable(result) else result
class _NoopRuntimeManager:
async def invoke_hook(self, hook_name: str, **kwargs: Any) -> Any:
del hook_name
return SimpleNamespace(aborted=False, kwargs=kwargs)
def _install_temp_main_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
db_dir = (tmp_path / "main_db").resolve()
db_dir.mkdir(parents=True, exist_ok=True)
db_file = db_dir / "MaiBot.db"
database_url = f"sqlite:///{db_file}"
try:
database_module.engine.dispose()
except Exception:
pass
engine = create_engine(
database_url,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
session_local = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
bootstrapper = create_database_migration_bootstrapper(engine)
monkeypatch.setattr(database_module, "_DB_DIR", db_dir, raising=False)
monkeypatch.setattr(database_module, "_DB_FILE", db_file, raising=False)
monkeypatch.setattr(database_module, "DATABASE_URL", database_url, raising=False)
monkeypatch.setattr(database_module, "engine", engine, raising=False)
monkeypatch.setattr(database_module, "SessionLocal", session_local, raising=False)
monkeypatch.setattr(database_module, "_migration_bootstrapper", bootstrapper, raising=False)
monkeypatch.setattr(database_module, "_db_initialized", False, raising=False)
def _build_chat_response(content: str, tool_calls: list[ToolCall]) -> ChatResponse:
return ChatResponse(
content=content,
tool_calls=tool_calls,
request_messages=[],
raw_message=AssistantMessage(
content=content,
timestamp=datetime.now(),
tool_calls=tool_calls,
),
selected_history_count=0,
tool_count=len(tool_calls),
prompt_tokens=0,
built_message_count=0,
completion_tokens=0,
total_tokens=0,
prompt_section=None,
)
def _build_message_data(
*,
content: str,
platform: str,
user_id: str,
user_name: str,
group_id: str,
group_name: str,
) -> Dict[str, Any]:
message_id = str(uuid.uuid4())
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": user_id,
"user_nickname": user_name,
"user_cardname": user_name,
"platform": platform,
},
"additional_config": {
"at_bot": True,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
async def _wait_until(
predicate: Callable[[], Any],
*,
timeout_seconds: float = 10.0,
interval_seconds: float = 0.05,
description: str,
) -> Any:
deadline = asyncio.get_running_loop().time() + max(0.5, float(timeout_seconds))
while asyncio.get_running_loop().time() < deadline:
value = predicate()
if inspect.isawaitable(value):
value = await value
if value:
return value
await asyncio.sleep(interval_seconds)
raise AssertionError(f"等待超时: {description}")
def _load_feedback_tasks(kernel: SDKMemoryKernel) -> list[Dict[str, Any]]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT query_tool_id FROM memory_feedback_tasks ORDER BY id"
).fetchall()
tasks: list[Dict[str, Any]] = []
for row in rows:
task = kernel.metadata_store.get_feedback_task(str(row["query_tool_id"] or ""))
if task is not None:
tasks.append(task)
return tasks
def _load_feedback_action_types(kernel: SDKMemoryKernel) -> list[str]:
assert kernel.metadata_store is not None
cursor = kernel.metadata_store.get_connection().cursor()
rows = cursor.execute(
"SELECT action_type FROM memory_feedback_action_logs ORDER BY id"
).fetchall()
return [str(row["action_type"] or "") for row in rows]
def _load_query_memory_tool_records(session_id: str) -> list[Dict[str, Any]]:
with database_module.get_db_session() as session:
statement = (
select(ToolRecord)
.where(ToolRecord.session_id == session_id)
.where(ToolRecord.tool_name == "query_memory")
.order_by(ToolRecord.timestamp)
)
rows = list(session.exec(statement).all())
return [
{
"tool_id": str(row.tool_id or ""),
"session_id": str(row.session_id or ""),
"tool_name": str(row.tool_name or ""),
"tool_data": str(row.tool_data or ""),
"timestamp": row.timestamp,
}
for row in rows
]
def _seed_person_info(*, person_id: str, person_name: str, session_info: Dict[str, Any]) -> None:
with database_module.get_db_session() as session:
session.add(
PersonInfo(
is_known=True,
person_id=person_id,
person_name=person_name,
platform=str(session_info["platform"]),
user_id=str(session_info["user_id"]),
user_nickname=str(session_info["user_name"]),
group_cardname=json.dumps(
[{"group_id": str(session_info["group_id"]), "group_cardname": person_name}],
ensure_ascii=False,
),
know_counts=1,
first_known_time=datetime.now(),
last_known_time=datetime.now(),
)
)
session.commit()
@pytest_asyncio.fixture
async def chat_feedback_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
_install_temp_main_database(monkeypatch, tmp_path)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
heartflow_manager.heartflow_chat_list.clear()
noop_runtime_manager = _NoopRuntimeManager()
monkeypatch.setattr(bot_module.ChatBot, "_get_runtime_manager", staticmethod(lambda: noop_runtime_manager))
monkeypatch.setattr(
component_query_module.component_query_service,
"find_command_by_text",
lambda text: None,
)
monkeypatch.setattr(
component_query_module.component_query_service,
"get_llm_available_tool_specs",
lambda: {},
)
monkeypatch.setattr(runtime_module.global_config.mcp, "enable", False, raising=False)
monkeypatch.setattr(
runtime_module.MaisakaHeartFlowChatting,
"_get_message_trigger_threshold",
lambda self: 1,
)
async def _noop_on_incoming_message(message: Any) -> None:
del message
monkeypatch.setattr(
memory_flow_service_module.memory_automation_service,
"on_incoming_message",
_noop_on_incoming_message,
)
fake_embedding_manager = _FakeEmbeddingManager(dimension=8)
async def _fake_runtime_self_check(
*,
config: Any,
sample_text: str,
vector_store: Any,
embedding_manager: Any,
) -> Dict[str, Any]:
del config, sample_text, vector_store, embedding_manager
return {
"ok": True,
"message": "ok",
"checked_at": time.time(),
"encoded_dimension": fake_embedding_manager.default_dimension,
}
monkeypatch.setattr(
kernel_module,
"create_embedding_api_adapter",
lambda **kwargs: fake_embedding_manager,
)
monkeypatch.setattr(
kernel_module,
"run_embedding_runtime_self_check",
_fake_runtime_self_check,
)
kernel = SDKMemoryKernel(
plugin_root=tmp_path / "plugin_root",
config={
"storage": {"data_dir": str((tmp_path / "a_memorix_data").resolve())},
"advanced": {"enable_auto_save": False},
"embedding": {"dimension": fake_embedding_manager.default_dimension},
"memory": {"base_decay_interval_hours": 24},
"person_profile": {"refresh_interval_minutes": 5},
},
)
monkeypatch.setattr(kernel, "_feedback_cfg_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_window_hours", lambda: 0.0004)
monkeypatch.setattr(kernel, "_feedback_cfg_check_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_batch_size", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_max_messages", lambda: 10)
monkeypatch.setattr(kernel, "_feedback_cfg_auto_apply_threshold", lambda: 0.85)
monkeypatch.setattr(kernel, "_feedback_cfg_prefilter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_mark_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_paragraph_hard_filter_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_refresh_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_profile_force_refresh_on_read", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_rebuild_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_episode_query_block_enabled", lambda: True)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_interval_seconds", lambda: 0.2)
monkeypatch.setattr(kernel, "_feedback_cfg_reconcile_batch_size", lambda: 10)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_paragraph_hard_filter_enabled", True, raising=False)
monkeypatch.setattr(kernel_module.global_config.memory, "feedback_correction_episode_query_block_enabled", True, raising=False)
async def _fake_classify_feedback(
*,
query_tool_id: str,
query_text: str,
hit_briefs: list[Dict[str, Any]],
feedback_messages: list[str],
) -> Dict[str, Any]:
del query_tool_id, query_text, feedback_messages
target_hash = ""
for item in hit_briefs:
if str(item.get("type", "") or "").strip() == "relation":
target_hash = str(item.get("hash", "") or "").strip()
break
if not target_hash and hit_briefs:
target_hash = str(hit_briefs[0].get("hash", "") or "").strip()
return {
"decision": "correct",
"confidence": 0.97,
"target_hashes": [target_hash] if target_hash else [],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
}
monkeypatch.setattr(kernel, "_classify_feedback", _fake_classify_feedback)
await kernel.initialize()
async def _force_episode_fallback(**kwargs: Any) -> Dict[str, Any]:
raise RuntimeError("force_fallback_for_test")
monkeypatch.setattr(
kernel.episode_service.segmentation_service,
"segment",
_force_episode_fallback,
)
monkeypatch.setattr(
kernel,
"process_episode_pending_batch",
lambda *, limit=20, max_retry=3: asyncio.sleep(0, result={"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0}),
)
host_manager = _KernelBackedRuntimeManager(kernel)
monkeypatch.setattr(memory_service_module, "a_memorix_host_service", host_manager)
planner_calls: list[str] = []
async def _fake_timing_gate(self, anchor_message: Any):
del self, anchor_message
return "continue", _build_chat_response("直接进入 planner。", []), []
async def _fake_planner(self, *, tool_definitions: list[dict[str, Any]] | None = None) -> ChatResponse:
del tool_definitions
latest_message = self._runtime.message_cache[-1]
latest_text = str(latest_message.processed_plain_text or "")
planner_calls.append(latest_text)
handled_message_ids = getattr(self._runtime, "_test_query_message_ids", None)
if handled_message_ids is None:
handled_message_ids = set()
setattr(self._runtime, "_test_query_message_ids", handled_message_ids)
if latest_message.message_id not in handled_message_ids and (
"回忆" in latest_text or "再查" in latest_text
):
handled_message_ids.add(latest_message.message_id)
tool_call = ToolCall(
call_id=f"query-{uuid.uuid4().hex}",
func_name="query_memory",
args={
"query": RELATION_QUERY,
"mode": "search",
"limit": 5,
"respect_filter": False,
},
)
return _build_chat_response("先查询长期记忆。", [tool_call])
stop_call = ToolCall(
call_id=f"stop-{uuid.uuid4().hex}",
func_name="no_reply",
args={},
)
return _build_chat_response("当前轮次结束。", [stop_call])
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_timing_gate",
_fake_timing_gate,
)
monkeypatch.setattr(
reasoning_engine_module.MaisakaReasoningEngine,
"_run_interruptible_planner",
_fake_planner,
)
session_info = {
"platform": "unit_test_chat",
"user_id": "user_feedback_flow",
"user_name": "反馈测试用户",
"group_id": "group_feedback_flow",
"group_name": "反馈纠错测试群",
}
person_id = "person_feedback_flow"
session_id = SessionUtils.calculate_session_id(
session_info["platform"],
user_id=session_info["user_id"],
group_id=session_info["group_id"],
)
_seed_person_info(person_id=person_id, person_name="测试用户", session_info=session_info)
try:
yield {
"kernel": kernel,
"session_id": session_id,
"session_info": session_info,
"person_id": person_id,
"planner_calls": planner_calls,
}
finally:
for key, chat in list(heartflow_manager.heartflow_chat_list.items()):
try:
await chat.stop()
except Exception:
pass
heartflow_manager.heartflow_chat_list.pop(key, None)
chat_manager.sessions.clear()
chat_manager.last_messages.clear()
await kernel.shutdown()
try:
database_module.engine.dispose()
except Exception:
pass
@pytest.mark.asyncio
async def test_feedback_correction_real_chat_flow(chat_feedback_env) -> None:
kernel = chat_feedback_env["kernel"]
session_id = chat_feedback_env["session_id"]
session_info = chat_feedback_env["session_info"]
person_id = chat_feedback_env["person_id"]
write_result = await memory_service.ingest_text(
external_id=f"test:feedback-seed:{uuid.uuid4().hex}",
source_type="chat_summary",
text="测试用户 最喜欢的颜色是 蓝色",
chat_id=session_id,
relations=[
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"confidence": 1.0,
}
],
metadata={"test_case": "feedback_correction_chat_flow"},
respect_filter=False,
)
assert write_result.success is True
pre_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert pre_search.hits
assert any("蓝色" in hit.content for hit in pre_search.hits)
pre_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
pre_profile_text = pre_profile.summary + "\n" + json.dumps(pre_profile.evidence, ensure_ascii=False)
assert "蓝色" in pre_profile_text
seed_source = f"chat_summary:{session_id}"
rebuild_result = await kernel.rebuild_episodes_for_sources([seed_source])
assert rebuild_result["rebuilt"] >= 1
pre_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert pre_episode.hits
assert any("蓝色" in hit.content for hit in pre_episode.hits)
await chat_bot.message_process(
_build_message_data(
content="请帮我回忆一下,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
await _wait_until(
lambda: chat_feedback_env["planner_calls"][0] if chat_feedback_env["planner_calls"] else None,
description="planner 收到首条聊天消息",
)
first_query_records = await _wait_until(
lambda: _load_query_memory_tool_records(session_id) if _load_query_memory_tool_records(session_id) else None,
description="首条 query_memory 工具记录生成",
)
assert first_query_records
first_task = await _wait_until(
lambda: _load_feedback_tasks(kernel)[0] if _load_feedback_tasks(kernel) else None,
description="首个反馈任务入队",
)
assert first_task["status"] == "pending"
first_hits = list((first_task.get("query_snapshot") or {}).get("hits") or [])
assert first_hits
assert any("蓝色" in str(item.get("content", "") or "") for item in first_hits)
await chat_bot.message_process(
_build_message_data(
content="不对,测试用户最喜欢的颜色不是蓝色,是绿色。",
**session_info,
)
)
finalized_task = await _wait_until(
lambda: (
kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
if kernel.metadata_store.get_feedback_task(first_task["query_tool_id"])
and kernel.metadata_store.get_feedback_task(first_task["query_tool_id"]).get("status")
in {"applied", "skipped", "error"}
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="反馈任务进入终态",
)
assert finalized_task["status"] == "applied", finalized_task
assert finalized_task["decision_payload"]["decision"] == "correct"
assert finalized_task["decision_payload"]["apply_result"]["applied"] is True
corrected_hashes = list(
(finalized_task["decision_payload"].get("apply_result") or {}).get("relation_hashes") or []
)
assert corrected_hashes
corrected_hash = str(corrected_hashes[0] or "")
relation_status = kernel.metadata_store.get_relation_status_batch([corrected_hash]).get(corrected_hash, {})
assert bool(relation_status.get("is_inactive")) is True
action_types = _load_feedback_action_types(kernel)
assert "classification" in action_types
assert "forget_relation" in action_types
assert "ingest_correction" in action_types
assert "mark_stale_paragraph" in action_types
assert "enqueue_episode_rebuild" in action_types
assert "enqueue_profile_refresh" in action_types
direct_post_search = await memory_service.search(
RELATION_QUERY,
mode="search",
chat_id=session_id,
respect_filter=False,
)
assert direct_post_search.hits
post_contents = "\n".join(hit.content for hit in direct_post_search.hits)
assert "绿色" in post_contents
assert "蓝色" not in post_contents
profile_refresh_request = await _wait_until(
lambda: (
kernel.metadata_store.get_person_profile_refresh_request(person_id)
if kernel.metadata_store.get_person_profile_refresh_request(person_id)
and kernel.metadata_store.get_person_profile_refresh_request(person_id).get("status") == "done"
else None
),
timeout_seconds=12.0,
interval_seconds=0.1,
description="人物画像刷新完成",
)
assert profile_refresh_request["status"] == "done"
post_profile = await memory_service.get_person_profile(person_id, chat_id=session_id, limit=10)
post_profile_text = post_profile.summary + "\n" + json.dumps(post_profile.evidence, ensure_ascii=False)
assert "绿色" in post_profile_text
assert "蓝色" not in post_profile_text
async def _latest_episode_result():
result = await memory_service.search(
"绿色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
if not result.hits:
return None
contents = "\n".join(hit.content for hit in result.hits)
if "绿色" in contents and "蓝色" not in contents:
return result
return None
post_episode = await _wait_until(
_latest_episode_result,
timeout_seconds=12.0,
interval_seconds=0.2,
description="episode 重建后返回修正结果",
)
assert post_episode is not None
stale_episode = await memory_service.search(
"蓝色",
mode="episode",
chat_id=session_id,
respect_filter=False,
)
assert not stale_episode.hits
await chat_bot.message_process(
_build_message_data(
content="再查一次,测试用户最喜欢的颜色是什么?",
**session_info,
)
)
tool_records = await _wait_until(
lambda: (
_load_query_memory_tool_records(session_id)
if len(_load_query_memory_tool_records(session_id)) >= 2
else None
),
timeout_seconds=10.0,
interval_seconds=0.1,
description="第二次 query_memory 工具记录生成",
)
latest_tool_data = json.loads(str(tool_records[-1].get("tool_data") or "{}"))
latest_hits = list((latest_tool_data.get("structured_content") or {}).get("hits") or [])
assert latest_hits
latest_contents = "\n".join(str(item.get("content", "") or "") for item in latest_hits)
assert "绿色" in latest_contents
assert "蓝色" not in latest_contents

View File

@@ -0,0 +1,396 @@
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict
import pytest
IMPORT_ERROR: str | None = None
try:
from src.A_memorix.core.retrieval.sparse_bm25 import SparseBM25Config, SparseBM25Index
from src.A_memorix.core.runtime import sdk_memory_kernel as kernel_module
from src.A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel
except SystemExit as exc:
IMPORT_ERROR = f"config initialization exited during import: {exc}"
SparseBM25Config = None # type: ignore[assignment]
SparseBM25Index = None # type: ignore[assignment]
kernel_module = None # type: ignore[assignment]
SDKMemoryKernel = None # type: ignore[assignment]
pytestmark = pytest.mark.skipif(IMPORT_ERROR is not None, reason=IMPORT_ERROR or "")
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_delegates_to_metadata_store(monkeypatch: pytest.MonkeyPatch) -> None:
captured: Dict[str, Any] = {}
def fake_enqueue_feedback_task(**kwargs):
captured.update(kwargs)
return {
"id": 1,
"query_tool_id": kwargs["query_tool_id"],
"session_id": kwargs["session_id"],
"query_timestamp": kwargs["query_timestamp"],
"due_at": kwargs["due_at"],
"query_snapshot": kwargs["query_snapshot"],
}
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
feedback_correction_enabled=True,
feedback_correction_window_hours=12.0,
)
),
)
query_time = datetime(2026, 4, 9, 10, 30, 0)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=fake_enqueue_feedback_task)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-1",
session_id="session-1",
query_timestamp=query_time,
structured_content={"query": "Alice 喜欢什么", "hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is True
assert payload["queued"] is True
assert captured["query_tool_id"] == "tool-query-1"
assert captured["session_id"] == "session-1"
assert captured["query_snapshot"]["query"] == "Alice 喜欢什么"
assert captured["query_snapshot"]["hits"] == [{"hash": "relation-1"}]
assert captured["due_at"] == pytest.approx(query_time.timestamp() + 12 * 3600, rel=0, abs=1e-6)
@pytest.mark.asyncio
async def test_kernel_enqueue_feedback_task_skipped_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
kernel_module,
"global_config",
SimpleNamespace(memory=SimpleNamespace(feedback_correction_enabled=False)),
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(enqueue_feedback_task=lambda **kwargs: kwargs)
payload = await kernel.enqueue_feedback_task(
query_tool_id="tool-query-2",
session_id="session-1",
query_timestamp=datetime.now(),
structured_content={"hits": [{"hash": "relation-1"}]},
)
assert payload["success"] is False
assert payload["reason"] == "feedback_correction_disabled"
@pytest.mark.asyncio
async def test_apply_feedback_decision_resolves_paragraph_targets() -> None:
action_logs: list[Dict[str, Any]] = []
forgotten_hashes: list[str] = []
ingested_payloads: list[Dict[str, Any]] = []
stale_marks: list[Dict[str, Any]] = []
episode_sources: list[str] = []
profile_refresh_ids: list[str] = []
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = SimpleNamespace(
get_paragraph_relations=lambda paragraph_hash: [
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
if paragraph_hash == "paragraph-1"
else [],
get_relation_status_batch=lambda hashes: {
str(hash_value): {"is_inactive": str(hash_value) in forgotten_hashes}
for hash_value in hashes
},
get_paragraph_hashes_by_relation_hashes=lambda hashes: {
"relation-1": ["paragraph-1"]
}
if "relation-1" in hashes
else {},
upsert_paragraph_stale_relation_mark=lambda **kwargs: stale_marks.append(kwargs) or kwargs,
enqueue_episode_source_rebuild=lambda source, reason="": episode_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: profile_refresh_ids.append(kwargs["person_id"]) or kwargs,
get_paragraph=lambda paragraph_hash: {"hash": "paragraph-1", "source": "chat_feedback_test_seed:session-1"}
if paragraph_hash == "paragraph-1"
else None,
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
)
kernel._feedback_cfg_auto_apply_threshold = lambda: 0.85 # type: ignore[method-assign]
kernel._apply_v5_relation_action = lambda *, action, hashes, strength=1.0: ( # type: ignore[method-assign]
forgotten_hashes.extend([str(item) for item in hashes]),
{"success": True, "action": action, "hashes": list(hashes), "strength": strength},
)[1]
kernel._feedback_cfg_paragraph_mark_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_episode_rebuild_enabled = lambda: True # type: ignore[method-assign]
kernel._feedback_cfg_profile_refresh_enabled = lambda: True # type: ignore[method-assign]
kernel._resolve_feedback_related_person_ids = lambda **kwargs: ["person-1"] # type: ignore[method-assign]
kernel._query_relation_rows_by_hashes = lambda relation_hashes, include_inactive=False: [ # type: ignore[method-assign]
{
"hash": "relation-1",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
}
]
async def _fake_ingest_feedback_relations(**kwargs):
ingested_payloads.append(kwargs)
return {"success": True, "stored_ids": ["relation-2"]}
kernel._ingest_feedback_relations = _fake_ingest_feedback_relations # type: ignore[method-assign]
payload = await kernel._apply_feedback_decision(
task_id=1,
query_tool_id="tool-query-1",
session_id="session-1",
decision={
"decision": "correct",
"confidence": 0.97,
"target_hashes": ["paragraph-1"],
"corrected_relations": [
{
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"confidence": 0.99,
}
],
"reason": "用户明确纠正为绿色",
},
hit_map={
"paragraph-1": {
"hash": "paragraph-1",
"type": "paragraph",
"content": "测试用户 最喜欢的颜色是 蓝色",
"linked_relation_hashes": ["relation-1"],
}
},
)
assert payload["applied"] is True
assert payload["relation_hashes"] == ["relation-1"]
assert forgotten_hashes == ["relation-1"]
assert ingested_payloads[0]["relation_hashes"] == ["relation-1"]
assert payload["stale_paragraph_hashes"] == ["paragraph-1"]
assert "chat_feedback_test_seed:session-1" in payload["episode_rebuild_sources"]
assert "chat_summary:session-1" in payload["episode_rebuild_sources"]
assert payload["profile_refresh_person_ids"] == ["person-1"]
assert stale_marks[0]["paragraph_hash"] == "paragraph-1"
assert {item["action_type"] for item in action_logs} == {
"forget_relation",
"ingest_correction",
"mark_stale_paragraph",
"enqueue_episode_rebuild",
"enqueue_profile_refresh",
}
def test_filter_active_relation_hits_removes_inactive_relations() -> None:
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel._feedback_cfg_paragraph_hard_filter_enabled = lambda: True # type: ignore[method-assign]
kernel.metadata_store = SimpleNamespace(
get_relation_status_batch=lambda hashes: {
"r-active": {"is_inactive": False},
"r-inactive": {"is_inactive": True},
"r-para-inactive": {"is_inactive": True},
"r-stale-active": {"is_inactive": False},
"r-stale-inactive": {"is_inactive": True},
},
get_paragraph_relations=lambda paragraph_hash: (
[{"hash": "r-para-inactive"}] if paragraph_hash == "p-inactive" else []
),
get_paragraph_stale_relation_marks_batch=lambda paragraph_hashes: {
"p-stale": [{"relation_hash": "r-stale-inactive"}],
"p-restored": [{"relation_hash": "r-stale-active"}],
},
)
hits = [
{"hash": "r-active", "type": "relation", "content": "A 喜欢 B"},
{"hash": "r-inactive", "type": "relation", "content": "A 讨厌 B"},
{"hash": "p-1", "type": "paragraph", "content": "段落证据"},
{"hash": "p-inactive", "type": "paragraph", "content": "失活段落证据"},
{"hash": "p-stale", "type": "paragraph", "content": "被标脏段落"},
{"hash": "p-restored", "type": "paragraph", "content": "恢复可见段落"},
]
filtered = kernel._filter_active_relation_hits(hits)
assert [item["hash"] for item in filtered] == ["r-active", "p-1", "p-restored"]
def test_sparse_relation_search_requests_active_only() -> None:
captured: Dict[str, Any] = {}
class FakeMetadataStore:
def fts_search_relations_bm25(self, **kwargs):
captured.update(kwargs)
return []
index = SparseBM25Index(
metadata_store=FakeMetadataStore(), # type: ignore[arg-type]
config=SparseBM25Config(enabled=True, lazy_load=False),
)
index._loaded = True
index._conn = object() # type: ignore[assignment]
result = index.search_relations("测试纠错", k=5)
assert result == []
assert captured["include_inactive"] is False
@pytest.mark.asyncio
async def test_feedback_task_rollback_restores_snapshots_and_requeues_followups() -> None:
action_logs: list[Dict[str, Any]] = []
queued_sources: list[str] = []
queued_profiles: list[str] = []
deleted_marks: list[tuple[str, str]] = []
deleted_paragraphs: list[str] = []
relation_statuses: Dict[str, Dict[str, Any]] = {
"rel-old": {"is_inactive": True, "weight": 0.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": 1.0},
"rel-new": {"is_inactive": False, "weight": 1.0, "is_pinned": False, "protected_until": 0.0, "last_reinforced": None, "inactive_since": None},
}
current_task: Dict[str, Any] = {
"id": 1,
"query_tool_id": "tool-query-rollback",
"session_id": "session-1",
"status": "applied",
"rollback_status": "none",
"query_snapshot": {"query": "测试用户最喜欢的颜色是什么"},
"decision_payload": {"decision": "correct", "confidence": 0.97},
"rollback_plan": {
"forgotten_relations": [
{
"hash": "rel-old",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "蓝色",
"before_status": {
"is_inactive": False,
"weight": 0.8,
"is_pinned": False,
"protected_until": 0.0,
"last_reinforced": None,
"inactive_since": None,
},
}
],
"corrected_write": {
"paragraph_hashes": ["paragraph-new"],
"corrected_relations": [
{
"hash": "rel-new",
"subject": "测试用户",
"predicate": "最喜欢的颜色是",
"object": "绿色",
"existed_before": False,
"before_status": {},
}
],
},
"stale_marks": [{"paragraph_hash": "paragraph-old", "relation_hash": "rel-old"}],
"episode_sources": ["chat_summary:session-1"],
"profile_person_ids": ["person-1"],
},
}
class _Conn:
def cursor(self):
return self
def execute(self, *_args, **_kwargs):
return self
def commit(self):
return None
metadata_store = SimpleNamespace(
get_feedback_task_by_id=lambda task_id: current_task if int(task_id) == 1 else None,
mark_feedback_task_rollback_running=lambda **kwargs: current_task.update({"rollback_status": "running"}) or current_task,
finalize_feedback_task_rollback=lambda **kwargs: current_task.update(
{
"rollback_status": kwargs["rollback_status"],
"rollback_result": kwargs.get("rollback_result") or {},
"rollback_error": kwargs.get("rollback_error", ""),
}
)
or current_task,
get_relation_status_batch=lambda hashes: {
hash_value: dict(relation_statuses[hash_value])
for hash_value in hashes
if hash_value in relation_statuses
},
restore_relation_status_from_snapshot=lambda hash_value, snapshot: relation_statuses.update(
{hash_value: dict(snapshot)}
)
or dict(snapshot),
append_feedback_action_log=lambda **kwargs: action_logs.append(kwargs),
mark_as_deleted=lambda hashes, type_: deleted_paragraphs.extend(list(hashes)) or len(list(hashes)),
get_paragraph=lambda paragraph_hash: {"hash": paragraph_hash, "source": "chat_summary:session-1"},
get_connection=lambda: _Conn(),
delete_external_memory_refs_by_paragraphs=lambda hashes: [
{"paragraph_hash": str(hash_value), "external_id": f"external:{hash_value}"}
for hash_value in hashes
],
update_relations_protection=lambda hashes, **kwargs: None,
mark_relations_inactive=lambda hashes, inactive_since=None: [
relation_statuses.__setitem__(
hash_value,
{
**relation_statuses.get(hash_value, {}),
"is_inactive": True,
"inactive_since": inactive_since,
},
)
for hash_value in hashes
],
delete_paragraph_stale_relation_marks=lambda marks: deleted_marks.extend(list(marks)) or len(list(marks)),
enqueue_episode_source_rebuild=lambda source, reason='': queued_sources.append(source) or True,
enqueue_person_profile_refresh=lambda **kwargs: queued_profiles.append(kwargs["person_id"]) or kwargs,
list_feedback_action_logs=lambda task_id: action_logs if int(task_id) == 1 else [],
)
kernel = SDKMemoryKernel(plugin_root=Path("."), config={})
kernel.metadata_store = metadata_store
async def _noop_initialize() -> None:
return None
kernel.initialize = _noop_initialize # type: ignore[method-assign]
kernel._rebuild_graph_from_metadata = lambda: None # type: ignore[method-assign]
kernel._persist = lambda: None # type: ignore[method-assign]
payload = await kernel._rollback_feedback_task(
task_id=1,
requested_by="pytest",
reason="manual rollback",
)
assert payload["success"] is True
assert relation_statuses["rel-old"]["is_inactive"] is False
assert relation_statuses["rel-new"]["is_inactive"] is True
assert deleted_paragraphs == ["paragraph-new"]
assert deleted_marks == [("paragraph-old", "rel-old")]
assert queued_sources == ["chat_summary:session-1"]
assert queued_profiles == ["person-1"]
assert current_task["rollback_status"] == "rolled_back"
assert {item["action_type"] for item in action_logs} >= {
"rollback_restore_relation",
"rollback_revert_corrected_relation",
"rollback_delete_correction_paragraph",
"rollback_clear_stale_mark",
"rollback_enqueue_episode_rebuild",
"rollback_enqueue_profile_refresh",
}

View File

@@ -638,3 +638,36 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
assert list_response.json()["count"] == 1
assert get_response.status_code == 200
assert get_response.json()["operation"]["operation_id"] == "del-1"
def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list":
assert kwargs == {"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get":
assert kwargs == {"task_id": 11}
return {"success": True, "task": {"task_id": 11, "query_text": "what color", "action_logs": []}}
if action == "rollback":
assert kwargs == {"task_id": 11, "requested_by": "tester", "reason": "manual revert"}
return {"success": True, "result": {"restored_relation_hashes": ["rel-1"]}}
raise AssertionError(action)
monkeypatch.setattr(memory_router_module.memory_service, "feedback_admin", fake_feedback_admin)
list_response = client.get(
"/api/webui/memory/feedback-corrections",
params={"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"},
)
get_response = client.get("/api/webui/memory/feedback-corrections/11")
rollback_response = client.post(
"/api/webui/memory/feedback-corrections/11/rollback",
json={"requested_by": "tester", "reason": "manual revert"},
)
assert list_response.status_code == 200
assert list_response.json()["items"][0]["task_id"] == 11
assert get_response.status_code == 200
assert get_response.json()["task"]["task_id"] == 11
assert rollback_response.status_code == 200
assert rollback_response.json()["result"]["restored_relation_hashes"] == ["rel-1"]