fix:收敛A_Memorix最小回归修复

最小修复聊天摘要写回游标恢复、摘要元数据透传、webui反馈参数解析、embedding批次缓存索引、图存储清理与配置默认值回归,并补齐针对性回归测试,确保问题解决且不影响现有逻辑。
This commit is contained in:
A-Dawn
2026-04-16 20:28:54 +08:00
parent 322309bef9
commit 6bfccf90a3
17 changed files with 361 additions and 60 deletions

View File

@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List
import asyncio import asyncio
import inspect import inspect
import json import json
import pickle
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlmodel import Session, create_engine from sqlmodel import Session, create_engine
@@ -394,6 +395,19 @@ async def test_text_to_stream_triggers_real_chat_summary_writeback(
assert "我最近买了一条绿色围巾。" in captured_prompts[-1] assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1] assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs) assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
assert any(
int(
(
pickle.loads(item.get("metadata"))
if isinstance(item.get("metadata"), (bytes, bytearray))
else item.get("metadata")
or {}
).get("trigger_message_count", 0)
or 0
)
== 2
for item in paragraphs
)
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2 assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
finally: finally:
await service.shutdown() await service.shutdown()

View File

@@ -164,3 +164,28 @@ async def test_runtime_self_check_reports_requested_dimension_without_explicit_o
assert report["detected_dimension"] == 384 assert report["detected_dimension"] == 384
assert report["encoded_dimension"] == 384 assert report["encoded_dimension"] == 384
assert manager.encode_calls == ["A_Memorix runtime self check"] assert manager.encode_calls == ["A_Memorix runtime self check"]
@pytest.mark.asyncio
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
adapter._dimension = 4
adapter._dimension_detected = True
async def fake_detect_dimension() -> int:
return 4
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
del dimensions
base = float(ord(str(text)[0]))
return [base, base + 1.0, base + 2.0, base + 3.0]
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
assert embeddings.shape == (4, 4)
assert np.array_equal(embeddings[0], embeddings[2])
assert embeddings[1][0] == float(ord("B"))
assert embeddings[3][0] == float(ord("C"))

View File

@@ -62,3 +62,21 @@ def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path
assert reloaded.num_nodes == 0 assert reloaded.num_nodes == 0
assert reloaded.num_edges == 0 assert reloaded.num_edges == 0
assert reloaded.get_nodes() == [] assert reloaded.get_nodes() == []
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
data_dir = tmp_path / "graph_data"
store = GraphStore(data_dir=data_dir)
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
store.save()
metadata_path = data_dir / "graph_metadata.pkl"
empty_metadata = _build_empty_graph_metadata()
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
with metadata_path.open("wb") as handle:
pickle.dump(empty_metadata, handle)
reloaded = GraphStore(data_dir=data_dir)
reloaded.load()
assert reloaded.has_edge_hash_map() is False

View File

