diff --git a/src/A_memorix/CHANGELOG.md b/src/A_memorix/CHANGELOG.md index f1a20b9d..a3ddac20 100644 --- a/src/A_memorix/CHANGELOG.md +++ b/src/A_memorix/CHANGELOG.md @@ -8,7 +8,7 @@ - 运行时主目录由 `storage.data_dir` 决定(当前模板默认 `data/a-memorix`); - 部分离线脚本仍以 `data/plugins/a-dawn.a-memorix` 作为默认处理目录。 - 修正文档中的导入示例参数,`memory_import_admin.create_paste` 的 `input_mode` 示例统一为 `text`/`json`。 -- 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 9` 保持一致。 +- 更新 `README.md` 关于元数据 schema 的描述,和当前代码 `SCHEMA_VERSION = 10` 保持一致。 ## [2.0.0] - 2026-03-18 diff --git a/src/A_memorix/CONFIG_REFERENCE.md b/src/A_memorix/CONFIG_REFERENCE.md index 421571c8..1a79f858 100644 --- a/src/A_memorix/CONFIG_REFERENCE.md +++ b/src/A_memorix/CONFIG_REFERENCE.md @@ -1,6 +1,6 @@ # A_Memorix 配置参考 (v2.0.0) -本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 9`)。 +本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 10`)。 说明: diff --git a/src/A_memorix/core/retrieval/dual_path.py b/src/A_memorix/core/retrieval/dual_path.py index c149e86d..437f3dd7 100644 --- a/src/A_memorix/core/retrieval/dual_path.py +++ b/src/A_memorix/core/retrieval/dual_path.py @@ -632,7 +632,7 @@ class DualPathRetriever: results: List[RetrievalResult] = [] for row in rows: hash_value = row["hash"] - relation = self.metadata_store.get_relation(hash_value) + relation = self.metadata_store.get_relation(hash_value, include_inactive=False) if relation is None: continue @@ -888,8 +888,8 @@ class DualPathRetriever: entity_name = entity["name"] related_rels = [] - related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) - related_rels.extend(self.metadata_store.get_relations(object=entity_name)) + related_rels.extend(self.metadata_store.get_relations(subject=entity_name, include_inactive=False)) + related_rels.extend(self.metadata_store.get_relations(object=entity_name, include_inactive=False)) for rel in related_rels: if rel["hash"] in seen_relations: @@ -1280,7 +1280,7 @@ class DualPathRetriever: results = [] for hash_value, score in zip(rel_ids, rel_scores): - relation = self.metadata_store.get_relation(hash_value) + relation = self.metadata_store.get_relation(hash_value, include_inactive=False) if relation is None: continue @@ -1378,7 +1378,7 @@ class DualPathRetriever: deduplicated_results.append(result) continue # 检查关系关联的段落是否已存在 - relation = self.metadata_store.get_relation(result.hash_value) + relation = self.metadata_store.get_relation(result.hash_value, include_inactive=False) if relation: # 获取关联的段落 para_rels = self.metadata_store.query(""" diff --git a/src/A_memorix/core/retrieval/graph_relation_recall.py b/src/A_memorix/core/retrieval/graph_relation_recall.py index 9af862f3..46acc3ce 100644 --- a/src/A_memorix/core/retrieval/graph_relation_recall.py +++ b/src/A_memorix/core/retrieval/graph_relation_recall.py @@ -255,7 +255,7 @@ class GraphRelationRecallService: graph_hops: int, graph_seed_entities: Sequence[str], ) -> Optional[GraphRelationCandidate]: - relation = self.metadata_store.get_relation(relation_hash) + relation = self.metadata_store.get_relation(relation_hash, include_inactive=False) if relation is None: return None supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) diff --git a/src/A_memorix/core/retrieval/sparse_bm25.py b/src/A_memorix/core/retrieval/sparse_bm25.py index 1fef9f80..276e8778 100644 --- a/src/A_memorix/core/retrieval/sparse_bm25.py +++ b/src/A_memorix/core/retrieval/sparse_bm25.py @@ -338,6 +338,7 @@ class SparseBM25Index: match_query=match_query, limit=max(1, int(k)), max_doc_len=self.config.relation_max_doc_len, + include_inactive=False, conn=self._conn, ) out: List[Dict[str, Any]] = [] diff --git a/src/A_memorix/core/runtime/sdk_memory_kernel.py b/src/A_memorix/core/runtime/sdk_memory_kernel.py index 0da4e02d..c006838a 100644 --- a/src/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/src/A_memorix/core/runtime/sdk_memory_kernel.py @@ -6,10 +6,16 @@ import pickle import time import uuid from dataclasses import dataclass +from datetime import datetime, timedelta from pathlib import Path from typing import Any, Callable, Coroutine, Dict, Iterable, List, Optional, Sequence +from json_repair import repair_json + from src.common.logger import get_logger +from src.config.config import global_config +from src.services import message_service as message_api +from src.services.llm_service import LLMServiceClient from ...paths import default_data_dir, resolve_repo_path from ..embedding import create_embedding_api_adapter @@ -180,6 +186,7 @@ class SDKMemoryKernel: "since": None, "last_check": None, } + self._feedback_classifier: Optional[LLMServiceClient] = None def _cfg(self, key: str, default: Any = None) -> Any: current: Any = self.config @@ -1096,7 +1103,7 @@ class SDKMemoryKernel: person=request.person_id or None, source=self._chat_source(request.chat_id), ) - hits = [self._episode_hit(row) for row in rows] + hits = self._filter_episode_hits([self._episode_hit(row) for row in rows]) return {"summary": self._summary(hits), "hits": hits} if mode == "aggregate": @@ -1115,6 +1122,7 @@ class SDKMemoryKernel: for item in hits: item.setdefault("metadata", {}) filtered = self._filter_hits(hits, request.person_id) + filtered = self._filter_user_visible_hits(filtered) return {"summary": self._summary(filtered), "hits": filtered} query_type = mode @@ -1148,24 +1156,80 @@ class SDKMemoryKernel: hits = [self._retrieval_result_hit(item) for item in result.results] filtered = self._filter_hits(hits, request.person_id) + filtered = self._filter_user_visible_hits(filtered) return {"summary": self._summary(filtered), "hits": filtered} - async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: - del chat_id - await self.initialize() + @staticmethod + def _empty_person_profile_response(*, person_id: str = "", person_name: str = "") -> Dict[str, Any]: + return { + "summary": "", + "traits": [], + "evidence": [], + "person_id": str(person_id or "").strip(), + "person_name": str(person_name or "").strip(), + "profile_source": "", + "has_manual_override": False, + } + + async def _query_person_profile_with_feedback_refresh( + self, + *, + person_id: str = "", + person_keyword: str = "", + limit: int = 10, + force_refresh: bool = False, + source_note: str, + ) -> Dict[str, Any]: assert self.metadata_store is not None assert self.person_profile_service is not None - self._mark_person_active(person_id) - profile = await self.person_profile_service.query_person_profile( - person_id=person_id, - top_k=max(4, int(limit or 10)), - source_note="sdk_memory_kernel.get_person_profile", - ) - if not profile.get("success"): - return {"summary": "", "traits": [], "evidence": []} - evidence = [] - for hash_value in profile.get("evidence_ids", [])[: max(1, int(limit))]: + pid = str(person_id or "").strip() + if not pid and person_keyword: + pid = self.person_profile_service.resolve_person_id(str(person_keyword or "").strip()) + + dirty_request = self.metadata_store.get_person_profile_refresh_request(pid) if pid else None + should_force_refresh = bool(force_refresh) + if ( + pid + and self._feedback_cfg_profile_refresh_enabled() + and self._feedback_cfg_profile_force_refresh_on_read() + and isinstance(dirty_request, dict) + and str(dirty_request.get("status", "") or "").strip().lower() in {"pending", "running", "failed"} + ): + should_force_refresh = True + + profile = await self.person_profile_service.query_person_profile( + person_id=pid, + person_keyword=str(person_keyword or "").strip(), + top_k=max(1, int(limit or 10)), + force_refresh=should_force_refresh, + source_note=source_note, + ) + payload = profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} + if dirty_request: + payload["feedback_refresh_request"] = dirty_request + if should_force_refresh and dirty_request and not bool(payload.get("success")): + payload.setdefault("error", "feedback_refresh_failed") + payload["feedback_refresh_failed"] = True + return payload + + def _build_person_profile_response( + self, + profile: Dict[str, Any], + *, + requested_person_id: str, + limit: int, + ) -> Dict[str, Any]: + assert self.metadata_store is not None + if not bool(profile.get("success")): + return self._empty_person_profile_response( + person_id=str(profile.get("person_id", "") or requested_person_id), + person_name=str(profile.get("person_name", "") or ""), + ) + + evidence: List[Dict[str, Any]] = [] + evidence_limit = max(1, int(limit or 10)) + for hash_value in profile.get("evidence_ids", [])[:evidence_limit]: paragraph = self.metadata_store.get_paragraph(hash_value) if paragraph is not None: evidence.append( @@ -1198,18 +1262,32 @@ class SDKMemoryKernel: } ) + evidence = self._filter_user_visible_hits(evidence) text = str(profile.get("profile_text", "") or "").strip() traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] return { "summary": text, "traits": traits, "evidence": evidence, - "person_id": str(profile.get("person_id", "") or person_id), + "person_id": str(profile.get("person_id", "") or requested_person_id), "person_name": str(profile.get("person_name", "") or ""), "profile_source": str(profile.get("profile_source", "") or "auto_snapshot"), "has_manual_override": bool(profile.get("has_manual_override", False)), } + async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: + del chat_id + await self.initialize() + assert self.metadata_store is not None + assert self.person_profile_service is not None + self._mark_person_active(person_id) + profile = await self._query_person_profile_with_feedback_refresh( + person_id=person_id, + limit=max(4, int(limit or 10)), + source_note="sdk_memory_kernel.get_person_profile", + ) + return self._build_person_profile_response(profile, requested_person_id=person_id, limit=limit) + async def refresh_person_profile(self, person_id: str, limit: int = 10, *, mark_active: bool = True) -> Dict[str, Any]: await self.initialize() assert self.person_profile_service @@ -1298,12 +1376,22 @@ class SDKMemoryKernel: "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" )[0]["c"] backfill = self._paragraph_vector_backfill_counts() + episode_rebuild_summary = self.metadata_store.get_episode_source_rebuild_summary() + episode_rebuild_counts = episode_rebuild_summary.get("counts", {}) if isinstance(episode_rebuild_summary, dict) else {} return { "paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "episode_pending": int(pending or 0), + "stale_paragraph_marks": int(stats.get("stale_paragraph_mark_count", 0) or 0), + "profile_refresh_pending": int(stats.get("person_profile_refresh_pending_count", 0) or 0), + "profile_refresh_failed": int(stats.get("person_profile_refresh_failed_count", 0) or 0), + "episode_rebuild_pending": int( + (episode_rebuild_counts.get("pending", 0) or 0) + + (episode_rebuild_counts.get("running", 0) or 0) + + (episode_rebuild_counts.get("failed", 0) or 0) + ), "paragraph_vector_backfill_pending": int(backfill.get("pending", 0) or 0), "paragraph_vector_backfill_failed": int(backfill.get("failed", 0) or 0), "last_maintenance_at": self._last_maintenance_at, @@ -1559,15 +1647,27 @@ class SDKMemoryKernel: act = str(action or "").strip().lower() if act == "query": - profile = await self.person_profile_service.query_person_profile( + profile = await self._query_person_profile_with_feedback_refresh( person_id=str(kwargs.get("person_id", "") or "").strip(), person_keyword=str(kwargs.get("person_keyword", "") or kwargs.get("keyword", "") or "").strip(), - top_k=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), + limit=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), force_refresh=bool(kwargs.get("force_refresh", False)), source_note="sdk_memory_kernel.memory_profile_admin.query", ) return profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} + if act == "status": + summary = self.metadata_store.get_person_profile_refresh_summary( + failed_limit=max(1, int(kwargs.get("limit", 20) or 20)) + ) + return {"success": True, **summary} + + if act == "process_pending": + result = await self._process_feedback_profile_refresh_batch( + limit=max(1, int(kwargs.get("limit", self._feedback_cfg_reconcile_batch_size()) or self._feedback_cfg_reconcile_batch_size())) + ) + return {"success": True, **result} + if act == "list": limit = max(1, int(kwargs.get("limit", 50) or 50)) rows = self.metadata_store.query( @@ -1621,6 +1721,39 @@ class SDKMemoryKernel: return {"success": False, "error": f"不支持的 profile action: {act}"} + async def memory_feedback_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + + act = str(action or "").strip().lower() + if act == "list": + items = self.metadata_store.list_feedback_tasks( + limit=max(1, int(kwargs.get("limit", 50) or 50)), + statuses=self._tokens(kwargs.get("status") or kwargs.get("statuses")), + rollback_statuses=self._tokens(kwargs.get("rollback_status") or kwargs.get("rollback_statuses")), + query=str(kwargs.get("query", "") or "").strip(), + ) + return { + "success": True, + "items": [self._build_feedback_task_summary(task) for task in items], + "count": len(items), + } + + if act == "get": + task = self.metadata_store.get_feedback_task_by_id(int(kwargs.get("task_id", 0) or 0)) + if task is None: + return {"success": False, "error": "反馈纠错任务不存在"} + return {"success": True, "task": self._build_feedback_task_detail(task)} + + if act == "rollback": + return await self._rollback_feedback_task( + task_id=int(kwargs.get("task_id", 0) or 0), + requested_by=str(kwargs.get("requested_by", "") or "").strip(), + reason=str(kwargs.get("reason", "") or "").strip(), + ) + + return {"success": False, "error": f"不支持的 feedback action: {act}"} + async def memory_runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: await self.initialize() act = str(action or "").strip().lower() @@ -1992,7 +2125,7 @@ class SDKMemoryKernel: person=request.person_id or None, source=self._chat_source(request.chat_id), ) - hits = [self._episode_hit(row) for row in rows] + hits = self._filter_episode_hits([self._episode_hit(row) for row in rows]) return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} def _persist(self) -> None: @@ -2012,6 +2145,8 @@ class SDKMemoryKernel: self._ensure_background_task("paragraph_vector_backfill", self._paragraph_vector_backfill_loop) self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop) self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop) + self._ensure_background_task("feedback_correction", self._feedback_correction_loop) + self._ensure_background_task("feedback_correction_reconcile", self._feedback_correction_reconcile_loop) def _ensure_background_task( self, @@ -2141,6 +2276,1622 @@ class SDKMemoryKernel: except Exception as exc: logger.warning(f"person_profile_refresh loop 异常: {exc}") + @staticmethod + def _relation_status_is_inactive(status: Optional[Dict[str, Any]]) -> bool: + if status is None: + return True + return bool(status.get("is_inactive")) + + def _load_paragraph_stale_marks( + self, + paragraph_hashes: Sequence[str], + ) -> tuple[Dict[str, List[Dict[str, Any]]], Dict[str, Dict[str, Any]]]: + if self.metadata_store is None: + return {}, {} + normalized = self._tokens(paragraph_hashes) + if not normalized: + return {}, {} + marks_by_paragraph = self.metadata_store.get_paragraph_stale_relation_marks_batch(normalized) + relation_hashes = self._tokens( + mark.get("relation_hash", "") + for marks in marks_by_paragraph.values() + for mark in marks + if isinstance(mark, dict) + ) + status_map = self.metadata_store.get_relation_status_batch(relation_hashes) if relation_hashes else {} + return marks_by_paragraph, status_map + + def _paragraph_hidden_by_stale_marks( + self, + paragraph_hash: str, + *, + marks_by_paragraph: Optional[Dict[str, List[Dict[str, Any]]]] = None, + relation_status_map: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> bool: + token = str(paragraph_hash or "").strip() + if not token or self.metadata_store is None or not self._feedback_cfg_paragraph_hard_filter_enabled(): + return False + + marks_map = marks_by_paragraph if isinstance(marks_by_paragraph, dict) else {} + status_map = relation_status_map if isinstance(relation_status_map, dict) else {} + if not marks_map: + marks_map, status_map = self._load_paragraph_stale_marks([token]) + elif not status_map: + relation_hashes = self._tokens( + mark.get("relation_hash", "") + for mark in marks_map.get(token, []) + if isinstance(mark, dict) + ) + status_map = self.metadata_store.get_relation_status_batch(relation_hashes) if relation_hashes else {} + + for mark in marks_map.get(token, []): + relation_hash = str((mark or {}).get("relation_hash", "") or "").strip() + if not relation_hash: + continue + if self._relation_status_is_inactive(status_map.get(relation_hash)): + return True + return False + + def _filter_episode_hits(self, hits: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if self.metadata_store is None or not self._feedback_cfg_episode_query_block_enabled(): + return hits + filtered: List[Dict[str, Any]] = [] + for item in hits: + if str(item.get("type", "") or "").strip() != "episode": + filtered.append(item) + continue + metadata = item.get("metadata") if isinstance(item.get("metadata"), dict) else {} + source = str(metadata.get("source", "") or item.get("source", "") or "").strip() + if source and self.metadata_store.is_episode_source_query_blocked(source): + continue + filtered.append(item) + return filtered + + def _filter_user_visible_hits(self, hits: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + return self._filter_active_relation_hits(self._filter_episode_hits(hits)) + + def _resolve_feedback_related_person_ids( + self, + *, + old_relation_rows: Sequence[Dict[str, Any]], + corrected_relations: Sequence[Dict[str, Any]], + ) -> List[str]: + candidates = self._tokens( + value + for row in list(old_relation_rows) + list(corrected_relations) + if isinstance(row, dict) + for value in (row.get("subject"), row.get("object")) + ) + resolved: List[str] = [] + seen = set() + for candidate in candidates: + person_id = PersonProfileService.resolve_person_id(candidate) + if not person_id or person_id in seen: + continue + seen.add(person_id) + resolved.append(person_id) + return resolved + + def _mark_feedback_stale_paragraphs( + self, + *, + task_id: int, + query_tool_id: str, + relation_hashes: Sequence[str], + reason: str, + ) -> Dict[str, List[str]]: + if self.metadata_store is None or not self._feedback_cfg_paragraph_mark_enabled(): + return {} + + relation_tokens = self._tokens(relation_hashes) + paragraph_map = self.metadata_store.get_paragraph_hashes_by_relation_hashes(relation_tokens) + for relation_hash, paragraph_hashes in paragraph_map.items(): + for paragraph_hash in paragraph_hashes: + self.metadata_store.upsert_paragraph_stale_relation_mark( + paragraph_hash=paragraph_hash, + relation_hash=relation_hash, + query_tool_id=query_tool_id, + task_id=task_id, + reason=reason, + ) + return paragraph_map + + def _enqueue_feedback_episode_rebuilds( + self, + *, + paragraph_hashes: Sequence[str], + session_id: str, + include_correction_source: bool, + ) -> List[str]: + if self.metadata_store is None or not self._feedback_cfg_episode_rebuild_enabled(): + return [] + + sources = self._tokens( + row.get("source", "") + for row in self._load_paragraph_rows(paragraph_hashes) + if isinstance(row, dict) + ) + correction_source = self._chat_source(session_id) + if include_correction_source and correction_source: + sources = self._merge_tokens(sources, [correction_source]) + + queued: List[str] = [] + for source in sources: + if self.metadata_store.enqueue_episode_source_rebuild(source, reason="feedback_correction"): + queued.append(source) + return queued + + def _enqueue_feedback_profile_refreshes( + self, + *, + person_ids: Sequence[str], + query_tool_id: str, + ) -> List[str]: + if self.metadata_store is None or not self._feedback_cfg_profile_refresh_enabled(): + return [] + queued: List[str] = [] + for person_id in self._tokens(person_ids): + payload = self.metadata_store.enqueue_person_profile_refresh( + person_id=person_id, + reason="feedback_correction", + source_query_tool_id=query_tool_id, + ) + if isinstance(payload, dict): + queued.append(person_id) + return queued + + @staticmethod + def _feedback_affected_counts(task: Dict[str, Any]) -> Dict[str, int]: + decision_payload = task.get("decision_payload") if isinstance(task.get("decision_payload"), dict) else {} + apply_result = decision_payload.get("apply_result") if isinstance(decision_payload.get("apply_result"), dict) else {} + rollback_plan = task.get("rollback_plan") if isinstance(task.get("rollback_plan"), dict) else {} + corrected_write = rollback_plan.get("corrected_write") if isinstance(rollback_plan.get("corrected_write"), dict) else {} + return { + "relations": len(list(apply_result.get("relation_hashes") or rollback_plan.get("forgotten_relations") or [])), + "stale_paragraphs": len(list(apply_result.get("stale_paragraph_hashes") or rollback_plan.get("stale_marks") or [])), + "episode_sources": len(list(apply_result.get("episode_rebuild_sources") or rollback_plan.get("episode_sources") or [])), + "profile_person_ids": len(list(apply_result.get("profile_refresh_person_ids") or rollback_plan.get("profile_person_ids") or [])), + "correction_paragraphs": len(list(corrected_write.get("paragraph_hashes") or [])), + "corrected_relations": len(list(corrected_write.get("corrected_relations") or [])), + } + + def _build_feedback_rollback_plan_summary(self, rollback_plan: Dict[str, Any]) -> Dict[str, Any]: + corrected_write = rollback_plan.get("corrected_write") if isinstance(rollback_plan.get("corrected_write"), dict) else {} + return { + "forgotten_relations": list(rollback_plan.get("forgotten_relations") or []), + "corrected_write": corrected_write, + "stale_marks": list(rollback_plan.get("stale_marks") or []), + "episode_sources": self._tokens(rollback_plan.get("episode_sources")), + "profile_person_ids": self._tokens(rollback_plan.get("profile_person_ids")), + "affected_counts": { + "forgotten_relations": len(list(rollback_plan.get("forgotten_relations") or [])), + "corrected_relations": len(list(corrected_write.get("corrected_relations") or [])), + "correction_paragraphs": len(list(corrected_write.get("paragraph_hashes") or [])), + "stale_marks": len(list(rollback_plan.get("stale_marks") or [])), + "episode_sources": len(self._tokens(rollback_plan.get("episode_sources"))), + "profile_person_ids": len(self._tokens(rollback_plan.get("profile_person_ids"))), + }, + } + + def _build_feedback_task_summary(self, task: Dict[str, Any]) -> Dict[str, Any]: + query_snapshot = task.get("query_snapshot") if isinstance(task.get("query_snapshot"), dict) else {} + decision_payload = task.get("decision_payload") if isinstance(task.get("decision_payload"), dict) else {} + return { + "task_id": int(task.get("id", 0) or 0), + "query_tool_id": str(task.get("query_tool_id", "") or "").strip(), + "session_id": str(task.get("session_id", "") or "").strip(), + "query_text": str(query_snapshot.get("query", "") or "").strip(), + "query_timestamp": task.get("query_timestamp"), + "task_status": str(task.get("status", "") or "").strip().lower(), + "decision": str(decision_payload.get("decision", "") or "").strip().lower(), + "decision_confidence": float(decision_payload.get("confidence", 0.0) or 0.0), + "feedback_message_count": int(decision_payload.get("feedback_message_count", 0) or 0), + "rollback_status": str(task.get("rollback_status", "") or "none").strip().lower() or "none", + "affected_counts": self._feedback_affected_counts(task), + "created_at": task.get("created_at"), + "updated_at": task.get("updated_at"), + } + + def _build_feedback_task_detail(self, task: Dict[str, Any]) -> Dict[str, Any]: + detail = self._build_feedback_task_summary(task) + detail.update( + { + "query_snapshot": task.get("query_snapshot") if isinstance(task.get("query_snapshot"), dict) else {}, + "decision_payload": task.get("decision_payload") if isinstance(task.get("decision_payload"), dict) else {}, + "rollback_plan_summary": self._build_feedback_rollback_plan_summary( + task.get("rollback_plan") if isinstance(task.get("rollback_plan"), dict) else {} + ), + "rollback_result": task.get("rollback_result") if isinstance(task.get("rollback_result"), dict) else {}, + "rollback_error": str(task.get("rollback_error", "") or "").strip(), + "rollback_requested_by": str(task.get("rollback_requested_by", "") or "").strip(), + "rollback_reason": str(task.get("rollback_reason", "") or "").strip(), + "rollback_requested_at": task.get("rollback_requested_at"), + "rolled_back_at": task.get("rolled_back_at"), + "action_logs": self.metadata_store.list_feedback_action_logs(int(task.get("id", 0) or 0)) + if self.metadata_store is not None + else [], + } + ) + return detail + + def _soft_delete_feedback_correction_paragraphs(self, paragraph_hashes: Sequence[str]) -> Dict[str, Any]: + assert self.metadata_store is not None + hashes = self._tokens(paragraph_hashes) + if not hashes: + return {"deleted_hashes": [], "deleted_external_refs": []} + + paragraph_rows = {hash_value: self.metadata_store.get_paragraph(hash_value) for hash_value in hashes} + self.metadata_store.mark_as_deleted(hashes, "paragraph") + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + cursor.execute( + f"DELETE FROM paragraph_entities WHERE paragraph_hash IN ({','.join(['?'] * len(hashes))})", + tuple(hashes), + ) + cursor.execute( + f"DELETE FROM paragraph_relations WHERE paragraph_hash IN ({','.join(['?'] * len(hashes))})", + tuple(hashes), + ) + conn.commit() + deleted_external_refs = self.metadata_store.delete_external_memory_refs_by_paragraphs(hashes) + return { + "deleted_hashes": hashes, + "paragraph_rows": paragraph_rows, + "deleted_external_refs": deleted_external_refs, + } + + async def _rollback_feedback_task( + self, + *, + task_id: int, + requested_by: str, + reason: str, + ) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + + task = self.metadata_store.get_feedback_task_by_id(task_id) + if task is None: + return {"success": False, "error": "反馈纠错任务不存在"} + if str(task.get("status", "") or "").strip().lower() != "applied": + return {"success": False, "error": "仅 applied 的反馈纠错任务允许回退"} + rollback_status = str(task.get("rollback_status", "") or "none").strip().lower() + if rollback_status == "rolled_back": + return { + "success": True, + "already_rolled_back": True, + "task": self._build_feedback_task_detail(task), + "result": task.get("rollback_result") if isinstance(task.get("rollback_result"), dict) else {}, + } + if rollback_status == "running": + return {"success": False, "error": "该反馈纠错任务正在回退中", "task": self._build_feedback_task_detail(task)} + + query_tool_id = str(task.get("query_tool_id", "") or "").strip() + rollback_plan = task.get("rollback_plan") if isinstance(task.get("rollback_plan"), dict) else {} + if not rollback_plan: + running_task = self.metadata_store.mark_feedback_task_rollback_running( + task_id=task_id, + requested_by=requested_by, + reason=reason, + ) + if running_task is None: + latest_task = self.metadata_store.get_feedback_task_by_id(task_id) + latest_status = str((latest_task or {}).get("rollback_status", "") or "none").strip().lower() + if latest_status == "running": + return { + "success": False, + "error": "该反馈纠错任务正在回退中", + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + } + if latest_status == "rolled_back": + return { + "success": True, + "already_rolled_back": True, + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + "result": (latest_task or {}).get("rollback_result") if isinstance((latest_task or {}).get("rollback_result"), dict) else {}, + } + return { + "success": False, + "error": "无法进入回退状态", + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + } + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_error", + reason="rollback_plan_missing", + ) + failed = self.metadata_store.finalize_feedback_task_rollback( + task_id=task_id, + rollback_status="error", + rollback_error="rollback_plan_missing", + ) + return {"success": False, "error": "缺少 rollback_plan,无法回退", "task": failed} + + running_task = self.metadata_store.mark_feedback_task_rollback_running( + task_id=task_id, + requested_by=requested_by, + reason=reason, + ) + if running_task is None: + latest_task = self.metadata_store.get_feedback_task_by_id(task_id) + latest_status = str((latest_task or {}).get("rollback_status", "") or "none").strip().lower() + if latest_status == "running": + return { + "success": False, + "error": "该反馈纠错任务正在回退中", + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + } + if latest_status == "rolled_back": + return { + "success": True, + "already_rolled_back": True, + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + "result": (latest_task or {}).get("rollback_result") if isinstance((latest_task or {}).get("rollback_result"), dict) else {}, + } + return { + "success": False, + "error": "无法进入回退状态", + "task": self._build_feedback_task_detail(latest_task) if isinstance(latest_task, dict) else None, + } + + result: Dict[str, Any] = { + "task_id": task_id, + "query_tool_id": query_tool_id, + "restored_relation_hashes": [], + "reverted_corrected_relation_hashes": [], + "deleted_correction_paragraph_hashes": [], + "cleared_stale_mark_count": 0, + "episode_sources_queued": [], + "profile_person_ids_queued": [], + "warnings": [], + } + try: + forgotten_relations = rollback_plan.get("forgotten_relations") if isinstance(rollback_plan.get("forgotten_relations"), list) else [] + for item in forgotten_relations: + if not isinstance(item, dict): + continue + relation_hash = str(item.get("hash", "") or "").strip() + snapshot = item.get("before_status") if isinstance(item.get("before_status"), dict) else {} + if not relation_hash or not snapshot: + continue + before_status = self.metadata_store.get_relation_status_batch([relation_hash]).get(relation_hash, {}) + after_status = self.metadata_store.restore_relation_status_from_snapshot(relation_hash, snapshot) + if after_status is None: + result["warnings"].append(f"restore_old_relation_failed:{relation_hash}") + continue + result["restored_relation_hashes"].append(relation_hash) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_restore_relation", + target_hash=relation_hash, + before_payload=before_status, + after_payload=after_status, + reason=reason, + ) + + corrected_write = rollback_plan.get("corrected_write") if isinstance(rollback_plan.get("corrected_write"), dict) else {} + correction_paragraph_hashes = self._tokens(corrected_write.get("paragraph_hashes")) + deleted_paragraphs = self._soft_delete_feedback_correction_paragraphs(correction_paragraph_hashes) + result["deleted_correction_paragraph_hashes"] = deleted_paragraphs.get("deleted_hashes", []) + paragraph_rows = deleted_paragraphs.get("paragraph_rows") if isinstance(deleted_paragraphs.get("paragraph_rows"), dict) else {} + deleted_external_refs = deleted_paragraphs.get("deleted_external_refs") if isinstance(deleted_paragraphs.get("deleted_external_refs"), list) else [] + deleted_ref_map: Dict[str, List[Dict[str, Any]]] = {} + for ref in deleted_external_refs: + if not isinstance(ref, dict): + continue + paragraph_hash = str(ref.get("paragraph_hash", "") or "").strip() + if not paragraph_hash: + continue + deleted_ref_map.setdefault(paragraph_hash, []).append(ref) + for paragraph_hash in result["deleted_correction_paragraph_hashes"]: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_delete_correction_paragraph", + target_hash=paragraph_hash, + before_payload={ + "paragraph": paragraph_rows.get(paragraph_hash) if isinstance(paragraph_rows.get(paragraph_hash), dict) else {}, + "external_refs": deleted_ref_map.get(paragraph_hash, []), + }, + reason=reason, + ) + + corrected_relations = corrected_write.get("corrected_relations") if isinstance(corrected_write.get("corrected_relations"), list) else [] + for item in corrected_relations: + if not isinstance(item, dict): + continue + relation_hash = str(item.get("hash", "") or "").strip() + if not relation_hash: + continue + before_status = self.metadata_store.get_relation_status_batch([relation_hash]).get(relation_hash, {}) + if bool(item.get("existed_before")): + snapshot = item.get("before_status") if isinstance(item.get("before_status"), dict) else {} + after_status = self.metadata_store.restore_relation_status_from_snapshot(relation_hash, snapshot) + else: + self.metadata_store.update_relations_protection([relation_hash], protected_until=0.0, is_pinned=False) + self.metadata_store.mark_relations_inactive([relation_hash], inactive_since=time.time()) + after_status = self.metadata_store.get_relation_status_batch([relation_hash]).get(relation_hash) + if after_status is None: + result["warnings"].append(f"revert_corrected_relation_failed:{relation_hash}") + continue + result["reverted_corrected_relation_hashes"].append(relation_hash) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_revert_corrected_relation", + target_hash=relation_hash, + before_payload=before_status, + after_payload=after_status, + reason=reason, + ) + + stale_marks_raw = rollback_plan.get("stale_marks") if isinstance(rollback_plan.get("stale_marks"), list) else [] + stale_marks: List[tuple[str, str]] = [] + for item in stale_marks_raw: + if not isinstance(item, dict): + continue + paragraph_hash = str(item.get("paragraph_hash", "") or "").strip() + relation_hash = str(item.get("relation_hash", "") or "").strip() + if not paragraph_hash or not relation_hash: + continue + stale_marks.append((paragraph_hash, relation_hash)) + result["cleared_stale_mark_count"] = self.metadata_store.delete_paragraph_stale_relation_marks(stale_marks) + for paragraph_hash, relation_hash in stale_marks: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_clear_stale_mark", + target_hash=paragraph_hash, + after_payload={"relation_hash": relation_hash}, + reason=reason, + ) + + for source in self._tokens(rollback_plan.get("episode_sources")): + if self.metadata_store.enqueue_episode_source_rebuild(source, reason="feedback_correction_rollback"): + result["episode_sources_queued"].append(source) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_enqueue_episode_rebuild", + target_hash=source, + reason=reason, + ) + + for person_id in self._tokens(rollback_plan.get("profile_person_ids")): + payload = self.metadata_store.enqueue_person_profile_refresh( + person_id=person_id, + reason="feedback_correction_rollback", + source_query_tool_id=query_tool_id, + ) + if not isinstance(payload, dict): + continue + result["profile_person_ids_queued"].append(person_id) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_enqueue_profile_refresh", + target_hash=person_id, + reason=reason, + ) + + self._rebuild_graph_from_metadata() + self._persist() + final_task = self.metadata_store.finalize_feedback_task_rollback( + task_id=task_id, + rollback_status="rolled_back", + rollback_result=result, + ) + return {"success": True, "result": result, "task": self._build_feedback_task_detail(final_task or running_task)} + except Exception as exc: + logger.warning(f"反馈纠错回退失败: task_id={task_id} err={exc}", exc_info=True) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="rollback_error", + reason=str(exc), + after_payload=result if result else None, + ) + final_task = self.metadata_store.finalize_feedback_task_rollback( + task_id=task_id, + rollback_status="error", + rollback_result=result if result else None, + rollback_error=str(exc), + ) + return { + "success": False, + "error": str(exc), + "result": result, + "task": self._build_feedback_task_detail(final_task or running_task), + } + + async def _process_feedback_profile_refresh_batch(self, *, limit: int) -> Dict[str, Any]: + if self.metadata_store is None or self.person_profile_service is None: + return {"processed": 0, "refreshed": 0, "failed": 0, "items": [], "failures": []} + + rows = self.metadata_store.fetch_person_profile_refresh_batch( + limit=max(1, int(limit or 1)), + max_retry=max(1, int(self._cfg("person_profile.max_retry", 3) or 3)), + ) + items: List[Dict[str, Any]] = [] + failures: List[Dict[str, Any]] = [] + for row in rows: + person_id = str(row.get("person_id", "") or "").strip() + requested_at = row.get("requested_at") + if not person_id: + continue + if not self.metadata_store.mark_person_profile_refresh_running(person_id, requested_at=requested_at): + continue + try: + profile = await self.refresh_person_profile( + person_id, + limit=max(4, int(self._cfg("person_profile.top_k_evidence", 12) or 12)), + mark_active=False, + ) + if isinstance(profile, dict) and bool(profile.get("success")): + self.metadata_store.mark_person_profile_refresh_done(person_id, requested_at=requested_at) + items.append( + { + "person_id": person_id, + "profile_version": int(profile.get("profile_version", 0) or 0), + "profile_source": str(profile.get("profile_source", "") or ""), + } + ) + else: + error = str((profile or {}).get("error", "") or "person profile refresh failed") + self.metadata_store.mark_person_profile_refresh_failed(person_id, error, requested_at=requested_at) + failures.append({"person_id": person_id, "error": error}) + except Exception as exc: + error = str(exc)[:500] + self.metadata_store.mark_person_profile_refresh_failed(person_id, error, requested_at=requested_at) + failures.append({"person_id": person_id, "error": error}) + return { + "processed": len(items) + len(failures), + "refreshed": len(items), + "failed": len(failures), + "items": items, + "failures": failures, + } + + async def _process_feedback_episode_rebuild_batch(self, *, limit: int) -> Dict[str, Any]: + if self.metadata_store is None or self.episode_service is None: + return {"processed": 0, "rebuilt": 0, "failed": 0, "items": [], "failures": []} + + rows = self.metadata_store.fetch_episode_source_rebuild_batch( + limit=max(1, int(limit or 1)), + max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3) or 3)), + ) + items: List[Dict[str, Any]] = [] + failures: List[Dict[str, Any]] = [] + for row in rows: + source = str(row.get("source", "") or "").strip() + requested_at = row.get("requested_at") + if not source: + continue + if not self.metadata_store.mark_episode_source_running(source, requested_at=requested_at): + continue + try: + result = await self.episode_service.rebuild_source(source) + self.metadata_store.mark_episode_source_done(source, requested_at=requested_at) + items.append(result if isinstance(result, dict) else {"source": source}) + except Exception as exc: + error = str(exc)[:500] + self.metadata_store.mark_episode_source_failed(source, error, requested_at=requested_at) + failures.append({"source": source, "error": error}) + if items or failures: + self._persist() + return { + "processed": len(items) + len(failures), + "rebuilt": len(items), + "failed": len(failures), + "items": items, + "failures": failures, + } + + async def _feedback_correction_reconcile_loop(self) -> None: + try: + while not self._background_stopping: + await asyncio.sleep(self._feedback_cfg_reconcile_interval_seconds()) + if self._background_stopping: + break + if self.metadata_store is None or not self._feedback_cfg_enabled(): + continue + batch_size = self._feedback_cfg_reconcile_batch_size() + if self._feedback_cfg_profile_refresh_enabled(): + await self._process_feedback_profile_refresh_batch(limit=batch_size) + if self._feedback_cfg_episode_rebuild_enabled(): + await self._process_feedback_episode_rebuild_batch(limit=batch_size) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"feedback_correction_reconcile loop 异常: {exc}") + + @staticmethod + def _coerce_datetime(value: Any) -> Optional[datetime]: + if isinstance(value, datetime): + return value + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value)) + except Exception: + return None + text = str(value or "").strip() + if not text: + return None + try: + return datetime.fromisoformat(text) + except Exception: + return None + + @staticmethod + def _feedback_signal_tokens() -> tuple[str, ...]: + return ( + "不对", + "错了", + "你记错", + "记错了", + "不是", + "并不是", + "纠正", + "更正", + "改成", + "应该是", + "实际是", + "说反了", + ) + + @classmethod + def _feedback_contains_signal(cls, text: str) -> bool: + content = str(text or "").strip().lower() + if not content: + return False + return any(token in content for token in cls._feedback_signal_tokens()) + + @staticmethod + def _feedback_noise(text: str) -> bool: + content = str(text or "").strip() + if not content: + return True + if SDKMemoryKernel._feedback_contains_signal(content): + return False + if len(content) <= 2: + return True + markers = ( + "哈哈", + "好的", + "收到", + "谢谢", + "嗯嗯", + "晚安", + "早安", + "拜拜", + "在吗", + ) + return len(content) <= 8 and any(marker in content for marker in markers) + + @staticmethod + def _safe_json_loads(raw: Any) -> Dict[str, Any]: + if isinstance(raw, dict): + return raw + text = str(raw or "").strip() + if not text: + return {} + try: + repaired = repair_json(text) + payload = json.loads(repaired) if isinstance(repaired, str) else repaired + except Exception: + payload = None + return payload if isinstance(payload, dict) else {} + + @staticmethod + def _feedback_cfg_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_enabled", False)) + + @staticmethod + def _feedback_cfg_window_hours() -> float: + memory_cfg = getattr(global_config, "memory", None) + return max(0.1, float(getattr(memory_cfg, "feedback_correction_window_hours", 12.0) or 12.0)) + + @staticmethod + def _feedback_cfg_check_interval_seconds() -> float: + memory_cfg = getattr(global_config, "memory", None) + minutes = max(1, int(getattr(memory_cfg, "feedback_correction_check_interval_minutes", 30) or 30)) + return float(minutes) * 60.0 + + @staticmethod + def _feedback_cfg_batch_size() -> int: + memory_cfg = getattr(global_config, "memory", None) + return max(1, int(getattr(memory_cfg, "feedback_correction_batch_size", 20) or 20)) + + @staticmethod + def _feedback_cfg_auto_apply_threshold() -> float: + memory_cfg = getattr(global_config, "memory", None) + value = float(getattr(memory_cfg, "feedback_correction_auto_apply_threshold", 0.85) or 0.85) + return min(1.0, max(0.0, value)) + + @staticmethod + def _feedback_cfg_max_messages() -> int: + memory_cfg = getattr(global_config, "memory", None) + return max(1, int(getattr(memory_cfg, "feedback_correction_max_feedback_messages", 30) or 30)) + + @staticmethod + def _feedback_cfg_prefilter_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_prefilter_enabled", True)) + + @staticmethod + def _feedback_cfg_paragraph_mark_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_paragraph_mark_enabled", True)) + + @staticmethod + def _feedback_cfg_paragraph_hard_filter_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)) + + @staticmethod + def _feedback_cfg_profile_refresh_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_profile_refresh_enabled", True)) + + @staticmethod + def _feedback_cfg_profile_force_refresh_on_read() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_profile_force_refresh_on_read", True)) + + @staticmethod + def _feedback_cfg_episode_rebuild_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_episode_rebuild_enabled", True)) + + @staticmethod + def _feedback_cfg_episode_query_block_enabled() -> bool: + memory_cfg = getattr(global_config, "memory", None) + return bool(getattr(memory_cfg, "feedback_correction_episode_query_block_enabled", True)) + + @staticmethod + def _feedback_cfg_reconcile_interval_seconds() -> float: + memory_cfg = getattr(global_config, "memory", None) + minutes = max(1, int(getattr(memory_cfg, "feedback_correction_reconcile_interval_minutes", 5) or 5)) + return float(minutes) * 60.0 + + @staticmethod + def _feedback_cfg_reconcile_batch_size() -> int: + memory_cfg = getattr(global_config, "memory", None) + return max(1, int(getattr(memory_cfg, "feedback_correction_reconcile_batch_size", 20) or 20)) + + @classmethod + def _feedback_cfg_window_label(cls) -> str: + hours = cls._feedback_cfg_window_hours() + if abs(hours - round(hours)) < 1e-9: + return f"{int(round(hours))}h" + return f"{hours:.2f}h" + + async def enqueue_feedback_task( + self, + *, + query_tool_id: str, + session_id: str, + query_timestamp: Any = None, + structured_content: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + if not self._feedback_cfg_enabled(): + return {"success": False, "queued": False, "reason": "feedback_correction_disabled"} + if self.metadata_store is None: + return {"success": False, "queued": False, "reason": "metadata_store_unavailable"} + + clean_tool_id = str(query_tool_id or "").strip() + clean_session_id = str(session_id or "").strip() + if not clean_tool_id or not clean_session_id: + return {"success": False, "queued": False, "reason": "missing_required_fields"} + + content = structured_content if isinstance(structured_content, dict) else {} + hits = content.get("hits") + if not isinstance(hits, list) or not hits: + return {"success": False, "queued": False, "reason": "no_hits"} + + query_time = self._coerce_datetime(query_timestamp) or datetime.now() + due_at = query_time + timedelta(hours=self._feedback_cfg_window_hours()) + saved = self.metadata_store.enqueue_feedback_task( + query_tool_id=clean_tool_id, + session_id=clean_session_id, + query_timestamp=query_time.timestamp(), + due_at=due_at.timestamp(), + query_snapshot=content, + ) + if not isinstance(saved, dict): + return {"success": False, "queued": False, "reason": "db_save_failed"} + + logger.debug( + "反馈纠错任务入队: query_tool_id=%s due_at=%s", + clean_tool_id, + due_at.isoformat(), + ) + return { + "success": True, + "queued": True, + "query_tool_id": clean_tool_id, + "due_at": due_at.isoformat(), + "task": saved, + } + + @staticmethod + def _extract_feedback_messages( + *, + session_id: str, + query_time: datetime, + due_time: datetime, + max_messages: int, + ) -> List[str]: + raw_messages = message_api.get_messages_by_time_in_chat( + chat_id=session_id, + start_time=query_time.timestamp(), + end_time=due_time.timestamp(), + limit=max(1, int(max_messages) * 4), + limit_mode="latest", + filter_mai=True, + filter_command=True, + ) + collected: List[str] = [] + seen = set() + for item in raw_messages: + text = str(getattr(item, "processed_plain_text", "") or "").strip() + if SDKMemoryKernel._feedback_noise(text): + continue + if text in seen: + continue + seen.add(text) + collected.append(text) + if len(collected) > max_messages: + collected = collected[-max_messages:] + return collected + + def _build_feedback_hit_briefs(self, hits: List[Dict[str, Any]], *, limit: int = 12) -> List[Dict[str, Any]]: + briefs: List[Dict[str, Any]] = [] + for raw in hits[: max(1, int(limit))]: + if not isinstance(raw, dict): + continue + metadata = raw.get("metadata") if isinstance(raw.get("metadata"), dict) else {} + subject = str(metadata.get("subject", "") or "").strip() + predicate = str(metadata.get("predicate", "") or "").strip() + obj = str(metadata.get("object", "") or "").strip() + linked_relation_hashes: List[str] = [] + linked_relation_texts: List[str] = [] + + item_type = str(raw.get("type", "") or "").strip() + item_hash = str(raw.get("hash", "") or "").strip() + if item_type == "paragraph" and item_hash and self.metadata_store is not None: + linked_relations = self.metadata_store.get_paragraph_relations(item_hash) + for relation in linked_relations: + relation_hash = str(relation.get("hash", "") or "").strip() + if not relation_hash or relation_hash in linked_relation_hashes: + continue + linked_relation_hashes.append(relation_hash) + rel_subject = str(relation.get("subject", "") or "").strip() + rel_predicate = str(relation.get("predicate", "") or "").strip() + rel_object = str(relation.get("object", "") or "").strip() + relation_text = self._format_relation_text(rel_subject, rel_predicate, rel_object) + if relation_text: + linked_relation_texts.append(relation_text) + if not (subject and predicate and obj): + subject = rel_subject + predicate = rel_predicate + obj = rel_object + briefs.append( + { + "hash": item_hash, + "type": item_type, + "content": str(raw.get("content", "") or "").strip(), + "subject": subject, + "predicate": predicate, + "object": obj, + "linked_relation_hashes": linked_relation_hashes[:6], + "linked_relation_texts": linked_relation_texts[:3], + } + ) + return briefs + + @staticmethod + def _should_invoke_feedback_classifier(feedback_messages: List[str]) -> bool: + if not feedback_messages: + return False + lowered = "\n".join(feedback_messages).lower() + return any(token in lowered for token in SDKMemoryKernel._feedback_signal_tokens()) + + async def _classify_feedback( + self, + *, + query_tool_id: str, + query_text: str, + hit_briefs: List[Dict[str, Any]], + feedback_messages: List[str], + ) -> Dict[str, Any]: + prompt = ( + "你是长期记忆纠错分类器。" + "你会根据“记忆检索命中列表”和“用户后续反馈”判断是否需要修正记忆。" + "请严格输出 JSON 对象,不要输出解释文字。\n\n" + f"query_tool_id: {query_tool_id}\n" + f"原查询: {query_text}\n" + f"候选命中: {json.dumps(hit_briefs, ensure_ascii=False)}\n" + f"反馈消息: {json.dumps(feedback_messages, ensure_ascii=False)}\n\n" + "输出 JSON schema:\n" + "{" + "\"decision\":\"confirm|reject|correct|supplement|none\"," + "\"confidence\":0.0," + "\"target_hashes\":[\"命中列表中的 hash\"]," + "\"corrected_relations\":[{\"subject\":\"\",\"predicate\":\"\",\"object\":\"\",\"confidence\":1.0}]," + "\"reason\":\"\"" + "}\n" + "约束:\n" + "1. 只有当反馈明确指向错误时才输出 reject/correct。\n" + "2. target_hashes 必须来自候选命中 hash。\n" + "3. corrected_relations 仅在 decision=correct 时填写,且必须是明确三元组。\n" + "4. 不确定时输出 decision=none, confidence<=0.5。" + ) + try: + if self._feedback_classifier is None: + self._feedback_classifier = LLMServiceClient( + task_name="utils", + request_type="memory_feedback_correction", + ) + response = await self._feedback_classifier.generate_response(prompt) + payload = self._safe_json_loads(getattr(response, "response", "")) + except Exception as exc: + logger.warning(f"反馈分类器调用失败: {exc}") + payload = {} + return payload + + @staticmethod + def _normalize_feedback_decision( + payload: Dict[str, Any], + *, + hit_hashes: Sequence[str], + ) -> Dict[str, Any]: + allowed = {"confirm", "reject", "correct", "supplement", "none"} + decision = str(payload.get("decision", "") or "").strip().lower() + if decision not in allowed: + decision = "none" + try: + confidence = float(payload.get("confidence", 0.0) or 0.0) + except (TypeError, ValueError): + confidence = 0.0 + confidence = min(1.0, max(0.0, confidence)) + + valid_hashes = {str(item or "").strip() for item in hit_hashes if str(item or "").strip()} + target_hashes_raw = payload.get("target_hashes") + if isinstance(target_hashes_raw, str): + target_hashes_candidates = [target_hashes_raw] + elif isinstance(target_hashes_raw, list): + target_hashes_candidates = target_hashes_raw + else: + target_hashes_candidates = [] + target_hashes = [ + str(item or "").strip() + for item in target_hashes_candidates + if str(item or "").strip() in valid_hashes + ] + + corrected_relations: List[Dict[str, Any]] = [] + raw_relations = payload.get("corrected_relations") + if isinstance(raw_relations, list): + for item in raw_relations: + if not isinstance(item, dict): + continue + subject = str(item.get("subject", "") or "").strip() + predicate = str(item.get("predicate", "") or "").strip() + obj = str(item.get("object", "") or "").strip() + if not (subject and predicate and obj): + continue + try: + rel_conf = float(item.get("confidence", 1.0) or 1.0) + except (TypeError, ValueError): + rel_conf = 1.0 + corrected_relations.append( + { + "subject": subject, + "predicate": predicate, + "object": obj, + "confidence": min(1.0, max(0.0, rel_conf)), + } + ) + corrected_relations = corrected_relations[:6] + + return { + "decision": decision, + "confidence": confidence, + "target_hashes": target_hashes, + "corrected_relations": corrected_relations, + "reason": str(payload.get("reason", "") or "").strip(), + "raw": payload, + } + + @staticmethod + def _feedback_apply_result_status(apply_result: Dict[str, Any]) -> str: + if bool(apply_result.get("applied")): + return "applied" + + reason = str(apply_result.get("reason", "") or "").strip().lower() + if reason in {"low_confidence", "no_relation_targets"} or reason.startswith("decision_"): + return "skipped" + return "error" + + def _restore_feedback_relations_from_snapshots( + self, + *, + task_id: int, + query_tool_id: str, + relation_hashes: Sequence[str], + snapshots: Dict[str, Dict[str, Any]], + current_statuses: Optional[Dict[str, Dict[str, Any]]] = None, + reason: str, + ) -> Dict[str, List[str]]: + assert self.metadata_store is not None + + restored_hashes: List[str] = [] + failed_hashes: List[str] = [] + status_map = current_statuses if isinstance(current_statuses, dict) else {} + + for relation_hash in self._tokens(relation_hashes): + snapshot = snapshots.get(relation_hash) if isinstance(snapshots, dict) else None + if not isinstance(snapshot, dict) or not snapshot: + failed_hashes.append(relation_hash) + continue + + after_status = self.metadata_store.restore_relation_status_from_snapshot(relation_hash, snapshot) + if after_status is None: + failed_hashes.append(relation_hash) + continue + + restored_hashes.append(relation_hash) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="compensate_restore_relation", + target_hash=relation_hash, + before_payload=status_map.get(relation_hash, {}), + after_payload=after_status, + reason=reason, + ) + + if restored_hashes or failed_hashes: + self._rebuild_graph_from_metadata() + self._persist() + + return { + "restored_hashes": restored_hashes, + "failed_hashes": failed_hashes, + } + + async def _ingest_feedback_relations( + self, + *, + query_tool_id: str, + session_id: str, + relation_hashes: List[str], + corrected_relations: List[Dict[str, Any]], + ) -> Dict[str, Any]: + supersedes_hash = relation_hashes[0] if relation_hashes else "" + relation_rows: List[Dict[str, Any]] = [] + for row in corrected_relations: + relation_rows.append( + { + "subject": str(row.get("subject", "") or "").strip(), + "predicate": str(row.get("predicate", "") or "").strip(), + "object": str(row.get("object", "") or "").strip(), + "confidence": float(row.get("confidence", 1.0) or 1.0), + "metadata": { + "supersedes_hash": supersedes_hash, + "supersedes_hashes": relation_hashes, + "from_query_tool_id": query_tool_id, + "feedback_window": self._feedback_cfg_window_label(), + }, + } + ) + plain_text = ";".join( + f"{item['subject']} {item['predicate']} {item['object']}" + for item in relation_rows + if item.get("subject") and item.get("predicate") and item.get("object") + ) + external_id = compute_hash( + "feedback_correction:" + + query_tool_id + + ":" + + json.dumps(relation_rows, ensure_ascii=False, sort_keys=True) + ) + payload = await self.ingest_text( + external_id=external_id, + source_type="chat_summary", + text=plain_text, + chat_id=session_id, + relations=relation_rows, + metadata={ + "from_query_tool_id": query_tool_id, + "feedback_window": self._feedback_cfg_window_label(), + "supersedes_hashes": relation_hashes, + "feedback_correction_source": True, + }, + respect_filter=False, + ) + if isinstance(payload, dict): + stored_ids = self._tokens(payload.get("stored_ids")) + corrected_relation_hashes = stored_ids[1:] + payload["external_id"] = external_id + payload["source"] = self._chat_source(session_id) + payload["paragraph_hashes"] = stored_ids[:1] + payload["corrected_relation_hashes"] = corrected_relation_hashes + base_success = bool(payload.get("success")) if "success" in payload else True + payload["success"] = base_success and bool(corrected_relation_hashes) + if not payload["success"] and not str(payload.get("error", "") or "").strip(): + payload["error"] = "missing_corrected_relations" + return payload + return {"success": False, "error": "invalid_ingest_payload"} + + async def _apply_feedback_decision( + self, + *, + task_id: int, + query_tool_id: str, + session_id: str, + decision: Dict[str, Any], + hit_map: Dict[str, Dict[str, Any]], + ) -> Dict[str, Any]: + threshold = self._feedback_cfg_auto_apply_threshold() + confidence = float(decision.get("confidence", 0.0) or 0.0) + if confidence < threshold: + return { + "applied": False, + "reason": "low_confidence", + "threshold": threshold, + "confidence": confidence, + } + + decision_type = str(decision.get("decision", "none") or "none").strip().lower() + if decision_type not in {"reject", "correct"}: + return { + "applied": False, + "reason": f"decision_{decision_type}_no_auto_apply", + } + + target_hashes = [ + str(item or "").strip() + for item in (decision.get("target_hashes") or []) + if str(item or "").strip() + ] + relation_hashes = self._resolve_feedback_relation_hashes( + target_hashes=target_hashes, + hit_map=hit_map, + ) + if not relation_hashes: + return { + "applied": False, + "reason": "no_relation_targets", + } + + corrected_relations = [ + dict(item) + for item in (decision.get("corrected_relations") or []) + if isinstance(item, dict) + ] + if decision_type == "correct" and not corrected_relations: + return { + "applied": False, + "reason": "missing_corrected_relations", + "relation_hashes": relation_hashes, + "stale_paragraph_hashes": [], + "episode_rebuild_sources": [], + "profile_refresh_person_ids": [], + "rollback_plan_summary": {}, + } + + assert self.metadata_store is not None + old_relation_rows = self._query_relation_rows_by_hashes(relation_hashes, include_inactive=True) + before_status = self.metadata_store.get_relation_status_batch(relation_hashes) + forget_result = self._apply_v5_relation_action(action="forget", hashes=relation_hashes, strength=1.0) + forget_success = bool(forget_result.get("success")) + after_status = self.metadata_store.get_relation_status_batch(relation_hashes) + for hash_value in relation_hashes: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="forget_relation", + target_hash=hash_value, + before_payload=before_status.get(hash_value) if isinstance(before_status, dict) else {}, + after_payload=after_status.get(hash_value) if isinstance(after_status, dict) else {}, + reason=str(decision.get("reason", "") or ""), + ) + + ingest_result = None + corrected_relation_hash_candidates: List[str] = [] + corrected_relation_specs_by_hash: Dict[str, Dict[str, Any]] = {} + if decision_type == "correct" and corrected_relations and self.metadata_store is not None: + for item in corrected_relations: + try: + relation_hash = self.metadata_store.compute_relation_hash( + str(item.get("subject", "") or "").strip(), + str(item.get("predicate", "") or "").strip(), + str(item.get("object", "") or "").strip(), + ) + except Exception: + continue + if not relation_hash: + continue + corrected_relation_hash_candidates.append(relation_hash) + corrected_relation_specs_by_hash[relation_hash] = { + "subject": str(item.get("subject", "") or "").strip(), + "predicate": str(item.get("predicate", "") or "").strip(), + "object": str(item.get("object", "") or "").strip(), + } + corrected_relation_before_status = ( + self.metadata_store.get_relation_status_batch(corrected_relation_hash_candidates) + if corrected_relation_hash_candidates + else {} + ) + if not forget_success: + return { + "applied": False, + "reason": "forget_failed", + "error": str(forget_result.get("error", "") or "forget_failed"), + "forget": forget_result, + "ingest": ingest_result, + "relation_hashes": relation_hashes, + "stale_paragraph_hashes": [], + "episode_rebuild_sources": [], + "profile_refresh_person_ids": [], + "rollback_plan_summary": {}, + } + + stale_paragraph_map: Dict[str, List[str]] = {} + stale_paragraph_hashes: List[str] = [] + episode_rebuild_sources: List[str] = [] + profile_refresh_person_ids: List[str] = [] + rollback_plan: Dict[str, Any] = {} + if decision_type == "correct" and corrected_relations: + ingest_result = await self._ingest_feedback_relations( + query_tool_id=query_tool_id, + session_id=session_id, + relation_hashes=relation_hashes, + corrected_relations=corrected_relations, + ) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="ingest_correction", + target_hash=relation_hashes[0] if relation_hashes else "", + before_payload={"target_hashes": relation_hashes}, + after_payload=ingest_result, + reason=str(decision.get("reason", "") or ""), + ) + + ingest_success = bool((ingest_result or {}).get("success")) if isinstance(ingest_result, dict) else False + if not ingest_success: + compensation_result = self._restore_feedback_relations_from_snapshots( + task_id=task_id, + query_tool_id=query_tool_id, + relation_hashes=relation_hashes, + snapshots=before_status if isinstance(before_status, dict) else {}, + current_statuses=after_status if isinstance(after_status, dict) else {}, + reason=str(decision.get("reason", "") or "") or "feedback_correction_ingest_failed", + ) + restore_failed_hashes = compensation_result.get("failed_hashes", []) + return { + "applied": False, + "reason": "correction_restore_failed" if restore_failed_hashes else "correction_ingest_failed", + "error": str((ingest_result or {}).get("error", "") or "correction_ingest_failed"), + "forget": forget_result, + "ingest": ingest_result, + "relation_hashes": relation_hashes, + "stale_paragraph_hashes": [], + "episode_rebuild_sources": [], + "profile_refresh_person_ids": [], + "restored_relation_hashes": compensation_result.get("restored_hashes", []), + "restore_failed_hashes": restore_failed_hashes, + "rollback_plan_summary": {}, + } + else: + ingest_success = False + + applied = forget_success if decision_type == "reject" else (forget_success and ingest_success) + if applied: + stale_paragraph_map = self._mark_feedback_stale_paragraphs( + task_id=task_id, + query_tool_id=query_tool_id, + relation_hashes=relation_hashes, + reason=str(decision.get("reason", "") or "") or "feedback_correction", + ) + stale_paragraph_hashes = self._merge_tokens( + *[ + paragraph_hashes + for paragraph_hashes in stale_paragraph_map.values() + if isinstance(paragraph_hashes, list) + ] + ) + episode_rebuild_sources = self._enqueue_feedback_episode_rebuilds( + paragraph_hashes=stale_paragraph_hashes, + session_id=session_id, + include_correction_source=bool(ingest_success), + ) + profile_refresh_person_ids = self._enqueue_feedback_profile_refreshes( + person_ids=self._resolve_feedback_related_person_ids( + old_relation_rows=old_relation_rows, + corrected_relations=corrected_relations, + ), + query_tool_id=query_tool_id, + ) + for relation_hash, paragraph_hashes in stale_paragraph_map.items(): + for paragraph_hash in paragraph_hashes: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="mark_stale_paragraph", + target_hash=paragraph_hash, + after_payload={"relation_hash": relation_hash}, + reason=str(decision.get("reason", "") or ""), + ) + for source in episode_rebuild_sources: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="enqueue_episode_rebuild", + target_hash=source, + reason=str(decision.get("reason", "") or ""), + ) + for person_id in profile_refresh_person_ids: + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="enqueue_profile_refresh", + target_hash=person_id, + reason=str(decision.get("reason", "") or ""), + ) + forgotten_relations = [] + for row in old_relation_rows: + relation_hash = str(row.get("hash", "") or "").strip() + if not relation_hash: + continue + forgotten_relations.append( + { + "hash": relation_hash, + "subject": str(row.get("subject", "") or "").strip(), + "predicate": str(row.get("predicate", "") or "").strip(), + "object": str(row.get("object", "") or "").strip(), + "before_status": before_status.get(relation_hash) if isinstance(before_status, dict) else {}, + } + ) + + corrected_write: Dict[str, Any] = {} + if isinstance(ingest_result, dict): + stored_relation_hashes = self._tokens(ingest_result.get("corrected_relation_hashes")) + corrected_write = { + "external_id": str(ingest_result.get("external_id", "") or "").strip(), + "source": str(ingest_result.get("source", "") or "").strip(), + "paragraph_hashes": self._tokens(ingest_result.get("paragraph_hashes")), + "corrected_relation_hashes": stored_relation_hashes, + "corrected_relations": [ + { + "hash": relation_hash, + **corrected_relation_specs_by_hash.get(relation_hash, {}), + "existed_before": relation_hash in corrected_relation_before_status, + "before_status": corrected_relation_before_status.get(relation_hash, {}), + } + for relation_hash in stored_relation_hashes + ], + } + + rollback_plan = { + "task_id": task_id, + "query_tool_id": query_tool_id, + "session_id": session_id, + "decision_type": decision_type, + "forgotten_relations": forgotten_relations, + "corrected_write": corrected_write, + "stale_marks": [ + {"paragraph_hash": paragraph_hash, "relation_hash": relation_hash} + for relation_hash, paragraph_hashes in stale_paragraph_map.items() + for paragraph_hash in (paragraph_hashes or []) + if str(paragraph_hash or "").strip() + ], + "episode_sources": episode_rebuild_sources, + "profile_person_ids": profile_refresh_person_ids, + "created_at": time.time(), + } + update_rollback_plan = getattr(self.metadata_store, "update_feedback_task_rollback_plan", None) + if callable(update_rollback_plan): + update_rollback_plan( + task_id=task_id, + rollback_plan=rollback_plan, + ) + return { + "applied": applied, + "forget": forget_result, + "ingest": ingest_result, + "relation_hashes": relation_hashes, + "stale_paragraph_hashes": stale_paragraph_hashes, + "episode_rebuild_sources": episode_rebuild_sources, + "profile_refresh_person_ids": profile_refresh_person_ids, + "rollback_plan_summary": self._build_feedback_rollback_plan_summary(rollback_plan) if rollback_plan else {}, + } + + def _resolve_feedback_relation_hashes( + self, + *, + target_hashes: Sequence[str], + hit_map: Dict[str, Dict[str, Any]], + ) -> List[str]: + resolved: List[str] = [] + seen: set[str] = set() + for target_hash in target_hashes: + token = str(target_hash or "").strip() + if not token: + continue + hit = hit_map.get(token) if isinstance(hit_map, dict) else None + item_type = str((hit or {}).get("type", "") or "").strip() + if item_type == "relation": + if token not in seen: + seen.add(token) + resolved.append(token) + continue + if item_type != "paragraph": + continue + + linked_candidates = self._tokens((hit or {}).get("linked_relation_hashes")) + if not linked_candidates and self.metadata_store is not None: + for relation in self.metadata_store.get_paragraph_relations(token): + linked_hash = str(relation.get("hash", "") or "").strip() + if linked_hash: + linked_candidates.append(linked_hash) + + for linked_hash in linked_candidates: + if linked_hash in seen: + continue + seen.add(linked_hash) + resolved.append(linked_hash) + return resolved + + async def _process_feedback_task(self, task: Dict[str, Any]) -> None: + task_id = int(task.get("id") or 0) + query_tool_id = str(task.get("query_tool_id", "") or "").strip() + if task_id <= 0 or not query_tool_id: + return + + assert self.metadata_store is not None + self.metadata_store.mark_feedback_task_running(task_id) + + decision_payload: Dict[str, Any] = {} + session_id = str(task.get("session_id", "") or "").strip() + try: + structured = task.get("query_snapshot") if isinstance(task.get("query_snapshot"), dict) else {} + if not session_id: + session_id = str(structured.get("chat_id", "") or "").strip() + if not session_id: + raise RuntimeError("反馈任务缺少 session_id") + hits_raw = structured.get("hits") + if not isinstance(hits_raw, list) or not hits_raw: + decision_payload = {"decision": "none", "confidence": 1.0, "reason": "no_hits"} + self.metadata_store.finalize_feedback_task( + task_id=task_id, + status="skipped", + decision_payload=decision_payload, + ) + return + + query_timestamp = self._coerce_datetime(task.get("query_timestamp")) or datetime.now() + due_at = self._coerce_datetime(task.get("due_at")) or ( + query_timestamp + timedelta(hours=self._feedback_cfg_window_hours()) + ) + if due_at <= query_timestamp: + due_at = query_timestamp + timedelta(hours=self._feedback_cfg_window_hours()) + + feedback_messages = self._extract_feedback_messages( + session_id=session_id, + query_time=query_timestamp, + due_time=due_at, + max_messages=self._feedback_cfg_max_messages(), + ) + if not feedback_messages: + decision_payload = {"decision": "none", "confidence": 1.0, "reason": "no_feedback_messages"} + self.metadata_store.finalize_feedback_task( + task_id=task_id, + status="skipped", + decision_payload=decision_payload, + ) + return + + if self._feedback_cfg_prefilter_enabled() and not self._should_invoke_feedback_classifier(feedback_messages): + decision_payload = {"decision": "none", "confidence": 1.0, "reason": "prefilter_skipped"} + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="skip", + reason="prefilter_skipped", + after_payload={"feedback_messages": feedback_messages}, + ) + self.metadata_store.finalize_feedback_task( + task_id=task_id, + status="skipped", + decision_payload=decision_payload, + ) + return + + hit_briefs = self._build_feedback_hit_briefs(hits_raw) + hit_map = {str(item.get("hash", "") or "").strip(): item for item in hit_briefs if str(item.get("hash", "") or "").strip()} + raw_decision = await self._classify_feedback( + query_tool_id=query_tool_id, + query_text=str(structured.get("query", "") or ""), + hit_briefs=hit_briefs, + feedback_messages=feedback_messages, + ) + decision_payload = self._normalize_feedback_decision(raw_decision, hit_hashes=list(hit_map.keys())) + decision_payload["feedback_message_count"] = len(feedback_messages) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="classification", + after_payload=decision_payload, + reason=str(decision_payload.get("reason", "") or ""), + ) + + apply_result = await self._apply_feedback_decision( + task_id=task_id, + query_tool_id=query_tool_id, + session_id=session_id, + decision=decision_payload, + hit_map=hit_map, + ) + decision_payload["apply_result"] = apply_result + final_status = self._feedback_apply_result_status(apply_result) + self.metadata_store.finalize_feedback_task( + task_id=task_id, + status=final_status, + decision_payload=decision_payload, + last_error=str(apply_result.get("error", "") or "") if final_status == "error" else "", + ) + except Exception as exc: + logger.warning(f"反馈纠错任务处理失败: task_id={task_id} err={exc}", exc_info=True) + self.metadata_store.append_feedback_action_log( + task_id=task_id, + query_tool_id=query_tool_id, + action_type="error", + reason=str(exc), + after_payload=decision_payload if decision_payload else None, + ) + self.metadata_store.finalize_feedback_task( + task_id=task_id, + status="error", + decision_payload=decision_payload if decision_payload else None, + last_error=str(exc), + ) + + async def _feedback_correction_loop(self) -> None: + try: + while not self._background_stopping: + interval_seconds = self._feedback_cfg_check_interval_seconds() + if not self._feedback_cfg_enabled(): + await asyncio.sleep(interval_seconds) + continue + if self.metadata_store is None: + await asyncio.sleep(interval_seconds) + continue + tasks = self.metadata_store.fetch_due_feedback_tasks( + limit=self._feedback_cfg_batch_size(), + now=datetime.now().timestamp(), + ) + if not tasks: + await asyncio.sleep(interval_seconds) + continue + for task in tasks: + if self._background_stopping: + break + if not isinstance(task, dict): + continue + await self._process_feedback_task(task) + await asyncio.sleep(2.0) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"feedback_correction loop 异常: {exc}") + async def _memory_maintenance_loop(self) -> None: try: while not self._background_stopping: @@ -2506,18 +4257,24 @@ class SDKMemoryKernel: ] ).strip() - def _query_relation_rows_by_hashes(self, relation_hashes: Sequence[str]) -> List[Dict[str, Any]]: + def _query_relation_rows_by_hashes( + self, + relation_hashes: Sequence[str], + *, + include_inactive: bool = False, + ) -> List[Dict[str, Any]]: assert self.metadata_store is not None hashes = [str(item or "").strip() for item in relation_hashes if str(item or "").strip()] if not hashes: return [] placeholders = ",".join(["?"] * len(hashes)) + inactive_clause = "" if include_inactive else "AND (is_inactive IS NULL OR is_inactive = 0)" rows = self.metadata_store.query( f""" SELECT hash, subject, predicate, object, confidence, created_at, source_paragraph FROM relations WHERE hash IN ({placeholders}) - AND (is_inactive IS NULL OR is_inactive = 0) + {inactive_clause} """, tuple(hashes), ) @@ -2645,6 +4402,16 @@ class SDKMemoryKernel: paragraph_hash = str(row.get("hash", "") or "").strip() entities = self.metadata_store.get_paragraph_entities(paragraph_hash) relations = self.metadata_store.get_paragraph_relations(paragraph_hash) + stale_marks_map, stale_status_map = self._load_paragraph_stale_marks([paragraph_hash]) + stale_marks = [ + { + **mark, + "relation_inactive": self._relation_status_is_inactive( + stale_status_map.get(str(mark.get("relation_hash", "") or "").strip()) + ), + } + for mark in stale_marks_map.get(paragraph_hash, []) + ] return { "hash": paragraph_hash, "content": str(row.get("content", "") or ""), @@ -2663,6 +4430,8 @@ class SDKMemoryKernel: ) for relation in relations ], + "is_stale": bool(stale_marks), + "stale_relation_marks": stale_marks, } @staticmethod @@ -3379,6 +5148,86 @@ class SDKMemoryKernel: filtered.append(item) return filtered or hits + def _filter_active_relation_hits(self, hits: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if self.metadata_store is None: + return hits + relation_hashes: List[str] = [] + paragraph_relation_cache: Dict[str, List[str]] = {} + paragraph_hashes: List[str] = [] + seen_relation_hashes: set[str] = set() + + for item in hits: + item_type = str(item.get("type", "") or "").strip() + item_hash = str(item.get("hash", "") or "").strip() + if item_type == "relation" and item_hash and item_hash not in seen_relation_hashes: + seen_relation_hashes.add(item_hash) + relation_hashes.append(item_hash) + continue + if item_type != "paragraph" or not item_hash: + continue + paragraph_hashes.append(item_hash) + linked_relations = self.metadata_store.get_paragraph_relations(item_hash) + linked_hashes: List[str] = [] + for relation in linked_relations: + linked_hash = str(relation.get("hash", "") or "").strip() + if not linked_hash or linked_hash in seen_relation_hashes: + continue + seen_relation_hashes.add(linked_hash) + relation_hashes.append(linked_hash) + linked_hashes.append(linked_hash) + if linked_hashes: + paragraph_relation_cache[item_hash] = linked_hashes + + marks_by_paragraph, _ = self._load_paragraph_stale_marks(paragraph_hashes) + stale_relation_hashes = self._tokens( + mark.get("relation_hash", "") + for marks in marks_by_paragraph.values() + for mark in marks + if isinstance(mark, dict) + ) + for relation_hash in stale_relation_hashes: + if relation_hash in seen_relation_hashes: + continue + seen_relation_hashes.add(relation_hash) + relation_hashes.append(relation_hash) + + if not relation_hashes and not marks_by_paragraph: + return hits + + status_map = self.metadata_store.get_relation_status_batch(relation_hashes) + filtered: List[Dict[str, Any]] = [] + for item in hits: + item_type = str(item.get("type", "") or "").strip() + if item_type == "paragraph": + paragraph_hash = str(item.get("hash", "") or "").strip() + if self._paragraph_hidden_by_stale_marks( + paragraph_hash, + marks_by_paragraph=marks_by_paragraph, + relation_status_map=status_map, + ): + continue + linked_hashes = paragraph_relation_cache.get(paragraph_hash, []) + if not linked_hashes: + filtered.append(item) + continue + if any( + not bool((status_map.get(linked_hash) or {}).get("is_inactive")) + for linked_hash in linked_hashes + ): + filtered.append(item) + continue + if item_type != "relation": + filtered.append(item) + continue + hash_value = str(item.get("hash", "") or "").strip() + status = status_map.get(hash_value) if isinstance(status_map, dict) else None + if status is None: + continue + if bool(status.get("is_inactive")): + continue + filtered.append(item) + return filtered + def _resolve_relation_hashes(self, target: str) -> List[str]: assert self.metadata_store token = str(target or "").strip() diff --git a/src/A_memorix/core/storage/metadata_store.py b/src/A_memorix/core/storage/metadata_store.py index e70cba4a..b79c23e6 100644 --- a/src/A_memorix/core/storage/metadata_store.py +++ b/src/A_memorix/core/storage/metadata_store.py @@ -11,7 +11,7 @@ import uuid import re from datetime import datetime from pathlib import Path -from typing import Optional, Union, List, Dict, Any, Tuple +from typing import Optional, Union, List, Dict, Any, Tuple, Sequence from src.common.logger import get_logger from ..utils.hash import compute_hash, normalize_text @@ -34,7 +34,8 @@ except Exception: logger = get_logger("A_Memorix.MetadataStore") -SCHEMA_VERSION = 9 +SCHEMA_VERSION = 10 +RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION = 9 class MetadataStore: @@ -130,7 +131,7 @@ class MetadataStore: logger.warning(f"初始化 FTS schema 失败,将跳过 BM25 检索: {e}") def _assert_schema_compatible(self, db_existed: bool) -> None: - """vNext 运行时只做 schema 版本校验,不做隐式迁移。""" + """运行时执行 post-1.0 自动迁移;legacy/vNext 仍要求离线迁移。""" cursor = self._conn.cursor() cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" @@ -147,12 +148,36 @@ class MetadataStore: cursor.execute("SELECT MAX(version) FROM schema_migrations") row = cursor.fetchone() version = int(row[0]) if row and row[0] is not None else 0 + if version < SCHEMA_VERSION and version >= RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION: + self._run_runtime_auto_migration(current_version=version) + cursor.execute("SELECT MAX(version) FROM schema_migrations") + row = cursor.fetchone() + version = int(row[0]) if row and row[0] is not None else 0 if version != SCHEMA_VERSION: raise RuntimeError( f"metadata schema 版本不匹配: current={version}, expected={SCHEMA_VERSION}。" " 请执行 scripts/release_vnext_migrate.py migrate。" ) + def _run_runtime_auto_migration(self, *, current_version: int) -> None: + """对 1.0 之后的已版本化库执行轻量自动迁移。""" + logger.info( + "检测到 metadata schema 需要运行时自动迁移: current=%s, target=%s", + current_version, + SCHEMA_VERSION, + ) + self._migrate_schema() + alias_result = self.rebuild_relation_hash_aliases() + knowledge_type_result = self.normalize_paragraph_knowledge_types() + self.set_schema_version(SCHEMA_VERSION) + logger.info( + "metadata schema 运行时自动迁移完成: %s -> %s, alias_inserted=%s, knowledge_normalized=%s", + current_version, + SCHEMA_VERSION, + int(alias_result.get("inserted", 0) or 0), + int(knowledge_type_result.get("normalized", 0) or 0), + ) + def close(self) -> None: """关闭数据库连接""" if self._conn: @@ -511,6 +536,129 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_paragraph_vector_backfill_status_updated ON paragraph_vector_backfill(status, updated_at) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_feedback_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + query_tool_id TEXT NOT NULL UNIQUE, + session_id TEXT NOT NULL, + query_timestamp REAL NOT NULL, + due_at REAL NOT NULL, + status TEXT DEFAULT 'pending', + attempt_count INTEGER DEFAULT 0, + query_snapshot_json TEXT, + decision_json TEXT, + last_error TEXT, + rollback_status TEXT DEFAULT 'none', + rollback_plan_json TEXT, + rollback_result_json TEXT, + rollback_error TEXT, + rollback_requested_by TEXT, + rollback_reason TEXT, + rollback_requested_at REAL, + rolled_back_at REAL, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_tasks_status_due + ON memory_feedback_tasks(status, due_at, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_tasks_session_query + ON memory_feedback_tasks(session_id, query_timestamp DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_feedback_action_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + query_tool_id TEXT NOT NULL, + action_type TEXT NOT NULL, + target_hash TEXT, + before_json TEXT, + after_json TEXT, + reason TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (task_id) REFERENCES memory_feedback_tasks(id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_task + ON memory_feedback_action_logs(task_id, created_at ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_query + ON memory_feedback_action_logs(query_tool_id, created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_target + ON memory_feedback_action_logs(target_hash) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_stale_relation_marks ( + paragraph_hash TEXT NOT NULL, + relation_hash TEXT NOT NULL, + query_tool_id TEXT, + task_id INTEGER, + reason TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + PRIMARY KEY (paragraph_hash, relation_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE, + FOREIGN KEY (task_id) REFERENCES memory_feedback_tasks(id) ON DELETE SET NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_paragraph + ON paragraph_stale_relation_marks(paragraph_hash, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_relation + ON paragraph_stale_relation_marks(relation_hash, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_updated + ON paragraph_stale_relation_marks(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_refresh_queue ( + person_id TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + reason TEXT, + source_query_tool_id TEXT, + retry_count INTEGER DEFAULT 0, + last_error TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_status_updated + ON person_profile_refresh_queue(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested + ON person_profile_refresh_queue(requested_at DESC) + """) + cursor.execute("PRAGMA table_info(memory_feedback_tasks)") + feedback_task_columns = {row[1] for row in cursor.fetchall()} + feedback_task_migrations = { + "rollback_status": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_status TEXT DEFAULT 'none'", + "rollback_plan_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_plan_json TEXT", + "rollback_result_json": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_result_json TEXT", + "rollback_error": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_error TEXT", + "rollback_requested_by": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_by TEXT", + "rollback_reason": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_reason TEXT", + "rollback_requested_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rollback_requested_at REAL", + "rolled_back_at": "ALTER TABLE memory_feedback_tasks ADD COLUMN rolled_back_at REAL", + } + for col, sql in feedback_task_migrations.items(): + if col not in feedback_task_columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (memory_feedback_tasks.{col}): {e}") cursor.execute(""" CREATE TABLE IF NOT EXISTS external_memory_refs ( external_id TEXT PRIMARY KEY, @@ -700,6 +848,111 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_paragraph_vector_backfill_status_updated ON paragraph_vector_backfill(status, updated_at) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_feedback_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + query_tool_id TEXT NOT NULL UNIQUE, + session_id TEXT NOT NULL, + query_timestamp REAL NOT NULL, + due_at REAL NOT NULL, + status TEXT DEFAULT 'pending', + attempt_count INTEGER DEFAULT 0, + query_snapshot_json TEXT, + decision_json TEXT, + last_error TEXT, + rollback_status TEXT DEFAULT 'none', + rollback_plan_json TEXT, + rollback_result_json TEXT, + rollback_error TEXT, + rollback_requested_by TEXT, + rollback_reason TEXT, + rollback_requested_at REAL, + rolled_back_at REAL, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_tasks_status_due + ON memory_feedback_tasks(status, due_at, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_tasks_session_query + ON memory_feedback_tasks(session_id, query_timestamp DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_feedback_action_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_id INTEGER NOT NULL, + query_tool_id TEXT NOT NULL, + action_type TEXT NOT NULL, + target_hash TEXT, + before_json TEXT, + after_json TEXT, + reason TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (task_id) REFERENCES memory_feedback_tasks(id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_task + ON memory_feedback_action_logs(task_id, created_at ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_query + ON memory_feedback_action_logs(query_tool_id, created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_feedback_action_logs_target + ON memory_feedback_action_logs(target_hash) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_stale_relation_marks ( + paragraph_hash TEXT NOT NULL, + relation_hash TEXT NOT NULL, + query_tool_id TEXT, + task_id INTEGER, + reason TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL, + PRIMARY KEY (paragraph_hash, relation_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE, + FOREIGN KEY (task_id) REFERENCES memory_feedback_tasks(id) ON DELETE SET NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_paragraph + ON paragraph_stale_relation_marks(paragraph_hash, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_relation + ON paragraph_stale_relation_marks(relation_hash, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_stale_relation_marks_updated + ON paragraph_stale_relation_marks(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_refresh_queue ( + person_id TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + reason TEXT, + source_query_tool_id TEXT, + retry_count INTEGER DEFAULT 0, + last_error TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_status_updated + ON person_profile_refresh_queue(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_refresh_queue_requested + ON person_profile_refresh_queue(requested_at DESC) + """) cursor.execute(""" CREATE TABLE IF NOT EXISTS external_memory_refs ( external_id TEXT PRIMARY KEY, @@ -1464,6 +1717,7 @@ class MetadataStore: match_query: str, limit: int = 20, max_doc_len: int = 512, + include_inactive: bool = True, conn: Optional[sqlite3.Connection] = None, ) -> List[Dict[str, Any]]: """使用 FTS5 + bm25 执行关系全文检索。""" @@ -1472,9 +1726,10 @@ class MetadataStore: c = self._resolve_conn(conn) cur = c.cursor() + active_clause = "" if include_inactive else " AND (r.is_inactive IS NULL OR r.is_inactive = 0)" try: cur.execute( - """ + f""" SELECT r.hash, r.subject, @@ -1484,6 +1739,7 @@ class MetadataStore: FROM relations_fts JOIN relations r ON r.hash = relations_fts.relation_hash WHERE relations_fts MATCH ? + {active_clause} ORDER BY bm25_score ASC LIMIT ? """, @@ -1852,18 +2108,7 @@ class MetadataStore: Returns: 关系哈希值 """ - # 1. 规范化输入 - s_canon = self._canonicalize_name(subject) - p_canon = self._canonicalize_name(predicate) - o_canon = self._canonicalize_name(obj) - - if not all([s_canon, p_canon, o_canon]): - raise ValueError("Relation components cannot be empty") - - # 2. 计算组合哈希 - # 公式: md5(s|p|o) - relation_key = f"{s_canon}|{p_canon}|{o_canon}" - hash_value = compute_hash(relation_key) + hash_value = self.compute_relation_hash(subject, predicate, obj) now = datetime.now().timestamp() @@ -1906,6 +2151,23 @@ class MetadataStore: logger.warning(f"添加关系异常: {e}") return hash_value + def compute_relation_hash(self, subject: str, predicate: str, obj: str) -> str: + """ + 计算 relation 的稳定 hash,不执行写入。 + """ + # 1. 规范化输入 + s_canon = self._canonicalize_name(subject) + p_canon = self._canonicalize_name(predicate) + o_canon = self._canonicalize_name(obj) + + if not all([s_canon, p_canon, o_canon]): + raise ValueError("Relation components cannot be empty") + + # 2. 计算组合哈希 + # 公式: md5(s|p|o) + relation_key = f"{s_canon}|{p_canon}|{o_canon}" + return compute_hash(relation_key) + def link_paragraph_relation( self, paragraph_hash: str, @@ -2125,7 +2387,7 @@ class MetadataStore: return self._row_to_dict(row, "entity") return None - def get_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + def get_relation(self, hash_value: str, include_inactive: bool = True) -> Optional[Dict[str, Any]]: """ 获取关系 @@ -2136,9 +2398,22 @@ class MetadataStore: 关系信息字典,不存在则返回None """ cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM relations WHERE hash = ? - """, (hash_value,)) + if include_inactive: + cursor.execute( + """ + SELECT * FROM relations WHERE hash = ? + """, + (hash_value,), + ) + else: + cursor.execute( + """ + SELECT * FROM relations + WHERE hash = ? + AND (is_inactive IS NULL OR is_inactive = 0) + """, + (hash_value,), + ) row = cursor.fetchone() if row: @@ -2164,6 +2439,44 @@ class MetadataStore: return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + def get_paragraph_hashes_by_relation_hashes( + self, + relation_hashes: List[str], + ) -> Dict[str, List[str]]: + normalized: List[str] = [] + seen = set() + for item in relation_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return {} + + placeholders = ",".join(["?"] * len(normalized)) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT pr.relation_hash, pr.paragraph_hash + FROM paragraph_relations pr + JOIN paragraphs p ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash IN ({placeholders}) + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY pr.relation_hash ASC, p.updated_at DESC, p.created_at DESC, pr.paragraph_hash ASC + """, + tuple(normalized), + ) + grouped: Dict[str, List[str]] = {token: [] for token in normalized} + for row in cursor.fetchall(): + relation_hash = str(row["relation_hash"] or "").strip() + paragraph_hash = str(row["paragraph_hash"] or "").strip() + if not relation_hash or not paragraph_hash: + continue + if paragraph_hash not in grouped.setdefault(relation_hash, []): + grouped[relation_hash].append(paragraph_hash) + return grouped + def get_paragraph_entities(self, paragraph_hash: str) -> List[Dict[str, Any]]: """ 获取段落的所有实体 @@ -2217,6 +2530,7 @@ class MetadataStore: subject: Optional[str] = None, predicate: Optional[str] = None, object: Optional[str] = None, + include_inactive: bool = True, ) -> List[Dict[str, Any]]: """ 查询关系(大小写不敏感) @@ -2242,6 +2556,8 @@ class MetadataStore: if object: conditions.append("LOWER(object) = ?") params.append(self._canonicalize_name(object)) + if not include_inactive: + conditions.append("(is_inactive IS NULL OR is_inactive = 0)") sql = "SELECT * FROM relations" if conditions: @@ -3092,6 +3408,19 @@ class MetadataStore: cursor.execute("SELECT COUNT(*) FROM relations") stats["relation_count"] = cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM paragraph_stale_relation_marks") + stats["stale_paragraph_mark_count"] = cursor.fetchone()[0] + + cursor.execute( + "SELECT COUNT(*) FROM person_profile_refresh_queue WHERE status IN ('pending', 'running', 'failed')" + ) + stats["person_profile_refresh_pending_count"] = cursor.fetchone()[0] + + cursor.execute( + "SELECT COUNT(*) FROM person_profile_refresh_queue WHERE status = 'failed'" + ) + stats["person_profile_refresh_failed_count"] = cursor.fetchone()[0] + # 总词数 cursor.execute("SELECT SUM(word_count) FROM paragraphs") result = cursor.fetchone()[0] @@ -3860,6 +4189,46 @@ class MetadataStore: def restore_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: """兼容旧调用名:恢复关系。""" return self.restore_relation_metadata(hash_value) + + def restore_relation_status_from_snapshot( + self, + hash_value: str, + snapshot: Optional[Dict[str, Any]], + ) -> Optional[Dict[str, Any]]: + token = str(hash_value or "").strip() + if not token or not isinstance(snapshot, dict): + return None + + current = self.get_relation_status_batch([token]).get(token) + if current is None: + restored = self.restore_relation(token) + if restored is None: + return None + + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE relations + SET is_inactive = ?, + confidence = ?, + is_pinned = ?, + protected_until = ?, + last_reinforced = ?, + inactive_since = ? + WHERE hash = ? + """, + ( + 1 if bool(snapshot.get("is_inactive")) else 0, + float(snapshot.get("weight", 0.0) or 0.0), + 1 if bool(snapshot.get("is_pinned")) else 0, + self._as_optional_float(snapshot.get("protected_until")), + self._as_optional_float(snapshot.get("last_reinforced")), + self._as_optional_float(snapshot.get("inactive_since")), + token, + ), + ) + self._conn.commit() + return self.get_relation_status_batch([token]).get(token) def get_protected_relations_hashes(self) -> List[str]: """获取所有受保护关系的哈希 (Pinned 或 Protected Until > Now)""" @@ -4864,7 +5233,7 @@ class MetadataStore: "failed": failed, } - def get_live_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: + def get_live_paragraphs_by_source(self, source: str, *, exclude_stale: bool = False) -> List[Dict[str, Any]]: """获取指定 source 下所有 live paragraphs。""" token = self._normalize_episode_source(source) if not token: @@ -4880,7 +5249,35 @@ class MetadataStore: """, (token,), ) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + rows = [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + if not exclude_stale: + return rows + paragraph_hashes = [str(row.get("hash", "") or "").strip() for row in rows if str(row.get("hash", "") or "").strip()] + marks_by_paragraph = self.get_paragraph_stale_relation_marks_batch(paragraph_hashes) if paragraph_hashes else {} + relation_hashes: List[str] = [] + seen = set() + for marks in marks_by_paragraph.values(): + for mark in marks: + relation_hash = str(mark.get("relation_hash", "") or "").strip() + if not relation_hash or relation_hash in seen: + continue + seen.add(relation_hash) + relation_hashes.append(relation_hash) + status_map = self.get_relation_status_batch(relation_hashes) if relation_hashes else {} + + filtered: List[Dict[str, Any]] = [] + for row in rows: + paragraph_hash = str(row.get("hash", "") or "").strip() + marks = marks_by_paragraph.get(paragraph_hash, []) + if any( + status_map.get(str(mark.get("relation_hash", "") or "").strip()) is None + or bool((status_map.get(str(mark.get("relation_hash", "") or "").strip()) or {}).get("is_inactive")) + for mark in marks + if str(mark.get("relation_hash", "") or "").strip() + ): + continue + filtered.append(row) + return filtered def list_episode_sources_for_rebuild(self) -> List[str]: """列出全量重建涉及的 source(live paragraphs + stale episodes)。""" @@ -5354,6 +5751,838 @@ class MetadataStore: counts[status] = int(row["count"] or 0) return counts + def _feedback_task_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + data["query_snapshot"] = self._json_loads(data.pop("query_snapshot_json", None), {}) + data["decision_payload"] = self._json_loads(data.get("decision_json"), {}) + data["rollback_status"] = str(data.get("rollback_status", "") or "none").strip().lower() or "none" + data["rollback_plan"] = self._json_loads(data.pop("rollback_plan_json", None), {}) + data["rollback_result"] = self._json_loads(data.pop("rollback_result_json", None), {}) + data["rollback_error"] = str(data.get("rollback_error", "") or "").strip() + data["rollback_requested_by"] = str(data.get("rollback_requested_by", "") or "").strip() + data["rollback_reason"] = str(data.get("rollback_reason", "") or "").strip() + return data + + def _feedback_action_log_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + data["id"] = int(data.get("id", 0) or 0) + data["task_id"] = int(data.get("task_id", 0) or 0) + data["query_tool_id"] = str(data.get("query_tool_id", "") or "").strip() + data["action_type"] = str(data.get("action_type", "") or "").strip() + data["target_hash"] = str(data.get("target_hash", "") or "").strip() + data["reason"] = str(data.get("reason", "") or "").strip() + data["before_payload"] = self._json_loads(data.pop("before_json", None), {}) + data["after_payload"] = self._json_loads(data.pop("after_json", None), {}) + return data + + def get_feedback_task(self, query_tool_id: str) -> Optional[Dict[str, Any]]: + token = str(query_tool_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM memory_feedback_tasks + WHERE query_tool_id = ? + LIMIT 1 + """, + (token,), + ) + row = cursor.fetchone() + return self._feedback_task_row_to_dict(row) if row is not None else None + + def get_feedback_task_by_id(self, task_id: int) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM memory_feedback_tasks + WHERE id = ? + LIMIT 1 + """, + (int(task_id),), + ) + row = cursor.fetchone() + return self._feedback_task_row_to_dict(row) if row is not None else None + + def list_feedback_tasks( + self, + *, + limit: int = 50, + statuses: Optional[List[str]] = None, + rollback_statuses: Optional[List[str]] = None, + query: str = "", + ) -> List[Dict[str, Any]]: + safe_limit = max(1, int(limit or 50)) + params: List[Any] = [] + conditions: List[str] = [] + + normalized_statuses = [ + str(item or "").strip().lower() + for item in (statuses or []) + if str(item or "").strip().lower() in {"pending", "running", "applied", "skipped", "error"} + ] + if normalized_statuses: + placeholders = ",".join(["?"] * len(normalized_statuses)) + conditions.append(f"LOWER(COALESCE(status, '')) IN ({placeholders})") + params.extend(normalized_statuses) + + normalized_rollback_statuses = [ + str(item or "").strip().lower() + for item in (rollback_statuses or []) + if str(item or "").strip().lower() in {"none", "running", "rolled_back", "error"} + ] + if normalized_rollback_statuses: + placeholders = ",".join(["?"] * len(normalized_rollback_statuses)) + conditions.append(f"LOWER(COALESCE(rollback_status, 'none')) IN ({placeholders})") + params.extend(normalized_rollback_statuses) + + query_token = str(query or "").strip().lower() + if query_token: + like_value = f"%{query_token}%" + conditions.append( + """ + ( + LOWER(COALESCE(query_tool_id, '')) LIKE ? + OR LOWER(COALESCE(session_id, '')) LIKE ? + OR LOWER(COALESCE(query_snapshot_json, '')) LIKE ? + OR LOWER(COALESCE(decision_json, '')) LIKE ? + OR LOWER(COALESCE(last_error, '')) LIKE ? + OR LOWER(COALESCE(rollback_reason, '')) LIKE ? + OR LOWER(COALESCE(rollback_error, '')) LIKE ? + ) + """ + ) + params.extend([like_value] * 7) + + where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(safe_limit) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT * + FROM memory_feedback_tasks + {where_sql} + ORDER BY query_timestamp DESC, id DESC + LIMIT ? + """, + tuple(params), + ) + return [self._feedback_task_row_to_dict(row) for row in cursor.fetchall()] + + def enqueue_feedback_task( + self, + *, + query_tool_id: str, + session_id: str, + query_timestamp: float, + due_at: float, + query_snapshot: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + tool_token = str(query_tool_id or "").strip() + session_token = str(session_id or "").strip() + if not tool_token or not session_token: + return None + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT OR IGNORE INTO memory_feedback_tasks ( + query_tool_id, session_id, query_timestamp, due_at, status, attempt_count, + query_snapshot_json, decision_json, last_error, created_at, updated_at + ) VALUES (?, ?, ?, ?, 'pending', 0, ?, NULL, NULL, ?, ?) + """, + ( + tool_token, + session_token, + float(query_timestamp), + float(due_at), + self._json_dumps(query_snapshot or {}), + now, + now, + ), + ) + self._conn.commit() + return self.get_feedback_task(tool_token) + + def update_feedback_task_rollback_plan( + self, + *, + task_id: int, + rollback_plan: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE memory_feedback_tasks + SET rollback_plan_json = ?, + updated_at = ? + WHERE id = ? + """, + ( + self._json_dumps(rollback_plan or {}), + datetime.now().timestamp(), + int(task_id), + ), + ) + self._conn.commit() + return self.get_feedback_task_by_id(int(task_id)) + + def fetch_due_feedback_tasks( + self, + *, + limit: int = 20, + now: Optional[float] = None, + ) -> List[Dict[str, Any]]: + safe_limit = max(1, int(limit)) + now_ts = self._as_optional_float(now) + if now_ts is None: + now_ts = datetime.now().timestamp() + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM memory_feedback_tasks + WHERE due_at <= ? + AND status IN ('pending', 'running') + ORDER BY due_at ASC, id ASC + LIMIT ? + """, + (now_ts, safe_limit), + ) + return [self._feedback_task_row_to_dict(row) for row in cursor.fetchall()] + + def mark_feedback_task_running(self, task_id: int) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE memory_feedback_tasks + SET status = 'running', + attempt_count = COALESCE(attempt_count, 0) + 1, + updated_at = ? + WHERE id = ? + AND status IN ('pending', 'running') + """, + (now, int(task_id)), + ) + self._conn.commit() + cursor.execute( + """ + SELECT * + FROM memory_feedback_tasks + WHERE id = ? + LIMIT 1 + """, + (int(task_id),), + ) + row = cursor.fetchone() + return self._feedback_task_row_to_dict(row) if row is not None else None + + def finalize_feedback_task( + self, + *, + task_id: int, + status: str, + decision_payload: Optional[Dict[str, Any]] = None, + last_error: str = "", + ) -> Optional[Dict[str, Any]]: + final_status = str(status or "").strip().lower() + if final_status not in {"applied", "skipped", "error"}: + raise ValueError(f"不支持的反馈任务结束状态: {status}") + if int(task_id or 0) <= 0: + return None + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE memory_feedback_tasks + SET status = ?, + decision_json = ?, + last_error = ?, + updated_at = ? + WHERE id = ? + """, + ( + final_status, + self._json_dumps(decision_payload or {}), + str(last_error or "").strip() or None, + now, + int(task_id), + ), + ) + self._conn.commit() + cursor.execute( + """ + SELECT * + FROM memory_feedback_tasks + WHERE id = ? + LIMIT 1 + """, + (int(task_id),), + ) + row = cursor.fetchone() + return self._feedback_task_row_to_dict(row) if row is not None else None + + def mark_feedback_task_rollback_running( + self, + *, + task_id: int, + requested_by: str = "", + reason: str = "", + ) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE memory_feedback_tasks + SET rollback_status = 'running', + rollback_requested_by = ?, + rollback_reason = ?, + rollback_error = NULL, + rollback_requested_at = ?, + updated_at = ? + WHERE id = ? + AND LOWER(COALESCE(status, '')) = 'applied' + AND LOWER(COALESCE(rollback_status, 'none')) IN ('none', 'error') + """, + ( + str(requested_by or "").strip() or None, + str(reason or "").strip() or None, + now, + now, + int(task_id), + ), + ) + self._conn.commit() + if int(cursor.rowcount or 0) <= 0: + return None + return self.get_feedback_task_by_id(int(task_id)) + + def finalize_feedback_task_rollback( + self, + *, + task_id: int, + rollback_status: str, + rollback_result: Optional[Dict[str, Any]] = None, + rollback_error: str = "", + ) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + final_status = str(rollback_status or "").strip().lower() + if final_status not in {"none", "rolled_back", "error"}: + raise ValueError(f"不支持的反馈任务回退状态: {rollback_status}") + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE memory_feedback_tasks + SET rollback_status = ?, + rollback_result_json = ?, + rollback_error = ?, + rolled_back_at = CASE WHEN ? = 'rolled_back' THEN ? ELSE rolled_back_at END, + updated_at = ? + WHERE id = ? + """, + ( + final_status, + self._json_dumps(rollback_result or {}), + str(rollback_error or "").strip() or None, + final_status, + now, + now, + int(task_id), + ), + ) + self._conn.commit() + return self.get_feedback_task_by_id(int(task_id)) + + def append_feedback_action_log( + self, + *, + task_id: int, + query_tool_id: str, + action_type: str, + target_hash: str = "", + before_payload: Optional[Dict[str, Any]] = None, + after_payload: Optional[Dict[str, Any]] = None, + reason: str = "", + ) -> Optional[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return None + query_token = str(query_tool_id or "").strip() + if not query_token: + return None + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO memory_feedback_action_logs ( + task_id, query_tool_id, action_type, target_hash, + before_json, after_json, reason, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + int(task_id), + query_token, + str(action_type or "").strip() or "unknown", + str(target_hash or "").strip() or None, + self._json_dumps(before_payload) if isinstance(before_payload, dict) else None, + self._json_dumps(after_payload) if isinstance(after_payload, dict) else None, + str(reason or "").strip() or None, + now, + ), + ) + self._conn.commit() + return { + "id": int(cursor.lastrowid or 0), + "task_id": int(task_id), + "query_tool_id": query_token, + "action_type": str(action_type or "").strip() or "unknown", + "target_hash": str(target_hash or "").strip(), + "before_json": self._json_dumps(before_payload) if isinstance(before_payload, dict) else None, + "after_json": self._json_dumps(after_payload) if isinstance(after_payload, dict) else None, + "reason": str(reason or "").strip(), + "created_at": now, + } + + def list_feedback_action_logs(self, task_id: int) -> List[Dict[str, Any]]: + if int(task_id or 0) <= 0: + return [] + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT id, task_id, query_tool_id, action_type, target_hash, before_json, after_json, reason, created_at + FROM memory_feedback_action_logs + WHERE task_id = ? + ORDER BY id ASC + """, + (int(task_id),), + ) + return [self._feedback_action_log_row_to_dict(row) for row in cursor.fetchall()] + + def upsert_paragraph_stale_relation_mark( + self, + *, + paragraph_hash: str, + relation_hash: str, + query_tool_id: str = "", + task_id: Optional[int] = None, + reason: str = "", + ) -> Optional[Dict[str, Any]]: + paragraph_token = str(paragraph_hash or "").strip() + relation_token = str(relation_hash or "").strip() + if not paragraph_token or not relation_token: + return None + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO paragraph_stale_relation_marks ( + paragraph_hash, relation_hash, query_tool_id, task_id, reason, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(paragraph_hash, relation_hash) DO UPDATE SET + query_tool_id = excluded.query_tool_id, + task_id = excluded.task_id, + reason = excluded.reason, + updated_at = excluded.updated_at + """, + ( + paragraph_token, + relation_token, + str(query_tool_id or "").strip() or None, + int(task_id) if int(task_id or 0) > 0 else None, + str(reason or "").strip() or None, + now, + now, + ), + ) + self._conn.commit() + return { + "paragraph_hash": paragraph_token, + "relation_hash": relation_token, + "query_tool_id": str(query_tool_id or "").strip(), + "task_id": int(task_id or 0) if int(task_id or 0) > 0 else None, + "reason": str(reason or "").strip(), + "updated_at": now, + } + + def get_paragraph_stale_relation_marks_batch( + self, + paragraph_hashes: Sequence[str], + ) -> Dict[str, List[Dict[str, Any]]]: + normalized: List[str] = [] + seen = set() + for item in paragraph_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return {} + + placeholders = ",".join(["?"] * len(normalized)) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT paragraph_hash, relation_hash, query_tool_id, task_id, reason, created_at, updated_at + FROM paragraph_stale_relation_marks + WHERE paragraph_hash IN ({placeholders}) + ORDER BY updated_at DESC, paragraph_hash ASC, relation_hash ASC + """, + tuple(normalized), + ) + grouped: Dict[str, List[Dict[str, Any]]] = {token: [] for token in normalized} + for row in cursor.fetchall(): + payload = { + "paragraph_hash": str(row["paragraph_hash"] or "").strip(), + "relation_hash": str(row["relation_hash"] or "").strip(), + "query_tool_id": str(row["query_tool_id"] or "").strip(), + "task_id": int(row["task_id"] or 0) if row["task_id"] is not None else None, + "reason": str(row["reason"] or "").strip(), + "created_at": self._as_optional_float(row["created_at"]), + "updated_at": self._as_optional_float(row["updated_at"]), + } + grouped.setdefault(payload["paragraph_hash"], []).append(payload) + return grouped + + def count_paragraph_stale_relation_marks(self) -> int: + cursor = self._conn.cursor() + cursor.execute("SELECT COUNT(*) FROM paragraph_stale_relation_marks") + row = cursor.fetchone() + return int(row[0]) if row and row[0] is not None else 0 + + def delete_paragraph_stale_relation_marks( + self, + marks: Sequence[Tuple[str, str]], + ) -> int: + normalized: List[Tuple[str, str]] = [] + seen: set[Tuple[str, str]] = set() + for paragraph_hash, relation_hash in marks or []: + paragraph_token = str(paragraph_hash or "").strip() + relation_token = str(relation_hash or "").strip() + if not paragraph_token or not relation_token: + continue + key = (paragraph_token, relation_token) + if key in seen: + continue + seen.add(key) + normalized.append(key) + if not normalized: + return 0 + + cursor = self._conn.cursor() + deleted = 0 + for paragraph_hash, relation_hash in normalized: + cursor.execute( + """ + DELETE FROM paragraph_stale_relation_marks + WHERE paragraph_hash = ? AND relation_hash = ? + """, + (paragraph_hash, relation_hash), + ) + deleted += int(cursor.rowcount or 0) + self._conn.commit() + return deleted + + @staticmethod + def _person_profile_refresh_row_to_dict(row: Optional[sqlite3.Row]) -> Optional[Dict[str, Any]]: + if row is None: + return None + payload = dict(row) + payload["person_id"] = str(payload.get("person_id", "") or "").strip() + payload["status"] = str(payload.get("status", "") or "").strip().lower() or "pending" + payload["reason"] = str(payload.get("reason", "") or "").strip() + payload["source_query_tool_id"] = str(payload.get("source_query_tool_id", "") or "").strip() + payload["retry_count"] = int(payload.get("retry_count", 0) or 0) + payload["last_error"] = str(payload.get("last_error", "") or "").strip() + payload["requested_at"] = MetadataStore._as_optional_float(payload.get("requested_at")) + payload["updated_at"] = MetadataStore._as_optional_float(payload.get("updated_at")) + return payload + + def get_person_profile_refresh_request(self, person_id: str) -> Optional[Dict[str, Any]]: + token = str(person_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT person_id, status, reason, source_query_tool_id, retry_count, last_error, requested_at, updated_at + FROM person_profile_refresh_queue + WHERE person_id = ? + LIMIT 1 + """, + (token,), + ) + return self._person_profile_refresh_row_to_dict(cursor.fetchone()) + + def enqueue_person_profile_refresh( + self, + *, + person_id: str, + reason: str = "", + source_query_tool_id: str = "", + ) -> Optional[Dict[str, Any]]: + token = str(person_id or "").strip() + if not token: + return None + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_refresh_queue ( + person_id, status, reason, source_query_tool_id, retry_count, last_error, requested_at, updated_at + ) VALUES (?, 'pending', ?, ?, 0, NULL, ?, ?) + ON CONFLICT(person_id) DO UPDATE SET + status = 'pending', + reason = excluded.reason, + source_query_tool_id = excluded.source_query_tool_id, + last_error = NULL, + requested_at = excluded.requested_at, + updated_at = excluded.updated_at + """, + ( + token, + str(reason or "").strip() or None, + str(source_query_tool_id or "").strip() or None, + now, + now, + ), + ) + self._conn.commit() + return self.get_person_profile_refresh_request(token) + + def fetch_person_profile_refresh_batch( + self, + *, + limit: int = 20, + max_retry: int = 3, + ) -> List[Dict[str, Any]]: + safe_limit = max(1, int(limit)) + safe_retry = max(0, int(max_retry)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT person_id, status, reason, source_query_tool_id, retry_count, last_error, requested_at, updated_at + FROM person_profile_refresh_queue + WHERE status = 'pending' + OR (status = 'failed' AND retry_count < ?) + ORDER BY requested_at ASC, updated_at ASC + LIMIT ? + """, + (safe_retry, safe_limit), + ) + return [ + item + for item in ( + self._person_profile_refresh_row_to_dict(row) + for row in cursor.fetchall() + ) + if item is not None + ] + + def mark_person_profile_refresh_running( + self, + person_id: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + token = str(person_id or "").strip() + if not token: + return False + + now = datetime.now().timestamp() + params: List[Any] = [now, token] + sql = """ + UPDATE person_profile_refresh_queue + SET status = 'running', + updated_at = ? + WHERE person_id = ? + AND status IN ('pending', 'failed') + """ + if requested_at is not None: + sql += " AND requested_at = ?" + params.append(float(requested_at)) + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_person_profile_refresh_done( + self, + person_id: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + token = str(person_id or "").strip() + if not token: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE person_profile_refresh_queue + SET status = 'done', + last_error = NULL, + updated_at = ? + WHERE person_id = ? + """, + (now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE person_profile_refresh_queue + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'done' + END, + last_error = NULL, + updated_at = ? + WHERE person_id = ? + """, + (req_ts, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_person_profile_refresh_failed( + self, + person_id: str, + error: str = "", + *, + requested_at: Optional[float] = None, + ) -> bool: + token = str(person_id or "").strip() + if not token: + return False + + err_text = str(error or "").strip()[:500] + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE person_profile_refresh_queue + SET status = 'failed', + retry_count = COALESCE(retry_count, 0) + 1, + last_error = ?, + updated_at = ? + WHERE person_id = ? + """, + (err_text, now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE person_profile_refresh_queue + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'failed' + END, + retry_count = CASE + WHEN requested_at > ? THEN COALESCE(retry_count, 0) + ELSE COALESCE(retry_count, 0) + 1 + END, + last_error = CASE + WHEN requested_at > ? THEN NULL + ELSE ? + END, + updated_at = ? + WHERE person_id = ? + """, + (req_ts, req_ts, req_ts, err_text, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_person_profile_refresh_requests( + self, + *, + statuses: Optional[List[str]] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + safe_limit = max(1, int(limit)) + params: List[Any] = [] + conditions: List[str] = [] + normalized_statuses = [ + str(item or "").strip().lower() + for item in (statuses or []) + if str(item or "").strip().lower() in {"pending", "running", "done", "failed"} + ] + if normalized_statuses: + placeholders = ",".join(["?"] * len(normalized_statuses)) + conditions.append(f"status IN ({placeholders})") + params.extend(normalized_statuses) + + where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(safe_limit) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT person_id, status, reason, source_query_tool_id, retry_count, last_error, requested_at, updated_at + FROM person_profile_refresh_queue + {where_sql} + ORDER BY updated_at DESC, person_id ASC + LIMIT ? + """, + tuple(params), + ) + return [ + item + for item in ( + self._person_profile_refresh_row_to_dict(row) + for row in cursor.fetchall() + ) + if item is not None + ] + + def get_person_profile_refresh_summary(self, failed_limit: int = 20) -> Dict[str, Any]: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) AS cnt + FROM person_profile_refresh_queue + GROUP BY status + """ + ) + counts = {"pending": 0, "running": 0, "done": 0, "failed": 0, "total": 0} + for row in cursor.fetchall(): + status = str(row["status"] or "").strip().lower() + cnt = int(row["cnt"] or 0) + counts[status] = counts.get(status, 0) + cnt + counts["total"] += cnt + running = self.list_person_profile_refresh_requests(statuses=["running"], limit=20) + failed = self.list_person_profile_refresh_requests( + statuses=["failed"], + limit=max(1, int(failed_limit)), + ) + return { + "counts": counts, + "running": running, + "failed": failed, + } + def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: data = dict(row) diff --git a/src/A_memorix/core/utils/episode_service.py b/src/A_memorix/core/utils/episode_service.py index ca94dd96..9c66f9c0 100644 --- a/src/A_memorix/core/utils/episode_service.py +++ b/src/A_memorix/core/utils/episode_service.py @@ -18,6 +18,7 @@ from datetime import datetime from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger +from src.config.config import global_config from .episode_segmentation_service import EpisodeSegmentationService from .hash import compute_hash @@ -528,7 +529,11 @@ class EpisodeService: "paragraph_count": 0, } - paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) + memory_cfg = getattr(global_config, "memory", None) + paragraphs = self.metadata_store.get_live_paragraphs_by_source( + token, + exclude_stale=bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)), + ) if not paragraphs: replace_result = self.metadata_store.replace_episodes_for_source(token, []) return { diff --git a/src/A_memorix/core/utils/path_fallback_service.py b/src/A_memorix/core/utils/path_fallback_service.py index 7a802743..c8ef0be8 100644 --- a/src/A_memorix/core/utils/path_fallback_service.py +++ b/src/A_memorix/core/utils/path_fallback_service.py @@ -90,9 +90,9 @@ def find_paths_between_entities( else: pred = "related" direction = "->" - rels = metadata_store.get_relations(subject=u, object=v) + rels = metadata_store.get_relations(subject=u, object=v, include_inactive=False) if not rels: - rels = metadata_store.get_relations(subject=v, object=u) + rels = metadata_store.get_relations(subject=v, object=u, include_inactive=False) direction = "<-" if rels: best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) @@ -162,4 +162,3 @@ def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResul ) ) return converted - diff --git a/src/A_memorix/core/utils/person_profile_service.py b/src/A_memorix/core/utils/person_profile_service.py index 79797712..f0a7f8c2 100644 --- a/src/A_memorix/core/utils/person_profile_service.py +++ b/src/A_memorix/core/utils/person_profile_service.py @@ -15,6 +15,7 @@ from sqlmodel import select from src.common.logger import get_logger from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo +from src.config.config import global_config from ..embedding import EmbeddingAPIAdapter from ..retrieval import ( @@ -285,11 +286,11 @@ class PersonProfileService: def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: relation_by_hash: Dict[str, Dict[str, Any]] = {} for alias in aliases: - for rel in self.metadata_store.get_relations(subject=alias): + for rel in self.metadata_store.get_relations(subject=alias, include_inactive=False): h = str(rel.get("hash", "")) if h: relation_by_hash[h] = rel - for rel in self.metadata_store.get_relations(object=alias): + for rel in self.metadata_store.get_relations(object=alias, include_inactive=False): h = str(rel.get("hash", "")) if h: relation_by_hash[h] = rel @@ -342,7 +343,53 @@ class PersonProfileService: "metadata": {}, } ) - return evidence + return self._filter_stale_paragraph_evidence(evidence) + + def _filter_stale_paragraph_evidence( + self, + evidence: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + memory_cfg = getattr(global_config, "memory", None) + if not bool(getattr(memory_cfg, "feedback_correction_paragraph_hard_filter_enabled", True)): + return evidence + paragraph_hashes = [ + str(item.get("hash", "") or "").strip() + for item in evidence + if str(item.get("type", "") or "").strip() == "paragraph" and str(item.get("hash", "") or "").strip() + ] + if not paragraph_hashes: + return evidence + + marks_by_paragraph = self.metadata_store.get_paragraph_stale_relation_marks_batch(paragraph_hashes) + relation_hashes: List[str] = [] + seen = set() + for marks in marks_by_paragraph.values(): + for mark in marks: + relation_hash = str(mark.get("relation_hash", "") or "").strip() + if not relation_hash or relation_hash in seen: + continue + seen.add(relation_hash) + relation_hashes.append(relation_hash) + status_map = self.metadata_store.get_relation_status_batch(relation_hashes) if relation_hashes else {} + + filtered: List[Dict[str, Any]] = [] + for item in evidence: + item_type = str(item.get("type", "") or "").strip() + item_hash = str(item.get("hash", "") or "").strip() + if item_type != "paragraph" or not item_hash: + filtered.append(item) + continue + marks = marks_by_paragraph.get(item_hash, []) + should_hide = any( + status_map.get(str(mark.get("relation_hash", "") or "").strip()) is None + or bool((status_map.get(str(mark.get("relation_hash", "") or "").strip()) or {}).get("is_inactive")) + for mark in marks + if str(mark.get("relation_hash", "") or "").strip() + ) + if should_hide: + continue + filtered.append(item) + return filtered async def _collect_vector_evidence( self, @@ -373,7 +420,7 @@ class PersonProfileService: "metadata": {}, } ) - return fallback[:top_k] + return self._filter_stale_paragraph_evidence(fallback[:top_k]) per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) seen_hash = set() @@ -406,7 +453,7 @@ class PersonProfileService: } ) evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) - return evidence[:top_k] + return self._filter_stale_paragraph_evidence(evidence[:top_k]) def _build_profile_text( self, diff --git a/src/A_memorix/host_service.py b/src/A_memorix/host_service.py index adbaf91a..b3766dd6 100644 --- a/src/A_memorix/host_service.py +++ b/src/A_memorix/host_service.py @@ -190,6 +190,16 @@ class AMemorixHostService: ) ) + if component_name == "enqueue_feedback_task": + return await kernel.enqueue_feedback_task( + query_tool_id=str(payload.get("query_tool_id", "") or ""), + session_id=str(payload.get("session_id", "") or ""), + query_timestamp=payload.get("query_timestamp"), + structured_content=payload.get("structured_content") + if isinstance(payload.get("structured_content"), dict) + else {}, + ) + if component_name == "ingest_summary": return await kernel.ingest_summary( external_id=str(payload.get("external_id", "") or ""), @@ -251,6 +261,7 @@ class AMemorixHostService: "memory_source_admin": kernel.memory_source_admin, "memory_episode_admin": kernel.memory_episode_admin, "memory_profile_admin": kernel.memory_profile_admin, + "memory_feedback_admin": kernel.memory_feedback_admin, "memory_runtime_admin": kernel.memory_runtime_admin, "memory_import_admin": kernel.memory_import_admin, "memory_tuning_admin": kernel.memory_tuning_admin, diff --git a/src/A_memorix/scripts/release_vnext_migrate.py b/src/A_memorix/scripts/release_vnext_migrate.py index c344e2f1..d41b755f 100644 --- a/src/A_memorix/scripts/release_vnext_migrate.py +++ b/src/A_memorix/scripts/release_vnext_migrate.py @@ -62,7 +62,10 @@ if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): try: from A_memorix.core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore - from A_memorix.core.storage.metadata_store import SCHEMA_VERSION + from A_memorix.core.storage.metadata_store import ( + RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION, + SCHEMA_VERSION, + ) except Exception as e: # pragma: no cover print(f"❌ failed to import storage modules: {e}") raise SystemExit(2) @@ -125,6 +128,14 @@ def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: return row is not None +def _sqlite_column_exists(conn: sqlite3.Connection, table: str, column: str) -> bool: + try: + rows = conn.execute(f"PRAGMA table_info({table})").fetchall() + except Exception: + return False + return any(str(row[1] or "") == str(column or "") for row in rows) + + def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: hashes: List[str] = [] if _sqlite_table_exists(conn, "relations"): @@ -152,6 +163,8 @@ def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[st def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: if not _sqlite_table_exists(conn, "paragraphs"): return [] + if not _sqlite_column_exists(conn, "paragraphs", "knowledge_type"): + return [] allowed = {item.value for item in KnowledgeType} rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() @@ -288,6 +301,14 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: facts["schema_migrations_exists"] = has_schema_table has_paragraph_backfill = _sqlite_table_exists(conn, "paragraph_vector_backfill") facts["paragraph_vector_backfill_exists"] = has_paragraph_backfill + has_stale_marks = _sqlite_table_exists(conn, "paragraph_stale_relation_marks") + facts["paragraph_stale_relation_marks_exists"] = has_stale_marks + has_profile_refresh_queue = _sqlite_table_exists(conn, "person_profile_refresh_queue") + facts["person_profile_refresh_queue_exists"] = has_profile_refresh_queue + has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status") + facts["memory_feedback_tasks_rollback_status_exists"] = has_feedback_rollback_status + has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json") + facts["memory_feedback_tasks_rollback_plan_exists"] = has_feedback_rollback_plan if not has_schema_table: checks.append( CheckItem( @@ -300,14 +321,28 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() version = int(row[0]) if row and row[0] is not None else 0 facts["schema_version"] = version + runtime_auto_migratable = ( + version < SCHEMA_VERSION + and version >= RUNTIME_AUTO_MIGRATION_MIN_SCHEMA_VERSION + ) + facts["schema_runtime_auto_migratable"] = runtime_auto_migratable if version != SCHEMA_VERSION: - checks.append( - CheckItem( - "CP-08", - "error", - f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", + if runtime_auto_migratable: + checks.append( + CheckItem( + "CP-18", + "warning", + f"schema version behind runtime target: current={version}, expected={SCHEMA_VERSION}; runtime auto migration will handle this update", + ) + ) + else: + checks.append( + CheckItem( + "CP-08", + "error", + f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", + ) ) - ) elif not has_paragraph_backfill: checks.append( CheckItem( @@ -316,6 +351,30 @@ def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: "paragraph_vector_backfill table missing under current schema version", ) ) + elif not has_stale_marks: + checks.append( + CheckItem( + "CP-15", + "error", + "paragraph_stale_relation_marks table missing under current schema version", + ) + ) + elif not has_profile_refresh_queue: + checks.append( + CheckItem( + "CP-16", + "error", + "person_profile_refresh_queue table missing under current schema version", + ) + ) + elif not has_feedback_rollback_status or not has_feedback_rollback_plan: + checks.append( + CheckItem( + "CP-17", + "error", + "memory_feedback_tasks rollback columns missing under current schema version", + ) + ) if _sqlite_table_exists(conn, "relations"): row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() @@ -616,6 +675,46 @@ def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: "paragraph_vector_backfill table missing after migration", ) ) + has_feedback_tasks = _sqlite_table_exists(conn, "memory_feedback_tasks") + facts["memory_feedback_tasks_exists"] = bool(has_feedback_tasks) + if not has_feedback_tasks: + checks.append( + CheckItem( + "CP-15", + "error", + "memory_feedback_tasks table missing after migration", + ) + ) + has_feedback_logs = _sqlite_table_exists(conn, "memory_feedback_action_logs") + facts["memory_feedback_action_logs_exists"] = bool(has_feedback_logs) + if not has_feedback_logs: + checks.append( + CheckItem( + "CP-16", + "error", + "memory_feedback_action_logs table missing after migration", + ) + ) + has_feedback_rollback_status = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_status") + facts["memory_feedback_tasks_rollback_status_exists"] = bool(has_feedback_rollback_status) + if not has_feedback_rollback_status: + checks.append( + CheckItem( + "CP-17", + "error", + "memory_feedback_tasks.rollback_status missing after migration", + ) + ) + has_feedback_rollback_plan = _sqlite_column_exists(conn, "memory_feedback_tasks", "rollback_plan_json") + facts["memory_feedback_tasks_rollback_plan_exists"] = bool(has_feedback_rollback_plan) + if not has_feedback_rollback_plan: + checks.append( + CheckItem( + "CP-18", + "error", + "memory_feedback_tasks.rollback_plan_json missing after migration", + ) + ) conflicts = _collect_hash_alias_conflicts(conn) invalid_knowledge_types = _collect_invalid_knowledge_types(conn) finally: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 58d577f6..b0077019 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -431,6 +431,153 @@ class MemoryConfig(ConfigBase): }, ) """是否在发送回复后自动提取并写回人物事实到长期记忆""" + + feedback_correction_enabled: bool = Field( + default=False, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "message-circle-warning", + }, + ) + """是否启用反馈驱动的延迟记忆纠错任务""" + + feedback_correction_window_hours: float = Field( + default=12.0, + ge=0.1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "clock-4", + }, + ) + """反馈窗口时长(小时),以 query_memory 执行时间为起点""" + + feedback_correction_check_interval_minutes: int = Field( + default=30, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "timer", + }, + ) + """反馈纠错定时任务轮询间隔(分钟)""" + + feedback_correction_batch_size: int = Field( + default=20, + ge=1, + le=200, + json_schema_extra={ + "x-widget": "input", + "x-icon": "list-ordered", + }, + ) + """反馈纠错每轮最大处理任务数""" + + feedback_correction_auto_apply_threshold: float = Field( + default=0.85, + ge=0.0, + le=1.0, + json_schema_extra={ + "x-widget": "slider", + "x-icon": "gauge", + "step": 0.01, + }, + ) + """自动应用纠错动作的最低置信度阈值""" + + feedback_correction_max_feedback_messages: int = Field( + default=30, + ge=1, + le=200, + json_schema_extra={ + "x-widget": "input", + "x-icon": "messages-square", + }, + ) + """每个纠错任务最多使用的窗口内用户反馈消息数""" + + feedback_correction_prefilter_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "filter", + }, + ) + """是否启用纠错前置预筛(用于减少不必要的模型调用)""" + + feedback_correction_paragraph_mark_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "sticky-note", + }, + ) + """是否为受影响 paragraph 写入已纠正旧事实标记""" + + feedback_correction_paragraph_hard_filter_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "eye-off", + }, + ) + """是否在用户侧查询中硬过滤带有 stale 标记的 paragraph""" + + feedback_correction_profile_refresh_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "user-round-search", + }, + ) + """是否在反馈纠错后将受影响人物画像加入刷新队列""" + + feedback_correction_profile_force_refresh_on_read: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "refresh-ccw", + }, + ) + """人物画像处于脏队列时,读取是否强制刷新而不直接复用旧快照""" + + feedback_correction_episode_rebuild_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "clapperboard", + }, + ) + """是否在反馈纠错后将受影响 source 加入 episode 重建队列""" + + feedback_correction_episode_query_block_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "ban", + }, + ) + """episode source 处于重建队列时,是否对用户侧查询做屏蔽""" + + feedback_correction_reconcile_interval_minutes: int = Field( + default=5, + ge=1, + json_schema_extra={ + "x-widget": "input", + "x-icon": "repeat", + }, + ) + """反馈纠错二阶段一致性后台协调任务轮询间隔(分钟)""" + + feedback_correction_reconcile_batch_size: int = Field( + default=20, + ge=1, + le=200, + json_schema_extra={ + "x-widget": "input", + "x-icon": "list-restart", + }, + ) + """反馈纠错二阶段一致性每轮处理 profile/episode 队列的批大小""" chat_history_topic_check_message_threshold: int = Field( default=80, ge=1, @@ -502,6 +649,39 @@ class MemoryConfig(ConfigBase): raise ValueError( f"chat_history_finalize_message_count 必须至少为1,当前值: {self.chat_history_finalize_message_count}" ) + if self.feedback_correction_window_hours <= 0: + raise ValueError( + f"feedback_correction_window_hours 必须大于0,当前值: {self.feedback_correction_window_hours}" + ) + if self.feedback_correction_check_interval_minutes < 1: + raise ValueError( + "feedback_correction_check_interval_minutes 必须至少为1," + f"当前值: {self.feedback_correction_check_interval_minutes}" + ) + if self.feedback_correction_batch_size < 1: + raise ValueError( + f"feedback_correction_batch_size 必须至少为1,当前值: {self.feedback_correction_batch_size}" + ) + if not 0 <= self.feedback_correction_auto_apply_threshold <= 1: + raise ValueError( + "feedback_correction_auto_apply_threshold 必须在 [0, 1] 之间," + f"当前值: {self.feedback_correction_auto_apply_threshold}" + ) + if self.feedback_correction_max_feedback_messages < 1: + raise ValueError( + "feedback_correction_max_feedback_messages 必须至少为1," + f"当前值: {self.feedback_correction_max_feedback_messages}" + ) + if self.feedback_correction_reconcile_interval_minutes < 1: + raise ValueError( + "feedback_correction_reconcile_interval_minutes 必须至少为1," + f"当前值: {self.feedback_correction_reconcile_interval_minutes}" + ) + if self.feedback_correction_reconcile_batch_size < 1: + raise ValueError( + "feedback_correction_reconcile_batch_size 必须至少为1," + f"当前值: {self.feedback_correction_reconcile_batch_size}" + ) return super().model_post_init(context) diff --git a/src/maisaka/reasoning_engine.py b/src/maisaka/reasoning_engine.py index 1465e246..46678e10 100644 --- a/src/maisaka/reasoning_engine.py +++ b/src/maisaka/reasoning_engine.py @@ -19,6 +19,7 @@ from src.core.tooling import ToolExecutionContext, ToolExecutionResult, ToolInvo from src.llm_models.exceptions import ReqAbortException from src.llm_models.payload_content.tool_option import ToolCall from src.services import database_service as database_api +from src.services.memory_service import memory_service from .builtin_tool import get_action_tool_specs from .builtin_tool import build_builtin_tool_handlers as build_split_builtin_tool_handlers @@ -1013,15 +1014,30 @@ class MaisakaReasoningEngine: builtin_prompt = tool_spec.build_llm_description() try: - await database_api.store_tool_info( + tool_record_payload = self._build_tool_record_payload(invocation, result, tool_spec) + saved_record = await database_api.store_tool_info( chat_stream=self._runtime.chat_stream, builtin_prompt=builtin_prompt, display_prompt=self._build_tool_display_prompt(invocation, result, tool_spec), tool_id=invocation.call_id, - tool_data=self._build_tool_record_payload(invocation, result, tool_spec), + tool_data=tool_record_payload, tool_name=invocation.tool_name, tool_reasoning=invocation.reasoning, ) + if invocation.tool_name == "query_memory" and isinstance(saved_record, dict): + enqueue_payload = await memory_service.enqueue_feedback_task( + query_tool_id=str(saved_record.get("tool_id") or invocation.call_id or "").strip(), + session_id=str(saved_record.get("session_id") or self._runtime.chat_stream.session_id or "").strip(), + query_timestamp=saved_record.get("timestamp"), + structured_content=tool_record_payload.get("structured_content") + if isinstance(tool_record_payload.get("structured_content"), dict) + else {}, + ) + if not bool(enqueue_payload.get("success")): + logger.debug( + f"{self._runtime.log_prefix} 反馈纠错任务未入队: " + f"tool_call_id={invocation.call_id} reason={enqueue_payload.get('reason', '')}" + ) except Exception: logger.exception( f"{self._runtime.log_prefix} 写入工具记录失败: 工具={invocation.tool_name} 调用编号={invocation.call_id}" @@ -1153,4 +1169,3 @@ class MaisakaReasoningEngine: return True, tool_result_summaries, tool_monitor_results return False, tool_result_summaries, tool_monitor_results - diff --git a/src/services/memory_service.py b/src/services/memory_service.py index 34e868b5..e4d0e216 100644 --- a/src/services/memory_service.py +++ b/src/services/memory_service.py @@ -233,6 +233,30 @@ class MemoryService: logger.warning("长期记忆搜索失败: %s", exc) return MemorySearchResult(success=False, error=str(exc)) + async def enqueue_feedback_task( + self, + *, + query_tool_id: str, + session_id: str, + query_timestamp: Any = None, + structured_content: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + try: + payload = await self._invoke( + "enqueue_feedback_task", + { + "query_tool_id": str(query_tool_id or "").strip(), + "session_id": str(session_id or "").strip(), + "query_timestamp": query_timestamp, + "structured_content": structured_content if isinstance(structured_content, dict) else {}, + }, + timeout_ms=10000, + ) + except Exception as exc: + logger.warning("反馈纠错任务入队失败: %s", exc) + return {"success": False, "queued": False, "reason": str(exc)} + return payload if isinstance(payload, dict) else {"success": False, "queued": False, "reason": "invalid_payload"} + async def ingest_summary( self, *, @@ -388,6 +412,13 @@ class MemoryService: logger.warning("画像管理调用失败: %s", exc) return {"success": False, "error": str(exc)} + async def feedback_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_feedback_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("反馈纠错管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: try: return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs) diff --git a/src/webui/app.py b/src/webui/app.py index 76434fa9..5a8c7094 100644 --- a/src/webui/app.py +++ b/src/webui/app.py @@ -205,6 +205,12 @@ def _setup_static_files(app: FastAPI): def _resolve_static_path() -> Path | None: + # 开发环境优先允许复用仓库里的现成 dist + base_dir = _get_project_root() + static_path = base_dir / "dashboard" / "dist" + if static_path.exists(): + return static_path + try: module = import_module("maibot_dashboard") get_dist_path = getattr(module, "get_dist_path", None) @@ -215,11 +221,6 @@ def _resolve_static_path() -> Path | None: except Exception: pass - # 开发环境允许复用仓库里的现成 dist,但不再在用户机器上触发任何前端自愈构建。 - base_dir = _get_project_root() - static_path = base_dir / "dashboard" / "dist" - if static_path.exists(): - return static_path return None diff --git a/src/webui/routers/memory.py b/src/webui/routers/memory.py index 25bb1d9f..7da7bdb7 100644 --- a/src/webui/routers/memory.py +++ b/src/webui/routers/memory.py @@ -124,6 +124,11 @@ class DeletePurgeRequest(BaseModel): limit: int = Field(1000, ge=1, le=5000) +class FeedbackRollbackRequest(BaseModel): + requested_by: str = "webui" + reason: str = "" + + def _build_import_guide_markdown(settings: dict[str, Any]) -> str: path_aliases_raw = settings.get("path_aliases") path_aliases = path_aliases_raw if isinstance(path_aliases_raw, dict) else {} @@ -359,6 +364,29 @@ async def _profile_delete_override(person_id: str) -> dict: return await memory_service.profile_admin(action="delete_override", person_id=person_id) +async def _feedback_list(limit: int, status: str, rollback_status: str, query: str) -> dict: + return await memory_service.feedback_admin( + action="list", + limit=limit, + status=status, + rollback_status=rollback_status, + query=query, + ) + + +async def _feedback_get(task_id: int) -> dict: + return await memory_service.feedback_admin(action="get", task_id=task_id) + + +async def _feedback_rollback(task_id: int, payload: FeedbackRollbackRequest) -> dict: + return await memory_service.feedback_admin( + action="rollback", + task_id=task_id, + requested_by=payload.requested_by, + reason=payload.reason, + ) + + async def _runtime_save() -> dict: return await memory_service.runtime_admin(action="save") @@ -830,6 +858,26 @@ async def delete_memory_profile_override(person_id: str): return await _profile_delete_override(person_id) +@router.get("/feedback-corrections") +async def list_memory_feedback_corrections( + limit: int = Query(50, ge=1, le=200), + status: str = Query(""), + rollback_status: str = Query(""), + query: str = Query(""), +): + return await _feedback_list(limit, status, rollback_status, query) + + +@router.get("/feedback-corrections/{task_id}") +async def get_memory_feedback_correction(task_id: int): + return await _feedback_get(task_id) + + +@router.post("/feedback-corrections/{task_id}/rollback") +async def rollback_memory_feedback_correction(task_id: int, payload: FeedbackRollbackRequest): + return await _feedback_rollback(task_id, payload) + + @router.post("/runtime/save") async def save_memory_runtime(): return await _runtime_save()