fix:收敛A_Memorix最小回归修复
最小修复聊天摘要写回游标恢复、摘要元数据透传、webui反馈参数解析、embedding批次缓存索引、图存储清理与配置默认值回归,并补齐针对性回归测试,确保问题解决且不影响现有逻辑。
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlmodel import Session, create_engine
|
from sqlmodel import Session, create_engine
|
||||||
@@ -394,6 +395,19 @@ async def test_text_to_stream_triggers_real_chat_summary_writeback(
|
|||||||
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
assert "我最近买了一条绿色围巾。" in captured_prompts[-1]
|
||||||
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
assert "好的,我会记住你最近买了绿色围巾。" in captured_prompts[-1]
|
||||||
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
assert any("绿色围巾" in str(item.get("content", "") or "") for item in paragraphs)
|
||||||
|
assert any(
|
||||||
|
int(
|
||||||
|
(
|
||||||
|
pickle.loads(item.get("metadata"))
|
||||||
|
if isinstance(item.get("metadata"), (bytes, bytearray))
|
||||||
|
else item.get("metadata")
|
||||||
|
or {}
|
||||||
|
).get("trigger_message_count", 0)
|
||||||
|
or 0
|
||||||
|
)
|
||||||
|
== 2
|
||||||
|
for item in paragraphs
|
||||||
|
)
|
||||||
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
assert service.chat_summary_writeback._states["test-session"].last_trigger_message_count == 2
|
||||||
finally:
|
finally:
|
||||||
await service.shutdown()
|
await service.shutdown()
|
||||||
|
|||||||
@@ -164,3 +164,28 @@ async def test_runtime_self_check_reports_requested_dimension_without_explicit_o
|
|||||||
assert report["detected_dimension"] == 384
|
assert report["detected_dimension"] == 384
|
||||||
assert report["encoded_dimension"] == 384
|
assert report["encoded_dimension"] == 384
|
||||||
assert manager.encode_calls == ["A_Memorix runtime self check"]
|
assert manager.encode_calls == ["A_Memorix runtime self check"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_encode_batch_keeps_batch_local_indexes_when_cache_hits_previous_batch(monkeypatch):
|
||||||
|
adapter = EmbeddingAPIAdapter(default_dimension=4, enable_cache=True)
|
||||||
|
adapter._dimension = 4
|
||||||
|
adapter._dimension_detected = True
|
||||||
|
|
||||||
|
async def fake_detect_dimension() -> int:
|
||||||
|
return 4
|
||||||
|
|
||||||
|
async def fake_get_embedding_direct(text: str, dimensions: int | None = None):
|
||||||
|
del dimensions
|
||||||
|
base = float(ord(str(text)[0]))
|
||||||
|
return [base, base + 1.0, base + 2.0, base + 3.0]
|
||||||
|
|
||||||
|
monkeypatch.setattr(adapter, "_detect_dimension", fake_detect_dimension)
|
||||||
|
monkeypatch.setattr(adapter, "_get_embedding_direct", fake_get_embedding_direct)
|
||||||
|
|
||||||
|
embeddings = await adapter.encode(["A", "B", "A", "C"], batch_size=2)
|
||||||
|
|
||||||
|
assert embeddings.shape == (4, 4)
|
||||||
|
assert np.array_equal(embeddings[0], embeddings[2])
|
||||||
|
assert embeddings[1][0] == float(ord("B"))
|
||||||
|
assert embeddings[3][0] == float(ord("C"))
|
||||||
|
|||||||
@@ -62,3 +62,21 @@ def test_graph_store_load_resets_stale_adjacency_when_metadata_is_empty(tmp_path
|
|||||||
assert reloaded.num_nodes == 0
|
assert reloaded.num_nodes == 0
|
||||||
assert reloaded.num_edges == 0
|
assert reloaded.num_edges == 0
|
||||||
assert reloaded.get_nodes() == []
|
assert reloaded.get_nodes() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_graph_store_load_clears_stale_edge_hash_map_when_metadata_is_empty(tmp_path: Path) -> None:
|
||||||
|
data_dir = tmp_path / "graph_data"
|
||||||
|
store = GraphStore(data_dir=data_dir)
|
||||||
|
store.add_edges([("Alice", "Bob")], relation_hashes=["rel-1"])
|
||||||
|
store.save()
|
||||||
|
|
||||||
|
metadata_path = data_dir / "graph_metadata.pkl"
|
||||||
|
empty_metadata = _build_empty_graph_metadata()
|
||||||
|
empty_metadata["edge_hash_map"] = {(0, 1): {"rel-1"}}
|
||||||
|
with metadata_path.open("wb") as handle:
|
||||||
|
pickle.dump(empty_metadata, handle)
|
||||||
|
|
||||||
|
reloaded = GraphStore(data_dir=data_dir)
|
||||||
|
reloaded.load()
|
||||||
|
|
||||||
|
assert reloaded.has_edge_hash_map() is False
|
||||||
|
|||||||
@@ -59,7 +59,16 @@ async def test_chat_summary_writeback_service_triggers_when_threshold_reached(mo
|
|||||||
events.append(("ingest_summary", kwargs))
|
events.append(("ingest_summary", kwargs))
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
return SimpleNamespace(success=True, detail="ok")
|
||||||
|
|
||||||
|
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||||
|
del self, session_id, total_message_count
|
||||||
|
return 0
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module.ChatSummaryWritebackService,
|
||||||
|
"_load_last_trigger_message_count",
|
||||||
|
fake_load_last_trigger_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
service = memory_flow_module.ChatSummaryWritebackService()
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||||
@@ -100,7 +109,16 @@ async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(m
|
|||||||
called = True
|
called = True
|
||||||
return SimpleNamespace(success=True, detail="ok")
|
return SimpleNamespace(success=True, detail="ok")
|
||||||
|
|
||||||
|
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||||
|
del self, session_id, total_message_count
|
||||||
|
return 0
|
||||||
|
|
||||||
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module.ChatSummaryWritebackService,
|
||||||
|
"_load_last_trigger_message_count",
|
||||||
|
fake_load_last_trigger_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
service = memory_flow_module.ChatSummaryWritebackService()
|
service = memory_flow_module.ChatSummaryWritebackService()
|
||||||
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||||
@@ -110,6 +128,116 @@ async def test_chat_summary_writeback_service_skips_when_threshold_not_reached(m
|
|||||||
assert called is False
|
assert called is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_summary_writeback_service_restores_previous_trigger_count(monkeypatch):
|
||||||
|
events: list[tuple[str, object]] = []
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(
|
||||||
|
memory=SimpleNamespace(
|
||||||
|
chat_summary_writeback_enabled=True,
|
||||||
|
chat_summary_writeback_message_threshold=3,
|
||||||
|
chat_summary_writeback_context_length=7,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 8)
|
||||||
|
|
||||||
|
async def fake_ingest_summary(**kwargs):
|
||||||
|
events.append(("ingest_summary", kwargs))
|
||||||
|
return SimpleNamespace(success=True, detail="ok")
|
||||||
|
|
||||||
|
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||||
|
del self, session_id, total_message_count
|
||||||
|
return 5
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module.ChatSummaryWritebackService,
|
||||||
|
"_load_last_trigger_message_count",
|
||||||
|
fake_load_last_trigger_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = memory_flow_module.ChatSummaryWritebackService()
|
||||||
|
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||||
|
|
||||||
|
await service._handle_message(message)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
_, payload = events[0]
|
||||||
|
assert payload["external_id"] == "chat_auto_summary:session-1:8"
|
||||||
|
assert service._states["session-1"].last_trigger_message_count == 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_summary_writeback_service_falls_back_to_current_count_for_legacy_summary(monkeypatch):
|
||||||
|
called = False
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module,
|
||||||
|
"global_config",
|
||||||
|
SimpleNamespace(
|
||||||
|
memory=SimpleNamespace(
|
||||||
|
chat_summary_writeback_enabled=True,
|
||||||
|
chat_summary_writeback_message_threshold=3,
|
||||||
|
chat_summary_writeback_context_length=7,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(memory_flow_module, "count_messages", lambda **kwargs: 5)
|
||||||
|
|
||||||
|
async def fake_ingest_summary(**kwargs):
|
||||||
|
nonlocal called
|
||||||
|
called = True
|
||||||
|
return SimpleNamespace(success=True, detail="ok")
|
||||||
|
|
||||||
|
async def fake_load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||||
|
del self, session_id, total_message_count
|
||||||
|
return 5
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_flow_module.memory_service, "ingest_summary", fake_ingest_summary)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
memory_flow_module.ChatSummaryWritebackService,
|
||||||
|
"_load_last_trigger_message_count",
|
||||||
|
fake_load_last_trigger_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
service = memory_flow_module.ChatSummaryWritebackService()
|
||||||
|
message = SimpleNamespace(session_id="session-1", session=SimpleNamespace(user_id="user-1", group_id="group-1"))
|
||||||
|
|
||||||
|
await service._handle_message(message)
|
||||||
|
|
||||||
|
assert called is False
|
||||||
|
assert service._states["session-1"].last_trigger_message_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_summary_writeback_service_loads_trigger_count_from_summary_metadata(monkeypatch):
|
||||||
|
class FakeMetadataStore:
|
||||||
|
@staticmethod
|
||||||
|
def get_paragraphs_by_source(source: str):
|
||||||
|
assert source == "chat_summary:session-1"
|
||||||
|
return [
|
||||||
|
{"created_at": 1.0, "metadata": {"trigger_message_count": 3}},
|
||||||
|
{"created_at": 2.0, "metadata": {"trigger_message_count": 6}},
|
||||||
|
]
|
||||||
|
|
||||||
|
class FakeRuntimeManager:
|
||||||
|
@staticmethod
|
||||||
|
async def _ensure_kernel():
|
||||||
|
return SimpleNamespace(metadata_store=FakeMetadataStore())
|
||||||
|
|
||||||
|
monkeypatch.setattr(memory_flow_module.memory_service_module, "a_memorix_host_service", FakeRuntimeManager())
|
||||||
|
|
||||||
|
service = memory_flow_module.ChatSummaryWritebackService()
|
||||||
|
|
||||||
|
restored = await service._load_last_trigger_message_count(session_id="session-1", total_message_count=8)
|
||||||
|
|
||||||
|
assert restored == 6
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_memory_automation_service_auto_starts_and_delegates():
|
async def test_memory_automation_service_auto_starts_and_delegates():
|
||||||
events: list[tuple[str, str]] = []
|
events: list[tuple[str, str]] = []
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ def test_resolve_static_path_prefers_installed_dashboard_package(monkeypatch, tm
|
|||||||
def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
|
def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
|
||||||
dashboard_dist = tmp_path / "dashboard" / "dist"
|
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||||
dashboard_dist.mkdir(parents=True)
|
dashboard_dist.mkdir(parents=True)
|
||||||
|
(dashboard_dist / "index.html").write_text("<html></html>", encoding="utf-8")
|
||||||
|
|
||||||
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||||
|
|
||||||
@@ -91,6 +92,26 @@ def test_resolve_static_path_uses_dashboard_dist(monkeypatch, tmp_path) -> None:
|
|||||||
assert resolved_path == dashboard_dist
|
assert resolved_path == dashboard_dist
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_static_path_falls_back_to_package_when_dashboard_dist_has_no_index(monkeypatch, tmp_path) -> None:
|
||||||
|
dashboard_dist = tmp_path / "dashboard" / "dist"
|
||||||
|
dashboard_dist.mkdir(parents=True)
|
||||||
|
|
||||||
|
package_dist = tmp_path / "site-packages" / "maibot_dashboard" / "dist"
|
||||||
|
package_dist.mkdir(parents=True)
|
||||||
|
|
||||||
|
class _DashboardModule:
|
||||||
|
@staticmethod
|
||||||
|
def get_dist_path() -> Path:
|
||||||
|
return package_dist
|
||||||
|
|
||||||
|
monkeypatch.setattr(webui_app, "_get_project_root", lambda: tmp_path)
|
||||||
|
|
||||||
|
with patch.object(webui_app, "import_module", return_value=_DashboardModule()):
|
||||||
|
resolved_path = webui_app._resolve_static_path()
|
||||||
|
|
||||||
|
assert resolved_path == package_dist
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
|
def test_resolve_safe_static_file_path_allows_regular_static_file(tmp_path) -> None:
|
||||||
static_path = tmp_path / "dist"
|
static_path = tmp_path / "dist"
|
||||||
asset_path = static_path / "assets" / "app.js"
|
asset_path = static_path / "assets" / "app.js"
|
||||||
|
|||||||
@@ -643,7 +643,12 @@ def test_delete_operation_routes(client: TestClient, monkeypatch):
|
|||||||
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
def test_feedback_correction_routes(client: TestClient, monkeypatch):
|
||||||
async def fake_feedback_admin(*, action: str, **kwargs):
|
async def fake_feedback_admin(*, action: str, **kwargs):
|
||||||
if action == "list":
|
if action == "list":
|
||||||
assert kwargs == {"limit": 7, "status": "applied", "rollback_status": "none", "query": "green"}
|
assert kwargs == {
|
||||||
|
"limit": 7,
|
||||||
|
"statuses": ["applied"],
|
||||||
|
"rollback_statuses": ["none"],
|
||||||
|
"query": "green",
|
||||||
|
}
|
||||||
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
return {"success": True, "items": [{"task_id": 11, "query_text": "what color"}], "count": 1}
|
||||||
if action == "get":
|
if action == "get":
|
||||||
assert kwargs == {"task_id": 11}
|
assert kwargs == {"task_id": 11}
|
||||||
|
|||||||
@@ -385,19 +385,19 @@ class EmbeddingAPIAdapter:
|
|||||||
|
|
||||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
|
||||||
async def encode_with_semaphore(text: str, index: int):
|
async def encode_with_semaphore(text: str, batch_index: int, absolute_index: int):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
|
embedding = await self._get_embedding_direct(text, dimensions=dimensions)
|
||||||
if embedding is None:
|
if embedding is None:
|
||||||
raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空")
|
raise RuntimeError(f"文本 {absolute_index} 编码失败:embedding 返回为空")
|
||||||
vector = self._validate_embedding_vector(
|
vector = self._validate_embedding_vector(
|
||||||
embedding,
|
embedding,
|
||||||
source=f"文本 {index}",
|
source=f"文本 {absolute_index}",
|
||||||
)
|
)
|
||||||
return index, vector
|
return batch_index, vector
|
||||||
|
|
||||||
tasks = [
|
tasks = [
|
||||||
encode_with_semaphore(text, offset + index)
|
encode_with_semaphore(text, index, offset + index)
|
||||||
for index, text in uncached_items
|
for index, text in uncached_items
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|||||||
@@ -803,6 +803,7 @@ class SDKMemoryKernel:
|
|||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
include_personality: Optional[bool] = None,
|
include_personality: Optional[bool] = None,
|
||||||
time_end: Optional[float] = None,
|
time_end: Optional[float] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
assert self.summary_importer
|
assert self.summary_importer
|
||||||
@@ -811,6 +812,7 @@ class SDKMemoryKernel:
|
|||||||
context_length=context_length,
|
context_length=context_length,
|
||||||
include_personality=include_personality,
|
include_personality=include_personality,
|
||||||
time_end=time_end,
|
time_end=time_end,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
|
await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])])
|
||||||
@@ -854,6 +856,12 @@ class SDKMemoryKernel:
|
|||||||
context_length=self._optional_int(summary_meta.get("context_length")),
|
context_length=self._optional_int(summary_meta.get("context_length")),
|
||||||
include_personality=summary_meta.get("include_personality"),
|
include_personality=summary_meta.get("include_personality"),
|
||||||
time_end=time_end,
|
time_end=time_end,
|
||||||
|
metadata={
|
||||||
|
**summary_meta,
|
||||||
|
"external_id": external_token,
|
||||||
|
"chat_id": str(chat_id or "").strip(),
|
||||||
|
"source_type": "chat_summary",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
result.setdefault("external_id", external_id)
|
result.setdefault("external_id", external_id)
|
||||||
result.setdefault("chat_id", chat_id)
|
result.setdefault("chat_id", chat_id)
|
||||||
|
|||||||
@@ -1294,6 +1294,7 @@ class GraphStore:
|
|||||||
if current_n == 0:
|
if current_n == 0:
|
||||||
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
|
logger.warning("检测到空图元数据但邻接矩阵仍然存在,已重置为空图。")
|
||||||
self._adjacency = None
|
self._adjacency = None
|
||||||
|
self._edge_hash_map = defaultdict(set)
|
||||||
elif current_n > adj_n:
|
elif current_n > adj_n:
|
||||||
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
|
logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...")
|
||||||
self._expand_adjacency_matrix(current_n - adj_n)
|
self._expand_adjacency_matrix(current_n - adj_n)
|
||||||
@@ -1305,6 +1306,14 @@ class GraphStore:
|
|||||||
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
|
self._adjacency = csc_matrix((current_n, current_n), dtype=np.float32)
|
||||||
else:
|
else:
|
||||||
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
|
self._adjacency = csr_matrix((current_n, current_n), dtype=np.float32)
|
||||||
|
self._edge_hash_map = defaultdict(
|
||||||
|
set,
|
||||||
|
{
|
||||||
|
(src_idx, dst_idx): set(hashes)
|
||||||
|
for (src_idx, dst_idx), hashes in self._edge_hash_map.items()
|
||||||
|
if src_idx < current_n and dst_idx < current_n
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
self._adjacency_dirty = True
|
self._adjacency_dirty = True
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -178,6 +178,27 @@ class MetadataStore:
|
|||||||
int(knowledge_type_result.get("normalized", 0) or 0),
|
int(knowledge_type_result.get("normalized", 0) or 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _ensure_memory_feedback_task_columns(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""补齐 memory_feedback_tasks 历史库缺失的 rollback_* 列。"""
|
||||||
|
cursor.execute("PRAGMA table_info(memory_feedback_tasks)")
|
||||||
|
feedback_task_columns = {row[1] for row in cursor.fetchall()}
|
||||||
|
feedback_task_migrations = {
|
||||||
|
"rollback_status": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_status TEXT DEFAULT 'none'",
|
||||||
|
"rollback_plan_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_plan_json TEXT",
|
||||||
|
"rollback_result_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_result_json TEXT",
|
||||||
|
"rollback_error": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_error TEXT",
|
||||||
|
"rollback_requested_by": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_by TEXT",
|
||||||
|
"rollback_reason": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_reason TEXT",
|
||||||
|
"rollback_requested_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_at REAL",
|
||||||
|
"rolled_back_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rolled_back_at REAL",
|
||||||
|
}
|
||||||
|
for col, sql in feedback_task_migrations.items():
|
||||||
|
if col not in feedback_task_columns:
|
||||||
|
try:
|
||||||
|
cursor.execute(sql)
|
||||||
|
except sqlite3.OperationalError as e:
|
||||||
|
logger.warning(f"Schema迁移失败 (memory_feedback_tasks.{col}): {e}")
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""关闭数据库连接"""
|
"""关闭数据库连接"""
|
||||||
if self._conn:
|
if self._conn:
|
||||||
@@ -641,24 +662,7 @@ class MetadataStore:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
|
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
|
||||||
ON person_profile_refresh_queue(requested_at DESC)
|
ON person_profile_refresh_queue(requested_at DESC)
|
||||||
""")
|
""")
|
||||||
cursor.execute("PRAGMA table_info(memory_feedback_tasks)")
|
self._ensure_memory_feedback_task_columns(cursor)
|
||||||
feedback_task_columns = {row[1] for row in cursor.fetchall()}
|
|
||||||
feedback_task_migrations = {
|
|
||||||
"rollback_status": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_status TEXT DEFAULT 'none'",
|
|
||||||
"rollback_plan_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_plan_json TEXT",
|
|
||||||
"rollback_result_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_result_json TEXT",
|
|
||||||
"rollback_error": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_error TEXT",
|
|
||||||
"rollback_requested_by": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_by TEXT",
|
|
||||||
"rollback_reason": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_reason TEXT",
|
|
||||||
"rollback_requested_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_at REAL",
|
|
||||||
"rolled_back_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rolled_back_at REAL",
|
|
||||||
}
|
|
||||||
for col, sql in feedback_task_migrations.items():
|
|
||||||
if col not in feedback_task_columns:
|
|
||||||
try:
|
|
||||||
cursor.execute(sql)
|
|
||||||
except sqlite3.OperationalError as e:
|
|
||||||
logger.warning(f"Schema迁移失败 (memory_feedback_tasks.{col}): {e}")
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS external_memory_refs (
|
CREATE TABLE IF NOT EXISTS external_memory_refs (
|
||||||
external_id TEXT PRIMARY KEY,
|
external_id TEXT PRIMARY KEY,
|
||||||
@@ -953,6 +957,7 @@ class MetadataStore:
|
|||||||
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
|
CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested
|
||||||
ON person_profile_refresh_queue(requested_at DESC)
|
ON person_profile_refresh_queue(requested_at DESC)
|
||||||
""")
|
""")
|
||||||
|
self._ensure_memory_feedback_task_columns(cursor)
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS external_memory_refs (
|
CREATE TABLE IF NOT EXISTS external_memory_refs (
|
||||||
external_id TEXT PRIMARY KEY,
|
external_id TEXT PRIMARY KEY,
|
||||||
@@ -3945,6 +3950,8 @@ class MetadataStore:
|
|||||||
"episodes", "episode_paragraphs",
|
"episodes", "episode_paragraphs",
|
||||||
"episode_rebuild_sources", "episode_pending_paragraphs",
|
"episode_rebuild_sources", "episode_pending_paragraphs",
|
||||||
"paragraph_vector_backfill",
|
"paragraph_vector_backfill",
|
||||||
|
"memory_feedback_tasks", "memory_feedback_action_logs",
|
||||||
|
"paragraph_stale_relation_marks", "person_profile_refresh_queue",
|
||||||
]
|
]
|
||||||
for table in tables:
|
for table in tables:
|
||||||
cursor.execute(f"DELETE FROM {table}")
|
cursor.execute(f"DELETE FROM {table}")
|
||||||
|
|||||||
@@ -225,6 +225,7 @@ class SummaryImporter:
|
|||||||
context_length: Optional[int] = None,
|
context_length: Optional[int] = None,
|
||||||
include_personality: Optional[bool] = None,
|
include_personality: Optional[bool] = None,
|
||||||
time_end: Optional[float] = None,
|
time_end: Optional[float] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
从指定的聊天流中提取记录并执行总结导入
|
从指定的聊天流中提取记录并执行总结导入
|
||||||
@@ -327,7 +328,14 @@ class SummaryImporter:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 6. 执行导入
|
# 6. 执行导入
|
||||||
await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta)
|
await self._execute_import(
|
||||||
|
summary_text,
|
||||||
|
entities,
|
||||||
|
relations,
|
||||||
|
stream_id,
|
||||||
|
time_meta=time_meta,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
# 7. 持久化
|
# 7. 持久化
|
||||||
self.vector_store.save()
|
self.vector_store.save()
|
||||||
@@ -393,6 +401,7 @@ class SummaryImporter:
|
|||||||
relations: List[Dict[str, str]],
|
relations: List[Dict[str, str]],
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
time_meta: Optional[Dict[str, Any]] = None,
|
time_meta: Optional[Dict[str, Any]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""将数据写入存储"""
|
"""将数据写入存储"""
|
||||||
# 获取默认知识类型
|
# 获取默认知识类型
|
||||||
@@ -407,6 +416,7 @@ class SummaryImporter:
|
|||||||
hash_value = self.metadata_store.add_paragraph(
|
hash_value = self.metadata_store.add_paragraph(
|
||||||
content=summary,
|
content=summary,
|
||||||
source=f"chat_summary:{stream_id}",
|
source=f"chat_summary:{stream_id}",
|
||||||
|
metadata=metadata,
|
||||||
knowledge_type=knowledge_type.value,
|
knowledge_type=knowledge_type.value,
|
||||||
time_meta=time_meta,
|
time_meta=time_meta,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -375,30 +375,6 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]:
|
|||||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
"memory_feedback_tasks rollback columns missing under current schema version",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif not has_stale_marks:
|
|
||||||
checks.append(
|
|
||||||
CheckItem(
|
|
||||||
"CP-15",
|
|
||||||
"error",
|
|
||||||
"paragraph_stale_relation_marks table missing under current schema version",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif not has_profile_refresh_queue:
|
|
||||||
checks.append(
|
|
||||||
CheckItem(
|
|
||||||
"CP-16",
|
|
||||||
"error",
|
|
||||||
"person_profile_refresh_queue table missing under current schema version",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif not has_feedback_rollback_status or not has_feedback_rollback_plan:
|
|
||||||
checks.append(
|
|
||||||
CheckItem(
|
|
||||||
"CP-17",
|
|
||||||
"error",
|
|
||||||
"memory_feedback_tasks rollback columns missing under current schema version",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if _sqlite_table_exists(conn, "relations"):
|
if _sqlite_table_exists(conn, "relations"):
|
||||||
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
row = conn.execute("SELECT COUNT(*) FROM relations").fetchone()
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ class VisualConfig(ConfigBase):
|
|||||||
__ui_icon__ = "image"
|
__ui_icon__ = "image"
|
||||||
|
|
||||||
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
planner_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||||
default="text",
|
default="auto",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"x-widget": "select",
|
"x-widget": "select",
|
||||||
"x-icon": "image",
|
"x-icon": "image",
|
||||||
@@ -155,7 +155,7 @@ class VisualConfig(ConfigBase):
|
|||||||
"""Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
"""Planner 视觉模式:text 仅文本,multimodal 强制多模态,auto 按模型能力自动选择"""
|
||||||
|
|
||||||
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
replyer_mode: Literal["text", "multimodal", "auto"] = Field(
|
||||||
default="text",
|
default="auto",
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"x-widget": "select",
|
"x-widget": "select",
|
||||||
"x-icon": "git-branch",
|
"x-icon": "git-branch",
|
||||||
|
|||||||
@@ -1134,7 +1134,14 @@ class MaisakaReasoningEngine:
|
|||||||
tool_name=invocation.tool_name,
|
tool_name=invocation.tool_name,
|
||||||
tool_reasoning=invocation.reasoning,
|
tool_reasoning=invocation.reasoning,
|
||||||
)
|
)
|
||||||
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict):
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if invocation.tool_name == "query_memory" and isinstance(saved_record, dict):
|
||||||
|
try:
|
||||||
enqueue_payload = await memory_service.enqueue_feedback_task(
|
enqueue_payload = await memory_service.enqueue_feedback_task(
|
||||||
query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(),
|
query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(),
|
||||||
session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(),
|
session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(),
|
||||||
@@ -1143,15 +1150,16 @@ class MaisakaReasoningEngine:
|
|||||||
if isinstance(tool_record_payload.get("structured_content"), dict)
|
if isinstance(tool_record_payload.get("structured_content"), dict)
|
||||||
else {},
|
else {},
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
f"{self._runtime.log_prefix} 反馈纠错任务入队失败: tool_call_id={invocation.call_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
if not bool(enqueue_payload.get("success")):
|
if not bool(enqueue_payload.get("success")):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"{self._runtime.log_prefix} 反馈纠错任务未入队: "
|
f"{self._runtime.log_prefix} 反馈纠错任务未入队: "
|
||||||
f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}"
|
f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}"
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
def _append_tool_execution_result(self, tool_call: ToolCall, result: ToolExecutionResult) -> None:
|
||||||
"""将统一工具执行结果写回 Maisaka 历史。
|
"""将统一工具执行结果写回 Maisaka 历史。
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ from typing import Any, List, Optional
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
from src.services import memory_service as memory_service_module
|
||||||
from src.chat.utils.utils import is_bot_self
|
from src.chat.utils.utils import is_bot_self
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_messages, find_messages
|
||||||
@@ -281,7 +283,17 @@ class ChatSummaryWritebackService:
|
|||||||
return
|
return
|
||||||
|
|
||||||
threshold = self._message_threshold()
|
threshold = self._message_threshold()
|
||||||
state = self._states.setdefault(session_id, ChatSummaryWritebackState())
|
state = self._states.get(session_id)
|
||||||
|
if state is None:
|
||||||
|
restored_count = await self._load_last_trigger_message_count(
|
||||||
|
session_id=session_id,
|
||||||
|
total_message_count=total_message_count,
|
||||||
|
)
|
||||||
|
state = ChatSummaryWritebackState(
|
||||||
|
last_trigger_message_count=restored_count,
|
||||||
|
last_trigger_time=time.time() if restored_count > 0 else 0.0,
|
||||||
|
)
|
||||||
|
self._states[session_id] = state
|
||||||
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
pending_message_count = max(0, total_message_count - state.last_trigger_message_count)
|
||||||
if pending_message_count < threshold:
|
if pending_message_count < threshold:
|
||||||
return
|
return
|
||||||
@@ -324,6 +336,64 @@ class ChatSummaryWritebackService:
|
|||||||
getattr(result, "detail", ""),
|
getattr(result, "detail", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _load_last_trigger_message_count(self, *, session_id: str, total_message_count: int) -> int:
|
||||||
|
"""从已落库的聊天摘要恢复触发游标,避免服务重启后重复摘要。"""
|
||||||
|
try:
|
||||||
|
runtime_manager = getattr(memory_service_module, "a_memorix_host_service", None)
|
||||||
|
ensure_kernel = getattr(runtime_manager, "_ensure_kernel", None)
|
||||||
|
if not callable(ensure_kernel):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
kernel = await ensure_kernel()
|
||||||
|
metadata_store = getattr(kernel, "metadata_store", None)
|
||||||
|
if metadata_store is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
paragraphs = metadata_store.get_paragraphs_by_source(f"chat_summary:{session_id}")
|
||||||
|
if not paragraphs:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
latest_paragraph = max(paragraphs, key=self._paragraph_created_at)
|
||||||
|
metadata = self._paragraph_metadata(latest_paragraph)
|
||||||
|
trigger_message_count = self._coerce_positive_int(metadata.get("trigger_message_count"))
|
||||||
|
if trigger_message_count > 0:
|
||||||
|
return min(total_message_count, trigger_message_count)
|
||||||
|
|
||||||
|
# 兼容旧摘要数据:没有触发计数时,只能退化为对齐当前计数,
|
||||||
|
# 至少避免重启后立刻重复写入一条相近摘要。
|
||||||
|
return total_message_count
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("恢复聊天摘要写回游标失败: session_id=%s error=%s", session_id, exc)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _paragraph_created_at(paragraph: dict[str, Any]) -> float:
|
||||||
|
try:
|
||||||
|
return float(paragraph.get("created_at") or 0.0)
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _paragraph_metadata(paragraph: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
metadata = paragraph.get("metadata")
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
return metadata
|
||||||
|
if isinstance(metadata, (bytes, bytearray)):
|
||||||
|
try:
|
||||||
|
parsed = pickle.loads(metadata)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_positive_int(value: Any) -> int:
|
||||||
|
try:
|
||||||
|
number = int(value or 0)
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
return max(0, number)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_session_id(message: Any) -> str:
|
def _resolve_session_id(message: Any) -> str:
|
||||||
return str(
|
return str(
|
||||||
|
|||||||
@@ -208,7 +208,7 @@ def _resolve_static_path() -> Path | None:
|
|||||||
# 开发环境优先允许复用仓库里的现成 dist
|
# 开发环境优先允许复用仓库里的现成 dist
|
||||||
base_dir = _get_project_root()
|
base_dir = _get_project_root()
|
||||||
static_path = base_dir / "dashboard" / "dist"
|
static_path = base_dir / "dashboard" / "dist"
|
||||||
if static_path.exists():
|
if static_path.is_dir() and (static_path / "index.html").exists():
|
||||||
return static_path
|
return static_path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -365,11 +365,13 @@ async def _profile_delete_override(person_id: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict:
|
async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict:
|
||||||
|
statuses = [item.strip() for item in str(status or "").split(",") if item.strip()]
|
||||||
|
rollback_statuses = [item.strip() for item in str(rollback_status or "").split(",") if item.strip()]
|
||||||
return await memory_service.feedback_admin(
|
return await memory_service.feedback_admin(
|
||||||
action="list",
|
action="list",
|
||||||
limit=limit,
|
limit=limit,
|
||||||
status=status,
|
statuses=statuses,
|
||||||
rollback_status=rollback_status,
|
rollback_statuses=rollback_statuses,
|
||||||
query=query,
|
query=query,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user