@@ -59,7 +59,16 @@ async def test_chat_summary_writeback_service_triggers_when_threshold_reached(mo
events.append(("ingest_summary", kwargs)) events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok") return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService() service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
@@ -100,7 +109,16 @@ async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(m
called = True called = True
return SimpleNamespace(success=True, detail="ok") return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 0
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary) monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService() service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1")) message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
@@ -110,6 +128,116 @@ async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(m
assert called is False assert called is False
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
events: list[tuple[str, object]] = []
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
async def fake_ingest_summary(**kwargs):
events.append(("ingest_summary", kwargs))
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert len(events) == 1
_, payload = events[0]
assert payload["external_id"] == "chat_auto_summary:session-1:8"
assert service._states["session-1"].last_trigger_message_count == 8
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
called = False
monkeypatch.setattr(
memory_flow_module,
"global_config",
SimpleNamespace(
memory=SimpleNamespace(
chat_summary_writeback_enabled=True,
chat_summary_writeback_message_threshold=3,
chat_summary_writeback_context_length=7,
)
),
)
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
async def fake_ingest_summary(**kwargs):
nonlocal called
called = True
return SimpleNamespace(success=True, detail="ok")
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
del self, session_id, total_message_count
return 5
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
monkeypatch.setattr(
memory_flow_module.ChatSummaryWritebackService,
"_load_last_trigger_message_count",
fake_load_last_trigger_message_count,
)
service = memory_flow_module.ChatSummaryWritebackService()
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
await service._handle_message(message)
assert called is False
assert service._states["session-1"].last_trigger_message_count == 5
@pytest.mark.asyncio
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
class FakeMetadataStore:
@staticmethod
def get_paragraphs_by_source(source: str):
assert source == "chat_summary:session-1"
return [
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
]
class FakeRuntimeManager:
@staticmethod
async def _ensure_kernel():
return SimpleNamespace(metadata_store=FakeMetadataStore())
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
service = memory_flow_module.ChatSummaryWritebackService()
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
assert restored == 6
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_memory_automation_service_auto_starts_and_delegates(): async def test_memory_automation_service_auto_starts_and_delegates():
events: list[tuple[str, str]] = [] events: list[tuple[str, str]] = []

View File

@@ -82,6 +82,7 @@ def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tm
def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None: def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist" dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True) dashboard_dist.mkdir(parents=True)
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path) monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
@@ -91,6 +92,26 @@ def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
assert resolved_path == dashboard_dist assert resolved_path == dashboard_dist
def test_resolve_static_path_falls_back_to_package_when_dashboard_dist_has_no_index(monkeypatch, tmp_path) -> None:
dashboard_dist = tmp_path / "dashboard" / "dist"
dashboard_dist.mkdir(parents=True)
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
package_dist.mkdir(parents=True)
class _DashboardModule:
@staticmethod
def get_dist_path() -> Path:
return package_dist
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
resolved_path = webui_app._resolve_static_path()
assert resolved_path == package_dist
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None: def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
static_path = tmp_path / "dist" static_path = tmp_path / "dist"
asset_path = static_path / "assets" / "app.js" asset_path = static_path / "assets" / "app.js"

View File

@@ -643,7 +643,12 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
def test_feedback_correction_routes(client: TestClient, monkeypatch): def test_feedback_correction_routes(client: TestClient, monkeypatch):
async def fake_feedback_admin(*, action: str, **kwargs): async def fake_feedback_admin(*, action: str, **kwargs):
if action == "list": if action == "list":
assert kwargs == {"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"} assert kwargs == {
"limit": 7,
"statuses": ["applied"],
"rollback_statuses": ["none"],
"query": "green",
}
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1} return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
if action == "get": if action == "get":
assert kwargs == {"task_id": 11} assert kwargs == {"task_id": 11}

View File

@@ -385,19 +385,19 @@ class EmbeddingAPIAdapter:
semaphore = asyncio.Semaphore(self.max_concurrent) semaphore = asyncio.Semaphore(self.max_concurrent)
async def encode_with_semaphore(text: str, index: int): async def encode_with_semaphore(text: str, batch_index: int, absolute_index: int):
async with semaphore: async with semaphore:
embedding = await self._get_embedding_direct(text, dimensions=dimensions) embedding = await self._get_embedding_direct(text, dimensions=dimensions)
if embedding is None: if embedding is None:
raise RuntimeError(f"文本 {index} 编码失败embedding 返回为空") raise RuntimeError(f"文本 {absolute_index} 编码失败embedding 返回为空")
vector = self._validate_embedding_vector( vector = self._validate_embedding_vector(
embedding, embedding,
source=f"文本 {index}", source=f"文本 {absolute_index}",
) )
return index, vector return batch_index, vector
tasks = [ tasks = [
encode_with_semaphore(text, offset + index) encode_with_semaphore(text, index, offset + index)
for index, text in uncached_items for index, text in uncached_items
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)

View File

@@ -803,6 +803,7 @@ class SDKMemoryKernel:
context_length: Optional[int] = None, context_length: Optional[int] = None,
include_personality: Optional[bool] = None, include_personality: Optional[bool] = None,
time_end: Optional[float] = None, time_end: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
await self.initialize() await self.initialize()
assert self.summary_importer assert self.summary_importer
@@ -811,6 +812,7 @@ class SDKMemoryKernel:
context_length=context_length, context_length=context_length,
include_personality=include_personality, include_personality=include_personality,
time_end=time_end, time_end=time_end,
metadata=metadata,
) )
if success: if success:
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
@@ -854,6 +856,12 @@ class SDKMemoryKernel:
context_length=self._optional_int(summary_meta.get("context_length")), context_length=self._optional_int(summary_meta.get("context_length")),
include_personality=summary_meta.get("include_personality"), include_personality=summary_meta.get("include_personality"),
time_end=time_end, time_end=time_end,
metadata={
**summary_meta,
"external_id": external_token,
"chat_id": str(chat_id or "").strip(),
"source_type": "chat_summary",
},
) )
result.setdefault("external_id", external_id) result.setdefault("external_id", external_id)
result.setdefault("chat_id", chat_id) result.setdefault("chat_id", chat_id)

View File

@@ -1294,6 +1294,7 @@ class GraphStore:
if current_n == 0: if current_n == 0:
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。") logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
self._adjacency = None self._adjacency = None
self._edge_hash_map = defaultdict(set)
elif current_n > adj_n: elif current_n > adj_n:
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
self._expand_adjacency_matrix(current_n - adj_n) self._expand_adjacency_matrix(current_n - adj_n)
@@ -1305,6 +1306,14 @@ class GraphStore:
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32) self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
else: else:
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32) self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
self._edge_hash_map = defaultdict(
set,
{
(src_idx, dst_idx): set(hashes)
for (src_idx, dst_idx), hashes in self._edge_hash_map.items()
if src_idx < current_n and dst_idx < current_n
},
)
self._adjacency_dirty = True self._adjacency_dirty = True
logger.info( logger.info(

View File

@@ -178,6 +178,27 @@ class MetadataStore:
int(knowledge_type_result.get("normalized", 0) or 0), int(knowledge_type_result.get("normalized", 0) or 0),
) )
def _ensure_memory_feedback_task_columns(self, cursor: sqlite3.Cursor) -> None:
"""补齐 memory_feedback_tasks 历史库缺失的 rollback_* 列。"""
cursor.execute("PRAGMA table_info(memory_feedback_tasks)")
feedback_task_columns = {row[1] for row in cursor.fetchall()}
feedback_task_migrations = {
"rollback_status": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_status TEXT DEFAULT 'none'",
"rollback_plan_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_plan_json TEXT",
"rollback_result_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_result_json TEXT",
"rollback_error": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_error TEXT",
"rollback_requested_by": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_by TEXT",
"rollback_reason": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_reason TEXT",
"rollback_requested_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_at REAL",
"rolled_back_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rolled_back_at REAL",
}
for col, sql in feedback_task_migrations.items():
if col not in feedback_task_columns:
try:
cursor.execute(sql)
except sqlite3.OperationalError as e:
logger.warning(f"Schema迁移失败 (memory_feedback_tasks.{col}): {e}")
def close(self) -> None: def close(self) -> None:
"""关闭数据库连接""" """关闭数据库连接"""
if self._conn: if self._conn:
@@ -641,24 +662,7 @@ class MetadataStore:
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
ON person_profile_refresh_queue(requested_at DESC) ON person_profile_refresh_queue(requested_at DESC)
""") """)
cursor.execute("PRAGMA table_info(memory_feedback_tasks)") self._ensure_memory_feedback_task_columns(cursor)
feedback_task_columns = {row[1] for row in cursor.fetchall()}
feedback_task_migrations = {
"rollback_status": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_status TEXT DEFAULT 'none'",
"rollback_plan_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_plan_json TEXT",
"rollback_result_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_result_json TEXT",
"rollback_error": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_error TEXT",
"rollback_requested_by": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_by TEXT",
"rollback_reason": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_reason TEXT",
"rollback_requested_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_at REAL",
"rolled_back_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rolled_back_at REAL",
}
for col, sql in feedback_task_migrations.items():
if col not in feedback_task_columns:
try:
cursor.execute(sql)
except sqlite3.OperationalError as e:
logger.warning(f"Schema迁移失败 (memory_feedback_tasks.{col}): {e}")
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS external_memory_refs ( CREATE TABLE IF NOT EXISTS external_memory_refs (
external_id TEXT PRIMARY KEY, external_id TEXT PRIMARY KEY,
@@ -953,6 +957,7 @@ class MetadataStore:
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
ON person_profile_refresh_queue(requested_at DESC) ON person_profile_refresh_queue(requested_at DESC)
""") """)
self._ensure_memory_feedback_task_columns(cursor)
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS external_memory_refs ( CREATE TABLE IF NOT EXISTS external_memory_refs (
external_id TEXT PRIMARY KEY, external_id TEXT PRIMARY KEY,
@@ -3945,6 +3950,8 @@ class MetadataStore:
"episodes", "episode_paragraphs", "episodes", "episode_paragraphs",
"episode_rebuild_sources", "episode_pending_paragraphs", "episode_rebuild_sources", "episode_pending_paragraphs",
"paragraph_vector_backfill", "paragraph_vector_backfill",
"memory_feedback_tasks", "memory_feedback_action_logs",
"paragraph_stale_relation_marks", "person_profile_refresh_queue",
] ]
for table in tables: for table in tables:
cursor.execute(f"DELETE FROM {table}") cursor.execute(f"DELETE FROM {table}")

View File

@@ -225,6 +225,7 @@ class SummaryImporter:
context_length: Optional[int] = None, context_length: Optional[int] = None,
include_personality: Optional[bool] = None, include_personality: Optional[bool] = None,
time_end: Optional[float] = None, time_end: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
从指定的聊天流中提取记录并执行总结导入 从指定的聊天流中提取记录并执行总结导入
@@ -327,7 +328,14 @@ class SummaryImporter:
} }
# 6. 执行导入 # 6. 执行导入
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) await self._execute_import(
summary_text,
entities,
relations,
stream_id,
time_meta=time_meta,
metadata=metadata,
)
# 7. 持久化 # 7. 持久化
self.vector_store.save() self.vector_store.save()
@@ -393,6 +401,7 @@ class SummaryImporter:
relations: List[Dict[str, str]], relations: List[Dict[str, str]],
stream_id: str, stream_id: str,
time_meta: Optional[Dict[str, Any]] = None, time_meta: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
): ):
"""将数据写入存储""" """将数据写入存储"""
# 获取默认知识类型 # 获取默认知识类型
@@ -407,6 +416,7 @@ class SummaryImporter:
hash_value = self.metadata_store.add_paragraph( hash_value = self.metadata_store.add_paragraph(
content=summary, content=summary,
source=f"chat_summary:{stream_id}", source=f"chat_summary:{stream_id}",
metadata=metadata,
knowledge_type=knowledge_type.value, knowledge_type=knowledge_type.value,
time_meta=time_meta, time_meta=time_meta,
) )

View File

@@ -375,30 +375,6 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
"memory_feedback_tasks rollback columns missing under current schema version", "memory_feedback_tasks rollback columns missing under current schema version",
) )
) )
elif not has_stale_marks:
checks.append(
CheckItem(
"CP-15",
"error",
"paragraph_stale_relation_marks table missing under current schema version",
)
)
elif not has_profile_refresh_queue:
checks.append(
CheckItem(
"CP-16",
"error",
"person_profile_refresh_queue table missing under current schema version",
)
)
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
checks.append(
CheckItem(
"CP-17",
"error",
"memory_feedback_tasks rollback columns missing under current schema version",
)
)
if _sqlite_table_exists(conn, "relations"): if _sqlite_table_exists(conn, "relations"):
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()

View File

@@ -146,7 +146,7 @@ class VisualConfig(ConfigBase):
__ui_icon__ = "image" __ui_icon__ = "image"
planner_mode: Literal["text", "multimodal", "auto"] = Field( planner_mode: Literal["text", "multimodal", "auto"] = Field(
default="text", default="auto",
json_schema_extra={ json_schema_extra={
"x-widget": "select", "x-widget": "select",
"x-icon": "image", "x-icon": "image",
@@ -155,7 +155,7 @@ class VisualConfig(ConfigBase):
"""Planner 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择""" """Planner 视觉模式text 仅文本multimodal 强制多模态auto 按模型能力自动选择"""
replyer_mode: Literal["text", "multimodal", "auto"] = Field( replyer_mode: Literal["text", "multimodal", "auto"] = Field(
default="text", default="auto",
json_schema_extra={ json_schema_extra={
"x-widget": "select", "x-widget": "select",
"x-icon": "git-branch", "x-icon": "git-branch",

View File

@@ -1134,7 +1134,14 @@ class MaisakaReasoningEngine:
tool_name=invocation.tool_name, tool_name=invocation.tool_name,
tool_reasoning=invocation.reasoning, tool_reasoning=invocation.reasoning,
) )
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict): except Exception:
logger.exception(
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
)
return
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict):
try:
enqueue_payload = await memory_service.enqueue_feedback_task( enqueue_payload = await memory_service.enqueue_feedback_task(
query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(), query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(),
session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(), session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(),
@@ -1143,15 +1150,16 @@ class MaisakaReasoningEngine:
if isinstance(tool_record_payload.get("structured_content"), dict) if isinstance(tool_record_payload.get("structured_content"), dict)
else {}, else {},
) )
except Exception:
logger.exception(
f"{self._runtime.log_prefix} 反馈纠错任务入队失败: tool_call_id={invocation.call_id}"
)
else:
if not bool(enqueue_payload.get("success")): if not bool(enqueue_payload.get("success")):
logger.debug( logger.debug(
f"{self._runtime.log_prefix} 反馈纠错任务未入队: " f"{self._runtime.log_prefix} 反馈纠错任务未入队: "
f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}" f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}"
) )
except Exception:
logger.exception(
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
)
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None: def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
"""将统一工具执行结果写回 Maisaka 历史。 """将统一工具执行结果写回 Maisaka 历史。

View File

@@ -6,10 +6,12 @@ from typing import Any, List, Optional
import asyncio import asyncio
import json import json
import pickle
import time import time
from json_repair import repair_json from json_repair import repair_json
from src.services import memory_service as memory_service_module
from src.chat.utils.utils import is_bot_self from src.chat.utils.utils import is_bot_self
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages from src.common.message_repository import count_messages, find_messages
@@ -281,7 +283,17 @@ class ChatSummaryWritebackService:
return return
threshold = self._message_threshold() threshold = self._message_threshold()
state = self._states.setdefault(session_id, ChatSummaryWritebackState()) state = self._states.get(session_id)
if state is None:
restored_count = await self._load_last_trigger_message_count(
session_id=session_id,
total_message_count=total_message_count,
)
state = ChatSummaryWritebackState(
last_trigger_message_count=restored_count,
last_trigger_time=time.time() if restored_count > 0 else 0.0,
)
self._states[session_id] = state
pending_message_count = max(0, total_message_count - state.last_trigger_message_count) pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
if pending_message_count < threshold: if pending_message_count < threshold:
return return
@@ -324,6 +336,64 @@ class ChatSummaryWritebackService:
getattr(result, "detail", ""), getattr(result, "detail", ""),
) )
async def _load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
"""从已落库的聊天摘要恢复触发游标,避免服务重启后重复摘要。"""
try:
runtime_manager = getattr(memory_service_module, "a_memorix_host_service", None)
ensure_kernel = getattr(runtime_manager, "_ensure_kernel", None)
if not callable(ensure_kernel):
return 0
kernel = await ensure_kernel()
metadata_store = getattr(kernel, "metadata_store", None)
if metadata_store is None:
return 0
paragraphs = metadata_store.get_paragraphs_by_source(f"chat_summary:{session_id}")
if not paragraphs:
return 0
latest_paragraph = max(paragraphs, key=self._paragraph_created_at)
metadata = self._paragraph_metadata(latest_paragraph)
trigger_message_count = self._coerce_positive_int(metadata.get("trigger_message_count"))
if trigger_message_count > 0:
return min(total_message_count, trigger_message_count)
# 兼容旧摘要数据:没有触发计数时,只能退化为对齐当前计数,
# 至少避免重启后立刻重复写入一条相近摘要。
return total_message_count
except Exception as exc:
logger.debug("恢复聊天摘要写回游标失败: session_id=%s error=%s", session_id, exc)
return 0
@staticmethod
def _paragraph_created_at(paragraph: dict[str, Any]) -> float:
try:
return float(paragraph.get("created_at") or 0.0)
except Exception:
return 0.0
@staticmethod
def _paragraph_metadata(paragraph: dict[str, Any]) -> dict[str, Any]:
metadata = paragraph.get("metadata")
if isinstance(metadata, dict):
return metadata
if isinstance(metadata, (bytes, bytearray)):
try:
parsed = pickle.loads(metadata)
except Exception:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
@staticmethod
def _coerce_positive_int(value: Any) -> int:
try:
number = int(value or 0)
except Exception:
return 0
return max(0, number)
@staticmethod @staticmethod
def _resolve_session_id(message: Any) -> str: def _resolve_session_id(message: Any) -> str:
return str( return str(

View File

@@ -208,7 +208,7 @@ def _resolve_static_path() -> Path | None:
# 开发环境优先允许复用仓库里的现成 dist # 开发环境优先允许复用仓库里的现成 dist
base_dir = _get_project_root() base_dir = _get_project_root()
static_path = base_dir / "dashboard" / "dist" static_path = base_dir / "dashboard" / "dist"
if static_path.exists(): if static_path.is_dir() and (static_path / "index.html").exists():
return static_path return static_path
try: try:

View File

@@ -365,11 +365,13 @@ async def _profile_delete_override(person_id: str) -> dict:
async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict: async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict:
statuses = [item.strip() for item in str(status or "").split(",") if item.strip()]
rollback_statuses = [item.strip() for item in str(rollback_status or "").split(",") if item.strip()]
return await memory_service.feedback_admin( return await memory_service.feedback_admin(
action="list", action="list",
limit=limit, limit=limit,
status=status, statuses=statuses,
rollback_status=rollback_status, rollback_statuses=rollback_statuses,
query=query, query=query,
) )