From 999e7246e25881430eb649f40d354ef870697359 Mon Sep 17 00:00:00 2001 From: DawnARC Date: Wed, 18 Mar 2026 21:33:15 +0800 Subject: [PATCH 01/14] =?UTF-8?q?feat=EF=BC=9A=E6=96=B0=E5=A2=9E=20A=5FMem?= =?UTF-8?q?orix=20=E8=AE=B0=E5=BF=86=E6=8F=92=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。 --- .gitignore | 1 + plugins/A_memorix/__init__.py | 12 + plugins/A_memorix/_manifest.json | 62 + plugins/A_memorix/core/__init__.py | 84 + plugins/A_memorix/core/embedding/__init__.py | 18 + .../A_memorix/core/embedding/api_adapter.py | 174 + plugins/A_memorix/core/embedding/manager.py | 510 ++ plugins/A_memorix/core/embedding/presets.py | 72 + plugins/A_memorix/core/retrieval/__init__.py | 54 + plugins/A_memorix/core/retrieval/dual_path.py | 1796 ++++++ .../core/retrieval/graph_relation_recall.py | 272 + plugins/A_memorix/core/retrieval/pagerank.py | 482 ++ .../A_memorix/core/retrieval/sparse_bm25.py | 402 ++ plugins/A_memorix/core/retrieval/threshold.py | 450 ++ plugins/A_memorix/core/runtime/__init__.py | 8 + .../core/runtime/sdk_memory_kernel.py | 579 ++ plugins/A_memorix/core/storage/__init__.py | 53 + plugins/A_memorix/core/storage/graph_store.py | 1434 +++++ .../A_memorix/core/storage/knowledge_types.py | 183 + .../A_memorix/core/storage/metadata_store.py | 5225 +++++++++++++++++ .../A_memorix/core/storage/type_detection.py | 137 + .../A_memorix/core/storage/vector_store.py | 776 +++ plugins/A_memorix/core/strategies/__init__.py | 0 plugins/A_memorix/core/strategies/base.py | 89 + plugins/A_memorix/core/strategies/factual.py | 98 + .../A_memorix/core/strategies/narrative.py | 126 + plugins/A_memorix/core/strategies/quote.py | 52 + plugins/A_memorix/core/utils/__init__.py | 33 + .../core/utils/aggregate_query_service.py | 360 ++ .../core/utils/episode_retrieval_service.py | 182 + plugins/A_memorix/core/utils/hash.py | 129 + .../A_memorix/core/utils/import_payloads.py | 110 + plugins/A_memorix/core/utils/io.py | 84 + plugins/A_memorix/core/utils/matcher.py | 89 + plugins/A_memorix/core/utils/monitor.py | 189 + .../core/utils/path_fallback_service.py | 165 + .../core/utils/person_profile_service.py | 495 ++ .../A_memorix/core/utils/plugin_id_policy.py | 27 + plugins/A_memorix/core/utils/quantization.py | 344 ++ .../A_memorix/core/utils/relation_query.py | 121 + .../core/utils/relation_write_service.py | 164 + .../core/utils/runtime_self_check.py | 197 + .../core/utils/search_postprocess.py | 90 + plugins/A_memorix/core/utils/time_parser.py | 170 + plugins/A_memorix/plugin.py | 207 + plugins/A_memorix/scripts/convert_lpmm.py | 535 ++ .../A_memorix/scripts/migrate_chat_history.py | 110 + .../scripts/migrate_person_memory_points.py | 120 + 48 files changed, 17070 insertions(+) create mode 100644 plugins/A_memorix/__init__.py create mode 100644 plugins/A_memorix/_manifest.json create mode 100644 plugins/A_memorix/core/__init__.py create mode 100644 plugins/A_memorix/core/embedding/__init__.py create mode 100644 plugins/A_memorix/core/embedding/api_adapter.py create mode 100644 plugins/A_memorix/core/embedding/manager.py create mode 100644 plugins/A_memorix/core/embedding/presets.py create mode 100644 plugins/A_memorix/core/retrieval/__init__.py create mode 100644 plugins/A_memorix/core/retrieval/dual_path.py create mode 100644 plugins/A_memorix/core/retrieval/graph_relation_recall.py create mode 100644 plugins/A_memorix/core/retrieval/pagerank.py create mode 100644 plugins/A_memorix/core/retrieval/sparse_bm25.py create mode 100644 plugins/A_memorix/core/retrieval/threshold.py create mode 100644 plugins/A_memorix/core/runtime/__init__.py create mode 100644 plugins/A_memorix/core/runtime/sdk_memory_kernel.py create mode 100644 plugins/A_memorix/core/storage/__init__.py create mode 100644 plugins/A_memorix/core/storage/graph_store.py create mode 100644 plugins/A_memorix/core/storage/knowledge_types.py create mode 100644 plugins/A_memorix/core/storage/metadata_store.py create mode 100644 plugins/A_memorix/core/storage/type_detection.py create mode 100644 plugins/A_memorix/core/storage/vector_store.py create mode 100644 plugins/A_memorix/core/strategies/__init__.py create mode 100644 plugins/A_memorix/core/strategies/base.py create mode 100644 plugins/A_memorix/core/strategies/factual.py create mode 100644 plugins/A_memorix/core/strategies/narrative.py create mode 100644 plugins/A_memorix/core/strategies/quote.py create mode 100644 plugins/A_memorix/core/utils/__init__.py create mode 100644 plugins/A_memorix/core/utils/aggregate_query_service.py create mode 100644 plugins/A_memorix/core/utils/episode_retrieval_service.py create mode 100644 plugins/A_memorix/core/utils/hash.py create mode 100644 plugins/A_memorix/core/utils/import_payloads.py create mode 100644 plugins/A_memorix/core/utils/io.py create mode 100644 plugins/A_memorix/core/utils/matcher.py create mode 100644 plugins/A_memorix/core/utils/monitor.py create mode 100644 plugins/A_memorix/core/utils/path_fallback_service.py create mode 100644 plugins/A_memorix/core/utils/person_profile_service.py create mode 100644 plugins/A_memorix/core/utils/plugin_id_policy.py create mode 100644 plugins/A_memorix/core/utils/quantization.py create mode 100644 plugins/A_memorix/core/utils/relation_query.py create mode 100644 plugins/A_memorix/core/utils/relation_write_service.py create mode 100644 plugins/A_memorix/core/utils/runtime_self_check.py create mode 100644 plugins/A_memorix/core/utils/search_postprocess.py create mode 100644 plugins/A_memorix/core/utils/time_parser.py create mode 100644 plugins/A_memorix/plugin.py create mode 100644 plugins/A_memorix/scripts/convert_lpmm.py create mode 100644 plugins/A_memorix/scripts/migrate_chat_history.py create mode 100644 plugins/A_memorix/scripts/migrate_person_memory_points.py diff --git a/.gitignore b/.gitignore index ace02c8b..36853b8a 100644 --- a/.gitignore +++ b/.gitignore @@ -339,6 +339,7 @@ run_pet.bat /plugins/* !/plugins +!/plugins/A_memorix !/plugins/hello_world_plugin !/plugins/emoji_manage_plugin !/plugins/take_picture_plugin diff --git a/plugins/A_memorix/__init__.py b/plugins/A_memorix/__init__.py new file mode 100644 index 00000000..d23a5bd5 --- /dev/null +++ b/plugins/A_memorix/__init__.py @@ -0,0 +1,12 @@ +""" +A_Memorix - 轻量级知识库插件 + +完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。 +""" + +__version__ = "2.0.0" +__author__ = "A_Dawn" + +from .plugin import AMemorixPlugin + +__all__ = ["AMemorixPlugin"] diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json new file mode 100644 index 00000000..a45b2f73 --- /dev/null +++ b/plugins/A_memorix/_manifest.json @@ -0,0 +1,62 @@ +{ + "manifest_version": 1, + "name": "A_Memorix", + "version": "2.0.0", + "description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。", + "author": { + "name": "A_Dawn" + }, + "license": "AGPL-3.0", + "repository_url": "https://github.com/A-Dawn/A_memorix/", + "host_application": { + "min_version": "1.0.0" + }, + "keywords": [ + "memory", + "knowledge", + "retrieval", + "profile", + "episode" + ], + "categories": [ + "Memory", + "Data" + ], + "plugin_info": { + "is_built_in": false, + "plugin_type": "memory_provider", + "components": [ + { + "type": "tool", + "name": "search_memory", + "description": "搜索长期记忆" + }, + { + "type": "tool", + "name": "ingest_summary", + "description": "写入聊天摘要" + }, + { + "type": "tool", + "name": "ingest_text", + "description": "写入普通长期记忆文本" + }, + { + "type": "tool", + "name": "get_person_profile", + "description": "查询人物画像" + }, + { + "type": "tool", + "name": "maintain_memory", + "description": "维护记忆关系" + }, + { + "type": "tool", + "name": "memory_stats", + "description": "查询记忆统计" + } + ] + }, + "capabilities": [] +} diff --git a/plugins/A_memorix/core/__init__.py b/plugins/A_memorix/core/__init__.py new file mode 100644 index 00000000..3f87929c --- /dev/null +++ b/plugins/A_memorix/core/__init__.py @@ -0,0 +1,84 @@ +"""核心模块 - 存储、嵌入、检索引擎""" + +# 存储模块(已实现) +from .storage import ( + VectorStore, + GraphStore, + MetadataStore, + ImportStrategy, + KnowledgeType, + parse_import_strategy, + resolve_stored_knowledge_type, + detect_knowledge_type, + select_import_strategy, + should_extract_relations, + get_type_display_name, +) + +# 嵌入模块(使用主程序 API) +from .embedding import ( + EmbeddingAPIAdapter, + create_embedding_api_adapter, +) + +# 检索模块(已实现) +from .retrieval import ( + DualPathRetriever, + RetrievalStrategy, + RetrievalResult, + DualPathRetrieverConfig, + TemporalQueryOptions, + FusionConfig, + GraphRelationRecallConfig, + RelationIntentConfig, + PersonalizedPageRank, + PageRankConfig, + create_ppr_from_graph, + DynamicThresholdFilter, + ThresholdMethod, + ThresholdConfig, + SparseBM25Index, + SparseBM25Config, +) +from .utils import ( + RelationWriteService, + RelationWriteResult, +) + +__all__ = [ + # Storage + "VectorStore", + "GraphStore", + "MetadataStore", + "ImportStrategy", + "KnowledgeType", + "parse_import_strategy", + "resolve_stored_knowledge_type", + "detect_knowledge_type", + "select_import_strategy", + "should_extract_relations", + "get_type_display_name", + # Embedding + "EmbeddingAPIAdapter", + "create_embedding_api_adapter", + # Retrieval + "DualPathRetriever", + "RetrievalStrategy", + "RetrievalResult", + "DualPathRetrieverConfig", + "TemporalQueryOptions", + "FusionConfig", + "GraphRelationRecallConfig", + "RelationIntentConfig", + "PersonalizedPageRank", + "PageRankConfig", + "create_ppr_from_graph", + "DynamicThresholdFilter", + "ThresholdMethod", + "ThresholdConfig", + "SparseBM25Index", + "SparseBM25Config", + "RelationWriteService", + "RelationWriteResult", +] + diff --git a/plugins/A_memorix/core/embedding/__init__.py b/plugins/A_memorix/core/embedding/__init__.py new file mode 100644 index 00000000..11a52db9 --- /dev/null +++ b/plugins/A_memorix/core/embedding/__init__.py @@ -0,0 +1,18 @@ +"""嵌入模块 - 向量生成与量化""" + +# 新的 API 适配器(主程序嵌入 API) +from .api_adapter import ( + EmbeddingAPIAdapter, + create_embedding_api_adapter, +) + +from ..utils.quantization import QuantizationType + +__all__ = [ + # 新的 API 适配器(推荐使用) + "EmbeddingAPIAdapter", + "create_embedding_api_adapter", + # 量化 + "QuantizationType", +] + diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py new file mode 100644 index 00000000..4262ddb9 --- /dev/null +++ b/plugins/A_memorix/core/embedding/api_adapter.py @@ -0,0 +1,174 @@ +""" +Hash-based embedding adapter used by the SDK runtime. + +The plugin runtime cannot import MaiBot host embedding internals from ``src.chat`` +or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in +Runner by generating deterministic dense vectors locally. +""" + +from __future__ import annotations + +import hashlib +import re +import time +from typing import List, Optional, Union + +import numpy as np + +from src.common.logger import get_logger + + +logger = get_logger("A_Memorix.EmbeddingAPIAdapter") + +_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}") + + +class EmbeddingAPIAdapter: + """Deterministic local embedding adapter.""" + + def __init__( + self, + batch_size: int = 32, + max_concurrent: int = 5, + default_dimension: int = 256, + enable_cache: bool = False, + model_name: str = "hash-v1", + retry_config: Optional[dict] = None, + ) -> None: + self.batch_size = max(1, int(batch_size)) + self.max_concurrent = max(1, int(max_concurrent)) + self.default_dimension = max(32, int(default_dimension)) + self.enable_cache = bool(enable_cache) + self.model_name = str(model_name or "hash-v1") + self.retry_config = retry_config or {} + + self._dimension: Optional[int] = None + self._dimension_detected = False + self._total_encoded = 0 + self._total_errors = 0 + self._total_time = 0.0 + + logger.info( + "EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s", + self.model_name, + self.batch_size, + self.default_dimension, + ) + + async def _detect_dimension(self) -> int: + if self._dimension_detected and self._dimension is not None: + return self._dimension + self._dimension = self.default_dimension + self._dimension_detected = True + return self._dimension + + @staticmethod + def _tokenize(text: str) -> List[str]: + clean = str(text or "").strip().lower() + if not clean: + return [] + return _TOKEN_PATTERN.findall(clean) + + @staticmethod + def _feature_weight(token: str) -> float: + digest = hashlib.sha256(token.encode("utf-8")).digest() + return 1.0 + (digest[10] / 255.0) * 0.5 + + def _encode_single(self, text: str, dimension: int) -> np.ndarray: + vector = np.zeros(dimension, dtype=np.float32) + content = str(text or "").strip() + tokens = self._tokenize(content) + if not tokens and content: + tokens = [content.lower()] + if not tokens: + vector[0] = 1.0 + return vector + + for token in tokens: + digest = hashlib.sha256(token.encode("utf-8")).digest() + bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension + sign = 1.0 if digest[8] % 2 == 0 else -1.0 + vector[bucket] += sign * self._feature_weight(token) + + second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension + if second_bucket != bucket: + vector[second_bucket] += (sign * 0.35) + + norm = float(np.linalg.norm(vector)) + if norm > 1e-8: + vector /= norm + else: + vector[0] = 1.0 + return vector + + async def encode( + self, + texts: Union[str, List[str]], + batch_size: Optional[int] = None, + show_progress: bool = False, + normalize: bool = True, + dimensions: Optional[int] = None, + ) -> np.ndarray: + _ = batch_size + _ = show_progress + _ = normalize + + started_at = time.time() + target_dimension = max(32, int(dimensions or await self._detect_dimension())) + + if isinstance(texts, str): + single_input = True + normalized_texts = [texts] + else: + single_input = False + normalized_texts = list(texts or []) + + if not normalized_texts: + empty = np.zeros((0, target_dimension), dtype=np.float32) + return empty[0] if single_input else empty + + try: + matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts]) + self._total_encoded += len(normalized_texts) + self._total_time += time.time() - started_at + except Exception: + self._total_errors += 1 + raise + + return matrix[0] if single_input else matrix + + def get_statistics(self) -> dict: + avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0 + return { + "model_name": self.model_name, + "dimension": self._dimension or self.default_dimension, + "total_encoded": self._total_encoded, + "total_errors": self._total_errors, + "total_time": self._total_time, + "avg_time_per_text": avg_time, + } + + def __repr__(self) -> str: + return ( + f"EmbeddingAPIAdapter(model_name={self.model_name}, " + f"dimension={self._dimension or self.default_dimension}, " + f"total_encoded={self._total_encoded})" + ) + + +def create_embedding_api_adapter( + batch_size: int = 32, + max_concurrent: int = 5, + default_dimension: int = 256, + enable_cache: bool = False, + model_name: str = "hash-v1", + retry_config: Optional[dict] = None, +) -> EmbeddingAPIAdapter: + return EmbeddingAPIAdapter( + batch_size=batch_size, + max_concurrent=max_concurrent, + default_dimension=default_dimension, + enable_cache=enable_cache, + model_name=model_name, + retry_config=retry_config, + ) diff --git a/plugins/A_memorix/core/embedding/manager.py b/plugins/A_memorix/core/embedding/manager.py new file mode 100644 index 00000000..d161e23b --- /dev/null +++ b/plugins/A_memorix/core/embedding/manager.py @@ -0,0 +1,510 @@ +""" +嵌入管理器 + +负责嵌入模型的加载、缓存和批量生成。 +""" + +import hashlib +import pickle +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Optional, Union, List, Dict, Any, Tuple + +import numpy as np + +try: + from sentence_transformers import SentenceTransformer + HAS_SENTENCE_TRANSFORMERS = True +except ImportError: + HAS_SENTENCE_TRANSFORMERS = False + +from src.common.logger import get_logger +from .presets import ( + EmbeddingModelConfig, + get_custom_config, + validate_config_compatibility, + are_models_compatible, +) +from ..utils.quantization import QuantizationType + +logger = get_logger("A_Memorix.EmbeddingManager") + + +class EmbeddingManager: + """ + 嵌入管理器 + + 功能: + - 模型加载与缓存 + - 批量生成嵌入 + - 多线程/多进程支持 + - 模型一致性检查 + - 智能分批 + + 参数: + config: 模型配置 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + """ + + def __init__( + self, + config: EmbeddingModelConfig, + cache_dir: Optional[Union[str, Path]] = None, + enable_cache: bool = True, + num_workers: int = 1, + ): + """ + 初始化嵌入管理器 + + Args: + config: 模型配置 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + """ + if not HAS_SENTENCE_TRANSFORMERS: + raise ImportError( + "sentence-transformers 未安装,请安装: " + "pip install sentence-transformers" + ) + + self.config = config + self.cache_dir = Path(cache_dir) if cache_dir else None + self.enable_cache = enable_cache + self.num_workers = max(1, num_workers) + + # 模型实例 + self._model: Optional[SentenceTransformer] = None + self._model_lock = threading.Lock() + + # 缓存 + self._embedding_cache: Dict[str, np.ndarray] = {} + self._cache_lock = threading.Lock() + + # 统计 + self._total_encoded = 0 + self._cache_hits = 0 + self._cache_misses = 0 + + logger.info( + f"EmbeddingManager 初始化: model={config.model_name}, " + f"dim={config.dimension}, workers={num_workers}" + ) + + def load_model(self) -> None: + """加载模型(懒加载)""" + if self._model is not None: + return + + with self._model_lock: + # 双重检查 + if self._model is not None: + return + + logger.info(f"正在加载模型: {self.config.model_name}") + + try: + # 构建模型参数 + model_kwargs = {} + if self.config.cache_dir: + model_kwargs["cache_folder"] = self.config.cache_dir + + # 加载模型 + self._model = SentenceTransformer( + self.config.model_path, + **model_kwargs, + ) + + logger.info(f"模型加载成功: {self.config.model_name}") + + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def encode( + self, + texts: Union[str, List[str]], + batch_size: Optional[int] = None, + show_progress: bool = False, + normalize: bool = True, + ) -> np.ndarray: + """ + 生成文本嵌入 + + Args: + texts: 文本或文本列表 + batch_size: 批次大小(默认使用配置值) + show_progress: 是否显示进度条 + normalize: 是否归一化 + + Returns: + 嵌入向量 (N x D) + """ + # 确保模型已加载 + self.load_model() + + # 标准化输入 + if isinstance(texts, str): + texts = [texts] + single_input = True + else: + single_input = False + + if not texts: + return np.zeros((0, self.config.dimension), dtype=np.float32) + + # 使用配置的批次大小 + if batch_size is None: + batch_size = self.config.batch_size + + # 生成嵌入 + try: + embeddings = self._model.encode( + texts, + batch_size=batch_size, + show_progress_bar=show_progress, + normalize_embeddings=normalize and self.config.normalization, + convert_to_numpy=True, + ) + + # 确保是2D数组 + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) + + self._total_encoded += len(texts) + + # 如果是单个输入,返回1D数组 + if single_input: + return embeddings[0] + + return embeddings + + except Exception as e: + logger.error(f"生成嵌入失败: {e}") + raise + + def encode_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + show_progress: bool = False, + ) -> np.ndarray: + """ + 批量生成嵌入(多线程优化) + + Args: + texts: 文本列表 + batch_size: 批次大小 + num_workers: 工作线程数(默认使用初始化时的值) + show_progress: 是否显示进度条 + + Returns: + 嵌入向量 (N x D) + """ + if not texts: + return np.zeros((0, self.config.dimension), dtype=np.float32) + + # 单线程模式 + num_workers = num_workers if num_workers is not None else self.num_workers + if num_workers == 1: + return self.encode(texts, batch_size=batch_size, show_progress=show_progress) + + # 多线程模式 + logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入") + + # 分批 + batch_size = batch_size or self.config.batch_size + batches = [ + texts[i:i + batch_size] + for i in range(0, len(texts), batch_size) + ] + + # 多线程生成 + all_embeddings = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + # 提交任务 + future_to_batch = { + executor.submit( + self.encode, + batch, + batch_size, + False, # 不显示进度条(多线程时会混乱) + ): i + for i, batch in enumerate(batches) + } + + # 收集结果 + for future in as_completed(future_to_batch): + batch_idx = future_to_batch[future] + try: + embeddings = future.result() + all_embeddings.append((batch_idx, embeddings)) + except Exception as e: + logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}") + raise + + # 按顺序合并 + all_embeddings.sort(key=lambda x: x[0]) + final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0) + + return final_embeddings + + def encode_with_cache( + self, + texts: List[str], + batch_size: Optional[int] = None, + show_progress: bool = False, + ) -> np.ndarray: + """ + 生成嵌入(带缓存) + + Args: + texts: 文本列表 + batch_size: 批次大小 + show_progress: 是否显示进度条 + + Returns: + 嵌入向量 (N x D) + """ + if not self.enable_cache: + return self.encode(texts, batch_size, show_progress) + + # 分离缓存命中和未命中的文本 + cached_embeddings = [] + uncached_texts = [] + uncached_indices = [] + + for i, text in enumerate(texts): + cache_key = self._get_cache_key(text) + + with self._cache_lock: + if cache_key in self._embedding_cache: + cached_embeddings.append((i, self._embedding_cache[cache_key])) + self._cache_hits += 1 + else: + uncached_texts.append(text) + uncached_indices.append(i) + self._cache_misses += 1 + + # 生成未缓存的嵌入 + if uncached_texts: + new_embeddings = self.encode( + uncached_texts, + batch_size, + show_progress, + ) + + # 更新缓存 + with self._cache_lock: + for text, embedding in zip(uncached_texts, new_embeddings): + cache_key = self._get_cache_key(text) + self._embedding_cache[cache_key] = embedding.copy() + + # 合并结果 + for idx, embedding in zip(uncached_indices, new_embeddings): + cached_embeddings.append((idx, embedding)) + + # 按原始顺序排序 + cached_embeddings.sort(key=lambda x: x[0]) + final_embeddings = np.array([emb for _, emb in cached_embeddings]) + + return final_embeddings + + def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: + """ + 保存缓存到磁盘 + + Args: + cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) + """ + if cache_path is None: + if self.cache_dir is None: + raise ValueError("未指定缓存目录") + cache_path = self.cache_dir / "embeddings_cache.pkl" + + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + + with self._cache_lock: + with open(cache_path, "wb") as f: + pickle.dump(self._embedding_cache, f) + + logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)") + + def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: + """ + 从磁盘加载缓存 + + Args: + cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) + """ + if cache_path is None: + if self.cache_dir is None: + raise ValueError("未指定缓存目录") + cache_path = self.cache_dir / "embeddings_cache.pkl" + + cache_path = Path(cache_path) + if not cache_path.exists(): + logger.warning(f"缓存文件不存在: {cache_path}") + return + + with self._cache_lock: + with open(cache_path, "rb") as f: + self._embedding_cache = pickle.load(f) + + logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)") + + def clear_cache(self) -> None: + """清空缓存""" + with self._cache_lock: + count = len(self._embedding_cache) + self._embedding_cache.clear() + logger.info(f"已清空缓存: {count} 条") + + def check_model_consistency( + self, + stored_embeddings: np.ndarray, + sample_texts: List[str] = None, + ) -> Tuple[bool, str]: + """ + 检查模型一致性 + + Args: + stored_embeddings: 存储的嵌入向量 + sample_texts: 样本文本(用于重新生成对比) + + Returns: + (是否一致, 详细信息) + """ + # 检查维度 + if stored_embeddings.shape[1] != self.config.dimension: + return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}" + + # 如果提供了样本文本,重新生成并比较 + if sample_texts: + try: + new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个 + + # 计算相似度 + similarities = np.dot( + stored_embeddings[:5], + new_embeddings.T, + ).diagonal() + + # 检查相似度 + if np.mean(similarities) < 0.95: + return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}" + + return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}" + + except Exception as e: + return False, f"一致性检查失败: {e}" + + return True, "维度匹配" + + def get_model_info(self) -> Dict[str, Any]: + """ + 获取模型信息 + + Returns: + 模型信息字典 + """ + return { + "model_name": self.config.model_name, + "dimension": self.config.dimension, + "max_seq_length": self.config.max_seq_length, + "batch_size": self.config.batch_size, + "normalization": self.config.normalization, + "pooling": self.config.pooling, + "model_loaded": self._model is not None, + "cache_enabled": self.enable_cache, + "cache_size": len(self._embedding_cache), + "total_encoded": self._total_encoded, + "cache_hits": self._cache_hits, + "cache_misses": self._cache_misses, + } + + def get_embedding_dimension(self) -> int: + """获取嵌入维度""" + return self.config.dimension + + def _get_cache_key(self, text: str) -> str: + """ + 生成缓存键 + + Args: + text: 文本内容 + + Returns: + 缓存键(SHA256哈希) + """ + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + @property + def is_model_loaded(self) -> bool: + """模型是否已加载""" + return self._model is not None + + @property + def cache_hit_rate(self) -> float: + """缓存命中率""" + total = self._cache_hits + self._cache_misses + if total == 0: + return 0.0 + return self._cache_hits / total + + def __repr__(self) -> str: + return ( + f"EmbeddingManager(model={self.config.model_name}, " + f"dim={self.config.dimension}, " + f"loaded={self.is_model_loaded}, " + f"cache={len(self._embedding_cache)})" + ) + + + + +def create_embedding_manager_from_config( + model_name: str, + model_path: str, + dimension: int, + cache_dir: Optional[Union[str, Path]] = None, + enable_cache: bool = True, + num_workers: int = 1, + **config_kwargs, +) -> EmbeddingManager: + """ + 从自定义配置创建嵌入管理器 + + Args: + model_name: 模型名称 + model_path: HuggingFace模型路径 + dimension: 输出维度 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + **config_kwargs: 其他配置参数 + + Returns: + 嵌入管理器实例 + """ + # 创建自定义配置 + config = get_custom_config( + model_name=model_name, + model_path=model_path, + dimension=dimension, + cache_dir=cache_dir, + **config_kwargs, + ) + + # 创建管理器 + return EmbeddingManager( + config=config, + cache_dir=cache_dir, + enable_cache=enable_cache, + num_workers=num_workers, + ) diff --git a/plugins/A_memorix/core/embedding/presets.py b/plugins/A_memorix/core/embedding/presets.py new file mode 100644 index 00000000..54e6f8b4 --- /dev/null +++ b/plugins/A_memorix/core/embedding/presets.py @@ -0,0 +1,72 @@ +""" +嵌入模型配置模块 +""" + +from dataclasses import dataclass +from typing import Optional, Dict, Any, Union +from pathlib import Path + + +@dataclass +class EmbeddingModelConfig: + """ + 嵌入模型配置 + + 属性: + model_name: 模型描述名称 + model_path: 实际加载路径(Local or HF) + dimension: 嵌入向量维度 + max_seq_length: 最大序列长度 + batch_size: 编码批次大小 + model_size_mb: 估计显存占用 + description: 模型说明 + normalization: 是否自动归一化 + pooling: 池化策略 (mean, cls, max) + cache_dir: 模型缓存目录 + """ + + model_name: str + model_path: str + dimension: int + max_seq_length: int = 512 + batch_size: int = 32 + model_size_mb: int = 100 + description: str = "" + normalization: bool = True + pooling: str = "mean" + cache_dir: Optional[Union[str, Path]] = None + + +def validate_config_compatibility( + config1: EmbeddingModelConfig, config2: EmbeddingModelConfig +) -> bool: + """检查两个配置是否兼容(主要看维度)""" + return config1.dimension == config2.dimension + + +def are_models_compatible( + config1: EmbeddingModelConfig, config2: EmbeddingModelConfig +) -> bool: + """检查模型是否完全相同(用于热切换判断)""" + return ( + config1.model_path == config2.model_path + and config1.dimension == config2.dimension + and config1.pooling == config2.pooling + ) + + +def get_custom_config( + model_name: str, + model_path: str, + dimension: int, + cache_dir: Optional[Union[str, Path]] = None, + **kwargs, +) -> EmbeddingModelConfig: + """创建自定义模型配置""" + return EmbeddingModelConfig( + model_name=model_name, + model_path=model_path, + dimension=dimension, + cache_dir=cache_dir, + **kwargs, + ) diff --git a/plugins/A_memorix/core/retrieval/__init__.py b/plugins/A_memorix/core/retrieval/__init__.py new file mode 100644 index 00000000..6efce7f6 --- /dev/null +++ b/plugins/A_memorix/core/retrieval/__init__.py @@ -0,0 +1,54 @@ +"""检索模块 - 双路检索与排序""" + +from .dual_path import ( + DualPathRetriever, + RetrievalStrategy, + RetrievalResult, + DualPathRetrieverConfig, + TemporalQueryOptions, + FusionConfig, + RelationIntentConfig, +) +from .pagerank import ( + PersonalizedPageRank, + PageRankConfig, + create_ppr_from_graph, +) +from .threshold import ( + DynamicThresholdFilter, + ThresholdMethod, + ThresholdConfig, +) +from .sparse_bm25 import ( + SparseBM25Index, + SparseBM25Config, +) +from .graph_relation_recall import ( + GraphRelationRecallConfig, + GraphRelationRecallService, +) + +__all__ = [ + # DualPathRetriever + "DualPathRetriever", + "RetrievalStrategy", + "RetrievalResult", + "DualPathRetrieverConfig", + "TemporalQueryOptions", + "FusionConfig", + "RelationIntentConfig", + # PersonalizedPageRank + "PersonalizedPageRank", + "PageRankConfig", + "create_ppr_from_graph", + # DynamicThresholdFilter + "DynamicThresholdFilter", + "ThresholdMethod", + "ThresholdConfig", + # Sparse BM25 + "SparseBM25Index", + "SparseBM25Config", + # Graph relation recall + "GraphRelationRecallConfig", + "GraphRelationRecallService", +] diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py new file mode 100644 index 00000000..cfeb343c --- /dev/null +++ b/plugins/A_memorix/core/retrieval/dual_path.py @@ -0,0 +1,1796 @@ +""" +双路检索器 + +同时检索关系和段落,实现知识图谱增强的检索。 +""" + +import asyncio +import re +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Tuple, Union +from enum import Enum + +import numpy as np + +from src.common.logger import get_logger +from ..storage import VectorStore, GraphStore, MetadataStore +from ..embedding import EmbeddingAPIAdapter +from ..utils.matcher import AhoCorasick +from ..utils.time_parser import format_timestamp +from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService +from .pagerank import PersonalizedPageRank, PageRankConfig +from .sparse_bm25 import SparseBM25Config, SparseBM25Index + +logger = get_logger("A_Memorix.DualPathRetriever") + + +class RetrievalStrategy(Enum): + """检索策略""" + + PARA_ONLY = "paragraph_only" # 仅段落检索 + REL_ONLY = "relation_only" # 仅关系检索 + DUAL_PATH = "dual_path" # 双路检索(推荐) + + +@dataclass +class RetrievalResult: + """ + 检索结果 + + 属性: + hash_value: 哈希值 + content: 内容(段落或关系) + score: 相似度分数 + result_type: 结果类型(paragraph/relation) + source: 来源(paragraph_search/relation_search/fusion) + metadata: 额外元数据 + """ + + hash_value: str + content: str + score: float + result_type: str # "paragraph" or "relation" + source: str # "paragraph_search", "relation_search", "fusion" + metadata: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "hash": self.hash_value, + "content": self.content, + "score": self.score, + "type": self.result_type, + "source": self.source, + "metadata": self.metadata, + } + + +@dataclass +class DualPathRetrieverConfig: + """ + 双路检索器配置 + + 属性: + top_k_paragraphs: 段落检索数量 + top_k_relations: 关系检索数量 + top_k_final: 最终返回数量 + alpha: 段落和关系的融合权重(0-1) + - 0: 仅使用关系分数 + - 1: 仅使用段落分数 + - 0.5: 平均融合 + enable_ppr: 是否启用PageRank重排序 + ppr_alpha: PageRank的alpha参数 + ppr_concurrency_limit: PPR计算的最大并发数 + enable_parallel: 是否并行检索 + retrieval_strategy: 检索策略 + debug: 是否启用调试模式(打印搜索结果原文) + """ + + top_k_paragraphs: int = 20 + top_k_relations: int = 10 + top_k_final: int = 10 + alpha: float = 0.5 # 融合权重 + enable_ppr: bool = True + ppr_alpha: float = 0.85 + ppr_timeout_seconds: float = 1.5 + ppr_concurrency_limit: int = 4 + enable_parallel: bool = True + retrieval_strategy: RetrievalStrategy = RetrievalStrategy.DUAL_PATH + debug: bool = False + sparse: SparseBM25Config = field(default_factory=SparseBM25Config) + fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig()) + relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig()) + graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig) + + def __post_init__(self): + """验证配置""" + if isinstance(self.sparse, dict): + self.sparse = SparseBM25Config(**self.sparse) + if isinstance(self.fusion, dict): + self.fusion = FusionConfig(**self.fusion) + if isinstance(self.relation_intent, dict): + self.relation_intent = RelationIntentConfig(**self.relation_intent) + if isinstance(self.graph_recall, dict): + self.graph_recall = GraphRelationRecallConfig(**self.graph_recall) + + if not 0 <= self.alpha <= 1: + raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}") + + if self.top_k_paragraphs <= 0: + raise ValueError(f"top_k_paragraphs必须大于0: {self.top_k_paragraphs}") + + if self.top_k_relations <= 0: + raise ValueError(f"top_k_relations必须大于0: {self.top_k_relations}") + + if self.top_k_final <= 0: + raise ValueError(f"top_k_final必须大于0: {self.top_k_final}") + if self.ppr_timeout_seconds <= 0: + raise ValueError(f"ppr_timeout_seconds必须大于0: {self.ppr_timeout_seconds}") + + +@dataclass +class TemporalQueryOptions: + """时序查询选项。""" + + time_from: Optional[float] = None + time_to: Optional[float] = None + person: Optional[str] = None + source: Optional[str] = None + allow_created_fallback: bool = True + candidate_multiplier: int = 8 + max_scan: int = 1000 + + +@dataclass +class RelationIntentConfig: + """关系意图增强配置。""" + + enabled: bool = True + alpha_override: float = 0.35 + relation_candidate_multiplier: int = 4 + preserve_top_relations: int = 3 + force_relation_sparse: bool = True + pair_predicate_rerank_enabled: bool = True + pair_predicate_limit: int = 3 + + def __post_init__(self): + self.alpha_override = min(1.0, max(0.0, float(self.alpha_override))) + self.relation_candidate_multiplier = max(1, int(self.relation_candidate_multiplier)) + self.preserve_top_relations = max(0, int(self.preserve_top_relations)) + self.force_relation_sparse = bool(self.force_relation_sparse) + self.pair_predicate_rerank_enabled = bool(self.pair_predicate_rerank_enabled) + self.pair_predicate_limit = max(1, int(self.pair_predicate_limit)) + + +@dataclass +class FusionConfig: + """融合配置。""" + + method: str = "weighted_rrf" # weighted_rrf | alpha_legacy + rrf_k: int = 60 + vector_weight: float = 0.7 + bm25_weight: float = 0.3 + normalize_score: bool = True + normalize_method: str = "minmax" + + def __post_init__(self): + self.method = str(self.method or "weighted_rrf").strip().lower() + self.normalize_method = str(self.normalize_method or "minmax").strip().lower() + self.rrf_k = max(1, int(self.rrf_k)) + self.vector_weight = max(0.0, float(self.vector_weight)) + self.bm25_weight = max(0.0, float(self.bm25_weight)) + s = self.vector_weight + self.bm25_weight + if s <= 0: + self.vector_weight = 0.7 + self.bm25_weight = 0.3 + elif abs(s - 1.0) > 1e-8: + self.vector_weight /= s + self.bm25_weight /= s + + +class DualPathRetriever: + """ + 双路检索器 + + 功能: + - 并行检索段落和关系 + - 结果融合与排序 + - PageRank重排序 + - 实体识别与加权 + + 参数: + vector_store: 向量存储 + graph_store: 图存储 + metadata_store: 元数据存储 + embedding_manager: 嵌入管理器 + config: 检索配置 + """ + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + metadata_store: MetadataStore, + embedding_manager: EmbeddingAPIAdapter, + sparse_index: Optional[SparseBM25Index] = None, + config: Optional[DualPathRetrieverConfig] = None, + ): + """ + 初始化双路检索器 + + Args: + vector_store: 向量存储 + graph_store: 图存储 + metadata_store: 元数据存储 + embedding_manager: 嵌入管理器 + config: 检索配置 + """ + self.vector_store = vector_store + self.graph_store = graph_store + self.metadata_store = metadata_store + self.embedding_manager = embedding_manager + self.config = config or DualPathRetrieverConfig() + self.sparse_index = sparse_index + + # PageRank计算器 + ppr_config = PageRankConfig(alpha=self.config.ppr_alpha) + self._ppr = PersonalizedPageRank( + graph_store=graph_store, + config=ppr_config, + ) + self._ppr_semaphore = asyncio.Semaphore(self.config.ppr_concurrency_limit) + self._graph_relation_recall = GraphRelationRecallService( + graph_store=graph_store, + metadata_store=metadata_store, + config=self.config.graph_recall, + ) + + logger.info( + f"DualPathRetriever 初始化: " + f"strategy={self.config.retrieval_strategy.value}, " + f"top_k_para={self.config.top_k_paragraphs}, " + f"top_k_rel={self.config.top_k_relations}" + ) + + # 缓存 Aho-Corasick 匹配器 + self._ac_matcher: Optional[AhoCorasick] = None + self._ac_nodes_count = 0 + self._relation_intent_pattern = re.compile( + r"(什么关系|有哪些关系|和.+关系|关联|关系网|subject|predicate|object|" + r"relation|related|between.+and)", + re.IGNORECASE, + ) + + async def retrieve( + self, + query: str, + top_k: Optional[int] = None, + strategy: Optional[RetrievalStrategy] = None, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 执行检索(异步方法) + + Args: + query: 查询文本 + top_k: 返回结果数量(默认使用配置值) + strategy: 检索策略(默认使用配置值) + temporal: 时序查询选项(可选) + + Returns: + 检索结果列表 + """ + top_k = top_k or self.config.top_k_final + strategy = strategy or self.config.retrieval_strategy + relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) + + logger.info( + "执行检索: query='%s...', strategy=%s, relation_intent=%s", + query[:50], + strategy.value, + relation_intent_ctx.get("enabled", False), + ) + + if temporal and not (query or "").strip(): + return self._retrieve_temporal_only(temporal, top_k) + + # 根据策略执行检索 + if strategy == RetrievalStrategy.PARA_ONLY: + results = await self._retrieve_paragraphs_only(query, top_k, temporal=temporal) + elif strategy == RetrievalStrategy.REL_ONLY: + results = await self._retrieve_relations_only(query, top_k, temporal=temporal) + else: # DUAL_PATH + results = await self._retrieve_dual_path( + query, + top_k, + temporal=temporal, + relation_intent=relation_intent_ctx, + ) + + logger.info(f"检索完成: 返回 {len(results)} 条结果") + + # 调试模式:打印结果原文 + if self.config.debug: + logger.info(f"[DEBUG] 检索结果内容原文:") + for i, res in enumerate(results): + logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}") + + return results + + def _is_relation_intent_query(self, query: str) -> bool: + q = str(query or "").strip() + if not q: + return False + if "|" in q or "->" in q: + return True + return self._relation_intent_pattern.search(q) is not None + + def _build_relation_intent_context(self, query: str, top_k: int) -> Dict[str, Any]: + cfg = self.config.relation_intent + enabled = bool(cfg.enabled) and self._is_relation_intent_query(query) + base_relation_k = max(1, int(self.config.top_k_relations)) + relation_top_k = max(base_relation_k, int(top_k)) + if enabled: + relation_top_k = max( + relation_top_k, + relation_top_k * int(cfg.relation_candidate_multiplier), + ) + return { + "enabled": enabled, + "alpha_override": float(cfg.alpha_override) if enabled else None, + "relation_top_k": int(relation_top_k), + "preserve_top_relations": int(cfg.preserve_top_relations) if enabled else 0, + "force_relation_sparse": bool(cfg.force_relation_sparse) if enabled else False, + "pair_predicate_rerank_enabled": bool(cfg.pair_predicate_rerank_enabled) if enabled else False, + "pair_predicate_limit": int(cfg.pair_predicate_limit) if enabled else 0, + } + + def _cap_temporal_scan_k( + self, + candidate_k: int, + temporal: Optional[TemporalQueryOptions], + ) -> int: + """对 temporal 模式候选召回数应用 max_scan 上限。""" + k = max(1, int(candidate_k)) + if temporal and temporal.max_scan and temporal.max_scan > 0: + k = min(k, int(temporal.max_scan)) + return max(1, k) + + def _is_valid_embedding(self, emb: Optional[np.ndarray]) -> bool: + if emb is None: + return False + arr = np.asarray(emb, dtype=np.float32) + if arr.ndim == 0 or arr.size == 0: + return False + return bool(np.all(np.isfinite(arr))) + + def _get_embedding_dim(self, emb: Optional[np.ndarray]) -> Optional[int]: + if emb is None: + return None + arr = np.asarray(emb) + if arr.ndim == 1: + return int(arr.shape[0]) if arr.size > 0 else None + if arr.ndim == 2: + if arr.shape[0] == 0: + return None + return int(arr.shape[1]) + return None + + def _is_embedding_dimension_compatible(self, emb: Optional[np.ndarray]) -> bool: + got_dim = self._get_embedding_dim(emb) + expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) + if got_dim is None or expected_dim <= 0: + return False + return got_dim == expected_dim + + def _is_embedding_ready_for_vector_search( + self, + emb: Optional[np.ndarray], + *, + stage: str, + ) -> bool: + if not self._is_valid_embedding(emb): + return False + if self._is_embedding_dimension_compatible(emb): + return True + + expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) + got_dim = self._get_embedding_dim(emb) + logger.warning( + "metric.embedding_dim_mismatch_fallback_count=1 " + f"stage={stage} expected_dim={expected_dim} got_dim={got_dim}" + ) + return False + + def _should_use_sparse( + self, + embedding_ok: bool, + vector_results: Optional[List[RetrievalResult]] = None, + ) -> bool: + if not self.config.sparse.enabled or self.sparse_index is None: + return False + + mode = self.config.sparse.mode + if mode == "hybrid": + return True + if mode == "fallback_only": + return not embedding_ok + # auto + if not embedding_ok: + return True + if not vector_results: + return True + best = max((float(r.score) for r in vector_results), default=0.0) + return best < 0.45 + + def _should_use_sparse_relations( + self, + embedding_ok: bool, + relation_results: Optional[List[RetrievalResult]] = None, + force_enable: bool = False, + ) -> bool: + if force_enable and self.config.sparse.enabled and self.sparse_index is not None: + return True + if not self.config.sparse.enable_relation_sparse_fallback: + return False + return self._should_use_sparse(embedding_ok, relation_results) + + def _normalize_scores_minmax(self, results: List[RetrievalResult]) -> None: + if not results: + return + vals = [float(r.score) for r in results] + lo = min(vals) + hi = max(vals) + if hi - lo < 1e-12: + for r in results: + r.score = 1.0 + return + for r in results: + r.score = (float(r.score) - lo) / (hi - lo) + + def _build_minmax_score_map(self, results: List[RetrievalResult]) -> Dict[str, float]: + if not results: + return {} + vals = [float(r.score) for r in results] + lo = min(vals) + hi = max(vals) + if hi - lo < 1e-12: + return {r.hash_value: 1.0 for r in results} + return { + r.hash_value: (float(r.score) - lo) / (hi - lo) + for r in results + } + + @staticmethod + def _clone_retrieval_result(item: RetrievalResult) -> RetrievalResult: + return RetrievalResult( + hash_value=item.hash_value, + content=item.content, + score=float(item.score), + result_type=item.result_type, + source=item.source, + metadata=dict(item.metadata or {}), + ) + + def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: + entities = self._extract_entities(query) + if not entities: + return [] + ranked = sorted( + entities.items(), + key=lambda x: (-float(x[1]), -len(str(x[0])), str(x[0]).lower()), + ) + return [str(name) for name, _ in ranked[: max(0, int(limit))]] + + def _search_relations_graph( + self, + query: str, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + service = getattr(self, "_graph_relation_recall", None) + if service is None or not bool(getattr(self.config.graph_recall, "enabled", True)): + return [] + + seed_entities = self._extract_graph_seed_entities(query, limit=2) + if not seed_entities: + return [] + + payloads = service.recall(seed_entities=seed_entities) + results: List[RetrievalResult] = [] + for payload in payloads: + meta = payload.to_payload() + results.append( + RetrievalResult( + hash_value=str(meta["hash"]), + content=str(meta["content"]), + score=0.0, + result_type="relation", + source="graph_relation_recall", + metadata={ + "subject": meta["subject"], + "predicate": meta["predicate"], + "object": meta["object"], + "confidence": float(meta["confidence"]), + "graph_seed_entities": list(meta["graph_seed_entities"]), + "graph_hops": int(meta["graph_hops"]), + "graph_candidate_type": str(meta["graph_candidate_type"]), + "supporting_paragraph_count": int(meta["supporting_paragraph_count"]), + }, + ) + ) + return self._apply_temporal_filter_to_relations(results, temporal) + + def _fuse_ranked_lists_weighted_rrf( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """按 weighted RRF 融合两路段落召回。""" + if not vector_results: + out = sparse_results[:] + if self.config.fusion.normalize_score: + self._normalize_scores_minmax(out) + return out + if not sparse_results: + out = vector_results[:] + if self.config.fusion.normalize_score: + self._normalize_scores_minmax(out) + return out + + k = self.config.fusion.rrf_k + w_vec = self.config.fusion.vector_weight + w_sparse = self.config.fusion.bm25_weight + merged: Dict[str, RetrievalResult] = {} + score_map: Dict[str, float] = {} + + for rank, item in enumerate(vector_results, start=1): + h = item.hash_value + if h not in merged: + merged[h] = item + merged[h].source = "fusion_rrf" + score_map[h] = score_map.get(h, 0.0) + w_vec * (1.0 / (k + rank)) + + for rank, item in enumerate(sparse_results, start=1): + h = item.hash_value + if h not in merged: + merged[h] = item + merged[h].source = "fusion_rrf" + score_map[h] = score_map.get(h, 0.0) + w_sparse * (1.0 / (k + rank)) + + out = list(merged.values()) + for item in out: + item.score = float(score_map.get(item.hash_value, 0.0)) + + out.sort(key=lambda x: x.score, reverse=True) + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(out) + return out + + def _search_paragraphs_sparse( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """BM25 段落召回。""" + if not self.sparse_index or not self.config.sparse.enabled: + return [] + + candidate_k = max(top_k, self.config.sparse.candidate_k) + candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) + sparse_rows = self.sparse_index.search(query=query, k=candidate_k) + results: List[RetrievalResult] = [] + for row in sparse_rows: + hash_value = row["hash"] + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + results.append( + RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(row.get("score", 0.0)), + result_type="paragraph", + source="sparse_bm25", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + "bm25_score": float(row.get("bm25_score", 0.0)), + }, + ) + ) + results = self._apply_temporal_filter_to_paragraphs(results, temporal) + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(results) + return results + + def _search_relations_sparse( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """关系 BM25 召回。""" + if not self.sparse_index or not self.config.sparse.enabled: + return [] + if not self.config.sparse.enable_relation_sparse_fallback: + return [] + + candidate_k = max(top_k, self.config.sparse.relation_candidate_k) + candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) + rows = self.sparse_index.search_relations(query=query, k=candidate_k) + results: List[RetrievalResult] = [] + for row in rows: + hash_value = row["hash"] + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + continue + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) + if relation_time_meta is None: + continue + + content = f"{relation['subject']} {relation['predicate']} {relation['object']}" + results.append( + RetrievalResult( + hash_value=hash_value, + content=content, + score=float(row.get("score", 0.0)), + result_type="relation", + source="sparse_relation_bm25", + metadata={ + "subject": relation["subject"], + "predicate": relation["predicate"], + "object": relation["object"], + "confidence": relation.get("confidence", 1.0), + "time_meta": relation_time_meta, + "bm25_score": float(row.get("bm25_score", 0.0)), + }, + ) + ) + + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(results) + return self._apply_temporal_filter_to_relations(results, temporal) + + def _merge_relation_results( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """合并关系候选,按 hash 去重并保留更高分。""" + merged: Dict[str, RetrievalResult] = {} + for item in vector_results: + merged[item.hash_value] = item + for item in sparse_results: + old = merged.get(item.hash_value) + if old is None or float(item.score) > float(old.score): + merged[item.hash_value] = item + elif old is not None and old.source != item.source: + old.source = "relation_fusion" + out = list(merged.values()) + out.sort(key=lambda x: x.score, reverse=True) + return out + + def _merge_relation_results_graph_enhanced( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + graph_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """Graph-aware relation fusion with semantic + graph + evidence scoring.""" + vector_norm = self._build_minmax_score_map(vector_results) + sparse_norm = self._build_minmax_score_map(sparse_results) + graph_score_map = { + "direct_pair": 1.0, + "one_hop_seed": 0.75, + "two_hop_pair": 0.55, + } + + merged: Dict[str, RetrievalResult] = {} + source_sets: Dict[str, set[str]] = {} + support_cache: Dict[str, int] = {} + + for group in (vector_results, sparse_results, graph_results): + for item in group: + existing = merged.get(item.hash_value) + if existing is None: + existing = self._clone_retrieval_result(item) + merged[item.hash_value] = existing + else: + for key, value in dict(item.metadata or {}).items(): + if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): + existing.metadata[key] = value + source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") + + out = list(merged.values()) + for item in out: + meta = item.metadata if isinstance(item.metadata, dict) else {} + semantic_norm = max( + float(vector_norm.get(item.hash_value, 0.0)), + float(sparse_norm.get(item.hash_value, 0.0)), + ) + graph_candidate_type = str(meta.get("graph_candidate_type", "") or "") + graph_score = float(graph_score_map.get(graph_candidate_type, 0.0)) + + if item.hash_value not in support_cache: + cached = meta.get("supporting_paragraph_count") + if cached is None: + support_cache[item.hash_value] = len( + self.metadata_store.get_paragraphs_by_relation(item.hash_value) + ) + else: + support_cache[item.hash_value] = max(0, int(cached)) + supporting_paragraph_count = support_cache[item.hash_value] + evidence_score = min(1.0, supporting_paragraph_count / 3.0) + + meta["supporting_paragraph_count"] = supporting_paragraph_count + meta["graph_seed_entities"] = list(meta.get("graph_seed_entities") or []) + if "graph_hops" in meta: + meta["graph_hops"] = int(meta.get("graph_hops") or 0) + item.score = 0.60 * semantic_norm + 0.30 * graph_score + 0.10 * evidence_score + + sources = source_sets.get(item.hash_value, set()) + if len(sources) > 1: + item.source = "relation_fusion" + elif sources: + item.source = next(iter(sources)) + + out.sort(key=lambda x: x.score, reverse=True) + return out + + async def _retrieve_paragraphs_only( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 仅检索段落(异步方法) + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 检索结果列表 + """ + query_emb = None + embedding_ok = False + vector_results: List[RetrievalResult] = [] + + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="paragraph_only", + ) + except Exception as e: + logger.warning(f"段落检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + if embedding_ok: + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * 2 * multiplier, temporal) + para_ids, para_scores = self.vector_store.search( + query_emb, # type: ignore[arg-type] + k=candidate_k, + ) + + for hash_value, score in zip(para_ids, para_scores): + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + vector_results.append( + RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(score), + result_type="paragraph", + source="paragraph_search", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + }, + ) + ) + vector_results = self._apply_temporal_filter_to_paragraphs(vector_results, temporal) + + sparse_results: List[RetrievalResult] = [] + if self._should_use_sparse(embedding_ok, vector_results): + sparse_results = self._search_paragraphs_sparse(query, top_k, temporal=temporal) + + if self.config.fusion.method == "weighted_rrf" and (vector_results and sparse_results): + results = self._fuse_ranked_lists_weighted_rrf(vector_results, sparse_results) + elif vector_results and sparse_results: + results = vector_results + sparse_results + results.sort(key=lambda x: x.score, reverse=True) + else: + results = vector_results if vector_results else sparse_results + + return results[:top_k] + + async def _retrieve_relations_only( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 仅检索关系 (通过实体枢纽 Entity-Pivot) + + 策略: + 1. 检索向量库中的 Top-K 实体 (Entity) + 2. 通过图结构/元数据扩展出与实体关联的关系 (Relation) + 3. 以实体相似度作为基础分返回关系 + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 检索结果列表 + """ + query_emb = None + embedding_ok = False + vector_results: List[RetrievalResult] = [] + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="relation_only", + ) + except Exception as e: + logger.warning(f"关系检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + if embedding_ok: + # 1. 检索向量 (混合了段落和实体,所以扩大检索范围以召回足够多实体) + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * 3 * multiplier, temporal) + ids, scores = self.vector_store.search( + query_emb, # type: ignore[arg-type] + k=candidate_k, + ) + + seen_relations = set() + for hash_value, score in zip(ids, scores): + entity = self.metadata_store.get_entity(hash_value) + if not entity: + continue + entity_name = entity["name"] + + related_rels = [] + related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) + related_rels.extend(self.metadata_store.get_relations(object=entity_name)) + + for rel in related_rels: + if rel["hash"] in seen_relations: + continue + seen_relations.add(rel["hash"]) + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(rel["hash"], temporal) + if relation_time_meta is None: + continue + + content = f"{rel['subject']} {rel['predicate']} {rel['object']}" + vector_results.append( + RetrievalResult( + hash_value=rel["hash"], + content=content, + score=float(score), + result_type="relation", + source="relation_search (via entity)", + metadata={ + "subject": rel["subject"], + "predicate": rel["predicate"], + "object": rel["object"], + "confidence": rel.get("confidence", 1.0), + "pivot_entity": entity_name, + "time_meta": relation_time_meta, + }, + ) + ) + + vector_results = self._apply_temporal_filter_to_relations(vector_results, temporal) + + sparse_results: List[RetrievalResult] = [] + if self._should_use_sparse_relations(embedding_ok, vector_results): + sparse_results = self._search_relations_sparse(query=query, top_k=top_k, temporal=temporal) + + graph_results = self._search_relations_graph(query=query, temporal=temporal) + if graph_results: + results = self._merge_relation_results_graph_enhanced( + vector_results, + sparse_results, + graph_results, + ) + elif vector_results and sparse_results: + results = self._merge_relation_results(vector_results, sparse_results) + else: + results = vector_results if vector_results else sparse_results + + return results[:top_k] + + async def _retrieve_dual_path( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + relation_intent: Optional[Dict[str, Any]] = None, + ) -> List[RetrievalResult]: + """ + 双路检索(段落+关系)(异步方法) + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 融合后的检索结果列表 + """ + query_emb = None + embedding_ok = False + relation_intent = relation_intent or {} + relation_top_k = max( + 1, + int(relation_intent.get("relation_top_k", self.config.top_k_relations)), + ) + force_relation_sparse = bool(relation_intent.get("force_relation_sparse", False)) + preserve_top_relations = max( + 0, + int(relation_intent.get("preserve_top_relations", 0)), + ) + pair_predicate_rerank_enabled = bool( + relation_intent.get("pair_predicate_rerank_enabled", False) + ) + pair_predicate_limit = max( + 1, + int( + relation_intent.get( + "pair_predicate_limit", + self.config.relation_intent.pair_predicate_limit, + ) + ), + ) + alpha_override = relation_intent.get("alpha_override") + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="dual_path", + ) + except Exception as e: + logger.warning(f"双路检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + para_results: List[RetrievalResult] = [] + rel_results: List[RetrievalResult] = [] + if embedding_ok: + # 并行检索(使用 asyncio) + if self.config.enable_parallel: + para_results, rel_results = await self._parallel_retrieve( + query_emb, + temporal=temporal, + relation_top_k=relation_top_k, + ) # type: ignore[arg-type] + else: + para_results, rel_results = self._sequential_retrieve( + query_emb, + temporal=temporal, + relation_top_k=relation_top_k, + ) # type: ignore[arg-type] + else: + logger.warning("embedding 不可用,跳过向量段落/关系召回") + + sparse_para_results: List[RetrievalResult] = [] + if self._should_use_sparse(embedding_ok, para_results): + sparse_para_results = self._search_paragraphs_sparse( + query=query, + top_k=max(top_k * 2, self.config.sparse.candidate_k), + temporal=temporal, + ) + sparse_rel_results: List[RetrievalResult] = [] + if self._should_use_sparse_relations( + embedding_ok, + rel_results, + force_enable=force_relation_sparse, + ): + sparse_rel_results = self._search_relations_sparse( + query=query, + top_k=max( + top_k, + self.config.sparse.relation_candidate_k, + relation_top_k, + ), + temporal=temporal, + ) + + graph_rel_results: List[RetrievalResult] = [] + if bool(relation_intent.get("enabled", False)): + graph_rel_results = self._search_relations_graph(query=query, temporal=temporal) + + if self.config.fusion.method == "weighted_rrf" and para_results and sparse_para_results: + para_results = self._fuse_ranked_lists_weighted_rrf(para_results, sparse_para_results) + elif para_results and sparse_para_results: + para_results = para_results + sparse_para_results + para_results.sort(key=lambda x: x.score, reverse=True) + elif sparse_para_results and (not para_results or not embedding_ok): + para_results = sparse_para_results + + if graph_rel_results: + rel_results = self._merge_relation_results_graph_enhanced( + rel_results, + sparse_rel_results, + graph_rel_results, + ) + elif rel_results and sparse_rel_results: + rel_results = self._merge_relation_results(rel_results, sparse_rel_results) + elif sparse_rel_results and (not rel_results or not embedding_ok): + rel_results = sparse_rel_results + + # 融合结果 + fused_results = self._fuse_results( + para_results, + rel_results, + query_emb, + alpha_override=alpha_override, + preserve_top_relations=preserve_top_relations, + ) + + # PageRank重排序 + if self.config.enable_ppr: + fused_results = await self._rerank_with_ppr( + fused_results, + query, + ) + + if temporal: + fused_results = self._sort_results_with_temporal(fused_results, temporal) + + fused_results = self._apply_relation_intent_pair_rerank( + fused_results, + enabled=bool(relation_intent.get("enabled", False)), + pair_rerank_enabled=pair_predicate_rerank_enabled, + pair_limit=pair_predicate_limit, + ) + + return fused_results[:top_k] + + async def _parallel_retrieve( + self, + query_emb: np.ndarray, + temporal: Optional[TemporalQueryOptions] = None, + relation_top_k: Optional[int] = None, + ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: + """ + 并行检索段落和关系(异步方法) + + Args: + query_emb: 查询嵌入 + + Returns: + (段落结果, 关系结果) + """ + # 使用 asyncio.gather 并发执行两个搜索任务 + # 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数, + # 使用 asyncio.to_thread 在线程池中执行 + try: + para_task = asyncio.to_thread( + self._search_paragraphs, + query_emb, + self.config.top_k_paragraphs, + temporal, + ) + rel_task = asyncio.to_thread( + self._search_relations, + query_emb, + relation_top_k if relation_top_k is not None else self.config.top_k_relations, + temporal, + ) + + para_results, rel_results = await asyncio.gather( + para_task, rel_task, return_exceptions=True + ) + + # 处理异常 + if isinstance(para_results, Exception): + logger.error(f"段落检索失败: {para_results}") + para_results = [] + if isinstance(rel_results, Exception): + logger.error(f"关系检索失败: {rel_results}") + rel_results = [] + + return para_results, rel_results + + except Exception as e: + logger.error(f"并行检索失败: {e}") + return [], [] + + def _sequential_retrieve( + self, + query_emb: np.ndarray, + temporal: Optional[TemporalQueryOptions] = None, + relation_top_k: Optional[int] = None, + ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: + """ + 顺序检索段落和关系 + + Args: + query_emb: 查询嵌入 + + Returns: + (段落结果, 关系结果) + """ + para_results = self._search_paragraphs( + query_emb, + self.config.top_k_paragraphs, + temporal, + ) + + rel_results = self._search_relations( + query_emb, + relation_top_k if relation_top_k is not None else self.config.top_k_relations, + temporal, + ) + + return para_results, rel_results + + def _search_paragraphs( + self, + query_emb: np.ndarray, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 搜索段落 + + Args: + query_emb: 查询嵌入 + top_k: 返回数量 + + Returns: + 段落结果列表 + """ + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) + para_ids, para_scores = self.vector_store.search(query_emb, k=candidate_k) + + results = [] + for hash_value, score in zip(para_ids, para_scores): + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + + time_meta = self._build_time_meta_from_paragraph( + paragraph, + temporal=temporal, + ) + results.append(RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(score), + result_type="paragraph", + source="paragraph_search", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + }, + )) + + return self._apply_temporal_filter_to_paragraphs(results, temporal) + + def _search_relations( + self, + query_emb: np.ndarray, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 搜索关系 + + Args: + query_emb: 查询嵌入 + top_k: 返回数量 + + Returns: + 关系结果列表 + """ + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) + rel_ids, rel_scores = self.vector_store.search(query_emb, k=candidate_k) + + results = [] + for hash_value, score in zip(rel_ids, rel_scores): + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + continue + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) + if relation_time_meta is None: + continue + + content = f"{relation['subject']} {relation['predicate']} {relation['object']}" + + results.append(RetrievalResult( + hash_value=hash_value, + content=content, + score=float(score), + result_type="relation", + source="relation_search", + metadata={ + "subject": relation["subject"], + "predicate": relation["predicate"], + "object": relation["object"], + "confidence": relation.get("confidence", 1.0), + "time_meta": relation_time_meta, + }, + )) + + return self._apply_temporal_filter_to_relations(results, temporal) + + def _fuse_results( + self, + para_results: List[RetrievalResult], + rel_results: List[RetrievalResult], + query_emb: Optional[np.ndarray] = None, + alpha_override: Optional[float] = None, + preserve_top_relations: int = 0, + ) -> List[RetrievalResult]: + """ + 融合段落和关系结果 + + 融合策略: + 1. 计算加权分数 + 2. 去重(基于段落和关系的关联) + 3. 排序 + + Args: + para_results: 段落结果 + rel_results: 关系结果 + query_emb: 查询嵌入(兼容参数,当前未使用) + + Returns: + 融合后的结果列表 + """ + del query_emb # 参数保留用于兼容 + alpha = float(alpha_override) if alpha_override is not None else self.config.alpha + + # 为段落结果计算加权分数 + for result in para_results: + result.score = result.score * alpha + result.source = "fusion" + + # 为关系结果计算加权分数 + for result in rel_results: + result.score = result.score * (1 - alpha) + result.source = "fusion" + + preserve_top_relations = max(0, int(preserve_top_relations)) + preserved_relation_hashes = set() + if preserve_top_relations > 0 and rel_results: + rel_ranked = sorted(rel_results, key=lambda x: x.score, reverse=True) + preserved_relation_hashes = { + item.hash_value for item in rel_ranked[:preserve_top_relations] + } + + # 合并结果 + all_results = para_results + rel_results + all_results.sort(key=lambda x: x.score, reverse=True) + + # 去重:如果段落有关联的关系,只保留分数更高的 + seen_paragraphs = set() + seen_items = set() + deduplicated_results = [] + + for result in all_results: + if result.hash_value in seen_items: + continue + if result.result_type == "paragraph": + hash_val = result.hash_value + if hash_val not in seen_paragraphs: + seen_paragraphs.add(hash_val) + seen_items.add(hash_val) + deduplicated_results.append(result) + else: # relation + if result.hash_value in preserved_relation_hashes: + seen_items.add(result.hash_value) + deduplicated_results.append(result) + continue + # 检查关系关联的段落是否已存在 + relation = self.metadata_store.get_relation(result.hash_value) + if relation: + # 获取关联的段落 + para_rels = self.metadata_store.query(""" + SELECT paragraph_hash FROM paragraph_relations + WHERE relation_hash = ? + """, (result.hash_value,)) + + if para_rels: + # 检查段落是否已在结果中 + for para_rel in para_rels: + if para_rel["paragraph_hash"] in seen_paragraphs: + # 段落已存在,跳过此关系 + break + else: + # 所有段落都不存在,添加关系 + seen_items.add(result.hash_value) + deduplicated_results.append(result) + else: + # 没有关联段落,直接添加 + seen_items.add(result.hash_value) + deduplicated_results.append(result) + else: + seen_items.add(result.hash_value) + deduplicated_results.append(result) + + # 按分数排序 + deduplicated_results.sort(key=lambda x: x.score, reverse=True) + + return deduplicated_results + + def _apply_relation_intent_pair_rerank( + self, + results: List[RetrievalResult], + *, + enabled: bool, + pair_rerank_enabled: bool, + pair_limit: int, + ) -> List[RetrievalResult]: + """仅在 relation-intent 下对关系项执行同主客体多谓词重排。""" + if not enabled or not pair_rerank_enabled: + return results + return self._rerank_relation_items_by_pair(results, pair_limit=pair_limit) + + def _rerank_relation_items_by_pair( + self, + results: List[RetrievalResult], + pair_limit: int, + ) -> List[RetrievalResult]: + """ + 同主客体多谓词重排: + 1. 关系项按 (subject, object) 分组 + 2. 组内按分数降序 + 原始位置升序 + 3. 组间按组最高分降序 + 组最早位置升序 + 4. 先拼接每组前 N 条,再拼接每组 overflow 条目 + 5. 回填到原关系槽位,段落槽位不变 + """ + if len(results) <= 1: + return results + + relation_positions: List[int] = [] + relation_items: List[Tuple[int, RetrievalResult]] = [] + for idx, item in enumerate(results): + if item.result_type == "relation": + relation_positions.append(idx) + relation_items.append((idx, item)) + + if len(relation_items) <= 1: + return results + + pair_limit = max(1, int(pair_limit)) + + grouped: Dict[Tuple[str, str], List[Tuple[int, RetrievalResult]]] = {} + for original_idx, item in relation_items: + metadata = item.metadata if isinstance(item.metadata, dict) else {} + subject = str(metadata.get("subject", "")).strip().lower() + obj = str(metadata.get("object", "")).strip().lower() + if subject and obj: + key = (subject, obj) + else: + key = ("__missing__", item.hash_value) + grouped.setdefault(key, []).append((original_idx, item)) + + for grouped_items in grouped.values(): + grouped_items.sort(key=lambda x: (-float(x[1].score), x[0])) + + ordered_groups = sorted( + grouped.values(), + key=lambda grouped_items: ( + -float(grouped_items[0][1].score), + grouped_items[0][0], + ), + ) + + prioritized: List[RetrievalResult] = [] + overflow: List[RetrievalResult] = [] + for grouped_items in ordered_groups: + prioritized.extend([item for _, item in grouped_items[:pair_limit]]) + overflow.extend([item for _, item in grouped_items[pair_limit:]]) + + reordered_relations = prioritized + overflow + if len(reordered_relations) != len(relation_items): + return results + + logger.debug( + "relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s", + len(ordered_groups), + len(overflow), + pair_limit, + ) + + rebuilt = list(results) + for slot_idx, relation_item in zip(relation_positions, reordered_relations): + rebuilt[slot_idx] = relation_item + return rebuilt + + async def _rerank_with_ppr( + self, + results: List[RetrievalResult], + query: str, + ) -> List[RetrievalResult]: + """ + 使用PageRank重排序结果 (异步 + 线程池) + + Args: + results: 检索结果 + query: 查询文本 + + Returns: + 重排序后的结果 + """ + # 从查询中提取实体 + entities = self._extract_entities(query) + + if not entities: + logger.debug("未识别到实体,跳过PPR重排序") + return results + + # 计算PPR分数 (放入线程池运行,避免阻塞主循环) + ppr_timeout_s = max(0.1, float(getattr(self.config, "ppr_timeout_seconds", 1.5) or 1.5)) + try: + async with self._ppr_semaphore: + ppr_scores = await asyncio.wait_for( + asyncio.to_thread( + self._ppr.compute, + personalization=entities, + normalize=True, + ), + timeout=ppr_timeout_s, + ) + except asyncio.TimeoutError: + logger.warning( + "metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s", + ppr_timeout_s, + len(entities), + ) + return results + except Exception as e: + logger.warning(f"PPR 重排序失败,回退原排序: {e}") + return results + + # 调整结果分数 + ppr_scores_by_name = { + str(name).strip().lower(): float(score) + for name, score in ppr_scores.items() + } + for result in results: + if result.result_type == "paragraph": + # 获取段落的实体 + para_entities = self.metadata_store.get_paragraph_entities( + result.hash_value + ) + + # 计算实体的平均PPR分数 + if para_entities: + entity_scores = [] + for ent in para_entities: + ent_name = str(ent.get("name", "")).strip().lower() + if ent_name in ppr_scores_by_name: + entity_scores.append(ppr_scores_by_name[ent_name]) + + if entity_scores: + avg_ppr = np.mean(entity_scores) + # 融合原始分数和PPR分数 + result.score = result.score * 0.7 + avg_ppr * 0.3 + + # 重新排序 + results.sort(key=lambda x: x.score, reverse=True) + + return results + + def _retrieve_temporal_only( + self, + temporal: TemporalQueryOptions, + top_k: int, + ) -> List[RetrievalResult]: + """无语义 query 时,直接走时序索引查询。""" + limit = self._cap_temporal_scan_k( + top_k * max(1, temporal.candidate_multiplier), + temporal, + ) + paragraphs = self.metadata_store.query_paragraphs_temporal( + start_ts=temporal.time_from, + end_ts=temporal.time_to, + person=temporal.person, + source=temporal.source, + limit=limit, + allow_created_fallback=temporal.allow_created_fallback, + ) + results: List[RetrievalResult] = [] + for para in paragraphs: + time_meta = self._build_time_meta_from_paragraph(para, temporal=temporal) + results.append( + RetrievalResult( + hash_value=para["hash"], + content=para["content"], + score=1.0, + result_type="paragraph", + source="temporal_scan", + metadata={ + "word_count": para.get("word_count", 0), + "time_meta": time_meta, + }, + ) + ) + + results = self._sort_results_with_temporal(results, temporal) + return results[:top_k] + + def _extract_effective_time( + self, + paragraph: Dict[str, Any], + temporal: Optional[TemporalQueryOptions] = None, + ) -> Tuple[Optional[float], Optional[float], Optional[str]]: + """提取段落有效时间区间与命中依据。""" + event_time = paragraph.get("event_time") + event_start = paragraph.get("event_time_start") + event_end = paragraph.get("event_time_end") + + if event_start is not None or event_end is not None: + effective_start = event_start if event_start is not None else ( + event_time if event_time is not None else event_end + ) + effective_end = event_end if event_end is not None else ( + event_time if event_time is not None else event_start + ) + return effective_start, effective_end, "event_time_range" + + if event_time is not None: + return event_time, event_time, "event_time" + + allow_fallback = True + if temporal is not None: + allow_fallback = temporal.allow_created_fallback + + created_at = paragraph.get("created_at") + if allow_fallback and created_at is not None: + return created_at, created_at, "created_at_fallback" + + return None, None, None + + def _build_time_meta_from_paragraph( + self, + paragraph: Dict[str, Any], + temporal: Optional[TemporalQueryOptions] = None, + ) -> Dict[str, Any]: + """构建统一 time_meta 结构。""" + effective_start, effective_end, match_basis = self._extract_effective_time( + paragraph, + temporal=temporal, + ) + return { + "event_time": paragraph.get("event_time"), + "event_time_start": paragraph.get("event_time_start"), + "event_time_end": paragraph.get("event_time_end"), + "ingest_time": paragraph.get("created_at"), + "time_granularity": paragraph.get("time_granularity"), + "time_confidence": paragraph.get("time_confidence", 1.0), + "effective_start": effective_start, + "effective_end": effective_end, + "effective_start_text": format_timestamp(effective_start), + "effective_end_text": format_timestamp(effective_end), + "match_basis": match_basis or "none", + } + + def _matches_person_filter(self, paragraph_hash: str, person: Optional[str]) -> bool: + if not person: + return True + target = person.strip().lower() + if not target: + return True + para_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) + for ent in para_entities: + name = str(ent.get("name", "")).strip().lower() + if target in name: + return True + return False + + def _is_temporal_match( + self, + paragraph: Dict[str, Any], + temporal: TemporalQueryOptions, + ) -> bool: + """判断段落是否命中时序筛选。""" + if temporal.source and paragraph.get("source") != temporal.source: + return False + + if not self._matches_person_filter(paragraph.get("hash", ""), temporal.person): + return False + + effective_start, effective_end, _ = self._extract_effective_time(paragraph, temporal=temporal) + if effective_start is None or effective_end is None: + return False + + if temporal.time_from is not None and temporal.time_to is not None: + return effective_end >= temporal.time_from and effective_start <= temporal.time_to + if temporal.time_from is not None: + return effective_end >= temporal.time_from + if temporal.time_to is not None: + return effective_start <= temporal.time_to + return True + + def _apply_temporal_filter_to_paragraphs( + self, + results: List[RetrievalResult], + temporal: Optional[TemporalQueryOptions], + ) -> List[RetrievalResult]: + if not temporal: + return results + + filtered: List[RetrievalResult] = [] + for result in results: + paragraph = self.metadata_store.get_paragraph(result.hash_value) + if not paragraph: + continue + if not self._is_temporal_match(paragraph, temporal): + continue + result.metadata["time_meta"] = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + filtered.append(result) + + return self._sort_results_with_temporal(filtered, temporal) + + def _best_supporting_time_meta( + self, + relation_hash: str, + temporal: TemporalQueryOptions, + ) -> Optional[Dict[str, Any]]: + """获取关系在时序窗口内最优支撑段落的 time_meta。""" + supports = self.metadata_store.get_paragraphs_by_relation(relation_hash) + if not supports: + return None + + best_meta: Optional[Dict[str, Any]] = None + best_time = float("-inf") + for para in supports: + if not self._is_temporal_match(para, temporal): + continue + meta = self._build_time_meta_from_paragraph(para, temporal=temporal) + eff = meta.get("effective_end") + score = float(eff) if eff is not None else float("-inf") + if score >= best_time: + best_time = score + best_meta = meta + + return best_meta + + def _apply_temporal_filter_to_relations( + self, + results: List[RetrievalResult], + temporal: Optional[TemporalQueryOptions], + ) -> List[RetrievalResult]: + if not temporal: + return results + + filtered: List[RetrievalResult] = [] + for result in results: + meta = result.metadata.get("time_meta") + if meta is None: + meta = self._best_supporting_time_meta(result.hash_value, temporal) + if meta is None: + continue + result.metadata["time_meta"] = meta + filtered.append(result) + + return self._sort_results_with_temporal(filtered, temporal) + + def _sort_results_with_temporal( + self, + results: List[RetrievalResult], + temporal: TemporalQueryOptions, + ) -> List[RetrievalResult]: + """语义优先,时间次排序(新到旧)。""" + del temporal # temporal 保留给未来扩展,目前只使用结果内 time_meta + + def _temporal_key(item: RetrievalResult) -> float: + time_meta = item.metadata.get("time_meta", {}) + effective = time_meta.get("effective_end") + if effective is None: + effective = time_meta.get("effective_start") + if effective is None: + return float("-inf") + return float(effective) + + results.sort(key=lambda x: (x.score, _temporal_key(x)), reverse=True) + return results + + def _extract_entities(self, text: str) -> Dict[str, float]: + """ + 从文本中提取实体(简化版本) + + Args: + text: 输入文本 + + Returns: + 实体字典 {实体名: 权重} + """ + # 获取所有实体 + all_entities = self.graph_store.get_nodes() + if not all_entities: + return {} + + # 检查是否需要更新 Aho-Corasick 匹配器 + if self._ac_matcher is None or self._ac_nodes_count != len(all_entities): + self._ac_matcher = AhoCorasick() + for entity in all_entities: + self._ac_matcher.add_pattern(entity.lower()) + self._ac_matcher.build() + self._ac_nodes_count = len(all_entities) + + # 执行匹配 + text_lower = text.lower() + stats = self._ac_matcher.find_all(text_lower) + + # 映射回原始名称并使用出现次数作为权重 + node_map = {node.lower(): node for node in all_entities} + entities = {node_map[low_name]: float(count) for low_name, count in stats.items()} + + return entities + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取检索统计信息 + + Returns: + 统计信息字典 + """ + vector_size = getattr(self.vector_store, "size", None) + if vector_size is None: + vector_size = getattr(self.vector_store, "num_vectors", 0) + + return { + "config": { + "top_k_paragraphs": self.config.top_k_paragraphs, + "top_k_relations": self.config.top_k_relations, + "top_k_final": self.config.top_k_final, + "alpha": self.config.alpha, + "enable_ppr": self.config.enable_ppr, + "enable_parallel": self.config.enable_parallel, + "strategy": self.config.retrieval_strategy.value, + "sparse_mode": self.config.sparse.mode, + "fusion_method": self.config.fusion.method, + "relation_intent_enabled": self.config.relation_intent.enabled, + "relation_intent_alpha_override": self.config.relation_intent.alpha_override, + "relation_intent_candidate_multiplier": self.config.relation_intent.relation_candidate_multiplier, + "relation_intent_preserve_top_relations": self.config.relation_intent.preserve_top_relations, + "relation_intent_force_sparse": self.config.relation_intent.force_relation_sparse, + "relation_intent_pair_rerank_enabled": self.config.relation_intent.pair_predicate_rerank_enabled, + "relation_intent_pair_predicate_limit": self.config.relation_intent.pair_predicate_limit, + "graph_recall_enabled": self.config.graph_recall.enabled, + "graph_recall_candidate_k": self.config.graph_recall.candidate_k, + "graph_recall_allow_two_hop_pair": self.config.graph_recall.allow_two_hop_pair, + "graph_recall_max_paths": self.config.graph_recall.max_paths, + }, + "vector_store": { + "size": int(vector_size), + }, + "graph_store": { + "num_nodes": self.graph_store.num_nodes, + "num_edges": self.graph_store.num_edges, + }, + "metadata_store": self.metadata_store.get_statistics(), + "sparse": self.sparse_index.stats() if self.sparse_index else None, + } + + def __repr__(self) -> str: + return ( + f"DualPathRetriever(" + f"strategy={self.config.retrieval_strategy.value}, " + f"para_k={self.config.top_k_paragraphs}, " + f"rel_k={self.config.top_k_relations})" + ) diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py new file mode 100644 index 00000000..3ce03b14 --- /dev/null +++ b/plugins/A_memorix/core/retrieval/graph_relation_recall.py @@ -0,0 +1,272 @@ +"""Graph-assisted relation candidate recall for relation-oriented queries.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Set + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.GraphRelationRecall") + + +@dataclass +class GraphRelationRecallConfig: + """Configuration for controlled graph relation recall.""" + + enabled: bool = True + candidate_k: int = 24 + max_hop: int = 1 + allow_two_hop_pair: bool = True + max_paths: int = 4 + + def __post_init__(self) -> None: + self.enabled = bool(self.enabled) + self.candidate_k = max(1, int(self.candidate_k)) + self.max_hop = max(1, int(self.max_hop)) + self.allow_two_hop_pair = bool(self.allow_two_hop_pair) + self.max_paths = max(1, int(self.max_paths)) + + +@dataclass +class GraphRelationCandidate: + """A graph-derived relation candidate before retriever-side fusion.""" + + hash_value: str + subject: str + predicate: str + object: str + confidence: float + graph_seed_entities: List[str] + graph_hops: int + graph_candidate_type: str + supporting_paragraph_count: int + + def to_payload(self) -> Dict[str, Any]: + content = f"{self.subject} {self.predicate} {self.object}" + return { + "hash": self.hash_value, + "content": content, + "subject": self.subject, + "predicate": self.predicate, + "object": self.object, + "confidence": self.confidence, + "graph_seed_entities": list(self.graph_seed_entities), + "graph_hops": int(self.graph_hops), + "graph_candidate_type": self.graph_candidate_type, + "supporting_paragraph_count": int(self.supporting_paragraph_count), + } + + +class GraphRelationRecallService: + """Collect relation candidates from the entity graph in a controlled way.""" + + def __init__( + self, + *, + graph_store: Any, + metadata_store: Any, + config: Optional[GraphRelationRecallConfig] = None, + ) -> None: + self.graph_store = graph_store + self.metadata_store = metadata_store + self.config = config or GraphRelationRecallConfig() + + def recall( + self, + *, + seed_entities: Sequence[str], + ) -> List[GraphRelationCandidate]: + if not self.config.enabled: + return [] + if self.graph_store is None or self.metadata_store is None: + return [] + + seeds = self._normalize_seed_entities(seed_entities) + if not seeds: + return [] + + seen_hashes: Set[str] = set() + candidates: List[GraphRelationCandidate] = [] + + if len(seeds) >= 2: + self._collect_direct_pair_candidates( + seed_a=seeds[0], + seed_b=seeds[1], + seen_hashes=seen_hashes, + out=candidates, + ) + if ( + len(candidates) < 3 + and self.config.allow_two_hop_pair + and len(candidates) < self.config.candidate_k + ): + self._collect_two_hop_pair_candidates( + seed_a=seeds[0], + seed_b=seeds[1], + seen_hashes=seen_hashes, + out=candidates, + ) + else: + self._collect_one_hop_seed_candidates( + seed=seeds[0], + seen_hashes=seen_hashes, + out=candidates, + ) + + return candidates[: self.config.candidate_k] + + def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]: + out: List[str] = [] + seen = set() + for raw in list(seed_entities)[:2]: + resolved = None + try: + resolved = self.graph_store.find_node(str(raw), ignore_case=True) + except Exception: + resolved = None + if not resolved: + continue + canon = str(resolved).strip().lower() + if not canon or canon in seen: + continue + seen.add(canon) + out.append(str(resolved)) + return out + + def _collect_direct_pair_candidates( + self, + *, + seed_a: str, + seed_b: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + relation_hashes = [] + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b)) + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a)) + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="direct_pair", + graph_hops=1, + graph_seed_entities=[seed_a, seed_b], + ) + + def _collect_two_hop_pair_candidates( + self, + *, + seed_a: str, + seed_b: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + try: + paths = self.graph_store.find_paths( + seed_a, + seed_b, + max_depth=2, + max_paths=self.config.max_paths, + ) + except Exception as e: + logger.debug("graph two-hop recall skipped: %s", e) + return + + for path_nodes in paths: + if len(out) >= self.config.candidate_k: + break + if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3: + continue + if len(path_nodes) != 3: + continue + for idx in range(len(path_nodes) - 1): + if len(out) >= self.config.candidate_k: + break + u = str(path_nodes[idx]) + v = str(path_nodes[idx + 1]) + relation_hashes = [] + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v)) + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u)) + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="two_hop_pair", + graph_hops=2, + graph_seed_entities=[seed_a, seed_b], + ) + + def _collect_one_hop_seed_candidates( + self, + *, + seed: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + try: + relation_hashes = self.graph_store.get_incident_relation_hashes( + seed, + limit=self.config.candidate_k, + ) + except Exception as e: + logger.debug("graph one-hop recall skipped: %s", e) + return + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="one_hop_seed", + graph_hops=min(1, self.config.max_hop), + graph_seed_entities=[seed], + ) + + def _append_relation_hashes( + self, + *, + relation_hashes: Sequence[str], + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + candidate_type: str, + graph_hops: int, + graph_seed_entities: Sequence[str], + ) -> None: + for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}): + if len(out) >= self.config.candidate_k: + break + if relation_hash in seen_hashes: + continue + candidate = self._build_candidate( + relation_hash=relation_hash, + candidate_type=candidate_type, + graph_hops=graph_hops, + graph_seed_entities=graph_seed_entities, + ) + if candidate is None: + continue + seen_hashes.add(relation_hash) + out.append(candidate) + + def _build_candidate( + self, + *, + relation_hash: str, + candidate_type: str, + graph_hops: int, + graph_seed_entities: Sequence[str], + ) -> Optional[GraphRelationCandidate]: + relation = self.metadata_store.get_relation(relation_hash) + if relation is None: + return None + supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) + return GraphRelationCandidate( + hash_value=relation_hash, + subject=str(relation.get("subject", "")), + predicate=str(relation.get("predicate", "")), + object=str(relation.get("object", "")), + confidence=float(relation.get("confidence", 1.0) or 1.0), + graph_seed_entities=[str(x) for x in graph_seed_entities], + graph_hops=int(graph_hops), + graph_candidate_type=str(candidate_type), + supporting_paragraph_count=len(supporting_paragraphs), + ) diff --git a/plugins/A_memorix/core/retrieval/pagerank.py b/plugins/A_memorix/core/retrieval/pagerank.py new file mode 100644 index 00000000..c8ee48bb --- /dev/null +++ b/plugins/A_memorix/core/retrieval/pagerank.py @@ -0,0 +1,482 @@ +""" +Personalized PageRank实现 + +提供个性化的图节点排序功能。 +""" + +from typing import Dict, List, Optional, Tuple, Union, Any +from dataclasses import dataclass +import numpy as np + +from src.common.logger import get_logger +from ..storage import GraphStore +from ..utils.matcher import AhoCorasick + +logger = get_logger("A_Memorix.PersonalizedPageRank") + + +@dataclass +class PageRankConfig: + """ + PageRank配置 + + 属性: + alpha: 阻尼系数(0-1之间) + max_iter: 最大迭代次数 + tol: 收敛阈值 + normalize: 是否归一化结果 + min_iterations: 最小迭代次数 + """ + + alpha: float = 0.85 + max_iter: int = 100 + tol: float = 1e-6 + normalize: bool = True + min_iterations: int = 20 + + def __post_init__(self): + """验证配置""" + if not 0 <= self.alpha < 1: + raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}") + + if self.max_iter <= 0: + raise ValueError(f"max_iter必须大于0: {self.max_iter}") + + if self.tol <= 0: + raise ValueError(f"tol必须大于0: {self.tol}") + + if self.min_iterations < 0: + raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}") + + if self.min_iterations >= self.max_iter: + raise ValueError(f"min_iterations必须小于max_iter") + + +class PersonalizedPageRank: + """ + Personalized PageRank计算器 + + 功能: + - 个性化向量支持 + - 快速收敛检测 + - 结果归一化 + - 批量计算 + - 统计信息 + + 参数: + graph_store: 图存储 + config: PageRank配置 + """ + + def __init__( + self, + graph_store: GraphStore, + config: Optional[PageRankConfig] = None, + ): + """ + 初始化PPR计算器 + + Args: + graph_store: 图存储 + config: PageRank配置 + """ + self.graph_store = graph_store + self.config = config or PageRankConfig() + + # 统计信息 + self._total_computations = 0 + self._total_iterations = 0 + self._convergence_history: List[int] = [] + + logger.info( + f"PersonalizedPageRank 初始化: " + f"alpha={self.config.alpha}, " + f"max_iter={self.config.max_iter}" + ) + + # 缓存 Aho-Corasick 匹配器 + self._ac_matcher: Optional[AhoCorasick] = None + self._ac_nodes_count = 0 + + def compute( + self, + personalization: Optional[Dict[str, float]] = None, + alpha: Optional[float] = None, + max_iter: Optional[int] = None, + normalize: Optional[bool] = None, + ) -> Dict[str, float]: + """ + 计算Personalized PageRank + + Args: + personalization: 个性化向量 {节点名: 权重} + alpha: 阻尼系数(覆盖配置值) + max_iter: 最大迭代次数(覆盖配置值) + normalize: 是否归一化(覆盖配置值) + + Returns: + 节点PageRank值字典 {节点名: 分数} + """ + # 使用覆盖值或配置值 + alpha = alpha if alpha is not None else self.config.alpha + max_iter = max_iter if max_iter is not None else self.config.max_iter + normalize = normalize if normalize is not None else self.config.normalize + + # 调用GraphStore的compute_pagerank + scores = self.graph_store.compute_pagerank( + personalization=personalization, + alpha=alpha, + max_iter=max_iter, + tol=self.config.tol, + ) + + # 归一化(如果需要) + if normalize and scores: + total = sum(scores.values()) + if total > 0: + scores = {node: score / total for node, score in scores.items()} + + # 更新统计 + self._total_computations += 1 + + logger.debug( + f"PPR计算完成: {len(scores)} 个节点, " + f"personalization_nodes={len(personalization) if personalization else 0}" + ) + + return scores + + def compute_batch( + self, + personalization_list: List[Dict[str, float]], + normalize: bool = True, + ) -> List[Dict[str, float]]: + """ + 批量计算PPR + + Args: + personalization_list: 个性化向量列表 + normalize: 是否归一化 + + Returns: + PageRank值字典列表 + """ + results = [] + + for i, personalization in enumerate(personalization_list): + logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR") + scores = self.compute(personalization=personalization, normalize=normalize) + results.append(scores) + + return results + + def compute_for_entities( + self, + entities: List[str], + weights: Optional[List[float]] = None, + normalize: bool = True, + ) -> Dict[str, float]: + """ + 为实体列表计算PPR + + Args: + entities: 实体列表 + weights: 权重列表(默认均匀权重) + normalize: 是否归一化 + + Returns: + PageRank值字典 + """ + if not entities: + logger.warning("实体列表为空,返回均匀PPR") + return self.compute(personalization=None, normalize=normalize) + + # 构建个性化向量 + if weights is None: + weights = [1.0] * len(entities) + + if len(weights) != len(entities): + raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}") + + personalization = {entity: weight for entity, weight in zip(entities, weights)} + + return self.compute(personalization=personalization, normalize=normalize) + + def compute_for_query( + self, + query: str, + entity_extractor: Optional[callable] = None, + normalize: bool = True, + ) -> Dict[str, float]: + """ + 为查询计算PPR + + Args: + query: 查询文本 + entity_extractor: 实体提取函数(可选) + normalize: 是否归一化 + + Returns: + PageRank值字典 + """ + # 提取实体 + if entity_extractor is not None: + entities = entity_extractor(query) + else: + # 简单实现:基于图中的节点匹配 + entities = self._extract_entities_from_query(query) + + if not entities: + logger.debug(f"未从查询中提取到实体: '{query}'") + return self.compute(personalization=None, normalize=normalize) + + # 计算PPR + return self.compute_for_entities(entities, normalize=normalize) + + def rank_nodes( + self, + scores: Dict[str, float], + top_k: Optional[int] = None, + min_score: float = 0.0, + ) -> List[Tuple[str, float]]: + """ + 对节点排序 + + Args: + scores: PageRank分数字典 + top_k: 返回前k个节点(None表示全部) + min_score: 最小分数阈值 + + Returns: + 排序后的节点列表 [(节点名, 分数), ...] + """ + # 过滤低分节点 + filtered = [(node, score) for node, score in scores.items() if score >= min_score] + + # 按分数降序排序 + sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True) + + # 返回top_k + if top_k is not None: + sorted_nodes = sorted_nodes[:top_k] + + return sorted_nodes + + def get_personalization_vector( + self, + nodes: List[str], + method: str = "uniform", + ) -> Dict[str, float]: + """ + 生成个性化向量 + + Args: + nodes: 节点列表 + method: 生成方法 + - "uniform": 均匀权重 + - "degree": 按度数加权 + - "inverse_degree": 按度数反比加权 + + Returns: + 个性化向量 {节点名: 权重} + """ + if not nodes: + return {} + + if method == "uniform": + # 均匀权重 + weight = 1.0 / len(nodes) + return {node: weight for node in nodes} + + elif method == "degree": + # 按度数加权 + node_degrees = {} + for node in nodes: + neighbors = self.graph_store.get_neighbors(node) + node_degrees[node] = len(neighbors) + + total_degree = sum(node_degrees.values()) + if total_degree > 0: + return {node: degree / total_degree for node, degree in node_degrees.items()} + else: + return {node: 1.0 / len(nodes) for node in nodes} + + elif method == "inverse_degree": + # 按度数反比加权 + node_degrees = {} + for node in nodes: + neighbors = self.graph_store.get_neighbors(node) + node_degrees[node] = len(neighbors) + + # 反度数 + inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()} + total_inv = sum(inv_degrees.values()) + + if total_inv > 0: + return {node: inv / total_inv for node, inv in inv_degrees.items()} + else: + return {node: 1.0 / len(nodes) for node in nodes} + + else: + raise ValueError(f"不支持的个性化向量生成方法: {method}") + + def compare_scores( + self, + scores1: Dict[str, float], + scores2: Dict[str, float], + ) -> Dict[str, Dict[str, float]]: + """ + 比较两组PPR分数 + + Args: + scores1: 第一组分数 + scores2: 第二组分数 + + Returns: + 比较结果 { + "common_nodes": {节点: (score1, score2)}, + "only_in_1": {节点: score1}, + "only_in_2": {节点: score2}, + } + """ + common_nodes = {} + only_in_1 = {} + only_in_2 = {} + + all_nodes = set(scores1.keys()) | set(scores2.keys()) + + for node in all_nodes: + if node in scores1 and node in scores2: + common_nodes[node] = (scores1[node], scores2[node]) + elif node in scores1: + only_in_1[node] = scores1[node] + else: + only_in_2[node] = scores2[node] + + return { + "common_nodes": common_nodes, + "only_in_1": only_in_1, + "only_in_2": only_in_2, + } + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + avg_iterations = ( + self._total_iterations / self._total_computations + if self._total_computations > 0 + else 0 + ) + + return { + "config": { + "alpha": self.config.alpha, + "max_iter": self.config.max_iter, + "tol": self.config.tol, + "normalize": self.config.normalize, + "min_iterations": self.config.min_iterations, + }, + "statistics": { + "total_computations": self._total_computations, + "total_iterations": self._total_iterations, + "avg_iterations": avg_iterations, + "convergence_history": self._convergence_history.copy(), + }, + "graph": { + "num_nodes": self.graph_store.num_nodes, + "num_edges": self.graph_store.num_edges, + }, + } + + def reset_statistics(self) -> None: + """重置统计信息""" + self._total_computations = 0 + self._total_iterations = 0 + self._convergence_history.clear() + logger.info("统计信息已重置") + + def _extract_entities_from_query(self, query: str) -> List[str]: + """ + 从查询中提取实体(简化实现) + + Args: + query: 查询文本 + + Returns: + 实体列表 + """ + # 获取所有节点 + all_nodes = self.graph_store.get_nodes() + if not all_nodes: + return [] + + # 检查是否需要更新 Aho-Corasick 匹配器 + if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes): + self._ac_matcher = AhoCorasick() + for node in all_nodes: + # 统一转为小写进行不区分大小写匹配 + self._ac_matcher.add_pattern(node.lower()) + self._ac_matcher.build() + self._ac_nodes_count = len(all_nodes) + + # 执行匹配 + query_lower = query.lower() + stats = self._ac_matcher.find_all(query_lower) + + # 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值) + # 为了简单,AC 中 add_pattern 存的是小写 + # 我们需要一个映射:小写 -> 原始 + node_map = {node.lower(): node for node in all_nodes} + entities = [node_map[low_name] for low_name in stats.keys()] + + return entities + + @property + def num_computations(self) -> int: + """计算次数""" + return self._total_computations + + @property + def avg_iterations(self) -> float: + """平均迭代次数""" + if self._total_computations == 0: + return 0.0 + return self._total_iterations / self._total_computations + + def __repr__(self) -> str: + return ( + f"PersonalizedPageRank(" + f"alpha={self.config.alpha}, " + f"computations={self._total_computations})" + ) + + +def create_ppr_from_graph( + graph_store: GraphStore, + alpha: float = 0.85, + max_iter: int = 100, +) -> PersonalizedPageRank: + """ + 从图存储创建PPR计算器 + + Args: + graph_store: 图存储 + alpha: 阻尼系数 + max_iter: 最大迭代次数 + + Returns: + PPR计算器实例 + """ + config = PageRankConfig( + alpha=alpha, + max_iter=max_iter, + ) + + return PersonalizedPageRank( + graph_store=graph_store, + config=config, + ) diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py new file mode 100644 index 00000000..3b6f075d --- /dev/null +++ b/plugins/A_memorix/core/retrieval/sparse_bm25.py @@ -0,0 +1,402 @@ +""" +稀疏检索组件(FTS5 + BM25) + +支持: +- 懒加载索引连接 +- jieba / char n-gram 分词 +- 可卸载并收缩 SQLite 内存缓存 +""" + +from __future__ import annotations + +import re +import sqlite3 +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger +from ..storage import MetadataStore + +logger = get_logger("A_Memorix.SparseBM25") + +try: + import jieba # type: ignore + + HAS_JIEBA = True +except Exception: + HAS_JIEBA = False + jieba = None + + +@dataclass +class SparseBM25Config: + """BM25 稀疏检索配置。""" + + enabled: bool = True + backend: str = "fts5" + lazy_load: bool = True + mode: str = "auto" # auto | fallback_only | hybrid + tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram + jieba_user_dict: str = "" + char_ngram_n: int = 2 + candidate_k: int = 80 + max_doc_len: int = 2000 + enable_ngram_fallback_index: bool = True + enable_like_fallback: bool = False + enable_relation_sparse_fallback: bool = True + relation_candidate_k: int = 60 + relation_max_doc_len: int = 512 + unload_on_disable: bool = True + shrink_memory_on_unload: bool = True + + def __post_init__(self) -> None: + self.backend = str(self.backend or "fts5").strip().lower() + self.mode = str(self.mode or "auto").strip().lower() + self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower() + self.char_ngram_n = max(1, int(self.char_ngram_n)) + self.candidate_k = max(1, int(self.candidate_k)) + self.max_doc_len = max(0, int(self.max_doc_len)) + self.relation_candidate_k = max(1, int(self.relation_candidate_k)) + self.relation_max_doc_len = max(0, int(self.relation_max_doc_len)) + if self.backend != "fts5": + raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}") + if self.mode not in {"auto", "fallback_only", "hybrid"}: + raise ValueError(f"sparse.mode 非法: {self.mode}") + if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}: + raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}") + + +class SparseBM25Index: + """ + 基于 SQLite FTS5 的 BM25 检索适配层。 + """ + + def __init__( + self, + metadata_store: MetadataStore, + config: Optional[SparseBM25Config] = None, + ): + self.metadata_store = metadata_store + self.config = config or SparseBM25Config() + self._conn: Optional[sqlite3.Connection] = None + self._loaded: bool = False + self._jieba_dict_loaded: bool = False + + @property + def loaded(self) -> bool: + return self._loaded and self._conn is not None + + def ensure_loaded(self) -> bool: + """按需加载 FTS 连接与索引。""" + if not self.config.enabled: + return False + if self.loaded: + return True + + db_path = self.metadata_store.get_db_path() + conn = sqlite3.connect( + str(db_path), + check_same_thread=False, + timeout=30.0, + ) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA temp_store=MEMORY") + + if not self.metadata_store.ensure_fts_schema(conn=conn): + conn.close() + return False + self.metadata_store.ensure_fts_backfilled(conn=conn) + # 关系稀疏检索按独立开关加载,避免不必要的初始化开销。 + if self.config.enable_relation_sparse_fallback: + self.metadata_store.ensure_relations_fts_schema(conn=conn) + self.metadata_store.ensure_relations_fts_backfilled(conn=conn) + if self.config.enable_ngram_fallback_index: + self.metadata_store.ensure_paragraph_ngram_schema(conn=conn) + self.metadata_store.ensure_paragraph_ngram_backfilled( + n=self.config.char_ngram_n, + conn=conn, + ) + + self._conn = conn + self._loaded = True + self._prepare_tokenizer() + logger.info( + "SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s", + self.config.tokenizer_mode, + self.config.mode, + ) + return True + + def _prepare_tokenizer(self) -> None: + if self._jieba_dict_loaded: + return + if self.config.tokenizer_mode not in {"jieba", "mixed"}: + return + if not HAS_JIEBA: + logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram") + return + user_dict = str(self.config.jieba_user_dict or "").strip() + if user_dict: + try: + jieba.load_userdict(user_dict) # type: ignore[union-attr] + logger.info("已加载 jieba 用户词典: %s", user_dict) + except Exception as e: + logger.warning("加载 jieba 用户词典失败: %s", e) + self._jieba_dict_loaded = True + + def _tokenize_jieba(self, text: str) -> List[str]: + if not HAS_JIEBA: + return [] + try: + tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr] + return [t.strip().lower() for t in tokens if t and t.strip()] + except Exception: + return [] + + def _tokenize_char_ngram(self, text: str, n: int) -> List[str]: + compact = re.sub(r"\s+", "", text.lower()) + if not compact: + return [] + if len(compact) < n: + return [compact] + return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] + + def _tokenize(self, text: str) -> List[str]: + text = str(text or "").strip() + if not text: + return [] + + mode = self.config.tokenizer_mode + if mode == "jieba": + tokens = self._tokenize_jieba(text) + if tokens: + return list(dict.fromkeys(tokens)) + return self._tokenize_char_ngram(text, self.config.char_ngram_n) + + if mode == "mixed": + toks = self._tokenize_jieba(text) + toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n)) + return list(dict.fromkeys([t for t in toks if t])) + + return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n))) + + def _build_match_query(self, tokens: List[str]) -> str: + safe_tokens: List[str] = [] + for token in tokens: + t = token.replace('"', '""').strip() + if not t: + continue + safe_tokens.append(f'"{t}"') + if not safe_tokens: + return "" + # 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。 + return " OR ".join(safe_tokens[:64]) + + def _fallback_substring_search( + self, + tokens: List[str], + limit: int, + ) -> List[Dict[str, Any]]: + """ + 当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。 + + 说明: + - FTS 索引当前采用 unicode61 tokenizer。 + - 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。 + - 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。 + """ + if not tokens: + return [] + + # 去重并裁剪 token 数量,避免生成超长 SQL。 + uniq_tokens = [t for t in dict.fromkeys(tokens) if t] + uniq_tokens = uniq_tokens[:32] + if not uniq_tokens: + return [] + + if self.config.enable_ngram_fallback_index: + try: + # 允许运行时切换开关后按需补齐 schema/回填。 + self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn) + self.metadata_store.ensure_paragraph_ngram_backfilled( + n=self.config.char_ngram_n, + conn=self._conn, + ) + rows = self.metadata_store.ngram_search_paragraphs( + tokens=uniq_tokens, + limit=limit, + max_doc_len=self.config.max_doc_len, + conn=self._conn, + ) + if rows: + return rows + except Exception as e: + logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}") + + if not self.config.enable_like_fallback: + return [] + + conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens)) + params: List[Any] = [f"%{tok}%" for tok in uniq_tokens] + scan_limit = max(int(limit) * 8, 200) + params.append(scan_limit) + + sql = f""" + SELECT p.hash, p.content + FROM paragraphs p + WHERE (p.is_deleted IS NULL OR p.is_deleted = 0) + AND ({conditions}) + LIMIT ? + """ + rows = self.metadata_store.query(sql, tuple(params)) + if not rows: + return [] + + scored: List[Dict[str, Any]] = [] + token_count = max(1, len(uniq_tokens)) + for row in rows: + content = str(row.get("content") or "") + content_low = content.lower() + matched = [tok for tok in uniq_tokens if tok in content_low] + if not matched: + continue + coverage = len(matched) / token_count + length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low)) + # 兜底路径使用相对分,保持与上层接口兼容。 + fallback_score = coverage * 0.8 + length_bonus * 0.2 + scored.append( + { + "hash": row["hash"], + "content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content, + "bm25_score": -float(fallback_score), + "fallback_score": float(fallback_score), + } + ) + + scored.sort(key=lambda x: x["fallback_score"], reverse=True) + return scored[:limit] + + def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]: + """执行 BM25 检索。""" + if not self.config.enabled: + return [] + if self.config.lazy_load and not self.loaded: + if not self.ensure_loaded(): + return [] + if not self.loaded: + return [] + # 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。 + self.metadata_store.ensure_relations_fts_schema(conn=self._conn) + self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn) + + tokens = self._tokenize(query) + match_query = self._build_match_query(tokens) + if not match_query: + return [] + + limit = max(1, int(k)) + rows = self.metadata_store.fts_search_bm25( + match_query=match_query, + limit=limit, + max_doc_len=self.config.max_doc_len, + conn=self._conn, + ) + if not rows: + rows = self._fallback_substring_search(tokens=tokens, limit=limit) + + results: List[Dict[str, Any]] = [] + for rank, row in enumerate(rows, start=1): + bm25_score = float(row.get("bm25_score", 0.0)) + results.append( + { + "hash": row["hash"], + "content": row["content"], + "rank": rank, + "bm25_score": bm25_score, + "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 + } + ) + return results + + def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]: + """执行关系稀疏检索(FTS5 + BM25)。""" + if not self.config.enabled or not self.config.enable_relation_sparse_fallback: + return [] + if self.config.lazy_load and not self.loaded: + if not self.ensure_loaded(): + return [] + if not self.loaded: + return [] + + tokens = self._tokenize(query) + match_query = self._build_match_query(tokens) + if not match_query: + return [] + + rows = self.metadata_store.fts_search_relations_bm25( + match_query=match_query, + limit=max(1, int(k)), + max_doc_len=self.config.relation_max_doc_len, + conn=self._conn, + ) + out: List[Dict[str, Any]] = [] + for rank, row in enumerate(rows, start=1): + bm25_score = float(row.get("bm25_score", 0.0)) + out.append( + { + "hash": row["hash"], + "subject": row["subject"], + "predicate": row["predicate"], + "object": row["object"], + "content": row["content"], + "rank": rank, + "bm25_score": bm25_score, + "score": -bm25_score, + } + ) + return out + + def upsert_paragraph(self, paragraph_hash: str) -> bool: + if not self.loaded: + return False + return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn) + + def delete_paragraph(self, paragraph_hash: str) -> bool: + if not self.loaded: + return False + return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn) + + def unload(self) -> None: + """卸载 BM25 连接并尽量释放内存。""" + if self._conn is not None: + try: + if self.config.shrink_memory_on_unload: + self.metadata_store.shrink_memory(conn=self._conn) + except Exception: + pass + try: + self._conn.close() + except Exception: + pass + self._conn = None + self._loaded = False + logger.info("SparseBM25Index unloaded") + + def stats(self) -> Dict[str, Any]: + doc_count = 0 + if self.loaded: + doc_count = self.metadata_store.fts_doc_count(conn=self._conn) + return { + "enabled": self.config.enabled, + "backend": self.config.backend, + "mode": self.config.mode, + "tokenizer_mode": self.config.tokenizer_mode, + "enable_ngram_fallback_index": self.config.enable_ngram_fallback_index, + "enable_like_fallback": self.config.enable_like_fallback, + "enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback, + "loaded": self.loaded, + "has_jieba": HAS_JIEBA, + "doc_count": doc_count, + } diff --git a/plugins/A_memorix/core/retrieval/threshold.py b/plugins/A_memorix/core/retrieval/threshold.py new file mode 100644 index 00000000..87a0094b --- /dev/null +++ b/plugins/A_memorix/core/retrieval/threshold.py @@ -0,0 +1,450 @@ +""" +动态阈值过滤器 + +根据检索结果的分布特征自适应调整过滤阈值。 +""" + +import numpy as np +from typing import List, Dict, Any, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum + +from src.common.logger import get_logger +from .dual_path import RetrievalResult + +logger = get_logger("A_Memorix.DynamicThresholdFilter") + + +class ThresholdMethod(Enum): + """阈值计算方法""" + + PERCENTILE = "percentile" # 百分位数 + STD_DEV = "std_dev" # 标准差 + GAP_DETECTION = "gap_detection" # 跳变检测 + ADAPTIVE = "adaptive" # 自适应(综合多种方法) + + +@dataclass +class ThresholdConfig: + """ + 阈值配置 + + 属性: + method: 阈值计算方法 + min_threshold: 最小阈值(绝对值) + max_threshold: 最大阈值(绝对值) + percentile: 百分位数(用于percentile方法) + std_multiplier: 标准差倍数(用于std_dev方法) + min_results: 最少保留结果数 + enable_auto_adjust: 是否自动调整参数 + """ + + method: ThresholdMethod = ThresholdMethod.ADAPTIVE + min_threshold: float = 0.3 + max_threshold: float = 0.95 + percentile: float = 75.0 # 百分位数 + std_multiplier: float = 1.5 # 标准差倍数 + min_results: int = 3 # 最少保留结果数 + enable_auto_adjust: bool = True + + def __post_init__(self): + """验证配置""" + if not 0 <= self.min_threshold <= 1: + raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}") + + if not 0 <= self.max_threshold <= 1: + raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}") + + if self.min_threshold >= self.max_threshold: + raise ValueError(f"min_threshold必须小于max_threshold") + + if not 0 <= self.percentile <= 100: + raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}") + + if self.std_multiplier <= 0: + raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}") + + if self.min_results < 0: + raise ValueError(f"min_results必须大于等于0: {self.min_results}") + + +class DynamicThresholdFilter: + """ + 动态阈值过滤器 + + 功能: + - 基于结果分布自适应计算阈值 + - 多种阈值计算方法 + - 自动参数调整 + - 统计信息收集 + + 参数: + config: 阈值配置 + """ + + def __init__( + self, + config: Optional[ThresholdConfig] = None, + ): + """ + 初始化动态阈值过滤器 + + Args: + config: 阈值配置 + """ + self.config = config or ThresholdConfig() + + # 统计信息 + self._total_filtered = 0 + self._total_processed = 0 + self._threshold_history: List[float] = [] + + logger.info( + f"DynamicThresholdFilter 初始化: " + f"method={self.config.method.value}, " + f"min_threshold={self.config.min_threshold}" + ) + + def filter( + self, + results: List[RetrievalResult], + return_threshold: bool = False, + ) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]: + """ + 过滤检索结果 + + Args: + results: 检索结果列表 + return_threshold: 是否返回使用的阈值 + + Returns: + 过滤后的结果列表,或 (结果列表, 阈值) 元组 + """ + if not results: + logger.debug("结果列表为空,无需过滤") + return ([], 0.0) if return_threshold else [] + + self._total_processed += len(results) + + # 提取分数 + scores = np.array([r.score for r in results]) + + # 计算阈值 + threshold = self._compute_threshold(scores, results) + + # 记录阈值 + self._threshold_history.append(threshold) + + # 应用阈值过滤 + filtered_results = [ + r for r in results + if r.score >= threshold + ] + + # 确保至少保留min_results个结果 + if len(filtered_results) < self.config.min_results: + # 按分数排序,取前min_results个 + sorted_results = sorted(results, key=lambda x: x.score, reverse=True) + filtered_results = sorted_results[:self.config.min_results] + threshold = filtered_results[-1].score if filtered_results else 0.0 + + self._total_filtered += len(results) - len(filtered_results) + + logger.info( + f"过滤完成: {len(results)} -> {len(filtered_results)} " + f"(threshold={threshold:.3f})" + ) + + if return_threshold: + return filtered_results, threshold + return filtered_results + + def _compute_threshold( + self, + scores: np.ndarray, + results: List[RetrievalResult], + ) -> float: + """ + 计算阈值 + + Args: + scores: 分数数组 + results: 检索结果列表 + + Returns: + 阈值 + """ + if self.config.method == ThresholdMethod.PERCENTILE: + threshold = self._percentile_threshold(scores) + elif self.config.method == ThresholdMethod.STD_DEV: + threshold = self._std_dev_threshold(scores) + elif self.config.method == ThresholdMethod.GAP_DETECTION: + threshold = self._gap_detection_threshold(scores) + else: # ADAPTIVE + # 自适应方法:综合多种方法 + thresholds = [ + self._percentile_threshold(scores), + self._std_dev_threshold(scores), + self._gap_detection_threshold(scores), + ] + # 使用中位数作为最终阈值 + threshold = float(np.median(thresholds)) + + # 限制在[min_threshold, max_threshold]范围内 + threshold = np.clip( + threshold, + self.config.min_threshold, + self.config.max_threshold, + ) + + # 自动调整 + if self.config.enable_auto_adjust: + threshold = self._auto_adjust_threshold(threshold, scores) + + return float(threshold) + + def _percentile_threshold(self, scores: np.ndarray) -> float: + """ + 基于百分位数计算阈值 + + Args: + scores: 分数数组 + + Returns: + 阈值 + """ + percentile = self.config.percentile + threshold = float(np.percentile(scores, percentile)) + + logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})") + return threshold + + def _std_dev_threshold(self, scores: np.ndarray) -> float: + """ + 基于标准差计算阈值 + + threshold = mean - std_multiplier * std + + Args: + scores: 分数数组 + + Returns: + 阈值 + """ + mean = float(np.mean(scores)) + std = float(np.std(scores)) + multiplier = self.config.std_multiplier + + threshold = mean - multiplier * std + + logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})") + return threshold + + def _gap_detection_threshold(self, scores: np.ndarray) -> float: + """ + 基于跳变检测计算阈值 + + 找到分数分布中最大的"跳变"位置,以此为阈值 + + Args: + scores: 分数数组(降序排列) + + Returns: + 阈值 + """ + # 降序排列 + sorted_scores = np.sort(scores)[::-1] + + if len(sorted_scores) < 2: + return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0 + + # 计算相邻分数的差值 + gaps = np.diff(sorted_scores) + + # 找到最大的跳变位置 + max_gap_idx = int(np.argmax(gaps)) + + # 阈值为跳变后的分数 + threshold = float(sorted_scores[max_gap_idx + 1]) + + logger.debug( + f"跳变检测阈值: {threshold:.3f} " + f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})" + ) + return threshold + + def _auto_adjust_threshold( + self, + threshold: float, + scores: np.ndarray, + ) -> float: + """ + 自动调整阈值 + + 基于历史阈值和当前分数分布调整 + + Args: + threshold: 当前阈值 + scores: 分数数组 + + Returns: + 调整后的阈值 + """ + if not self._threshold_history: + return threshold + + # 计算历史阈值的移动平均 + recent_thresholds = self._threshold_history[-10:] # 最近10次 + avg_threshold = float(np.mean(recent_thresholds)) + + # 当前阈值与历史平均的差异 + diff = threshold - avg_threshold + + # 如果差异过大(>0.2),向历史平均靠拢 + if abs(diff) > 0.2: + adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50% + logger.debug( + f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} " + f"(历史平均={avg_threshold:.3f})" + ) + return adjusted_threshold + + return threshold + + def filter_by_confidence( + self, + results: List[RetrievalResult], + min_confidence: float = 0.5, + ) -> List[RetrievalResult]: + """ + 基于置信度过滤结果 + + Args: + results: 检索结果列表 + min_confidence: 最小置信度 + + Returns: + 过滤后的结果列表 + """ + filtered = [] + for result in results: + # 对于关系结果,使用confidence字段 + if result.result_type == "relation": + confidence = result.metadata.get("confidence", 1.0) + if confidence >= min_confidence: + filtered.append(result) + else: + # 对于段落结果,直接使用分数 + if result.score >= min_confidence: + filtered.append(result) + + logger.info( + f"置信度过滤: {len(results)} -> {len(filtered)} " + f"(min_confidence={min_confidence})" + ) + + return filtered + + def filter_by_diversity( + self, + results: List[RetrievalResult], + similarity_threshold: float = 0.9, + top_k: int = 10, + ) -> List[RetrievalResult]: + """ + 基于多样性过滤结果(去除重复) + + Args: + results: 检索结果列表 + similarity_threshold: 相似度阈值(高于此值视为重复) + top_k: 最多保留结果数 + + Returns: + 过滤后的结果列表 + """ + if not results: + return [] + + # 按分数排序 + sorted_results = sorted(results, key=lambda x: x.score, reverse=True) + + # 贪心选择:选择与已选结果相似度低的结果 + selected = [] + selected_hashes = [] + + for result in sorted_results: + if len(selected) >= top_k: + break + + # 检查与已选结果的相似度 + is_duplicate = False + for selected_hash in selected_hashes: + # 简单判断:基于hash的前缀 + if result.hash_value[:8] == selected_hash[:8]: + is_duplicate = True + break + + if not is_duplicate: + selected.append(result) + selected_hashes.append(result.hash_value) + + logger.info( + f"多样性过滤: {len(results)} -> {len(selected)} " + f"(similarity_threshold={similarity_threshold})" + ) + + return selected + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + filter_rate = ( + self._total_filtered / self._total_processed + if self._total_processed > 0 + else 0.0 + ) + + stats = { + "config": { + "method": self.config.method.value, + "min_threshold": self.config.min_threshold, + "max_threshold": self.config.max_threshold, + "percentile": self.config.percentile, + "std_multiplier": self.config.std_multiplier, + "min_results": self.config.min_results, + "enable_auto_adjust": self.config.enable_auto_adjust, + }, + "statistics": { + "total_processed": self._total_processed, + "total_filtered": self._total_filtered, + "filter_rate": filter_rate, + "avg_threshold": float(np.mean(self._threshold_history)) + if self._threshold_history else 0.0, + "threshold_count": len(self._threshold_history), + }, + } + + if self._threshold_history: + stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history)) + stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history)) + + return stats + + def reset_statistics(self) -> None: + """重置统计信息""" + self._total_filtered = 0 + self._total_processed = 0 + self._threshold_history.clear() + logger.info("统计信息已重置") + + def __repr__(self) -> str: + return ( + f"DynamicThresholdFilter(" + f"method={self.config.method.value}, " + f"min_threshold={self.config.min_threshold}, " + f"filtered={self._total_filtered}/{self._total_processed})" + ) diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py new file mode 100644 index 00000000..fa222715 --- /dev/null +++ b/plugins/A_memorix/core/runtime/__init__.py @@ -0,0 +1,8 @@ +"""SDK runtime exports for A_Memorix.""" + +from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel + +__all__ = [ + "KernelSearchRequest", + "SDKMemoryKernel", +] diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py new file mode 100644 index 00000000..7c8f9213 --- /dev/null +++ b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from src.common.logger import get_logger + +from ..embedding import create_embedding_api_adapter +from ..retrieval import ( + DualPathRetriever, + DualPathRetrieverConfig, + RetrievalResult, + SparseBM25Config, + SparseBM25Index, + TemporalQueryOptions, +) +from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore +from ..utils.aggregate_query_service import AggregateQueryService +from ..utils.episode_retrieval_service import EpisodeRetrievalService +from ..utils.hash import normalize_text +from ..utils.relation_write_service import RelationWriteService + +logger = get_logger("A_Memorix.SDKMemoryKernel") + + +@dataclass +class KernelSearchRequest: + query: str = "" + limit: int = 5 + mode: str = "hybrid" + chat_id: str = "" + person_id: str = "" + time_start: Optional[float] = None + time_end: Optional[float] = None + + +class SDKMemoryKernel: + def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None: + self.plugin_root = Path(plugin_root).resolve() + self.config = config or {} + storage_cfg = self._cfg("storage", {}) or {} + data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") + self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) + self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256))) + self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) + + self.embedding_manager = None + self.vector_store: Optional[VectorStore] = None + self.graph_store: Optional[GraphStore] = None + self.metadata_store: Optional[MetadataStore] = None + self.relation_write_service: Optional[RelationWriteService] = None + self.sparse_index = None + self.retriever: Optional[DualPathRetriever] = None + self.episode_retriever: Optional[EpisodeRetrievalService] = None + self.aggregate_query_service: Optional[AggregateQueryService] = None + self._initialized = False + self._last_maintenance_at: Optional[float] = None + + def _cfg(self, key: str, default: Any = None) -> Any: + current: Any = self.config + if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict): + return current.get(key, default) + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + async def initialize(self) -> None: + if self._initialized: + return + self.data_dir.mkdir(parents=True, exist_ok=True) + self.embedding_manager = create_embedding_api_adapter( + batch_size=int(self._cfg("embedding.batch_size", 32)), + max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), + default_dimension=self.embedding_dimension, + model_name=str(self._cfg("embedding.model_name", "hash-v1")), + retry_config=self._cfg("embedding.retry", {}) or {}, + ) + self.embedding_dimension = int(await self.embedding_manager._detect_dimension()) + self.vector_store = VectorStore( + dimension=self.embedding_dimension, + quantization_type=QuantizationType.INT8, + data_dir=self.data_dir / "vectors", + ) + self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph") + self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") + self.metadata_store.connect() + if self.vector_store.has_data(): + self.vector_store.load() + self.vector_store.warmup_index(force_train=True) + if self.graph_store.has_data(): + self.graph_store.load() + + sparse_cfg = self._cfg("retrieval.sparse", {}) or {} + self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg)) + if getattr(self.sparse_index.config, "enabled", False): + self.sparse_index.ensure_loaded() + + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + self.retriever = DualPathRetriever( + vector_store=self.vector_store, + graph_store=self.graph_store, + metadata_store=self.metadata_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + config=DualPathRetrieverConfig( + top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)), + top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)), + top_k_final=int(self._cfg("retrieval.top_k_final", 10)), + alpha=float(self._cfg("retrieval.alpha", 0.5)), + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), + ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), + enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), + sparse=sparse_cfg, + fusion=self._cfg("retrieval.fusion", {}) or {}, + graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {}, + relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {}, + ), + ) + self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) + self.aggregate_query_service = AggregateQueryService(plugin_config=self.config) + self._initialized = True + + def close(self) -> None: + if self.vector_store is not None: + self.vector_store.save() + if self.graph_store is not None: + self.graph_store.save() + if self.metadata_store is not None: + self.metadata_store.close() + self._initialized = False + + async def ingest_summary( + self, + *, + external_id: str, + chat_id: str, + text: str, + participants: Optional[Sequence[str]] = None, + time_start: Optional[float] = None, + time_end: Optional[float] = None, + tags: Optional[Sequence[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + summary_meta = dict(metadata or {}) + summary_meta.setdefault("kind", "chat_summary") + return await self.ingest_text( + external_id=external_id, + source_type="chat_summary", + text=text, + chat_id=chat_id, + participants=participants, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=summary_meta, + ) + + async def ingest_text( + self, + *, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[Sequence[str]] = None, + participants: Optional[Sequence[str]] = None, + timestamp: Optional[float] = None, + time_start: Optional[float] = None, + time_end: Optional[float] = None, + tags: Optional[Sequence[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + entities: Optional[Sequence[str]] = None, + relations: Optional[Sequence[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager + assert self.relation_write_service + content = normalize_text(text) + if not content: + return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"} + if ref := self.metadata_store.get_external_memory_ref(external_id): + return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"} + + person_tokens = self._tokens(person_ids) + participant_tokens = self._tokens(participants) + entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) + source = self._build_source(source_type, chat_id, person_tokens) + paragraph_meta = dict(metadata or {}) + paragraph_meta.update( + { + "external_id": external_id, + "source_type": str(source_type or "").strip(), + "chat_id": str(chat_id or "").strip(), + "person_ids": person_tokens, + "participants": participant_tokens, + "tags": self._tokens(tags), + } + ) + paragraph_hash = self.metadata_store.add_paragraph( + content=content, + source=source, + metadata=paragraph_meta, + knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed", + time_meta=self._time_meta(timestamp, time_start, time_end), + ) + embedding = await self.embedding_manager.encode(content) + self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) + for name in entity_tokens: + self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) + + stored_relations: List[str] = [] + for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: + s = str(row.get("subject", "") or "").strip() + p = str(row.get("predicate", "") or "").strip() + o = str(row.get("object", "") or "").strip() + if not (s and p and o): + continue + result = await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=float(row.get("confidence", 1.0) or 1.0), + source_paragraph=paragraph_hash, + metadata={"external_id": external_id, "source_type": source_type}, + write_vector=self.relation_vectors_enabled, + ) + self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) + stored_relations.append(result.hash_value) + + self.metadata_store.upsert_external_memory_ref( + external_id=external_id, + paragraph_hash=paragraph_hash, + source_type=source_type, + metadata={"chat_id": chat_id, "person_ids": person_tokens}, + ) + self._persist() + self.rebuild_episodes_for_sources([source]) + for person_id in person_tokens: + await self.refresh_person_profile(person_id) + return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} + + async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + await self.initialize() + assert self.retriever and self.episode_retriever and self.aggregate_query_service + mode = str(request.mode or "hybrid").strip().lower() or "hybrid" + clean_query = str(request.query or "").strip() + limit = max(1, int(request.limit or 5)) + temporal = self._temporal(request) + if mode == "episode": + rows = await self.episode_retriever.query( + query=clean_query, + top_k=limit, + time_from=request.time_start, + time_to=request.time_end, + source=self._chat_source(request.chat_id), + ) + hits = [self._episode_hit(row) for row in rows] + return {"summary": self._summary(hits), "hits": hits} + if mode == "aggregate": + payload = await self.aggregate_query_service.execute( + query=clean_query, + top_k=limit, + mix=True, + mix_top_k=limit, + time_from=str(request.time_start) if request.time_start is not None else None, + time_to=str(request.time_end) if request.time_end is not None else None, + search_runner=lambda: self._aggregate_search(clean_query, limit, temporal), + time_runner=lambda: self._aggregate_time(clean_query, limit, temporal), + episode_runner=lambda: self._aggregate_episode(clean_query, limit, request), + ) + hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] + for item in hits: + item.setdefault("metadata", {}) + return {"summary": self._summary(hits), "hits": hits} + results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal) + hits = [self._retrieval_hit(item) for item in results] + return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)} + + async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: + _ = chat_id + await self.initialize() + assert self.metadata_store + snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit) + evidence = [] + for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]: + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is not None: + evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}}) + text = str(snapshot.get("profile_text", "") or "").strip() + traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] + return {"summary": text, "traits": traits, "evidence": evidence} + + async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + rows = self.metadata_store.query( + """ + SELECT DISTINCT p.* + FROM paragraphs p + JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash + JOIN entities e ON e.hash = pe.entity_hash + WHERE e.name = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC + LIMIT ? + """, + (person_id, max(1, int(limit)) * 3), + ) + evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] + vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]] + relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]] + if relation_edges: + profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6]) + elif vector_evidence: + profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6]) + else: + profile_text = "暂无稳定画像证据。" + return self.metadata_store.upsert_person_profile_snapshot( + person_id=person_id, + profile_text=profile_text, + aliases=[person_id], + relation_edges=relation_edges, + vector_evidence=vector_evidence, + evidence_ids=evidence_ids[: max(1, int(limit))], + expires_at=time.time() + 6 * 3600, + source_note="sdk_memory_kernel", + ) + + async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]: + _ = reason + await self.initialize() + assert self.metadata_store + hashes = self._resolve_relation_hashes(target) + if not hashes: + return {"success": False, "detail": "未命中可维护关系"} + act = str(action or "").strip().lower() + if act == "reinforce": + self.metadata_store.reinforce_relations(hashes) + elif act == "protect": + ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 + self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) + elif act == "restore": + restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) + if restored <= 0: + return {"success": False, "detail": "未恢复任何关系"} + else: + return {"success": False, "detail": f"不支持的维护动作: {act}"} + self._last_maintenance_at = time.time() + self._persist() + return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} + + def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int: + assert self.metadata_store + rebuilt = 0 + for source in self._tokens(sources): + rows = self.metadata_store.query( + """ + SELECT * FROM paragraphs + WHERE source = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC + """, + (source,), + ) + if not rows: + continue + paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] + payload = self.metadata_store.upsert_episode( + { + "source": source, + "title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80], + "summary": ";".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。", + "participants": self._episode_participants(rows), + "keywords": self._episode_keywords(rows), + "evidence_ids": paragraph_hashes, + "paragraph_count": len(paragraph_hashes), + "event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False), + "event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True), + "time_granularity": "day", + "time_confidence": 0.7, + "llm_confidence": 0.0, + "segmentation_model": "rule_based_sdk", + "segmentation_version": "1", + } + ) + self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes) + rebuilt += 1 + return rebuilt + + def memory_stats(self) -> Dict[str, Any]: + assert self.metadata_store + stats = self.metadata_store.get_statistics() + episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] + profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] + return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at} + + async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: + assert self.retriever + hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] + return {"success": True, "results": hits, "count": len(hits), "query_type": "search"} + + async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: + if temporal is None: + return {"success": False, "error": "missing temporal window", "results": []} + assert self.retriever + hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] + return {"success": True, "results": hits, "count": len(hits), "query_type": "time"} + + async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + assert self.episode_retriever + rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + hits = [self._episode_hit(row) for row in rows] + return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} + + def _persist(self) -> None: + if self.vector_store is not None: + self.vector_store.save() + if self.graph_store is not None: + self.graph_store.save() + if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): + self.sparse_index.ensure_loaded() + + @staticmethod + def _tokens(values: Optional[Iterable[Any]]) -> List[str]: + result: List[str] = [] + seen = set() + for item in values or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result + + @classmethod + def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]: + merged: List[str] = [] + seen = set() + for group in groups: + for item in cls._tokens(group): + if item in seen: + continue + seen.add(item) + merged.append(item) + return merged + + @staticmethod + def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str: + clean_type = str(source_type or "").strip() or "memory" + if clean_type == "chat_summary" and chat_id: + return f"chat_summary:{chat_id}" + if clean_type == "person_fact" and person_ids: + return f"person_fact:{person_ids[0]}" + return f"{clean_type}:{chat_id}" if chat_id else clean_type + + @staticmethod + def _chat_source(chat_id: str) -> Optional[str]: + clean = str(chat_id or "").strip() + return f"chat_summary:{clean}" if clean else None + + @staticmethod + def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if timestamp is not None: + payload["event_time"] = float(timestamp) + if time_start is not None: + payload["event_time_start"] = float(time_start) + if time_end is not None: + payload["event_time_end"] = float(time_end) + if payload: + payload["time_granularity"] = "minute" + payload["time_confidence"] = 0.95 + return payload + + def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]: + if request.time_start is None and request.time_end is None and not request.chat_id: + return None + return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + + @staticmethod + def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]: + payload = item.to_dict() + return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}} + + @staticmethod + def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: + return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}} + + @staticmethod + def _summary(hits: Sequence[Dict[str, Any]]) -> str: + if not hits: + return "" + lines = [] + for index, item in enumerate(hits[:5], start=1): + content = str(item.get("content", "") or "").strip().replace("\n", " ") + lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}") + return "\n".join(lines) + + @staticmethod + def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]: + if not person_id: + return hits + filtered = [] + for item in hits: + metadata = item.get("metadata", {}) or {} + if person_id in (metadata.get("person_ids", []) or []): + filtered.append(item) + continue + if person_id and person_id in str(item.get("content", "") or ""): + filtered.append(item) + return filtered or hits + + @staticmethod + def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]: + seen = set() + result: List[str] = [] + for row in rows: + meta = row.get("metadata", {}) or {} + for key in ("participants", "person_ids"): + for item in meta.get(key, []) or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result[:16] + + @staticmethod + def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]: + seen = set() + result: List[str] = [] + for row in rows: + meta = row.get("metadata", {}) or {} + for item in meta.get("tags", []) or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result[:12] + + @staticmethod + def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]: + values: List[float] = [] + for row in rows: + for key in (primary, fallback): + value = row.get(key) + try: + if value is not None: + values.append(float(value)) + break + except Exception: + continue + if not values: + return None + return max(values) if reverse else min(values) + + def _resolve_relation_hashes(self, target: str) -> List[str]: + assert self.metadata_store + token = str(target or "").strip() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): + return [token] + hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) + if hashes: + return hashes + return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()] diff --git a/plugins/A_memorix/core/storage/__init__.py b/plugins/A_memorix/core/storage/__init__.py new file mode 100644 index 00000000..d878b8e7 --- /dev/null +++ b/plugins/A_memorix/core/storage/__init__.py @@ -0,0 +1,53 @@ +"""存储层""" + +from .vector_store import VectorStore, QuantizationType +from .graph_store import GraphStore, SparseMatrixFormat +from .metadata_store import MetadataStore +from .knowledge_types import ( + ImportStrategy, + KnowledgeType, + allowed_import_strategy_values, + allowed_knowledge_type_values, + get_knowledge_type_from_string, + get_import_strategy_from_string, + parse_import_strategy, + resolve_stored_knowledge_type, + should_extract_relations, + get_default_chunk_size, + get_type_display_name, + validate_stored_knowledge_type, +) +from .type_detection import ( + detect_knowledge_type, + get_type_from_user_input, + looks_like_factual_text, + looks_like_quote_text, + looks_like_structured_text, + select_import_strategy, +) + +__all__ = [ + "VectorStore", + "GraphStore", + "MetadataStore", + "QuantizationType", + "SparseMatrixFormat", + "ImportStrategy", + "KnowledgeType", + "allowed_import_strategy_values", + "allowed_knowledge_type_values", + "get_knowledge_type_from_string", + "get_import_strategy_from_string", + "parse_import_strategy", + "resolve_stored_knowledge_type", + "should_extract_relations", + "get_default_chunk_size", + "get_type_display_name", + "validate_stored_knowledge_type", + "detect_knowledge_type", + "get_type_from_user_input", + "looks_like_factual_text", + "looks_like_quote_text", + "looks_like_structured_text", + "select_import_strategy", +] diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py new file mode 100644 index 00000000..8a075864 --- /dev/null +++ b/plugins/A_memorix/core/storage/graph_store.py @@ -0,0 +1,1434 @@ +""" +图存储模块 + +基于SciPy稀疏矩阵的知识图谱存储与计算。 +""" + +import pickle +from enum import Enum +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Set, Any +from collections import defaultdict +import threading +import asyncio + +import numpy as np + +class SparseMatrixFormat(Enum): + """稀疏矩阵格式""" + CSR = "csr" + CSC = "csc" + +try: + from scipy.sparse import csr_matrix, csc_matrix, triu, save_npz, load_npz, bmat, lil_matrix + from scipy.sparse.linalg import norm + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + +import contextlib +from src.common.logger import get_logger +from ..utils.hash import compute_hash +from ..utils.io import atomic_write + +logger = get_logger("A_Memorix.GraphStore") + + +class GraphModificationMode(Enum): + """图修改模式""" + BATCH = "batch" # 批量模式 (默认, 适合一次性加载) + INCREMENTAL = "incremental" # 增量模式 (适合频繁随机写入, 使用LIL) + READ_ONLY = "read_only" # 只读模式 (适合计算, CSR/CSC) + + +class GraphStore: + """ + 图存储类 + + 功能: + - CSR稀疏矩阵存储图结构 + - 节点和边的CRUD操作 + - Personalized PageRank计算 + - 同义词自动连接 + - 图持久化 + + 参数: + matrix_format: 稀疏矩阵格式(csr/csc) + data_dir: 数据目录 + """ + + def __init__( + self, + matrix_format: str = "csr", + data_dir: Optional[Union[str, Path]] = None, + ): + """ + 初始化图存储 + + Args: + matrix_format: 稀疏矩阵格式(csr/csc) + data_dir: 数据目录 + """ + if not HAS_SCIPY: + raise ImportError("SciPy 未安装,请安装: pip install scipy") + + if isinstance(matrix_format, SparseMatrixFormat): + self.matrix_format = matrix_format.value + else: + self.matrix_format = str(matrix_format).lower() + self.data_dir = Path(data_dir) if data_dir else None + + # 节点管理 + self._nodes: List[str] = [] # 节点列表 + self._node_to_idx: Dict[str, int] = {} # 节点名到索引的映射 + self._node_attrs: Dict[str, Dict[str, Any]] = {} # 节点属性 + + # 边管理(邻接矩阵) + self._adjacency: Optional[Union[csr_matrix, csc_matrix]] = None + + # 统计信息 + self._total_nodes_added = 0 + self._total_edges_added = 0 + self._total_nodes_deleted = 0 + self._total_edges_deleted = 0 + + # 状态管理 + self._modification_mode = GraphModificationMode.BATCH + + # 状态管理 + self._adjacency_T: Optional[Union[csr_matrix, csc_matrix]] = None + self._adjacency_dirty: bool = True + self._saliency_cache: Optional[Dict[str, float]] = None + + # V5: 多关系映射 (src_idx, dst_idx) -> Set[relation_hash] + self._edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) + # V5: 简单的异步锁 (实际上 asyncio 环境下单线程主循环可能不需要,但为了安全保留) + self._lock = asyncio.Lock() + + logger.info(f"GraphStore 初始化: format={matrix_format}") + + def _canonicalize(self, node: str) -> str: + """规范化节点名称 (用于去重和内部索引)""" + if not node: + return "" + return str(node).strip().lower() + + @contextlib.contextmanager + def batch_update(self): + """ + 批量更新上下文管理器 + + 进入时切换到 LIL 格式以优化随机/增量更新 + 退出时恢复到 CSR/CSC 格式以优化存储和计算 + """ + original_mode = self._modification_mode + self._switch_mode(GraphModificationMode.INCREMENTAL) + try: + yield + finally: + self._switch_mode(original_mode) + + def _switch_mode(self, new_mode: GraphModificationMode): + """切换修改模式并转换矩阵格式""" + if new_mode == self._modification_mode: + return + + if self._adjacency is None: + self._modification_mode = new_mode + return + + logger.debug(f"切换图模式: {self._modification_mode.value} -> {new_mode.value}") + + # 转换逻辑 + if new_mode == GraphModificationMode.INCREMENTAL: + # 转换为 LIL 格式 + if not isinstance(self._adjacency, lil_matrix): # 粗略检查是否非 lil + try: + self._adjacency = self._adjacency.tolil() + logger.debug("已转换为 LIL 格式") + except Exception as e: + logger.warning(f"转换为 LIL 失败: {e}") + + elif new_mode in [GraphModificationMode.BATCH, GraphModificationMode.READ_ONLY]: + # 转换回配置的格式 (CSR/CSC) + if self.matrix_format == "csr": + self._adjacency = self._adjacency.tocsr() + elif self.matrix_format == "csc": + self._adjacency = self._adjacency.tocsc() + logger.debug(f"已恢复为 {self.matrix_format.upper()} 格式") + + self._modification_mode = new_mode + + def add_nodes( + self, + nodes: List[str], + attributes: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> int: + """ + 添加节点 + + Args: + nodes: 节点名称列表 + attributes: 节点属性字典 {node_name: {attr: value}} + + Returns: + 成功添加的节点数量 + """ + added = 0 + for node in nodes: + canon = self._canonicalize(node) + if canon in self._node_to_idx: + logger.debug(f"节点已存在,跳过: {node}") + continue + + # 添加到节点列表 + idx = len(self._nodes) + self._nodes.append(node) # 存储原始节点名 + self._node_to_idx[canon] = idx # 映射规范化节点名到索引 + self._adjacency_dirty = True + self._saliency_cache = None + + # 添加属性 + if attributes and node in attributes: + self._node_attrs[canon] = attributes[node] + else: + self._node_attrs[canon] = {} + + added += 1 + self._total_nodes_added += 1 + + # 扩展邻接矩阵 + if added > 0: + self._expand_adjacency_matrix(added) + + logger.debug(f"添加 {added} 个节点") + return added + + def add_edges( + self, + edges: List[Tuple[str, str]], + weights: Optional[List[float]] = None, + relation_hashes: Optional[List[str]] = None, # V5: 支持关系哈希映射 (Relation Hash Mapping) + ) -> int: + """ + 添加边 + + Args: + edges: 边列表 [(source, target), ...] + weights: 边权重列表(默认为1.0) + + Returns: + 成功添加的边数量 + """ + if not edges: + return 0 + + # 确保所有节点存在 + nodes_to_add = set() + for src, tgt in edges: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon not in self._node_to_idx: + nodes_to_add.add(src) + if tgt_canon not in self._node_to_idx: + nodes_to_add.add(tgt) + + if nodes_to_add: + self.add_nodes(list(nodes_to_add)) + + # 处理权重 + if weights is None: + weights = [1.0] * len(edges) + + if len(weights) != len(edges): + raise ValueError(f"边数量与权重数量不匹配: {len(edges)} vs {len(weights)}") + + # 如果仅仅是添加边且处于增量模式 (LIL),直接更新 + if self._modification_mode == GraphModificationMode.INCREMENTAL: + if self._adjacency is None: + # 初始化为空 LIL + n = len(self._nodes) + from scipy.sparse import lil_matrix + self._adjacency = lil_matrix((n, n), dtype=np.float32) + + # 尝试直接使用 LIL 索引更新 + try: + # 批量获取索引 + rows = [self._node_to_idx[self._canonicalize(src)] for src, _ in edges] + cols = [self._node_to_idx[self._canonicalize(tgt)] for _, tgt in edges] + + # 确保矩阵足够大 (如果 add_nodes 没有扩展它) - 通常 add_nodes 会处理 + # 这里直接赋值 + self._adjacency[rows, cols] = weights + + self._total_edges_added += len(edges) + + # V5: Update edge hash map + if relation_hashes: + for (src, tgt), r_hash in zip(edges, relation_hashes): + if r_hash: + s_idx = self._node_to_idx[self._canonicalize(src)] + t_idx = self._node_to_idx[self._canonicalize(tgt)] + self._edge_hash_map[(s_idx, t_idx)].add(r_hash) + + logger.debug(f"增量添加 {len(edges)} 条边 (LIL)") + return len(edges) + except Exception as e: + logger.warning(f"LIL 增量更新失败,回退到通用方法: {e}") + # Fallback to general method below + + # 通用方法 (构建 COO 然后合并) + # 构建边的三元组 + row_indices = [] + col_indices = [] + data_values = [] + + for (src, tgt), weight in zip(edges, weights): + src_idx = self._node_to_idx[self._canonicalize(src)] + tgt_idx = self._node_to_idx[self._canonicalize(tgt)] + + row_indices.append(src_idx) + col_indices.append(tgt_idx) + data_values.append(weight) + + # 创建新的边的矩阵 + n = len(self._nodes) + new_edges = csr_matrix( + (data_values, (row_indices, col_indices)), + shape=(n, n), + ) + + # 合并到邻接矩阵 + if self._adjacency is None: + self._adjacency = new_edges + else: + self._adjacency = self._adjacency + new_edges + + # 转换为指定格式 + if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): + self._adjacency = self._adjacency.tocsc() + elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): + self._adjacency = self._adjacency.tocsr() + + self._total_edges_added += len(edges) + self._adjacency_dirty = True # 标记脏位 + + # V5: 更新边哈希映射 (Edge Hash Map) + if relation_hashes: + for (src, tgt), r_hash in zip(edges, relation_hashes): + if r_hash: + try: + s_idx = self._node_to_idx[self._canonicalize(src)] + t_idx = self._node_to_idx[self._canonicalize(tgt)] + self._edge_hash_map[(s_idx, t_idx)].add(r_hash) + except KeyError: + pass # 正常情况下节点已在上方添加,此处仅作防错处理 + + logger.debug(f"添加 {len(edges)} 条边") + return len(edges) + + def update_edge_weight( + self, + source: str, + target: str, + delta: float, + min_weight: float = 0.1, + max_weight: float = 10.0, + ) -> float: + """ + 更新边权重 (增量/强化/弱化) + + Args: + source: 源节点 + target: 目标节点 + delta: 权重变化量 (+/-) + min_weight: 最小权重限制 + max_weight: 最大权重限制 + + Returns: + 更新后的权重 + """ + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + logger.warning(f"节点不存在,无法更新权重: {source} -> {target}") + return 0.0 + + current_weight = self.get_edge_weight(source, target) + if current_weight == 0.0 and delta <= 0: + # 边不存在且试图减少权重,忽略 + return 0.0 + + # 如果边不存在但 delta > 0,相当于添加新边 (默认基础权重0 + delta) + # 但为了逻辑清晰,我们假设只更新存在的边,或者确实添加 + + new_weight = current_weight + delta + new_weight = max(min_weight, min(max_weight, new_weight)) + + # 使用 batch_update 上下文自动处理格式转换 + # 这里我们临时切换到 incremental 模式进行单次更新 + with self.batch_update(): + # add_edges 会覆盖或添加,我们需要覆盖 + self.add_edges([(source, target)], [new_weight]) + + logger.debug(f"更新权重 {source}->{target}: {current_weight:.2f} -> {new_weight:.2f}") + return new_weight + + def delete_nodes(self, nodes: List[str]) -> int: + """ + 删除节点(及相关的边) + + Args: + nodes: 要删除的节点列表 + + Returns: + 成功删除的节点数量 + """ + if not nodes: + return 0 + + # 检查哪些节点存在 + existing_nodes = [node for node in nodes if self._canonicalize(node) in self._node_to_idx] + if not existing_nodes: + logger.warning("所有节点都不存在,无法删除") + return 0 + + # 获取要删除的索引 + indices_to_delete = {self._node_to_idx[self._canonicalize(node)] for node in existing_nodes} + indices_to_keep = [ + i for i in range(len(self._nodes)) + if i not in indices_to_delete + ] + + # 创建索引映射 + old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(indices_to_keep)} + + # 重建节点列表 (存储原始节点名) + self._nodes = [self._nodes[i] for i in indices_to_keep] + # 重建规范化节点名到索引的映射 + self._node_to_idx = {self._canonicalize(self._nodes[new_idx]): new_idx for new_idx in range(len(self._nodes))} + + # 删除并重构节点属性 + new_node_attrs = {} + for idx, node_name in enumerate(self._nodes): + canon = self._canonicalize(node_name) + if canon in self._node_attrs: + new_node_attrs[canon] = self._node_attrs[canon] + self._node_attrs = new_node_attrs + + # 重建邻接矩阵 + if self._adjacency is not None: + # 转换为COO格式以进行切片,然后转换回原始格式 + self._adjacency = self._adjacency.tocoo() + mask_rows = np.isin(self._adjacency.row, list(indices_to_keep)) + mask_cols = np.isin(self._adjacency.col, list(indices_to_keep)) + + # 筛选出保留的行和列 + new_rows = self._adjacency.row[mask_rows & mask_cols] + new_cols = self._adjacency.col[mask_rows & mask_cols] + new_data = self._adjacency.data[mask_rows & mask_cols] + + # 更新索引 + new_rows = np.array([old_to_new[r] for r in new_rows]) + new_cols = np.array([old_to_new[c] for c in new_cols]) + + n = len(self._nodes) + if self.matrix_format == "csr": + self._adjacency = csr_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) + else: # csc + self._adjacency = csc_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) + + # 重建关系哈希映射,移除涉及已删除节点的记录并重映射索引。 + if self._edge_hash_map: + new_edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) + for (old_src, old_tgt), hashes in self._edge_hash_map.items(): + if old_src in indices_to_delete or old_tgt in indices_to_delete: + continue + if old_src in old_to_new and old_tgt in old_to_new and hashes: + new_edge_hash_map[(old_to_new[old_src], old_to_new[old_tgt])] = set(hashes) + self._edge_hash_map = new_edge_hash_map + + deleted_count = len(existing_nodes) + self._total_nodes_deleted += deleted_count + self._adjacency_dirty = True + self._saliency_cache = None + + logger.info(f"删除 {deleted_count} 个节点") + return deleted_count + + def remove_nodes(self, nodes: List[str]) -> int: + """兼容性别名:删除节点""" + return self.delete_nodes(nodes) + + def delete_edges( + self, + edges: List[Tuple[str, str]], + ) -> int: + """ + 删除边 + + Args: + edges: 要删除的边列表 [(source, target), ...] + + Returns: + 成功删除的边数量 + """ + if not edges: + return 0 + + deleted = 0 + # 构建要删除的边的索引集合 + edges_to_delete = set() + for src, tgt in edges: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + edges_to_delete.add((src_idx, tgt_idx)) + + if self._adjacency is not None and edges_to_delete: + # 转换为COO格式便于修改 + adj_coo = self._adjacency.tocoo() + + # 过滤要删除的边 + new_row = [] + new_col = [] + new_data = [] + + for i, j, val in zip(adj_coo.row, adj_coo.col, adj_coo.data): + if (i, j) not in edges_to_delete: + new_row.append(i) + new_col.append(j) + new_data.append(val) + else: + deleted += 1 + + # 重建邻接矩阵 + n = len(self._nodes) + self._adjacency = csr_matrix((new_data, (new_row, new_col)), shape=(n, n)) + + # 转换回指定格式 + if self.matrix_format == "csc": + self._adjacency = self._adjacency.tocsc() + + # delete_edges 是“物理删除”语义,必须同步清理 edge_hash_map。 + if edges_to_delete and self._edge_hash_map: + for key in edges_to_delete: + self._edge_hash_map.pop(key, None) + + self._total_edges_deleted += deleted + self._adjacency_dirty = True + self._saliency_cache = None + logger.info(f"删除 {deleted} 条边") + return deleted + + def remove_edges(self, edges: List[Tuple[str, str]]) -> int: + """兼容性别名:删除边""" + return self.delete_edges(edges) + + def get_nodes(self) -> List[str]: + """ + 获取所有节点 + + Returns: + 节点列表 + """ + return self._nodes.copy() + + def has_node(self, node: str) -> bool: + """ + 检查节点是否存在 + + Args: + node: 节点名称 + """ + return self._canonicalize(node) in self._node_to_idx + + def find_node(self, node: str, ignore_case: bool = False) -> Optional[str]: + """ + 查找节点 (由于底层已统一规范化,ignore_case 始终有效) + + Args: + node: 节点名称 + ignore_case: 是否忽略大小写 (已默认忽略) + + Returns: + 真实节点名称 (如果存在),否则 None + """ + canon = self._canonicalize(node) + if canon in self._node_to_idx: + return self._nodes[self._node_to_idx[canon]] + return None + + def get_node_attributes(self, node: str) -> Optional[Dict[str, Any]]: + """ + 获取节点属性 + + Args: + node: 节点名称 + + Returns: + 节点属性字典,不存在则返回None + """ + canon = self._canonicalize(node) + return self._node_attrs.get(canon) + + def get_neighbors(self, node: str) -> List[str]: + """ + 获取节点的出邻居 + + Args: + node: 节点名称 + + Returns: + 出邻居节点列表 + """ + canon = self._canonicalize(node) + if canon not in self._node_to_idx or self._adjacency is None: + return [] + + idx = self._node_to_idx[canon] + neighbor_indices = self._row_neighbor_indices(self._adjacency, idx) + return [self._nodes[int(i)] for i in neighbor_indices] + + def get_in_neighbors(self, node: str) -> List[str]: + """ + 获取节点的入邻居 + + Args: + node: 节点名称 + + Returns: + 入邻居节点列表 + """ + canon = self._canonicalize(node) + if canon not in self._node_to_idx or self._adjacency is None: + return [] + + self._ensure_adjacency_T() + if self._adjacency_T is None: + return [] + + idx = self._node_to_idx[canon] + neighbor_indices = self._row_neighbor_indices(self._adjacency_T, idx) + return [self._nodes[int(i)] for i in neighbor_indices] + + def get_edge_weight(self, source: str, target: str) -> float: + """ + 获取边的权重 + """ + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + return 0.0 + + if self._adjacency is None: + return 0.0 + + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + + return float(self._adjacency[src_idx, tgt_idx]) + + def canonicalize_node(self, node: str) -> str: + """公开节点规范化接口,避免外部访问私有方法。""" + return self._canonicalize(node) + + def has_edge_hash_map(self) -> bool: + """是否存在 relation-hash 映射。""" + return bool(self._edge_hash_map) + + def get_relation_hashes_for_edge(self, source: str, target: str) -> Set[str]: + """获取边 (source -> target) 关联的关系哈希集合。""" + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + return set() + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + return set(self._edge_hash_map.get((src_idx, tgt_idx), set())) + + def get_incident_relation_hashes(self, node: str, limit: Optional[int] = None) -> List[str]: + """获取与指定节点关联的关系哈希列表(入边 + 出边)。""" + canon = self._canonicalize(node) + if canon not in self._node_to_idx or not self._edge_hash_map: + return [] + + idx = self._node_to_idx[canon] + limit_val = max(1, int(limit)) if limit is not None else None + collected: List[Tuple[str, str, str]] = [] + idx_to_node = self._nodes + + for (src_idx, tgt_idx), hashes in self._edge_hash_map.items(): + if idx not in {src_idx, tgt_idx}: + continue + src_name = idx_to_node[src_idx] if 0 <= src_idx < len(idx_to_node) else "" + tgt_name = idx_to_node[tgt_idx] if 0 <= tgt_idx < len(idx_to_node) else "" + for hash_value in hashes: + hash_text = str(hash_value).strip() + if not hash_text: + continue + collected.append((src_name, tgt_name, hash_text)) + + collected.sort(key=lambda x: (x[0].lower(), x[1].lower(), x[2])) + out: List[str] = [] + seen = set() + for _, _, hash_value in collected: + if hash_value in seen: + continue + seen.add(hash_value) + out.append(hash_value) + if limit_val is not None and len(out) >= limit_val: + break + return out + + def edge_contains_relation_hash(self, source: str, target: str, hash_value: str) -> bool: + """判断边是否包含指定关系哈希。""" + if not hash_value: + return False + return str(hash_value) in self.get_relation_hashes_for_edge(source, target) + + def iter_edge_hash_entries(self) -> List[Tuple[str, str, Set[str]]]: + """以节点名形式遍历 edge-hash-map。""" + out: List[Tuple[str, str, Set[str]]] = [] + if not self._edge_hash_map: + return out + idx_to_node = self._nodes + for (s_idx, t_idx), hashes in self._edge_hash_map.items(): + if not hashes: + continue + if s_idx < 0 or t_idx < 0: + continue + if s_idx >= len(idx_to_node) or t_idx >= len(idx_to_node): + continue + out.append((idx_to_node[s_idx], idx_to_node[t_idx], set(hashes))) + return out + + def deactivate_edges(self, edges: List[Tuple[str, str]]) -> int: + """ + 冻结边 (将权重设为0.0,使其在计算意义上消失,但保留在Map中) + + Args: + edges: [(s1, t1), (s2, t2)...] + """ + if not edges or self._adjacency is None: + return 0 + + deactivated_count = 0 + with self.batch_update(): + # 我们需要 explicit set to 0. + # 使用增量更新模式覆盖 + for s, t in edges: + s_canon = self._canonicalize(s) + t_canon = self._canonicalize(t) + if s_canon in self._node_to_idx and t_canon in self._node_to_idx: + idx_s = self._node_to_idx[s_canon] + idx_t = self._node_to_idx[t_canon] + self._adjacency[idx_s, idx_t] = 0.0 + deactivated_count += 1 + + self._adjacency_dirty = True + return deactivated_count + + def _ensure_adjacency_T(self): + """确保转置邻接矩阵是最新的""" + if self._adjacency is None: + self._adjacency_T = None + return + + if self._adjacency_dirty or self._adjacency_T is None: + # 只有在确实需要时才计算转置 + # find_paths 以“按行读取邻居”为主,因此统一缓存为 CSR,避免 + # CSR->CSC 转置后按行切片读出错误的索引视图。 + self._adjacency_T = self._adjacency.transpose().tocsr() + + self._adjacency_dirty = False + # logger.debug("重建转置邻接矩阵缓存") + + @staticmethod + def _row_neighbor_indices( + matrix: Optional[Union[csr_matrix, csc_matrix]], + row_idx: int, + ) -> np.ndarray: + """返回指定行的非零列索引。""" + if matrix is None: + return np.asarray([], dtype=np.int32) + + if isinstance(matrix, csr_matrix): + return matrix.indices[matrix.indptr[row_idx]:matrix.indptr[row_idx + 1]] + + row = matrix[row_idx, :] + _, indices = row.nonzero() + return np.asarray(indices, dtype=np.int32) + + def find_paths( + self, + start_node: str, + end_node: str, + max_depth: int = 3, + max_paths: int = 5, + max_expansions: int = 20000 + ) -> List[List[str]]: + """ + 查找两个节点之间的路径 (BFS) + 支持有向和无向 (视作双向) 探索 + + Args: + start_node: 起始节点 + end_node: 目标节点 + max_depth: 最大深度 + max_paths: 最大路径数 (找到这么多就停止) + max_expansions: 最大扩展次数 (防止爆炸) + + Returns: + 路径列表 [[n1, n2, n3], ...] + """ + start_canon = self._canonicalize(start_node) + end_canon = self._canonicalize(end_node) + + if start_canon not in self._node_to_idx or end_canon not in self._node_to_idx: + return [] + + if self._adjacency is None: + return [] + + # 确保转置矩阵可用 (用于查找入边) + self._ensure_adjacency_T() + + start_idx = self._node_to_idx[start_canon] + end_idx = self._node_to_idx[end_canon] + + # 队列: (current_idx, path_indices) + queue = [(start_idx, [start_idx])] + found_paths = [] + expansions = 0 + + unique_paths = set() + + while queue: + curr, path = queue.pop(0) + + if len(path) > max_depth + 1: + continue + + if curr == end_idx: + # 找到路径 + # 转换回节点名 + path_names = [self._nodes[i] for i in path] + path_tuple = tuple(path_names) + if path_tuple not in unique_paths: + found_paths.append(path_names) + unique_paths.add(path_tuple) + + if len(found_paths) >= max_paths: + break + continue + + if expansions >= max_expansions: + break + + expansions += 1 + + # 获取邻居 (出边 + 入边) + out_indices = self._row_neighbor_indices(self._adjacency, curr) + + # 2. 入边 (使用转置矩阵) + if self._adjacency_T is not None: + in_indices = self._row_neighbor_indices(self._adjacency_T, curr) + neighbors = np.concatenate((out_indices, in_indices)) + else: + neighbors = out_indices + + # 去重并过滤已在路径中的节点 (防止环) + # 注意: 这里简单去重,可能包含重复的邻居(如果既是出又是入) + seen_in_path = set(path) + queued_neighbors = set() + + for neighbor_idx in neighbors: + neighbor = int(neighbor_idx) + if neighbor not in seen_in_path and neighbor not in queued_neighbors: + # 只有未访问过的才加入 + queue.append((neighbor, path + [neighbor])) + queued_neighbors.add(neighbor) + + return found_paths + + def compute_pagerank( + self, + personalization: Optional[Dict[str, float]] = None, + alpha: float = 0.85, + max_iter: int = 100, + tol: float = 1e-6, + ) -> Dict[str, float]: + """ + 计算Personalized PageRank + + Args: + personalization: 个性化向量 {node: weight},默认为均匀分布 + alpha: 阻尼系数(0-1之间) + max_iter: 最大迭代次数 + tol: 收敛阈值 + + Returns: + 节点PageRank值字典 {node: score} + """ + if self._adjacency is None or len(self._nodes) == 0: + logger.warning("图为空,无法计算PageRank") + return {} + + n = len(self._nodes) + + # 构建列归一化的转移矩阵 + adj = self._adjacency.astype(np.float32) + + # 计算出度 + out_degrees = np.array(adj.sum(axis=1)).flatten() + + # 处理悬挂节点(出度为0) + dangling = (out_degrees == 0) + out_degrees_inv = np.zeros_like(out_degrees) + out_degrees_inv[~dangling] = 1.0 / out_degrees[~dangling] + + # 归一化 (使用稀疏对角阵避免内存溢出) + from scipy.sparse import diags + D_inv = diags(out_degrees_inv) + M = adj.T @ D_inv # 转移矩阵 + + # 初始化个性化向量 + if personalization is None: + # 均匀分布 + p = np.ones(n) / n + else: + # 使用指定的个性化向量 + p = np.zeros(n) + total_weight = sum(personalization.values()) + for node, weight in personalization.items(): + if node in self._node_to_idx: + idx = self._node_to_idx[node] + p[idx] = weight / total_weight + + # 确保和为1 + if p.sum() == 0: + p = np.ones(n) / n + else: + p = p / p.sum() + + # 幂迭代法 + p_orig = p.copy() + for i in range(max_iter): + # p_new = alpha * M * p + (1-alpha) * personalization + p_new = alpha * (M @ p) + (1 - alpha) * p_orig + + # 处理因为悬挂节点导致的概率流失 + current_sum = p_new.sum() + if current_sum < 1.0: + p_new += (1.0 - current_sum) * p_orig + + # 检查收敛 + diff = np.linalg.norm(p_new - p, 1) + if diff < tol: + logger.debug(f"PageRank在 {i+1} 次迭代后收敛") + p = p_new + break + p = p_new + else: + logger.warning(f"PageRank未在 {max_iter} 次迭代内收敛") + + # 转换为真实节点名称字典 + return {self._nodes[idx]: float(val) for idx, val in enumerate(p)} + + def get_saliency_scores(self) -> Dict[str, float]: + """ + 获取节点显著性得分 (带有缓存机制) + """ + if self._saliency_cache is not None and not self._adjacency_dirty: + return self._saliency_cache + + logger.debug("正在计算节点显著性得分 (PageRank)...") + scores = self.compute_pagerank() + self._saliency_cache = scores + # 注意:这里我们不把 _adjacency_dirty 设为 False,因为其它逻辑(如_adjacency_T)也依赖它 + return scores + + def connect_synonyms( + self, + similarity_matrix: np.ndarray, + node_list: List[str], + threshold: float = 0.85, + ) -> int: + """ + 连接相似节点(同义词) + + Args: + similarity_matrix: 相似度矩阵 (N x N) + node_list: 对应的节点列表(长度为N) + threshold: 相似度阈值 + + Returns: + 添加的边数量 + """ + if len(node_list) != similarity_matrix.shape[0]: + raise ValueError( + f"节点列表长度与相似度矩阵维度不匹配: " + f"{len(node_list)} vs {similarity_matrix.shape[0]}" + ) + + # 找到相似的节点对(上三角,排除对角线) + similar_pairs = np.argwhere( + (triu(similarity_matrix, k=1) >= threshold) & + (triu(similarity_matrix, k=1) < 1.0) # 排除完全相同的 + ) + + # 添加边 + edges = [] + for i, j in similar_pairs: + if i < len(node_list) and j < len(node_list): + src = node_list[i] + tgt = node_list[j] + # 使用相似度作为权重 + weight = float(similarity_matrix[i, j]) + edges.append((src, tgt, weight)) + + if edges: + edge_pairs = [(src, tgt) for src, tgt, _ in edges] + weights = [w for _, _, w in edges] + count = self.add_edges(edge_pairs, weights) + logger.info(f"连接 {count} 对相似节点(阈值={threshold})") + return count + return 0 + + + # ========================================================================= + # V5 Memory System Methods (Graph Level) + # ========================================================================= + + def decay(self, factor: float, min_active_weight: float = 0.0) -> None: + """ + 全图衰减 (Atomic Decay) + + Args: + factor: 衰减因子 (0.0 < factor < 1.0) + min_active_weight: 最小活跃权重 (低于此值可能被视为无效,但在物理修剪前仍保留) + """ + if self._adjacency is None or factor >= 1.0 or factor <= 0.0: + return + + logger.debug(f"正在执行全图衰减,因子: {factor}") + + # 直接矩阵乘法,SciPy CSR/CSC 非常高效 + self._adjacency *= factor + + # 如果需要处理极小值 (可选,防止下溢,但通常浮点数足够小) + # if min_active_weight > 0: + # ... (复杂操作,暂不需要,由 prune 逻辑处理) + + self._adjacency_dirty = True + + def prune_relation_hashes(self, operations: List[Tuple[str, str, str]]) -> None: + """ + 修剪特定关系哈希 (从 _edge_hash_map 移除; 如果边变空则从矩阵移除) + + Args: + operations: List[(src, tgt, relation_hash)] + """ + if not operations: + return + + edges_to_check_removal = set() + + # 1. 更新映射 (Update Map) + for src, tgt, h in operations: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: + s_idx = self._node_to_idx[src_canon] + t_idx = self._node_to_idx[tgt_canon] + + key = (s_idx, t_idx) + if key in self._edge_hash_map: + if h in self._edge_hash_map[key]: + self._edge_hash_map[key].remove(h) + + if not self._edge_hash_map[key]: + del self._edge_hash_map[key] + edges_to_check_removal.add((src, tgt)) + + # 2. 从矩阵中移除空边 (Remove Empty Edges from Matrix) + if edges_to_check_removal: + self.deactivate_edges(list(edges_to_check_removal)) + self._total_edges_deleted += len(edges_to_check_removal) + + def get_low_weight_edges(self, threshold: float) -> List[Tuple[str, str]]: + """ + 获取低于阈值的边 (candidates for pruning/freezing) + + Args: + threshold: 权重阈值 + + Returns: + List[(src, tgt)]: 边列表 + """ + if self._adjacency is None: + return [] + + # 获取所有非零元素 + rows, cols = self._adjacency.nonzero() + data = self._adjacency.data + + low_weight_indices = np.where(data < threshold)[0] + + results = [] + for idx in low_weight_indices: + r = rows[idx] + c = cols[idx] + src = self._nodes[r] + tgt = self._nodes[c] + results.append((src, tgt)) + + return results + + def get_isolated_nodes(self, include_inactive: bool = True) -> List[str]: + """ + 获取孤儿节点 (Active Degree = 0) + + Args: + include_inactive: 是否包含参与了inactive边(冻结边)的节点。 + 如果 True (默认推荐): 排除掉虽然active degree=0但存在于_edge_hash_map(冻结边)中的节点。 + 如果 False: 只要在 active matrix 里度为0就返回 (可能会误删冻结节点)。 + + Returns: + 孤儿节点名称列表 + """ + if self._adjacency is None: + # 如果全空,则所有节点都是孤儿 + return self._nodes.copy() + + n = len(self._nodes) + + # 计算 Active Degree (In + Out) + # 用 sum(axis) 会得到 dense matrix/array + active_adj = self._adjacency + out_degrees = np.array(active_adj.sum(axis=1)).flatten() + in_degrees = np.array(active_adj.sum(axis=0)).flatten() + + # 处理悬挂节点 (dangling node check not really needed here, just sum) + total_degrees = out_degrees + in_degrees + + # 找到 active degree = 0 的索引 + isolated_indices = np.where(total_degrees == 0)[0] + + if len(isolated_indices) == 0: + return [] + + isolated_nodes_set = {self._nodes[i] for i in isolated_indices} + + # 如果需要排除 Inactive 参与者 + if include_inactive and self._edge_hash_map: + # 收集所有在冻结边中的 unique 节点索引 + frozen_participant_indices = set() + for (u_idx, v_idx), hashes in self._edge_hash_map.items(): + if hashes: # 只要有 hash 存在 (哪怕 inactive) + frozen_participant_indices.add(u_idx) + frozen_participant_indices.add(v_idx) + + # 过滤 + final_isolated = [] + for idx in isolated_indices: + if idx not in frozen_participant_indices: + final_isolated.append(self._nodes[idx]) + return final_isolated + + else: + return list(isolated_nodes_set) + + def clear(self) -> None: + """清空所有数据""" + self._nodes.clear() + self._node_to_idx.clear() + self._node_attrs.clear() + self._adjacency = None + self._edge_hash_map.clear() + self._adjacency_T = None + self._adjacency_dirty = True + self._total_nodes_added = 0 + self._total_edges_added = 0 + self._total_nodes_deleted = 0 + self._total_edges_deleted = 0 + logger.info("图存储已清空") + + def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: + """ + 保存到磁盘 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + # 保存邻接矩阵 + if self._adjacency is not None: + matrix_path = data_dir / "graph_adjacency.npz" + with atomic_write(matrix_path, "wb") as f: + save_npz(f, self._adjacency) + logger.debug(f"保存邻接矩阵: {matrix_path}") + + # 保存元数据 + metadata = { + "nodes": self._nodes, + "node_to_idx": self._node_to_idx, + "node_attrs": self._node_attrs, + "matrix_format": self.matrix_format, + "total_nodes_added": self._total_nodes_added, + "total_edges_added": self._total_edges_added, + "total_nodes_deleted": self._total_nodes_deleted, + "total_edges_deleted": self._total_edges_deleted, + "edge_hash_map": dict(self._edge_hash_map), # 持久化 V5 映射 (将 defaultdict 转换为普通 dict) + } + + metadata_path = data_dir / "graph_metadata.pkl" + with atomic_write(metadata_path, "wb") as f: + pickle.dump(metadata, f) + logger.debug(f"保存元数据: {metadata_path}") + + logger.info(f"图存储已保存到: {data_dir}") + + def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: + """ + 从磁盘加载 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + if not data_dir.exists(): + raise FileNotFoundError(f"数据目录不存在: {data_dir}") + + # 加载元数据 + metadata_path = data_dir / "graph_metadata.pkl" + if not metadata_path.exists(): + raise FileNotFoundError(f"元数据文件不存在: {metadata_path}") + + with open(metadata_path, "rb") as f: + metadata = pickle.load(f) + + # 恢复状态,并通过规范化处理旧数据中的重复项 + self._nodes = metadata["nodes"] + self._node_attrs = {} # 重新构建以确保键名 (Key) 规范化 + self._node_to_idx = {} # 重新构建以确保键名 (Key) 规范化 + + # 重新构建映射,处理旧数据中的碰撞 + for idx, node_name in enumerate(self._nodes): + canon = self._canonicalize(node_name) + if canon not in self._node_to_idx: + self._node_to_idx[canon] = idx + + # 处理属性 (优先保留已有的) + orig_attrs = metadata.get("node_attrs", {}) + if node_name in orig_attrs and canon not in self._node_attrs: + self._node_attrs[canon] = orig_attrs[node_name] + + self.matrix_format = metadata["matrix_format"] + self._total_nodes_added = metadata["total_nodes_added"] + self._total_edges_added = metadata["total_edges_added"] + self._total_nodes_deleted = metadata["total_nodes_deleted"] + self._total_edges_deleted = metadata["total_edges_deleted"] + + # 恢复 V5 边哈希映射 (Restore V5 edge hash map) + edge_map_data = metadata.get("edge_hash_map", {}) + # 重新初始化为 defaultdict(set) + self._edge_hash_map = defaultdict(set) + if edge_map_data: + for k, v in edge_map_data.items(): + self._edge_hash_map[k] = set(v) # 确保类型为 set + + # 加载邻接矩阵 + matrix_path = data_dir / "graph_adjacency.npz" + if matrix_path.exists(): + self._adjacency = load_npz(str(matrix_path)) + + # 确保格式正确 + if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): + self._adjacency = self._adjacency.tocsc() + elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): + self._adjacency = self._adjacency.tocsr() + + logger.debug(f"加载邻接矩阵: {matrix_path}, shape={self._adjacency.shape}") + + # 检查维度不匹配并修复 + if self._adjacency is not None: + adj_n = self._adjacency.shape[0] + current_n = len(self._nodes) + if current_n > adj_n: + logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") + self._expand_adjacency_matrix(current_n - adj_n) + + self._adjacency_dirty = True + logger.info( + f"图存储已加载: {len(self._nodes)} 个节点, " + f"{self._adjacency.nnz if self._adjacency is not None else 0} 条边" + ) + + def _expand_adjacency_matrix(self, added_nodes: int) -> None: + """ + 扩展邻接矩阵以容纳新节点 + + Args: + added_nodes: 新增节点数量 + """ + if self._adjacency is None: + n = len(self._nodes) + # 根据模式初始化 + + if self._modification_mode == GraphModificationMode.INCREMENTAL: + self._adjacency = lil_matrix((n, n), dtype=np.float32) + else: + self._adjacency = csr_matrix((n, n), dtype=np.float32) + return + + old_n = self._adjacency.shape[0] + new_n = old_n + added_nodes + + # 优化:根据模式选择不同的扩容策略 + if self._modification_mode == GraphModificationMode.INCREMENTAL: + # LIL 格式可以直接 resize,非常高效 + try: + if not isinstance(self._adjacency, lil_matrix): + self._adjacency = self._adjacency.tolil() + + self._adjacency.resize((new_n, new_n)) + # logger.debug(f"扩展 LIL 矩阵: {old_n} -> {new_n}") + except Exception as e: + logger.warning(f"LIL resize 失败,回退到通用方法: {e}") + self._expand_generic(new_n, old_n) + + else: + # CSR/CSC 格式使用 bmat 避免结构破坏警告 + try: + # bmat 需要明确的形状,不能全部依赖 None + added = new_n - old_n + # 创建零矩阵块 + # 注意: 这里统一创建 CSR 零矩阵,bmat 会处理合并 + z_tr = csr_matrix((old_n, added), dtype=np.float32) + z_bl = csr_matrix((added, old_n), dtype=np.float32) + z_br = csr_matrix((added, added), dtype=np.float32) + + self._adjacency = bmat( + [[self._adjacency, z_tr], [z_bl, z_br]], + format=self.matrix_format, + dtype=np.float32 + ) + # logger.debug(f"扩展矩阵 (bmat): {old_n} -> {new_n}") + except Exception as e: + logger.warning(f"bmat 扩展失败: {e}") + self._expand_generic(new_n, old_n) + + def _expand_generic(self, new_n: int, old_n: int): + """通用扩展方法(回退方案)""" + if self.matrix_format == "csr": + new_adjacency = csr_matrix((new_n, new_n), dtype=np.float32) + new_adjacency[:old_n, :old_n] = self._adjacency + else: + new_adjacency = csc_matrix((new_n, new_n), dtype=np.float32) + new_adjacency[:old_n, :old_n] = self._adjacency + self._adjacency = new_adjacency + self._adjacency_dirty = True + + # 如果都在增量模式,确保是LIL + if self._modification_mode == GraphModificationMode.INCREMENTAL: + try: + self._adjacency = self._adjacency.tolil() + except: + pass + + @property + def num_nodes(self) -> int: + """节点数量""" + return len(self._nodes) + + @property + def num_edges(self) -> int: + """边数量""" + if self._adjacency is None: + return 0 + return int(self._adjacency.nnz) + + @property + def density(self) -> float: + """ + 图密度(实际边数 / 可能的最大边数) + + 有向图: E / (V * (V - 1)) + 无向图: 2E / (V * (V - 1)) + + 这里按有向图计算 + """ + if self.num_nodes < 2: + return 0.0 + + max_edges = self.num_nodes * (self.num_nodes - 1) + return self.num_edges / max_edges if max_edges > 0 else 0.0 + + def __len__(self) -> int: + """节点数量""" + return self.num_nodes + + def has_data(self) -> bool: + """检查磁盘上是否存在现有数据""" + if self.data_dir is None: + return False + return (self.data_dir / "graph_metadata.pkl").exists() + + def __repr__(self) -> str: + return ( + f"GraphStore(nodes={self.num_nodes}, edges={self.num_edges}, " + f"density={self.density:.4f}, format={self.matrix_format})" + ) + + def rebuild_edge_hash_map(self, triples: List[Tuple[str, str, str, str]]) -> int: + """ + 从元数据重建 V5 边哈希映射 (Migration Tool) + + Args: + triples: List of (s, p, o, hash) + + Returns: + count of mapped hashes + """ + count = 0 + self._edge_hash_map = defaultdict(set) + + for s, p, o, h in triples: + if not h: continue + + s_canon = self._canonicalize(s) + o_canon = self._canonicalize(o) + + if s_canon in self._node_to_idx and o_canon in self._node_to_idx: + u = self._node_to_idx[s_canon] + v = self._node_to_idx[o_canon] + + # 如果是双向的,通常在元数据中存储为有向,而 GraphStore 也通常是有向的。 + # 映射键对应特定的边方向。 + self._edge_hash_map[(u, v)].add(h) + count += 1 + + self._adjacency_dirty = True + logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") + return count + diff --git a/plugins/A_memorix/core/storage/knowledge_types.py b/plugins/A_memorix/core/storage/knowledge_types.py new file mode 100644 index 00000000..4ab91218 --- /dev/null +++ b/plugins/A_memorix/core/storage/knowledge_types.py @@ -0,0 +1,183 @@ +"""Knowledge type and import strategy helpers.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, Optional + + +class KnowledgeType(str, Enum): + """持久化到 paragraphs.knowledge_type 的合法类型。""" + + STRUCTURED = "structured" + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + MIXED = "mixed" + + +class ImportStrategy(str, Enum): + """文本导入阶段的策略选择。""" + + AUTO = "auto" + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + + +def allowed_knowledge_type_values() -> tuple[str, ...]: + return tuple(item.value for item in KnowledgeType) + + +def allowed_import_strategy_values() -> tuple[str, ...]: + return tuple(item.value for item in ImportStrategy) + + +def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]: + """从字符串解析合法的落库知识类型。""" + + if not isinstance(type_str, str): + return None + normalized = type_str.lower().strip() + try: + return KnowledgeType(normalized) + except ValueError: + return None + + +def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]: + """从字符串解析文本导入策略。""" + + if not isinstance(value, str): + return None + normalized = value.lower().strip() + try: + return ImportStrategy(normalized) + except ValueError: + return None + + +def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy: + """解析 import strategy;非法值直接报错。""" + + if value is None: + return default + if isinstance(value, ImportStrategy): + return value + + normalized = str(value or "").strip().lower() + if not normalized: + return default + + strategy = get_import_strategy_from_string(normalized) + if strategy is None: + allowed = "/".join(allowed_import_strategy_values()) + raise ValueError(f"strategy_override 必须为 {allowed}") + return strategy + + +def validate_stored_knowledge_type(value: Any) -> KnowledgeType: + """校验写库 knowledge_type,仅允许合法落库类型。""" + + if isinstance(value, KnowledgeType): + return value + + resolved = get_knowledge_type_from_string(value) + if resolved is None: + allowed = "/".join(allowed_knowledge_type_values()) + raise ValueError(f"knowledge_type 必须为 {allowed}") + return resolved + + +def resolve_stored_knowledge_type( + value: Any, + *, + content: str = "", + allow_legacy: bool = False, + unknown_fallback: Optional[KnowledgeType] = None, +) -> KnowledgeType: + """ + 将策略/字符串/旧值解析为合法落库类型。 + + `allow_legacy=True` 仅供迁移使用。 + """ + + if isinstance(value, KnowledgeType): + return value + + if isinstance(value, ImportStrategy): + if value == ImportStrategy.AUTO: + if not str(content or "").strip(): + raise ValueError("knowledge_type=auto 需要 content 才能推断") + from .type_detection import detect_knowledge_type + + return detect_knowledge_type(content) + return KnowledgeType(value.value) + + raw = str(value or "").strip() + if not raw: + if str(content or "").strip(): + from .type_detection import detect_knowledge_type + + return detect_knowledge_type(content) + raise ValueError("knowledge_type 不能为空") + + direct = get_knowledge_type_from_string(raw) + if direct is not None: + return direct + + strategy = get_import_strategy_from_string(raw) + if strategy is not None: + return resolve_stored_knowledge_type(strategy, content=content) + + if allow_legacy: + normalized = raw.lower() + if normalized == "imported": + return KnowledgeType.FACTUAL + if str(content or "").strip(): + from .type_detection import detect_knowledge_type + + detected = detect_knowledge_type(content) + if detected is not None: + return detected + if unknown_fallback is not None: + return unknown_fallback + + allowed = "/".join(allowed_knowledge_type_values()) + raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})") + + +def should_extract_relations(knowledge_type: KnowledgeType) -> bool: + """判断是否应该做关系抽取。""" + + return knowledge_type in [ + KnowledgeType.STRUCTURED, + KnowledgeType.FACTUAL, + KnowledgeType.MIXED, + ] + + +def get_default_chunk_size(knowledge_type: KnowledgeType) -> int: + """获取默认分块大小。""" + + chunk_sizes = { + KnowledgeType.STRUCTURED: 300, + KnowledgeType.NARRATIVE: 800, + KnowledgeType.FACTUAL: 500, + KnowledgeType.QUOTE: 400, + KnowledgeType.MIXED: 500, + } + return chunk_sizes.get(knowledge_type, 500) + + +def get_type_display_name(knowledge_type: KnowledgeType) -> str: + """获取知识类型中文名称。""" + + display_names = { + KnowledgeType.STRUCTURED: "结构化知识", + KnowledgeType.NARRATIVE: "叙事性文本", + KnowledgeType.FACTUAL: "事实陈述", + KnowledgeType.QUOTE: "引用文本", + KnowledgeType.MIXED: "混合类型", + } + return display_names.get(knowledge_type, "未知类型") diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py new file mode 100644 index 00000000..e94610f0 --- /dev/null +++ b/plugins/A_memorix/core/storage/metadata_store.py @@ -0,0 +1,5225 @@ +""" +元数据存储模块 + +基于SQLite的元数据管理,存储段落、实体、关系等信息。 +""" + +import sqlite3 +import pickle +import json +from datetime import datetime +from pathlib import Path +from typing import Optional, Union, List, Dict, Any, Tuple + +from src.common.logger import get_logger +from ..utils.hash import compute_hash, normalize_text +from ..utils.time_parser import normalize_time_meta +from .knowledge_types import ( + KnowledgeType, + allowed_knowledge_type_values, + resolve_stored_knowledge_type, + validate_stored_knowledge_type, +) + +logger = get_logger("A_Memorix.MetadataStore") + + +SCHEMA_VERSION = 7 + + +class MetadataStore: + """ + 元数据存储类 + + 功能: + - SQLite数据库管理 + - 段落/实体/关系元数据存储 + - 增删改查操作 + - 事务支持 + - 索引优化 + + 参数: + data_dir: 数据目录 + db_name: 数据库文件名(默认metadata.db) + """ + + def __init__( + self, + data_dir: Optional[Union[str, Path]] = None, + db_name: str = "metadata.db", + ): + """ + 初始化元数据存储 + + Args: + data_dir: 数据目录 + db_name: 数据库文件名 + """ + self.data_dir = Path(data_dir) if data_dir else None + self.db_name = db_name + self._conn: Optional[sqlite3.Connection] = None + self._is_initialized = False + self._db_path: Optional[Path] = None + + logger.info(f"MetadataStore 初始化: db={db_name}") + + def connect( + self, + data_dir: Optional[Union[str, Path]] = None, + *, + enforce_schema: bool = True, + ) -> None: + """ + 连接到数据库 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + db_path = data_dir / self.db_name + db_existed = db_path.exists() + self._db_path = db_path + + # 连接数据库 + self._conn = sqlite3.connect( + str(db_path), + check_same_thread=False, + timeout=30.0, + ) + self._conn.row_factory = sqlite3.Row # 使用字典式访问 + + # 优化性能 + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA cache_size=-64000") # 64MB缓存 + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.execute("PRAGMA foreign_keys = ON") # 开启外键约束支持级联删除 + + logger.info(f"连接到数据库: {db_path}") + + # 初始化或校验 schema + if not self._is_initialized: + if not db_existed: + self._initialize_tables() + if enforce_schema: + self._assert_schema_compatible(db_existed=db_existed) + self._is_initialized = True + + # 初始化 FTS schema(幂等) + try: + self.ensure_fts_schema() + except Exception as e: + logger.warning(f"初始化 FTS schema 失败,将跳过 BM25 检索: {e}") + + def _assert_schema_compatible(self, db_existed: bool) -> None: + """vNext 运行时只做 schema 版本校验,不做隐式迁移。""" + cursor = self._conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" + ) + has_version_table = cursor.fetchone() is not None + if not has_version_table: + if db_existed: + raise RuntimeError( + "检测到旧版 metadata schema(缺少 schema_migrations)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + return + + cursor.execute("SELECT MAX(version) FROM schema_migrations") + row = cursor.fetchone() + version = int(row[0]) if row and row[0] is not None else 0 + if version != SCHEMA_VERSION: + raise RuntimeError( + f"metadata schema 版本不匹配: current={version}, expected={SCHEMA_VERSION}。" + " 请执行 scripts/release_vnext_migrate.py migrate。" + ) + + def close(self) -> None: + """关闭数据库连接""" + if self._conn: + self._conn.close() + self._conn = None + logger.info("数据库连接已关闭") + + def _initialize_tables(self) -> None: + """初始化数据库表结构""" + cursor = self._conn.cursor() + + # 段落表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraphs ( + hash TEXT PRIMARY KEY, + content TEXT NOT NULL, + vector_index INTEGER, + created_at REAL, + updated_at REAL, + metadata TEXT, + source TEXT, + word_count INTEGER, + event_time REAL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + knowledge_type TEXT DEFAULT 'mixed', + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_deleted INTEGER DEFAULT 0, + deleted_at REAL + ) + """) + + # 实体表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS entities ( + hash TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + vector_index INTEGER, + appearance_count INTEGER DEFAULT 1, + created_at REAL, + metadata TEXT, + is_deleted INTEGER DEFAULT 0, + deleted_at REAL + ) + """) + + # 关系表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + UNIQUE(subject, predicate, object) + ) + """) + + # 回收站关系表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS deleted_relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + deleted_at REAL + ) + """) + + # 32位哈希别名映射(用于 vNext 唯一解析) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + + # Schema 版本 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at REAL NOT NULL + ) + """) + + # 三元组与段落的关联表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_relations ( + paragraph_hash TEXT NOT NULL, + relation_hash TEXT NOT NULL, + PRIMARY KEY (paragraph_hash, relation_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE + ) + """) + + # 实体与段落的关联表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_entities ( + paragraph_hash TEXT NOT NULL, + entity_hash TEXT NOT NULL, + mention_count INTEGER DEFAULT 1, + PRIMARY KEY (paragraph_hash, entity_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (entity_hash) REFERENCES entities(hash) ON DELETE CASCADE + ) + """) + + # 创建索引 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_vector + ON paragraphs(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_vector + ON entities(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_vector + ON relations(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_subject + ON relations(subject) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_object + ON relations(object) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_name + ON entities(name) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_source + ON paragraphs(source) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_deleted + ON paragraphs(is_deleted, deleted_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_deleted + ON entities(is_deleted, deleted_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_inactive + ON relations(is_inactive, inactive_since) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_protected + ON relations(is_pinned, protected_until) + """) + + # 人物画像开关表(按 stream_id + user_id 维度) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_switches ( + stream_id TEXT NOT NULL, + user_id TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 0, + updated_at REAL NOT NULL, + PRIMARY KEY (stream_id, user_id) + ) + """) + + # 人物画像快照表(版本化) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_snapshots ( + snapshot_id INTEGER PRIMARY KEY AUTOINCREMENT, + person_id TEXT NOT NULL, + profile_version INTEGER NOT NULL, + profile_text TEXT NOT NULL, + aliases_json TEXT, + relation_edges_json TEXT, + vector_evidence_json TEXT, + evidence_ids_json TEXT, + updated_at REAL NOT NULL, + expires_at REAL, + source_note TEXT, + UNIQUE(person_id, profile_version) + ) + """) + + # 已开启范围内的活跃人物集合 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_active_persons ( + stream_id TEXT NOT NULL, + user_id TEXT NOT NULL, + person_id TEXT NOT NULL, + last_seen_at REAL NOT NULL, + PRIMARY KEY (stream_id, user_id, person_id) + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_overrides ( + person_id TEXT PRIMARY KEY, + override_text TEXT NOT NULL, + updated_at REAL NOT NULL, + updated_by TEXT, + source TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_switches_enabled + ON person_profile_switches(enabled) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_snapshots_person + ON person_profile_snapshots(person_id, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_active_seen + ON person_profile_active_persons(last_seen_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_overrides_updated + ON person_profile_overrides(updated_at DESC) + """) + + # Episode 情景记忆表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + episode_id TEXT PRIMARY KEY, + source TEXT, + title TEXT NOT NULL, + summary TEXT NOT NULL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + participants_json TEXT, + keywords_json TEXT, + evidence_ids_json TEXT, + paragraph_count INTEGER DEFAULT 0, + llm_confidence REAL DEFAULT 0.0, + segmentation_model TEXT, + segmentation_version TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + + # Episode -> Paragraph 映射 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_paragraphs ( + episode_id TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + position INTEGER DEFAULT 0, + PRIMARY KEY (episode_id, paragraph_hash), + FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + + # Episode 生成队列(异步) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( + paragraph_hash TEXT PRIMARY KEY, + source TEXT, + created_at REAL, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( + source TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + reason TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end + ON episodes(source, event_time_end DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_updated_at + ON episodes(updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph + ON episode_paragraphs(paragraph_hash) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated + ON episode_pending_paragraphs(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created + ON episode_pending_paragraphs(source, created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated + ON episode_rebuild_sources(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at + ON episode_rebuild_sources(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS external_memory_refs ( + external_id TEXT PRIMARY KEY, + paragraph_hash TEXT NOT NULL, + source_type TEXT, + created_at REAL NOT NULL, + metadata_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph + ON external_memory_refs(paragraph_hash) + """) + # 新版 schema 包含完整字段,直接写入版本信息 + cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) + self._conn.commit() + logger.debug("数据库表结构初始化完成") + + def _migrate_schema(self) -> None: + """执行数据库schema迁移""" + cursor = self._conn.cursor() + + # vNext 关键表兜底:历史库可能缺失,需在迁移阶段主动补齐。 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at REAL NOT NULL + ) + """) + + # Episode MVP 表结构补齐 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + episode_id TEXT PRIMARY KEY, + source TEXT, + title TEXT NOT NULL, + summary TEXT NOT NULL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + participants_json TEXT, + keywords_json TEXT, + evidence_ids_json TEXT, + paragraph_count INTEGER DEFAULT 0, + llm_confidence REAL DEFAULT 0.0, + segmentation_model TEXT, + segmentation_version TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_paragraphs ( + episode_id TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + position INTEGER DEFAULT 0, + PRIMARY KEY (episode_id, paragraph_hash), + FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( + paragraph_hash TEXT PRIMARY KEY, + source TEXT, + created_at REAL, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( + source TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + reason TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end + ON episodes(source, event_time_end DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_updated_at + ON episodes(updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph + ON episode_paragraphs(paragraph_hash) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated + ON episode_pending_paragraphs(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created + ON episode_pending_paragraphs(source, created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated + ON episode_rebuild_sources(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at + ON episode_rebuild_sources(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS external_memory_refs ( + external_id TEXT PRIMARY KEY, + paragraph_hash TEXT NOT NULL, + source_type TEXT, + created_at REAL NOT NULL, + metadata_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph + ON external_memory_refs(paragraph_hash) + """) + + # 检查paragraphs表是否有knowledge_type列 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + + if "knowledge_type" not in columns: + logger.info("检测到旧版schema,正在迁移添加knowledge_type字段...") + try: + cursor.execute(""" + ALTER TABLE paragraphs + ADD COLUMN knowledge_type TEXT DEFAULT 'mixed' + """) + self._conn.commit() + logger.info("Schema迁移完成:已添加knowledge_type字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败(可能已存在): {e}") + + # 问题2: 时序字段迁移 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + temporal_columns = { + "event_time": "ALTER TABLE paragraphs ADD COLUMN event_time REAL", + "event_time_start": "ALTER TABLE paragraphs ADD COLUMN event_time_start REAL", + "event_time_end": "ALTER TABLE paragraphs ADD COLUMN event_time_end REAL", + "time_granularity": "ALTER TABLE paragraphs ADD COLUMN time_granularity TEXT", + "time_confidence": "ALTER TABLE paragraphs ADD COLUMN time_confidence REAL DEFAULT 1.0", + } + for col, sql in temporal_columns.items(): + if col not in columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败({col}): {e}") + + # 时序索引(仅在列存在时创建,兼容旧库迁移) + self._create_temporal_indexes_if_ready() + self._conn.commit() + + # 检查paragraphs表是否有is_permanent列 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_permanent" not in columns: + logger.info("正在迁移: 添加记忆动态字段...") + try: + # 段落表 + cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_permanent BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN last_accessed REAL") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN access_count INTEGER DEFAULT 0") + + # 关系表 + cursor.execute("ALTER TABLE relations ADD COLUMN is_permanent BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN last_accessed REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN access_count INTEGER DEFAULT 0") + + self._conn.commit() + logger.info("Schema迁移完成:已添加记忆动态字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败: {e}") + + # 检查relations表是否有is_inactive列 (V5 Memory System) + cursor.execute("PRAGMA table_info(relations)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_inactive" not in columns: + logger.info("正在迁移: 添加V5记忆动态字段 (inactive, protected)...") + try: + # 关系表 V5 新增字段 + cursor.execute("ALTER TABLE relations ADD COLUMN is_inactive BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN inactive_since REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN is_pinned BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN protected_until REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN last_reinforced REAL") + + # 为回收站创建 deleted_relations 表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS deleted_relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + deleted_at REAL -- 用于记录删除时间的额外列 + ) + """) + + self._conn.commit() + logger.info("Schema迁移完成:已添加V5记忆动态字段及回收站表") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (V5): {e}") + + # 关系向量状态字段迁移 + cursor.execute("PRAGMA table_info(relations)") + relation_columns = {row[1] for row in cursor.fetchall()} + relation_vector_columns = { + "vector_state": "ALTER TABLE relations ADD COLUMN vector_state TEXT DEFAULT 'none'", + "vector_updated_at": "ALTER TABLE relations ADD COLUMN vector_updated_at REAL", + "vector_error": "ALTER TABLE relations ADD COLUMN vector_error TEXT", + "vector_retry_count": "ALTER TABLE relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", + } + for col, sql in relation_vector_columns.items(): + if col not in relation_columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (relations.{col}): {e}") + + # 回收站同步字段迁移(用于 restore 保留向量状态) + cursor.execute("PRAGMA table_info(deleted_relations)") + deleted_relation_columns = {row[1] for row in cursor.fetchall()} + deleted_relation_vector_columns = { + "vector_state": "ALTER TABLE deleted_relations ADD COLUMN vector_state TEXT DEFAULT 'none'", + "vector_updated_at": "ALTER TABLE deleted_relations ADD COLUMN vector_updated_at REAL", + "vector_error": "ALTER TABLE deleted_relations ADD COLUMN vector_error TEXT", + "vector_retry_count": "ALTER TABLE deleted_relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", + } + for col, sql in deleted_relation_vector_columns.items(): + if col not in deleted_relation_columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (deleted_relations.{col}): {e}") + + # 检查 entities 表是否有 is_deleted 列 (Soft Delete System) + cursor.execute("PRAGMA table_info(entities)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_deleted" not in columns: + logger.info("正在迁移: 添加软删除字段 (Soft Delete)...") + try: + # 实体表 + cursor.execute("ALTER TABLE entities ADD COLUMN is_deleted INTEGER DEFAULT 0") + cursor.execute("ALTER TABLE entities ADD COLUMN deleted_at REAL") + + # 段落表 + cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_deleted INTEGER DEFAULT 0") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN deleted_at REAL") + + self._conn.commit() + logger.info("Schema迁移完成:已添加软删除字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (Soft Delete): {e}") + + # 数据修复: 检查是否存在 source/vector_index 列错位的情况 + # 症状: vector_index (本应是int) 变成了文件名字符串, source (本应是文件名) 变成了类型字符串 + try: + cursor.execute(""" + SELECT count(*) FROM paragraphs + WHERE typeof(vector_index) = 'text' + AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') + """) + count = cursor.fetchone()[0] + if count > 0: + logger.warning(f"检测到 {count} 条数据存在列错位(文件名误存入vector_index),正在自动修复...") + cursor.execute(""" + UPDATE paragraphs + SET + knowledge_type = source, + source = vector_index, + vector_index = NULL + WHERE typeof(vector_index) = 'text' + AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') + """) + self._conn.commit() + logger.info(f"自动修复完成: 已校正 {cursor.rowcount} 条数据") + except Exception as e: + logger.error(f"数据自动修复失败: {e}") + + def _create_temporal_indexes_if_ready(self) -> None: + """ + 仅当时序列已存在时创建索引。 + + 旧库升级时,_initialize_tables 不能提前对不存在的列建索引; + 因此统一在迁移阶段按列存在性安全创建。 + """ + cursor = self._conn.cursor() + cursor.execute("PRAGMA table_info(paragraphs)") + columns = {row[1] for row in cursor.fetchall()} + + if "event_time" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_time ON paragraphs(event_time)" + ) + if "event_time_start" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_start ON paragraphs(event_time_start)" + ) + if "event_time_end" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_end ON paragraphs(event_time_end)" + ) + + def run_legacy_migration_for_vnext(self) -> Dict[str, Any]: + """ + 离线迁移入口: + - 复用旧迁移逻辑补齐历史库字段 + - 重建 relation 32位别名 + - 归一化历史 knowledge_type + - 写入 vNext schema 版本 + """ + self._migrate_schema() + alias_result = self.rebuild_relation_hash_aliases() + knowledge_type_result = self.normalize_paragraph_knowledge_types() + self.set_schema_version(SCHEMA_VERSION) + return { + "schema_version": SCHEMA_VERSION, + "alias_result": alias_result, + "knowledge_type_result": knowledge_type_result, + } + + def list_invalid_paragraph_knowledge_types(self) -> List[str]: + """列出当前库中不合法的段落 knowledge_type。""" + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT DISTINCT knowledge_type + FROM paragraphs + WHERE knowledge_type IS NULL + OR TRIM(COALESCE(knowledge_type, '')) = '' + OR LOWER(TRIM(knowledge_type)) NOT IN ({placeholders}) + ORDER BY knowledge_type + """.format(placeholders=", ".join("?" for _ in allowed_knowledge_type_values())), + tuple(allowed_knowledge_type_values()), + ) + invalid: List[str] = [] + for row in cursor.fetchall(): + raw = row[0] + invalid.append(str(raw) if raw is not None else "") + return invalid + + def normalize_paragraph_knowledge_types(self) -> Dict[str, Any]: + """将历史非法 knowledge_type 归一化为合法值。""" + + cursor = self._conn.cursor() + cursor.execute("SELECT hash, content, knowledge_type FROM paragraphs") + rows = cursor.fetchall() + + normalized_count = 0 + normalized_map: Dict[str, int] = {} + invalid_before: List[str] = [] + invalid_seen = set() + + for row in rows: + paragraph_hash = str(row["hash"]) + content = str(row["content"] or "") + raw_value = row["knowledge_type"] + try: + validate_stored_knowledge_type(raw_value) + continue + except ValueError: + raw_text = str(raw_value) if raw_value is not None else "" + if raw_text not in invalid_seen: + invalid_seen.add(raw_text) + invalid_before.append(raw_text) + + normalized_type = resolve_stored_knowledge_type( + raw_value, + content=content, + allow_legacy=True, + unknown_fallback=KnowledgeType.MIXED, + ) + cursor.execute( + "UPDATE paragraphs SET knowledge_type = ? WHERE hash = ?", + (normalized_type.value, paragraph_hash), + ) + normalized_count += 1 + normalized_map[normalized_type.value] = normalized_map.get(normalized_type.value, 0) + 1 + + self._conn.commit() + return { + "normalized": normalized_count, + "invalid_before": sorted(invalid_before), + "normalized_to": normalized_map, + } + + def _resolve_conn(self, conn: Optional[sqlite3.Connection] = None) -> sqlite3.Connection: + """解析可用连接。""" + resolved = conn or self._conn + if resolved is None: + raise RuntimeError("MetadataStore 未连接数据库") + return resolved + + def get_db_path(self) -> Path: + """获取 SQLite 数据库文件路径。""" + if self._db_path is not None: + return self._db_path + if self.data_dir is None: + raise RuntimeError("MetadataStore 未配置 data_dir") + return Path(self.data_dir) / self.db_name + + def ensure_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保 FTS5 schema 存在(幂等)。 + + 采用 external-content 方式,不在 FTS 表重复存储正文。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS paragraphs_fts + USING fts5( + content, + content='paragraphs', + content_rowid='rowid', + tokenize='unicode61' + ) + """) + + # insert trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_ai + AFTER INSERT ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(rowid, content) + VALUES (new.rowid, new.content); + END + """) + + # delete trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_ad + AFTER DELETE ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) + VALUES ('delete', old.rowid, old.content); + END + """) + + # update trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_au + AFTER UPDATE OF content ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) + VALUES ('delete', old.rowid, old.content); + INSERT INTO paragraphs_fts(rowid, content) + VALUES (new.rowid, new.content); + END + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS5 schema 创建失败(可能不支持 FTS5): {e}") + c.rollback() + return False + + def ensure_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保 FTS 索引已回填。 + + 当历史数据存在但 FTS 表为空/不一致时执行 rebuild。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) AS n FROM paragraphs") + para_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(1) AS n FROM paragraphs_fts") + fts_count = int(cur.fetchone()[0]) + + if para_count > 0 and fts_count != para_count: + cur.execute("INSERT INTO paragraphs_fts(paragraphs_fts) VALUES ('rebuild')") + c.commit() + logger.info(f"FTS 回填完成: paragraphs={para_count}, fts={para_count}") + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS 回填失败: {e}") + c.rollback() + return False + + def ensure_relations_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保关系 FTS5 schema 存在(幂等)。 + + 注意:relations 表没有 content 列,因此使用独立 FTS 表并通过触发器同步。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS relations_fts + USING fts5( + relation_hash UNINDEXED, + content, + tokenize='unicode61' + ) + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_ai + AFTER INSERT ON relations + BEGIN + INSERT INTO relations_fts(relation_hash, content) + VALUES ( + new.hash, + COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') + ); + END + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_ad + AFTER DELETE ON relations + BEGIN + DELETE FROM relations_fts WHERE relation_hash = old.hash; + END + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_au + AFTER UPDATE OF subject, predicate, object ON relations + BEGIN + DELETE FROM relations_fts WHERE relation_hash = new.hash; + INSERT INTO relations_fts(relation_hash, content) + VALUES ( + new.hash, + COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') + ); + END + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS5 schema 创建失败(可能不支持 FTS5): {e}") + c.rollback() + return False + + def ensure_relations_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """确保关系 FTS 索引已回填。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) AS n FROM relations") + rel_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(1) AS n FROM relations_fts") + fts_count = int(cur.fetchone()[0]) + + if rel_count != fts_count: + cur.execute("DELETE FROM relations_fts") + cur.execute(""" + INSERT INTO relations_fts(relation_hash, content) + SELECT + r.hash, + COALESCE(r.subject, '') || ' ' || COALESCE(r.predicate, '') || ' ' || COALESCE(r.object, '') + FROM relations r + """) + c.commit() + logger.info(f"relations FTS 回填完成: relations={rel_count}, fts={rel_count}") + return True + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS 回填失败: {e}") + c.rollback() + return False + + def ensure_paragraph_ngram_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """确保段落 ngram 倒排表存在。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_ngrams ( + term TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + PRIMARY KEY (term, paragraph_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_ngrams_hash + ON paragraph_ngrams(paragraph_hash) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_ngram_meta ( + key TEXT PRIMARY KEY, + value TEXT + ) + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"paragraph ngram schema 创建失败: {e}") + c.rollback() + return False + + @staticmethod + def _char_ngrams(text: str, n: int) -> List[str]: + compact = "".join(str(text or "").lower().split()) + if not compact: + return [] + if len(compact) < n: + return [compact] + return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] + + def ensure_paragraph_ngram_backfilled( + self, + n: int = 2, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 确保段落 ngram 倒排索引已回填。 + + 仅在 n 变化或文档数量变化时重建,避免每次加载都全量重建。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + n = max(1, int(n)) + try: + cur.execute("SELECT value FROM paragraph_ngram_meta WHERE key='ngram_n'") + row = cur.fetchone() + current_n = int(row[0]) if row and row[0] is not None else None + + cur.execute("SELECT COUNT(1) FROM paragraphs WHERE is_deleted IS NULL OR is_deleted = 0") + para_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(DISTINCT paragraph_hash) FROM paragraph_ngrams") + indexed_docs = int(cur.fetchone()[0]) + + need_rebuild = (current_n != n) or (para_count != indexed_docs) + if not need_rebuild: + return True + + cur.execute("DELETE FROM paragraph_ngrams") + cur.execute(""" + SELECT hash, content + FROM paragraphs + WHERE is_deleted IS NULL OR is_deleted = 0 + """) + rows = cur.fetchall() + + batch: List[Tuple[str, str]] = [] + batch_size = 2000 + for row in rows: + p_hash = str(row["hash"]) + terms = list(dict.fromkeys(self._char_ngrams(str(row["content"] or ""), n))) + for term in terms: + batch.append((term, p_hash)) + if len(batch) >= batch_size: + cur.executemany( + "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", + batch, + ) + batch.clear() + if batch: + cur.executemany( + "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", + batch, + ) + + cur.execute(""" + INSERT INTO paragraph_ngram_meta(key, value) VALUES('ngram_n', ?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value + """, (str(n),)) + cur.execute(""" + INSERT INTO paragraph_ngram_meta(key, value) VALUES('paragraph_count', ?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value + """, (str(para_count),)) + c.commit() + logger.info(f"paragraph ngram 回填完成: n={n}, paragraphs={para_count}") + return True + except Exception as e: + logger.warning(f"paragraph ngram 回填失败: {e}") + c.rollback() + return False + + def fts_upsert_paragraph( + self, + paragraph_hash: str, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 将段落写入(或覆盖)到 FTS 索引。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + "SELECT rowid, content FROM paragraphs WHERE hash = ?", + (paragraph_hash,), + ) + row = cur.fetchone() + if not row: + return False + rowid = int(row[0]) + content = str(row[1] or "") + cur.execute( + "INSERT OR REPLACE INTO paragraphs_fts(rowid, content) VALUES (?, ?)", + (rowid, content), + ) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS upsert 失败: {e}") + c.rollback() + return False + + def fts_delete_paragraph( + self, + paragraph_hash: str, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 从 FTS 索引删除段落。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + "SELECT rowid, content FROM paragraphs WHERE hash = ?", + (paragraph_hash,), + ) + row = cur.fetchone() + if not row: + return False + rowid = int(row[0]) + content = str(row[1] or "") + cur.execute( + "INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) VALUES ('delete', ?, ?)", + (rowid, content), + ) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS delete 失败: {e}") + c.rollback() + return False + + def fts_search_bm25( + self, + match_query: str, + limit: int = 20, + max_doc_len: int = 2000, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """ + 使用 FTS5 + bm25 执行全文检索。 + """ + if not match_query.strip(): + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + """ + SELECT p.hash, p.content, bm25(paragraphs_fts) AS bm25_score + FROM paragraphs_fts + JOIN paragraphs p ON p.rowid = paragraphs_fts.rowid + WHERE paragraphs_fts MATCH ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY bm25_score ASC + LIMIT ? + """, + (match_query, max(1, int(limit))), + ) + rows = cur.fetchall() + results: List[Dict[str, Any]] = [] + for row in rows: + content = str(row["content"] or "") + if max_doc_len > 0: + content = content[:max_doc_len] + results.append( + { + "hash": row["hash"], + "content": content, + "bm25_score": float(row["bm25_score"]), + } + ) + return results + except sqlite3.OperationalError as e: + logger.warning(f"FTS 查询失败: {e}") + return [] + + def fts_search_relations_bm25( + self, + match_query: str, + limit: int = 20, + max_doc_len: int = 512, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """使用 FTS5 + bm25 执行关系全文检索。""" + if not match_query.strip(): + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + """ + SELECT + r.hash, + r.subject, + r.predicate, + r.object, + bm25(relations_fts) AS bm25_score + FROM relations_fts + JOIN relations r ON r.hash = relations_fts.relation_hash + WHERE relations_fts MATCH ? + ORDER BY bm25_score ASC + LIMIT ? + """, + (match_query, max(1, int(limit))), + ) + rows = cur.fetchall() + out: List[Dict[str, Any]] = [] + for row in rows: + content = f"{row['subject']} {row['predicate']} {row['object']}" + if max_doc_len > 0: + content = content[:max_doc_len] + out.append( + { + "hash": row["hash"], + "subject": row["subject"], + "predicate": row["predicate"], + "object": row["object"], + "content": content, + "bm25_score": float(row["bm25_score"]), + } + ) + return out + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS 查询失败: {e}") + return [] + + def ngram_search_paragraphs( + self, + tokens: List[str], + limit: int = 20, + max_doc_len: int = 2000, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """按 ngram 倒排索引检索段落,避免 LIKE 全表扫描。""" + uniq = [t for t in dict.fromkeys([str(x).strip().lower() for x in tokens]) if t] + if not uniq: + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + placeholders = ",".join(["?"] * len(uniq)) + try: + cur.execute( + f""" + SELECT + p.hash, + p.content, + COUNT(*) AS hit_terms + FROM paragraph_ngrams ng + JOIN paragraphs p ON p.hash = ng.paragraph_hash + WHERE ng.term IN ({placeholders}) + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + GROUP BY p.hash, p.content + ORDER BY hit_terms DESC + LIMIT ? + """, + tuple(uniq + [max(1, int(limit))]), + ) + rows = cur.fetchall() + out: List[Dict[str, Any]] = [] + token_count = max(1, len(uniq)) + for row in rows: + hit_terms = int(row["hit_terms"]) + score = float(hit_terms / token_count) + content = str(row["content"] or "") + if max_doc_len > 0: + content = content[:max_doc_len] + out.append( + { + "hash": row["hash"], + "content": content, + "bm25_score": -score, + "fallback_score": score, + } + ) + return out + except sqlite3.OperationalError as e: + logger.warning(f"ngram 倒排查询失败: {e}") + return [] + + def fts_doc_count(self, conn: Optional[sqlite3.Connection] = None) -> int: + """获取 FTS 文档数量。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) FROM paragraphs_fts") + return int(cur.fetchone()[0]) + except sqlite3.OperationalError: + return 0 + + def shrink_memory(self, conn: Optional[sqlite3.Connection] = None) -> None: + """请求 SQLite 收缩当前连接缓存。""" + c = self._resolve_conn(conn) + try: + c.execute("PRAGMA shrink_memory") + except sqlite3.OperationalError: + pass + + @staticmethod + def _normalize_episode_source(source: Any) -> str: + return str(source or "").strip() + + def _dedupe_episode_sources(self, sources: List[Any]) -> List[str]: + normalized: List[str] = [] + seen = set() + for item in sources or []: + token = self._normalize_episode_source(item) + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + return normalized + + def _get_sources_for_paragraph_hashes( + self, + hashes: List[str], + *, + include_deleted: bool = True, + ) -> List[str]: + normalized_hashes = [ + str(item or "").strip() + for item in (hashes or []) + if str(item or "").strip() + ] + if not normalized_hashes: + return [] + + placeholders = ",".join(["?"] * len(normalized_hashes)) + conditions = ["hash IN ({})".format(placeholders), "TRIM(COALESCE(source, '')) != ''"] + if not include_deleted: + conditions.append("(is_deleted IS NULL OR is_deleted = 0)") + + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT DISTINCT TRIM(source) AS source + FROM paragraphs + WHERE {' AND '.join(conditions)} + """, + tuple(normalized_hashes), + ) + return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) + + def _enqueue_episode_source_rebuilds(self, sources: List[Any], reason: str = "") -> int: + normalized_sources = self._dedupe_episode_sources(sources) + if not normalized_sources: + return 0 + + now = datetime.now().timestamp() + reason_text = str(reason or "").strip()[:200] or None + cursor = self._conn.cursor() + cursor.executemany( + """ + INSERT INTO episode_rebuild_sources ( + source, status, retry_count, last_error, reason, requested_at, updated_at + ) VALUES (?, 'pending', 0, NULL, ?, ?, ?) + ON CONFLICT(source) DO UPDATE SET + status = 'pending', + last_error = NULL, + reason = excluded.reason, + requested_at = excluded.requested_at, + updated_at = excluded.updated_at + """, + [ + (source, reason_text, now, now) + for source in normalized_sources + ], + ) + self._conn.commit() + return len(normalized_sources) + + def add_paragraph( + self, + content: str, + vector_index: Optional[int] = None, + source: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + knowledge_type: str = "mixed", + time_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加段落 + + Args: + content: 段落内容 + vector_index: 向量索引 + source: 来源 + metadata: 额外元数据 + knowledge_type: 知识类型 (narrative/factual/quote/structured/mixed) + time_meta: 时间元信息 (event_time/event_time_start/event_time_end/...) + + Returns: + 段落哈希值 + """ + content_normalized = normalize_text(content) + hash_value = compute_hash(content_normalized) + resolved_knowledge_type = validate_stored_knowledge_type(knowledge_type) + + now = datetime.now().timestamp() + word_count = len(content_normalized.split()) + normalized_time = normalize_time_meta(time_meta) + + cursor = self._conn.cursor() + try: + cursor.execute(""" + INSERT INTO paragraphs + ( + hash, content, vector_index, created_at, updated_at, metadata, source, word_count, + event_time, event_time_start, event_time_end, time_granularity, time_confidence, + knowledge_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + hash_value, + content, + vector_index, + now, + now, + pickle.dumps(metadata or {}), + source, + word_count, + normalized_time.get("event_time"), + normalized_time.get("event_time_start"), + normalized_time.get("event_time_end"), + normalized_time.get("time_granularity"), + normalized_time.get("time_confidence", 1.0), + resolved_knowledge_type.value, + )) + self._conn.commit() + try: + self.enqueue_episode_source_rebuild( + source=source, + reason="paragraph_added", + ) + except Exception as e: + logger.warning(f"Episode source 重建入队失败: hash={hash_value[:16]}..., err={e}") + logger.debug( + f"添加段落: hash={hash_value[:16]}..., words={word_count}, type={resolved_knowledge_type.value}" + ) + return hash_value + except sqlite3.IntegrityError: + logger.debug(f"段落已存在: {hash_value[:16]}...") + # 尝试复活 + self.revive_if_deleted(paragraph_hashes=[hash_value]) + return hash_value + + def _canonicalize_name(self, name: str) -> str: + """ + 规范化名称 (统一小写并去除首尾空格) + + Args: + name: 原始名称 + + Returns: + 规范化后的名称 + """ + if not name: + return "" + return name.strip().lower() + + def add_entity( + self, + name: str, + vector_index: Optional[int] = None, + source_paragraph: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加实体 + + Args: + name: 实体名称 + vector_index: 向量索引 + source_paragraph: 来源段落哈希 (如果提供,将建立关联) + metadata: 额外元数据 + + Returns: + 实体哈希值 + """ + # 1. 规范化名称 + name_normalized = self._canonicalize_name(name) + if not name_normalized: + raise ValueError("Entity name cannot be empty") + + hash_value = compute_hash(name_normalized) + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + + # 2. 插入实体 (INSERT OR IGNORE) + # 注意:这里我们保留原有的 name 字段存储,可以是 display name, + # 但 hash 必须由 canonical name 生成。 + # 如果实体已存在,我们其实不一定要更新 name (保留第一次的 display name 往往更好) + # 或者我们也可以选择不作为唯一键冲突,而是逻辑判断。 + # 考虑到 entities.hash 是主键,entities.name 是 UNIQUE。 + # 如果 name 大小写不同但 hash 相同 (冲突),或者 name 不同但 canonical name 相同? + # 由于 hash 是由 canonical name 算出来的,所以 hash 相同意味着 canonical name 相同。 + # 如果 db 中已存在的 name 是 "Apple",新来的 name 是 "apple",它们 canonical name 都是 "apple",hash 一样。 + # 此时 INSERT OR IGNORE 会忽略。 + + try: + cursor.execute(""" + INSERT INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, 1, ?, ?) + """, ( + hash_value, + name, + vector_index, + now, + pickle.dumps(metadata or {}), + )) + + logger.debug(f"添加实体: {name} ({hash_value[:8]})") + self._conn.commit() + + # 3. 建立来源关联 + if source_paragraph: + self.link_paragraph_entity(source_paragraph, hash_value) + + return hash_value + + except sqlite3.IntegrityError: + # 实体已存在 + # 1. 尝试复活 (自动复活) + self.revive_if_deleted(entity_hashes=[hash_value]) + + # 2. 更新计数 + cursor.execute(""" + UPDATE entities + SET appearance_count = appearance_count + 1 + WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + logger.debug(f"实体已存在(复活/计数+1): {name}") + + # 3. 建立来源关联 + if source_paragraph: + self.link_paragraph_entity(source_paragraph, hash_value) + + return hash_value + + def add_relation( + self, + subject: str, + predicate: str, + obj: str, + vector_index: Optional[int] = None, + confidence: float = 1.0, + source_paragraph: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加关系 + + Args: + subject: 主语 + predicate: 谓语 + obj: 宾语 + vector_index: 向量索引 + confidence: 置信度 + source_paragraph: 来源段落哈希 + metadata: 额外元数据 + + Returns: + 关系哈希值 + """ + # 1. 规范化输入 + s_canon = self._canonicalize_name(subject) + p_canon = self._canonicalize_name(predicate) + o_canon = self._canonicalize_name(obj) + + if not all([s_canon, p_canon, o_canon]): + raise ValueError("Relation components cannot be empty") + + # 2. 计算组合哈希 + # 公式: md5(s|p|o) + relation_key = f"{s_canon}|{p_canon}|{o_canon}" + hash_value = compute_hash(relation_key) + + now = datetime.now().timestamp() + + # 记录原始 display name 到 metadata (如果需要的话,或者直接存到 DB 字段) + # 这里我们直接存入 subject, predicate, object 字段, + # 注意:如果 DB 里已存在该关系 (hash 相同),则不会更新这些字段,保留第一次的拼写。 + + cursor = self._conn.cursor() + try: + cursor.execute(""" + INSERT OR IGNORE INTO relations + (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + hash_value, + subject, # 原始拼写 + predicate, + obj, + vector_index, + confidence, + now, + source_paragraph, # 这里的 source_paragraph 仅作为 "首次发现地" 记录,也可留空 + pickle.dumps(metadata or {}), + )) + self._conn.commit() + + if cursor.rowcount > 0: + logger.debug(f"添加关系: {subject} -{predicate}-> {obj}") + else: + logger.debug(f"关系已存在: {subject} -{predicate}-> {obj}") + + # 3. 建立来源关联 (幂等) + # 无论关系是新创建的还是已存在的,只要提供了 source_paragraph,都要建立连接 + if source_paragraph: + self.link_paragraph_relation(source_paragraph, hash_value) + + return hash_value + + except sqlite3.IntegrityError as e: + logger.warning(f"添加关系异常: {e}") + return hash_value + + def link_paragraph_relation( + self, + paragraph_hash: str, + relation_hash: str, + ) -> bool: + """ + 关联段落和关系 (幂等) + """ + cursor = self._conn.cursor() + try: + # 使用 INSERT OR IGNORE 避免重复报错 + cursor.execute(""" + INSERT OR IGNORE INTO paragraph_relations + (paragraph_hash, relation_hash) + VALUES (?, ?) + """, (paragraph_hash, relation_hash)) + self._conn.commit() + self._enqueue_episode_source_rebuilds( + self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), + reason="paragraph_relation_linked", + ) + return True + except sqlite3.IntegrityError: + return False + + def link_paragraph_entity( + self, + paragraph_hash: str, + entity_hash: str, + mention_count: int = 1, + ) -> bool: + """ + 关联段落和实体 (幂等) + """ + cursor = self._conn.cursor() + try: + # 首先尝试插入 + cursor.execute(""" + INSERT OR IGNORE INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, (paragraph_hash, entity_hash, mention_count)) + + if cursor.rowcount == 0: + # 如果已存在 (IGNORE生效),则更新计数 + cursor.execute(""" + UPDATE paragraph_entities + SET mention_count = mention_count + ? + WHERE paragraph_hash = ? AND entity_hash = ? + """, (mention_count, paragraph_hash, entity_hash)) + + self._conn.commit() + self._enqueue_episode_source_rebuilds( + self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), + reason="paragraph_entity_linked", + ) + return True + except sqlite3.IntegrityError: + return False + + def get_paragraph(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取段落 + + Args: + hash_value: 段落哈希 + + Returns: + 段落信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM paragraphs WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "paragraph") + return None + + def update_paragraph_time_meta( + self, + paragraph_hash: str, + time_meta: Dict[str, Any], + ) -> bool: + """ + 更新段落时间元信息。 + """ + normalized = normalize_time_meta(time_meta) + if not normalized: + return False + source_to_rebuild = self._get_sources_for_paragraph_hashes( + [paragraph_hash], + include_deleted=True, + ) + + updates: List[str] = [] + params: List[Any] = [] + for key in [ + "event_time", + "event_time_start", + "event_time_end", + "time_granularity", + "time_confidence", + ]: + if key in normalized: + updates.append(f"{key} = ?") + params.append(normalized[key]) + + if not updates: + return False + + updates.append("updated_at = ?") + params.append(datetime.now().timestamp()) + params.append(paragraph_hash) + + cursor = self._conn.cursor() + cursor.execute( + f"UPDATE paragraphs SET {', '.join(updates)} WHERE hash = ?", + tuple(params), + ) + self._conn.commit() + changed = cursor.rowcount > 0 + if changed: + self._enqueue_episode_source_rebuilds( + source_to_rebuild, + reason="paragraph_time_updated", + ) + return changed + + def query_paragraphs_temporal( + self, + start_ts: Optional[float] = None, + end_ts: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + limit: int = 100, + allow_created_fallback: bool = True, + ) -> List[Dict[str, Any]]: + """ + 查询时序命中的段落(区间相交语义)。 + """ + if limit <= 0: + return [] + + effective_start = "COALESCE(p.event_time_start, p.event_time, p.event_time_end" + effective_end = "COALESCE(p.event_time_end, p.event_time, p.event_time_start" + if allow_created_fallback: + effective_start += ", p.created_at)" + effective_end += ", p.created_at)" + else: + effective_start += ")" + effective_end += ")" + + conditions = ["(p.is_deleted IS NULL OR p.is_deleted = 0)"] + params: List[Any] = [] + + if source: + conditions.append("p.source = ?") + params.append(source) + + if person: + conditions.append( + """ + EXISTS ( + SELECT 1 + FROM paragraph_entities pe + JOIN entities e ON e.hash = pe.entity_hash + WHERE pe.paragraph_hash = p.hash + AND LOWER(e.name) LIKE ? + ) + """ + ) + params.append(f"%{str(person).strip().lower()}%") + + if start_ts is not None and end_ts is not None: + conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") + params.extend([start_ts, end_ts]) + elif start_ts is not None: + conditions.append(f"({effective_end} >= ?)") + params.append(start_ts) + elif end_ts is not None: + conditions.append(f"({effective_start} <= ?)") + params.append(end_ts) + + where_sql = " AND ".join(conditions) + sql = f""" + SELECT p.* + FROM paragraphs p + WHERE {where_sql} + ORDER BY {effective_end} DESC, p.updated_at DESC + LIMIT ? + """ + params.append(limit) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_entity(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取实体 + + Args: + hash_value: 实体哈希 + + Returns: + 实体信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM entities WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "entity") + return None + + def get_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取关系 + + Args: + hash_value: 关系哈希 + + Returns: + 关系信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM relations WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "relation") + return None + + def get_paragraph_relations(self, paragraph_hash: str) -> List[Dict[str, Any]]: + """ + 获取段落的所有关系 + + Args: + paragraph_hash: 段落哈希 + + Returns: + 关系列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT r.* FROM relations r + JOIN paragraph_relations pr ON r.hash = pr.relation_hash + WHERE pr.paragraph_hash = ? + """, (paragraph_hash,)) + + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def get_paragraph_entities(self, paragraph_hash: str) -> List[Dict[str, Any]]: + """ + 获取段落的所有实体 + + Args: + paragraph_hash: 段落哈希 + + Returns: + 实体列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT e.*, pe.mention_count + FROM entities e + JOIN paragraph_entities pe ON e.hash = pe.entity_hash + WHERE pe.paragraph_hash = ? + """, (paragraph_hash,)) + + return [self._row_to_dict(row, "entity") for row in cursor.fetchall()] + + def get_paragraphs_by_entity(self, entity_name: str) -> List[Dict[str, Any]]: + """ + 获取包含指定实体的所有段落 (自动处理规范化) + + Args: + entity_name: 实体名称 (支持任意大小写) + + Returns: + 段落列表 + """ + # 1. 计算规范化 Hash + name_canon = self._canonicalize_name(entity_name) + if not name_canon: + return [] + + entity_hash = compute_hash(name_canon) + + cursor = self._conn.cursor() + # 2. 直接使用 Hash 查询中间表,完全避开 Name 匹配 + cursor.execute(""" + SELECT p.* + FROM paragraphs p + JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash + WHERE pe.entity_hash = ? + """, (entity_hash,)) + + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_relations( + self, + subject: Optional[str] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + 查询关系(大小写不敏感) + + Args: + subject: 主语(可选) + predicate: 谓语(可选) + object: 宾语(可选) + + Returns: + 关系列表 + """ + # 构建查询条件 + conditions = [] + params = [] + + if subject: + conditions.append("LOWER(subject) = ?") + params.append(self._canonicalize_name(subject)) + if predicate: + conditions.append("LOWER(predicate) = ?") + params.append(self._canonicalize_name(predicate)) + if object: + conditions.append("LOWER(object) = ?") + params.append(self._canonicalize_name(object)) + + sql = "SELECT * FROM relations" + if conditions: + sql += " WHERE " + " AND ".join(conditions) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def get_all_triples(self) -> List[Tuple[str, str, str, str]]: + """ + 高效获取所有三元组 (subject, predicate, object, hash) + 直接返回元组,跳过字典转换和pickle反序列化,用于构建 V5 Map 缓存。 + """ + cursor = self._conn.cursor() + cursor.execute("SELECT subject, predicate, object, hash FROM relations") + return list(cursor.fetchall()) + + def get_paragraphs_by_relation(self, relation_hash: str) -> List[Dict[str, Any]]: + """ + 获取支持指定关系的所有段落 + + Args: + relation_hash: 关系哈希 + + Returns: + 段落列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT p.* + FROM paragraphs p + JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + """, (relation_hash,)) + + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: + """ + 按来源获取段落 + + Args: + source: 来源标识符 + + Returns: + 段落列表 + """ + return self.query("SELECT * FROM paragraphs WHERE source = ?", (source,)) + + def get_all_sources(self) -> List[Dict[str, Any]]: + """ + 获取所有来源文件统计信息 + + Returns: + 来源列表 [{'source': 'name', 'count': int, 'last_updated': timestamp}] + """ + cursor = self._conn.cursor() + # 排除 source 为 NULL 或空的记录 + cursor.execute(""" + SELECT source, COUNT(*) as count, MAX(created_at) as last_updated + FROM paragraphs + WHERE source IS NOT NULL AND source != '' + GROUP BY source + ORDER BY last_updated DESC + """) + + results = [] + for row in cursor.fetchall(): + results.append({ + "source": row[0], + "count": row[1], + "last_updated": row[2] + }) + return results + + + def search_paragraphs_by_content(self, content_query: str) -> List[Dict[str, Any]]: + """按内容模糊搜索段落""" + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM paragraphs WHERE content LIKE ? + """, (f"%{content_query}%",)) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def delete_paragraph(self, hash_value: str) -> bool: + """ + 删除段落(级联删除相关关联) + + Args: + hash_value: 段落哈希 + + Returns: + 是否成功删除 + """ + cursor = self._conn.cursor() + cursor.execute(""" + DELETE FROM paragraphs WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + deleted = cursor.rowcount > 0 + if deleted: + logger.info(f"删除段落: {hash_value[:16]}...") + + return deleted + + def delete_entity(self, hash_or_name: str) -> bool: + """ + 删除实体(级联删除相关关联) + 支持通过哈希值或名称删除 + + 注意:会同时删除所有引用该实体(作为主语或宾语)的关系 + """ + cursor = self._conn.cursor() + + # 1. 解析实体信息 (获取 Name 和 Hash) + entity_name = None + entity_hash = None + + # 尝试作为 Hash 查询 + cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (hash_or_name,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + else: + # 尝试作为 Name 查询 (原始匹配) + cursor.execute("SELECT name, hash FROM entities WHERE name = ?", (hash_or_name,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + else: + # 最后的最后:尝试规范化名称 (Canonical) 查询,解决大小写或 WebUI 手动输入导致的不匹配 + name_canon = self._canonicalize_name(hash_or_name) + canon_hash = compute_hash(name_canon) + cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (canon_hash,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + + if not entity_name or not entity_hash: + logger.debug(f"删除实体请求跳过:未在元数据记录中找到 {hash_or_name}") + return False + + logger.info(f"开始删除实体: {entity_name} (Hash: {entity_hash[:8]}...)") + + try: + # 2. 查找相关关系 (Subject 或 Object 为该实体) + cursor.execute(""" + SELECT hash FROM relations + WHERE subject = ? OR object = ? + """, (entity_name, entity_name)) + + relation_hashes = [r[0] for r in cursor.fetchall()] + + if relation_hashes: + logger.info(f"发现 {len(relation_hashes)} 个相关关系,准备级联删除") + + # 3. 删除这些关系与段落的关联 + # SQLite 不支持直接 DELETE ... WHERE ... IN (...) 的列表参数,需要拼接占位符 + placeholders = ','.join(['?'] * len(relation_hashes)) + + cursor.execute(f""" + DELETE FROM paragraph_relations + WHERE relation_hash IN ({placeholders}) + """, relation_hashes) + + # 4. 删除关系本体 + cursor.execute(f""" + DELETE FROM relations + WHERE hash IN ({placeholders}) + """, relation_hashes) + + logger.info("相关关系已级联删除") + + # 5. 删除实体与段落的关联 + cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (entity_hash,)) + + # 6. 删除实体本体 + cursor.execute("DELETE FROM entities WHERE hash = ?", (entity_hash,)) + + self._conn.commit() + logger.info("实体删除完成") + return True + + except Exception as e: + logger.error(f"删除实体时发生错误: {e}") + self._conn.rollback() + return False + + def delete_relation(self, hash_value: str) -> bool: + """ + 删除关系(级联删除相关关联) + + Args: + hash_value: 关系哈希 + + Returns: + 是否成功删除 + """ + cursor = self._conn.cursor() + cursor.execute(""" + DELETE FROM relations WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + deleted = cursor.rowcount > 0 + if deleted: + logger.info(f"删除关系: {hash_value[:16]}...") + + return deleted + + def set_relation_vector_state( + self, + hash_value: str, + state: str, + error: Optional[str] = None, + bump_retry: bool = False, + ) -> bool: + """ + 更新关系向量状态。 + """ + state_norm = str(state or "").strip().lower() + if state_norm not in {"none", "pending", "ready", "failed"}: + raise ValueError(f"无效 vector_state: {state}") + + now = datetime.now().timestamp() + err_text = (str(error).strip() if error is not None else None) + if err_text: + err_text = err_text[:500] + clear_error = state_norm in {"none", "pending", "ready"} + + cursor = self._conn.cursor() + if bump_retry: + cursor.execute( + """ + UPDATE relations + SET vector_state = ?, + vector_updated_at = ?, + vector_error = ?, + vector_retry_count = COALESCE(vector_retry_count, 0) + 1 + WHERE hash = ? + """, + (state_norm, now, None if clear_error else err_text, hash_value), + ) + else: + cursor.execute( + """ + UPDATE relations + SET vector_state = ?, + vector_updated_at = ?, + vector_error = ? + WHERE hash = ? + """, + (state_norm, now, None if clear_error else err_text, hash_value), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_relations_by_vector_state( + self, + states: List[str], + limit: int = 200, + max_retry: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """ + 根据向量状态列出关系,用于回填任务。 + """ + normalized_states = [ + str(s or "").strip().lower() + for s in (states or []) + if str(s or "").strip() + ] + normalized_states = [ + s for s in normalized_states + if s in {"none", "pending", "ready", "failed"} + ] + if not normalized_states: + return [] + + placeholders = ",".join(["?"] * len(normalized_states)) + params: List[Any] = list(normalized_states) + sql = f""" + SELECT hash, subject, predicate, object, confidence, source_paragraph, + vector_state, vector_updated_at, vector_error, vector_retry_count, created_at + FROM relations + WHERE vector_state IN ({placeholders}) + """ + if max_retry is not None: + sql += " AND COALESCE(vector_retry_count, 0) < ?" + params.append(int(max_retry)) + sql += " ORDER BY COALESCE(vector_updated_at, created_at, 0) ASC LIMIT ?" + params.append(max(1, int(limit))) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def count_relations_by_vector_state(self) -> Dict[str, int]: + """ + 统计关系向量状态分布。 + """ + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT COALESCE(vector_state, 'none') AS state, COUNT(*) AS cnt + FROM relations + GROUP BY COALESCE(vector_state, 'none') + """ + ) + result: Dict[str, int] = {"none": 0, "pending": 0, "ready": 0, "failed": 0} + total = 0 + for row in cursor.fetchall(): + state = str(row["state"] or "none").lower() + count = int(row["cnt"] or 0) + if state not in result: + result[state] = 0 + result[state] += count + total += count + result["total"] = total + return result + + def update_vector_index( + self, + item_type: str, + hash_value: str, + vector_index: int, + ) -> bool: + """ + 更新向量索引 + + Args: + item_type: 类型(paragraph/entity/relation) + hash_value: 哈希值 + vector_index: 向量索引 + + Returns: + 是否成功更新 + """ + valid_types = ["paragraph", "entity", "relation"] + if item_type not in valid_types: + raise ValueError(f"无效的类型: {item_type}") + + table_map = { + "paragraph": "paragraphs", + "entity": "entities", + "relation": "relations", + } + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET vector_index = ? + WHERE hash = ? + """, (vector_index, hash_value)) + self._conn.commit() + + return cursor.rowcount > 0 + + def set_permanence(self, hash_value: str, item_type: str, is_permanent: bool) -> bool: + """设置永久记忆标记""" + table_map = { + "paragraph": "paragraphs", + "relation": "relations", + } + if item_type not in table_map: + raise ValueError(f"类型 {item_type} 不支持设置永久性") + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET is_permanent = ? + WHERE hash = ? + """, (1 if is_permanent else 0, hash_value)) + self._conn.commit() + + if cursor.rowcount > 0: + logger.debug(f"设置永久记忆: {item_type}/{hash_value[:8]} -> {is_permanent}") + return True + return False + + def record_access(self, hash_value: str, item_type: str) -> bool: + """记录访问(更新时间和次数)""" + table_map = { + "paragraph": "paragraphs", + "relation": "relations", + } + if item_type not in table_map: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET last_accessed = ?, access_count = access_count + 1 + WHERE hash = ? + """, (now, hash_value)) + self._conn.commit() + return cursor.rowcount > 0 + + def query( + self, + sql: str, + params: Optional[Tuple] = None, + ) -> List[Dict[str, Any]]: + """ + 执行自定义查询 + + Args: + sql: SQL语句 + params: 参数 + + Returns: + 查询结果列表 + """ + cursor = self._conn.cursor() + if params: + cursor.execute(sql, params) + else: + cursor.execute(sql) + + return [dict(row) for row in cursor.fetchall()] + + def get_external_memory_ref(self, external_id: str) -> Optional[Dict[str, Any]]: + """按 external_id 查询外部记忆映射。""" + token = str(external_id or "").strip() + if not token: + return None + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT external_id, paragraph_hash, source_type, created_at, metadata_json + FROM external_memory_refs + WHERE external_id = ? + LIMIT 1 + """, + (token,), + ) + row = cursor.fetchone() + if row is None: + return None + + payload = dict(row) + raw_metadata = payload.get("metadata_json") + if raw_metadata: + try: + payload["metadata"] = json.loads(raw_metadata) + except Exception: + payload["metadata"] = {} + else: + payload["metadata"] = {} + return payload + + def upsert_external_memory_ref( + self, + *, + external_id: str, + paragraph_hash: str, + source_type: str = "", + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """注册 external_id 到段落哈希的幂等映射。""" + external_token = str(external_id or "").strip() + paragraph_token = str(paragraph_hash or "").strip() + if not external_token: + raise ValueError("external_id 不能为空") + if not paragraph_token: + raise ValueError("paragraph_hash 不能为空") + + now = datetime.now().timestamp() + metadata_json = json.dumps(metadata or {}, ensure_ascii=False) + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO external_memory_refs ( + external_id, paragraph_hash, source_type, created_at, metadata_json + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(external_id) DO UPDATE SET + paragraph_hash = excluded.paragraph_hash, + source_type = excluded.source_type, + metadata_json = excluded.metadata_json + """, + ( + external_token, + paragraph_token, + str(source_type or "").strip() or None, + now, + metadata_json, + ), + ) + self._conn.commit() + return self.get_external_memory_ref(external_token) or { + "external_id": external_token, + "paragraph_hash": paragraph_token, + "source_type": str(source_type or "").strip(), + "created_at": now, + "metadata": metadata or {}, + } + + def get_statistics(self) -> Dict[str, int]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + cursor = self._conn.cursor() + + stats = {} + + # 段落数量 + cursor.execute("SELECT COUNT(*) FROM paragraphs") + stats["paragraph_count"] = cursor.fetchone()[0] + + # 实体数量 + cursor.execute("SELECT COUNT(*) FROM entities") + stats["entity_count"] = cursor.fetchone()[0] + + # 关系数量 + cursor.execute("SELECT COUNT(*) FROM relations") + stats["relation_count"] = cursor.fetchone()[0] + + # 总词数 + cursor.execute("SELECT SUM(word_count) FROM paragraphs") + result = cursor.fetchone()[0] + stats["total_words"] = result if result else 0 + + return stats + + def count_paragraphs(self, include_deleted: bool = False, only_deleted: bool = False) -> int: + """ + 获取段落数量 + """ + cursor = self._conn.cursor() + if only_deleted: + cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 1") + return cursor.fetchone()[0] + if include_deleted: + cursor.execute("SELECT COUNT(*) FROM paragraphs") + return cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 0") + return cursor.fetchone()[0] + + def count_relations(self, include_deleted: bool = False, only_deleted: bool = False) -> int: + """ + 获取关系数量 + """ + cursor = self._conn.cursor() + if only_deleted: + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + return cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM relations") + active_count = cursor.fetchone()[0] + if not include_deleted: + return active_count + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + deleted_count = cursor.fetchone()[0] + return int(active_count) + int(deleted_count) + + def count_entities(self) -> int: + """ + 获取实体数量 + + Returns: + 实体数量 + """ + cursor = self._conn.cursor() + cursor.execute("SELECT COUNT(*) FROM entities") + return cursor.fetchone()[0] + + def get_knowledge_type_distribution(self) -> Dict[str, int]: + """获取段落知识类型分布。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT knowledge_type, COUNT(*) as count + FROM paragraphs + WHERE is_deleted = 0 + GROUP BY knowledge_type + """ + ) + result: Dict[str, int] = {} + for row in cursor.fetchall(): + type_name = row[0] if row[0] else "未分类" + result[str(type_name)] = int(row[1] or 0) + return result + + def get_memory_status_summary(self, now_ts: Optional[float] = None) -> Dict[str, int]: + """聚合 memory status 统计。""" + now_ts = float(now_ts) if now_ts is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 0") + active_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 1") + inactive_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + deleted_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_pinned = 1") + pinned_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE protected_until > ?", (now_ts,)) + ttl_count = int(cursor.fetchone()[0] or 0) + return { + "active_count": active_count, + "inactive_count": inactive_count, + "deleted_count": deleted_count, + "pinned_count": pinned_count, + "temp_protected_count": ttl_count, + } + + def get_relations_subject_object_map(self, hashes: List[str]) -> Dict[str, Tuple[str, str]]: + """批量获取关系 hash 对应的 (subject, object)。""" + if not hashes: + return {} + cursor = self._conn.cursor() + placeholders = ",".join(["?"] * len(hashes)) + cursor.execute( + f"SELECT hash, subject, object FROM relations WHERE hash IN ({placeholders})", + hashes, + ) + return {str(row[0]): (str(row[1]), str(row[2])) for row in cursor.fetchall()} + + def get_connection(self) -> sqlite3.Connection: + """公开连接访问(用于离线脚本),替代外部访问私有字段。""" + return self._resolve_conn() + + def get_relation_db_snapshot(self) -> Tuple[int, float, str]: + """返回关系快照:(relation_count, max_created_at, max_hash)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT + COUNT(*) AS relation_count, + COALESCE(MAX(created_at), 0) AS max_created_at, + COALESCE(MAX(hash), '') AS max_hash + FROM relations + """ + ) + row = cursor.fetchone() + if not row: + return (0, 0.0, "") + return ( + int(row[0] or 0), + float(row[1] or 0.0), + str(row[2] or ""), + ) + + def is_entity_still_referenced(self, entity_hash: str, entity_name: str = "") -> bool: + """ + 判断实体是否仍被引用: + 1) 被 paragraph_entities 引用 + 2) 在 relations.subject/object 中出现 + """ + token_hash = str(entity_hash or "").strip() + if token_hash: + cursor = self._conn.cursor() + cursor.execute( + "SELECT 1 FROM paragraph_entities WHERE entity_hash = ? LIMIT 1", + (token_hash,), + ) + if cursor.fetchone() is not None: + return True + + canon_name = self._canonicalize_name(entity_name) + if canon_name: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM relations + WHERE LOWER(TRIM(subject)) = ? OR LOWER(TRIM(object)) = ? + LIMIT 1 + """, + (canon_name, canon_name), + ) + if cursor.fetchone() is not None: + return True + return False + + def search_relations_by_subject_or_object( + self, + query: str, + *, + limit: int = 5, + include_deleted: bool = False, + ) -> List[Dict[str, Any]]: + """按 subject/object 模糊查询关系。""" + q = str(query or "").strip() + if not q: + return [] + max_limit = int(max(1, limit)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM relations + WHERE subject LIKE ? OR object LIKE ? + LIMIT ? + """, + (f"%{q}%", f"%{q}%", max_limit), + ) + rows = [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + if rows or not include_deleted: + return rows + + cursor.execute( + """ + SELECT * + FROM deleted_relations + WHERE subject LIKE ? OR object LIKE ? + LIMIT ? + """, + (f"%{q}%", f"%{q}%", max_limit), + ) + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def list_hashes(self, table: str) -> List[str]: + """安全枚举指定表的 hash 列。""" + allowed = {"paragraphs", "entities", "relations", "deleted_relations"} + token = str(table or "").strip().lower() + if token not in allowed: + raise ValueError(f"unsupported table for list_hashes: {table}") + cursor = self._conn.cursor() + cursor.execute(f"SELECT hash FROM {token}") + return [str(row[0]) for row in cursor.fetchall()] + + def get_orphan_deleted_relation_hashes(self, limit: int = 200) -> List[str]: + """获取 deleted_relations 中已不在 relations 的孤儿 hash。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT d.hash + FROM deleted_relations d + LEFT JOIN relations r ON r.hash = d.hash + WHERE r.hash IS NULL + LIMIT ? + """, + (int(max(1, limit)),), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def resolve_relation_hash_alias( + self, + value: str, + *, + include_deleted: bool = False, + ) -> List[str]: + """ + 解析关系哈希输入: + - 64位:直接校验存在性 + - 32位:通过 relation_hash_aliases 唯一映射 + """ + token = str(value or "").strip().lower() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token): + cursor = self._conn.cursor() + cursor.execute("SELECT 1 FROM relations WHERE hash = ? LIMIT 1", (token,)) + if cursor.fetchone(): + return [token] + if include_deleted: + cursor.execute("SELECT 1 FROM deleted_relations WHERE hash = ? LIMIT 1", (token,)) + if cursor.fetchone(): + return [token] + return [] + + if len(token) != 32 or not all(ch in "0123456789abcdef" for ch in token): + return [] + + cursor = self._conn.cursor() + cursor.execute("SELECT hash FROM relation_hash_aliases WHERE alias32 = ?", (token,)) + row = cursor.fetchone() + if not row: + return [] + resolved = str(row[0]) + return [resolved] + + def rebuild_relation_hash_aliases(self) -> Dict[str, Any]: + """重建 32 位 relation hash 别名映射。""" + cursor = self._conn.cursor() + # 历史库兜底:缺表时先创建,避免迁移过程直接中断。 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + cursor.execute("DELETE FROM relation_hash_aliases") + + cursor.execute("SELECT hash FROM relations") + hashes = [str(r[0]) for r in cursor.fetchall()] + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='deleted_relations'" + ) + has_deleted_relations = cursor.fetchone() is not None + if has_deleted_relations: + cursor.execute("SELECT hash FROM deleted_relations") + hashes.extend(str(r[0]) for r in cursor.fetchall()) + + alias_map: Dict[str, str] = {} + conflicts: Dict[str, set[str]] = {} + for h in hashes: + if len(h) != 64: + continue + alias = h[:32] + old = alias_map.get(alias) + if old is None: + alias_map[alias] = h + elif old != h: + conflicts.setdefault(alias, set()).update({old, h}) + + for alias, full_hash in alias_map.items(): + if alias in conflicts: + continue + cursor.execute( + "INSERT INTO relation_hash_aliases(alias32, hash) VALUES (?, ?)", + (alias, full_hash), + ) + self._conn.commit() + return { + "inserted": len(alias_map) - len(conflicts), + "conflict_count": len(conflicts), + "conflicts": sorted(conflicts.keys()), + } + + def search_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: + """按 relation 内容模糊查询 hash。""" + q = str(query or "").strip() + if not q: + return [] + cursor = self._conn.cursor() + cursor.execute( + "SELECT hash FROM relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", + (f"%{q}%", f"%{q}%", int(max(1, limit))), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def search_deleted_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: + """按 deleted_relations 内容模糊查询 hash。""" + q = str(query or "").strip() + if not q: + return [] + cursor = self._conn.cursor() + cursor.execute( + "SELECT hash FROM deleted_relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", + (f"%{q}%", f"%{q}%", int(max(1, limit))), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def restore_entity_by_hash(self, entity_hash: str) -> bool: + """恢复软删除实体。""" + cursor = self._conn.cursor() + cursor.execute( + "UPDATE entities SET is_deleted=0, deleted_at=NULL WHERE hash=?", + (str(entity_hash),), + ) + changed = cursor.rowcount > 0 + if changed: + self._conn.commit() + return changed + + def backfill_temporal_metadata_from_created_at( + self, + *, + limit: int = 100000, + dry_run: bool = False, + no_created_fallback: bool = False, + ) -> Dict[str, int]: + """回填段落 event_time 字段(created_at 兜底)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT hash, created_at, source + FROM paragraphs + WHERE (event_time IS NULL AND event_time_start IS NULL AND event_time_end IS NULL) + ORDER BY created_at DESC + LIMIT ? + """, + (int(max(1, limit)),), + ) + rows = cursor.fetchall() + candidates = len(rows) + if dry_run: + return {"candidates": candidates, "updated": 0} + if no_created_fallback: + return {"candidates": candidates, "updated": 0} + + updated = 0 + touched_sources: List[str] = [] + for row in rows: + created_at = row["created_at"] + if created_at is None: + continue + cursor.execute( + """ + UPDATE paragraphs + SET event_time = ?, time_granularity = ?, time_confidence = ?, updated_at = ? + WHERE hash = ? + """, + (float(created_at), "day", 0.2, float(created_at), row["hash"]), + ) + if cursor.rowcount > 0: + updated += 1 + touched_sources.append(row["source"]) + self._conn.commit() + if updated > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_time_backfill", + ) + return {"candidates": candidates, "updated": updated} + + def get_schema_version(self) -> int: + cursor = self._conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" + ) + if cursor.fetchone() is None: + return 0 + cursor.execute("SELECT MAX(version) FROM schema_migrations") + row = cursor.fetchone() + return int(row[0]) if row and row[0] is not None else 0 + + def set_schema_version(self, version: int = SCHEMA_VERSION) -> None: + cursor = self._conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at REAL NOT NULL)" + ) + cursor.execute( + "INSERT OR REPLACE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + (int(version), datetime.now().timestamp()), + ) + self._conn.commit() + + def delete_paragraph_atomic(self, paragraph_hash: str) -> Dict[str, Any]: + """ + 两阶段删除段落:DB 事务内计算 + 提交后执行清理 + + Args: + paragraph_hash: 段落哈希 + + Returns: + cleanup_plan: 包含需要后续从 Vector/GraphStore 中移除的 ID 列表 + """ + cleanup_plan = { + "paragraph_hash": paragraph_hash, + "vector_id_to_remove": None, + "edges_to_remove": [], # (src, tgt) 元组列表 (fallback) + "relation_prune_ops": [], # (subject, object, relation_hash) 精准裁剪 + "episode_sources_to_rebuild": [], + } + + cursor = self._conn.cursor() + try: + # === Phase 1: DB Transaction (可回滚) === + # 使用 IMMEDIATE 模式,一旦开启事务立即锁定 DB (防止其他写操作插队导致幻读) + cursor.execute("BEGIN IMMEDIATE") + + # 1. [快照] 获取候选关系 + cursor.execute("SELECT relation_hash FROM paragraph_relations WHERE paragraph_hash = ?", (paragraph_hash,)) + candidate_relations = [row[0] for row in cursor.fetchall()] + + # 2. [快照] 确认该段落存在并记录 ID 用于向量删除 + cursor.execute("SELECT hash, source FROM paragraphs WHERE hash = ?", (paragraph_hash,)) + paragraph_row = cursor.fetchone() + if paragraph_row: + cleanup_plan["vector_id_to_remove"] = paragraph_hash + cleanup_plan["episode_sources_to_rebuild"] = self._dedupe_episode_sources( + [paragraph_row["source"]] + ) + + # 3. [主删除] 删除段落 (触发 CASCADE 删 paragraph_relations) + cursor.execute("DELETE FROM paragraphs WHERE hash = ?", (paragraph_hash,)) + + # 4. [计算孤儿] + orphaned_hashes = [] + for rel_hash in candidate_relations: + count = cursor.execute( + "SELECT count(*) FROM paragraph_relations WHERE relation_hash = ?", + (rel_hash,) + ).fetchone()[0] + + if count == 0: + # 是孤儿:记录边信息以便后续删 Graph + cursor.execute("SELECT subject, object FROM relations WHERE hash = ?", (rel_hash,)) + rel_info = cursor.fetchone() + if rel_info: + s_val, o_val = rel_info[0], rel_info[1] + cleanup_plan["relation_prune_ops"].append((s_val, o_val, rel_hash)) + + # 仅当 (subject, object) 不再有任何关系时,才计划删整条边(兼容旧实现)。 + sibling_count = cursor.execute( + """ + SELECT count(*) FROM relations + WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) + AND LOWER(TRIM(object)) = LOWER(TRIM(?)) + AND hash != ? + """, + (s_val, o_val, rel_hash) + ).fetchone()[0] + if sibling_count == 0: + cleanup_plan["edges_to_remove"].append((s_val, o_val)) + + orphaned_hashes.append(rel_hash) + + # 5. [DB清理] 删除孤儿关系记录 + if orphaned_hashes: + placeholders = ','.join(['?'] * len(orphaned_hashes)) + cursor.execute(f"DELETE FROM relations WHERE hash IN ({placeholders})", orphaned_hashes) + + self._conn.commit() + if cleanup_plan["episode_sources_to_rebuild"]: + self._enqueue_episode_source_rebuilds( + cleanup_plan["episode_sources_to_rebuild"], + reason="paragraph_deleted", + ) + if cleanup_plan["vector_id_to_remove"]: + logger.debug(f"原子删除段落成功: {paragraph_hash}, 计划清理 {len(orphaned_hashes)} 个孤儿关系") + return cleanup_plan + + except Exception as e: + self._conn.rollback() + logger.error(f"DB Transaction failed: {e}") + raise e + + + def clear_all(self) -> None: + """清空所有表数据""" + cursor = self._conn.cursor() + tables = [ + "paragraphs", "entities", "relations", + "paragraph_relations", "paragraph_entities", + "episodes", "episode_paragraphs", + "episode_rebuild_sources", "episode_pending_paragraphs", + ] + for table in tables: + cursor.execute(f"DELETE FROM {table}") + self._conn.commit() + logger.info("元数据存储所有表已清空") + + + + def update_relation_timestamp(self, hash_value: str, access_count_delta: int = 1) -> None: + """更新关系的访问时间和计数""" + now = datetime.now().timestamp() + + # 同时更新 last_accessed (旧) 和 last_reinforced (V5) + + cursor = self._conn.cursor() + cursor.execute(""" + UPDATE relations + SET last_accessed = ?, + access_count = access_count + ? + WHERE hash = ? + """, (now, access_count_delta, hash_value)) + self._conn.commit() + + # ========================================================================= + # V5 Memory System Methods + # ========================================================================= + + def get_relation_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: + """ + 批量获取关系状态 (V5) + + Args: + hashes: 关系哈希列表 + + Returns: + Dict[hash, status_dict] + status_dict 包含: is_inactive, weight(confidence), is_pinned, protected_until, last_reinforced, inactive_since + """ + if not hashes: + return {} + + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT hash, is_inactive, confidence, is_pinned, protected_until, last_reinforced, inactive_since + FROM relations + WHERE hash IN ({placeholders}) + """, hashes) + + result = {} + for row in cursor.fetchall(): + result[row["hash"]] = { + "is_inactive": bool(row["is_inactive"]), + "weight": row["confidence"], + "is_pinned": bool(row["is_pinned"]), + "protected_until": row["protected_until"], + "last_reinforced": row["last_reinforced"], + "inactive_since": row["inactive_since"] + } + return result + + def mark_relations_active(self, hashes: List[str], boost_weight: Optional[float] = None) -> None: + """ + 批量标记关系为活跃 (Active/Revive) + + Args: + hashes: 关系哈希列表 + boost_weight: 如果提供,将设置 confidence = max(confidence, boost_weight) + """ + if not hashes: + return + + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + + if boost_weight is not None: + cursor.execute(f""" + UPDATE relations + SET is_inactive = 0, + inactive_since = NULL, + confidence = MAX(confidence, ?) + WHERE hash IN ({placeholders}) + """, (boost_weight, *hashes)) + else: + cursor.execute(f""" + UPDATE relations + SET is_inactive = 0, + inactive_since = NULL + WHERE hash IN ({placeholders}) + """, hashes) + + self._conn.commit() + + def update_relations_protection( + self, + hashes: List[str], + protected_until: Optional[float] = None, + is_pinned: Optional[bool] = None, + last_reinforced: Optional[float] = None + ) -> None: + """ + 批量更新关系保护状态 + """ + if not hashes: + return + + updates = [] + params = [] + + if protected_until is not None: + updates.append("protected_until = ?") + params.append(protected_until) + if is_pinned is not None: + updates.append("is_pinned = ?") + params.append(1 if is_pinned else 0) + if last_reinforced is not None: + updates.append("last_reinforced = ?") + params.append(last_reinforced) + + if not updates: + return + + sql_set = ", ".join(updates) + placeholders = ",".join(["?"] * len(hashes)) + + params.extend(hashes) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE relations + SET {sql_set} + WHERE hash IN ({placeholders}) + """, params) + self._conn.commit() + + def get_prune_candidates(self, cutoff_time: float, limit: int = 1000) -> List[str]: + """ + 获取待修剪候选 (已过冷冻保留期) + + Args: + cutoff_time: 截止时间 (now - 冷冻时长) + limit: 限制数量 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash FROM relations + WHERE is_inactive = 1 + AND inactive_since < ? + LIMIT ? + """, (cutoff_time, limit)) + return [row[0] for row in cursor.fetchall()] + + def backup_and_delete_relations(self, hashes: List[str]) -> int: + """ + 备份并删除关系 (Prune) + + Returns: + 删除的数量 + """ + if not hashes: + return 0 + + placeholders = ",".join(["?"] * len(hashes)) + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + try: + # 1. 备份 + cursor.execute(f""" + INSERT OR REPLACE INTO deleted_relations + (hash, subject, predicate, object, vector_index, confidence, created_at, + vector_state, vector_updated_at, vector_error, vector_retry_count, + source_paragraph, metadata, is_permanent, last_accessed, access_count, + is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, deleted_at) + SELECT + hash, subject, predicate, object, vector_index, confidence, created_at, + vector_state, vector_updated_at, vector_error, vector_retry_count, + source_paragraph, metadata, is_permanent, last_accessed, access_count, + is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, ? + FROM relations + WHERE hash IN ({placeholders}) + """, (now, *hashes)) + + # 2. 删除 (级联删除会自动处理 paragraph_relations 关联) + cursor.execute(f""" + DELETE FROM relations + WHERE hash IN ({placeholders}) + """, hashes) + + deleted_count = cursor.rowcount + self._conn.commit() + return deleted_count + + except Exception as e: + logger.error(f"备份删除失败: {e}") + self._conn.rollback() + return 0 + + def restore_relation_metadata(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 从回收站恢复关系元数据 + + Returns: + 恢复后的关系数据 (字典),失败返回 None + """ + cursor = self._conn.cursor() + try: + # 1. 查询备份数据 + cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) + row = cursor.fetchone() + if not row: + return None + + data = dict(row) + # 移除 deleted_at 字段 + if "deleted_at" in data: + del data["deleted_at"] + + # 2. 插入回 relations 表 + # 动态构建 SQL 以适应字段变化 + columns = list(data.keys()) + placeholders = ",".join(["?"] * len(columns)) + cols_str = ",".join(columns) + values = list(data.values()) + + cursor.execute(f""" + INSERT OR REPLACE INTO relations ({cols_str}) + VALUES ({placeholders}) + """, values) + + # 3. 从备份表删除 + cursor.execute("DELETE FROM deleted_relations WHERE hash = ?", (hash_value,)) + + self._conn.commit() + return self._row_to_dict(row, "relation") # 使用助手函数将原始行转换为字典 + + except Exception as e: + logger.error(f"恢复关系失败: {hash_value} - {e}") + self._conn.rollback() + return None + + def restore_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """兼容旧调用名:恢复关系。""" + return self.restore_relation_metadata(hash_value) + + def get_protected_relations_hashes(self) -> List[str]: + """获取所有受保护关系的哈希 (Pinned 或 Protected Until > Now)""" + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash FROM relations + WHERE is_pinned = 1 OR protected_until > ? + """, (now,)) + + return [row[0] for row in cursor.fetchall()] + + + + def get_deleted_relations(self, limit: int = 50) -> List[Dict[str, Any]]: + """获取回收站中的关系记录""" + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM deleted_relations ORDER BY deleted_at DESC LIMIT ?", (limit,)) + data = [] + for row in cursor.fetchall(): + d = dict(row) + # 是否需要解码元数据?是的,与普通行相同 + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + data.append(d) + return data + + def get_deleted_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """获取单条回收站记录""" + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) + row = cursor.fetchone() + if not row: return None + + d = dict(row) + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + return d + + def reinforce_relations(self, hashes: List[str]) -> None: + """强化关系 (更新 last_reinforced, is_inactive=0)""" + if not hashes: return + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + # Batch update? chunking + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + sql = f""" + UPDATE relations + SET last_reinforced = ?, is_inactive = 0, inactive_since = NULL + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [now] + chunk) + + self._conn.commit() + + def mark_relations_inactive(self, hashes: List[str], inactive_since: Optional[float] = None) -> None: + """标记关系为非活跃 (Freeze)。兼容显式 inactive_since 或默认当前时间。""" + if not hashes: + return + mark_time = inactive_since if inactive_since is not None else datetime.now().timestamp() + + cursor = self._conn.cursor() + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + sql = f""" + UPDATE relations + SET is_inactive = 1, inactive_since = ? + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [mark_time] + chunk) + + self._conn.commit() + + def protect_relations( + self, + hashes: List[str], + is_pinned: bool = False, + ttl_seconds: float = 0 + ) -> None: + """ + 设置保护状态 + """ + if not hashes: return + now = datetime.now().timestamp() + protected_until = (now + ttl_seconds) if ttl_seconds > 0 else 0 + + cursor = self._conn.cursor() + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + + # 由于 is_pinned 和 protected_until 是分开的,如果请求固定(pin),我们会同时更新这两项, + # 但通常用户要么切换固定状态,要么设置 TTL。 + # 如果 is_pinned=True,TTL 通常就不重要了。 + # 但目前的逻辑是正交处理它们的。 + + # 如果用户取消固定 (is_pinned=False),我们是否应该尊重已设置的 TTL? + # 当前的 API 会同时设置这两项。 + + sql = f""" + UPDATE relations + SET is_pinned = ?, protected_until = ? + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [is_pinned, protected_until] + chunk) + + self._conn.commit() + + def vacuum(self) -> None: + """优化数据库""" + cursor = self._conn.cursor() + cursor.execute("VACUUM") + self._conn.commit() + logger.info("数据库优化完成") + + def _row_to_dict(self, row: sqlite3.Row, row_type: str) -> Dict[str, Any]: + """ + 将数据库行转换为字典 + + Args: + row: 数据库行 + row_type: 行类型 + + Returns: + 字典 + """ + d = dict(row) + + # 解码pickle字段 + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + + return d + + @property + def is_connected(self) -> bool: + """是否已连接""" + return self._conn is not None + + def __enter__(self): + """上下文管理器入口""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + self.close() + + # ========================================================================= + # V5 Soft Delete & Garbage Collection + # ========================================================================= + + def get_entity_gc_candidates(self, isolated_hashes: List[str], retention_seconds: float) -> List[str]: + """ + 获取实体 GC 候选列表 (Soft Delete Candidates) + 条件: + 1. 在 isolated_hashes 列表中 (由 GraphStore 提供;通常是实体名称) + 2. is_deleted = 0 (未被标记) + 3. created_at < now - retention (过了新手保护期) + 4. 不被任何 active paragraph 引用 (paragraph_entities check) + + Args: + isolated_hashes: 孤儿实体名称列表(兼容传入 hash) + retention_seconds: 保留时间 (秒) + """ + if not isolated_hashes: + return [] + + # GraphStore.get_isolated_nodes 返回节点名,这里做 canonicalize -> entity hash 映射。 + # 同时兼容历史调用直接传 hash。 + normalized_hashes: List[str] = [] + for item in isolated_hashes: + if not item: + continue + v = str(item).strip() + if len(v) == 64 and all(c in "0123456789abcdefABCDEF" for c in v): + normalized_hashes.append(v.lower()) + else: + canon = self._canonicalize_name(v) + if canon: + normalized_hashes.append(compute_hash(canon)) + + normalized_hashes = list(dict.fromkeys(normalized_hashes)) + if not normalized_hashes: + return [] + + now = datetime.now().timestamp() + cutoff = now - retention_seconds + + candidates = [] + batch_size = 900 + + # 分批处理 IN 查询 + for i in range(0, len(normalized_hashes), batch_size): + batch = normalized_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + # 使用 NOT EXISTS 子查询检查引用 + # 注意: paragraph_entities 中引用的 paragraph 如果被软删了,是否算引用? + # 这里的语义: 只要有 rows 存在于 paragraph_entities 且该 row 对应的 paragraph 没被彻底物理删除,就算引用。 + # 更严格: ... OR (EXISTS ... AND entity_hash=... AND is_deleted=0) + # 但 paragraph_entities 表没有 is_deleted 字段(它是关联表). 我们检查关联是否存在。 + # 如果 paragraph 本身 soft deleted, 它的引用应该失效吗? + # 策略: 只有当 paragraph 也是 active 时,引用才有效。 + # JOIN paragraphs p ON pe.paragraph_hash = p.hash WHERE p.is_deleted = 0 + + query = f""" + SELECT e.hash FROM entities e + WHERE e.hash IN ({placeholders}) + AND e.is_deleted = 0 + AND (e.created_at IS NULL OR e.created_at < ?) + AND NOT EXISTS ( + SELECT 1 FROM paragraph_entities pe + JOIN paragraphs p ON pe.paragraph_hash = p.hash + WHERE pe.entity_hash = e.hash + AND p.is_deleted = 0 + ) + """ + + cursor = self._conn.cursor() + cursor.execute(query, [*batch, cutoff]) + candidates.extend([row[0] for row in cursor.fetchall()]) + + return candidates + + def get_paragraph_gc_candidates(self, retention_seconds: float) -> List[str]: + """ + 获取段落 GC 候选列表 + 条件: + 1. is_deleted = 0 + 2. created_at < cutoff + 3. 没有 Relations (paragraph_relations empty) + 4. 没有 Entities 引用 (paragraph_entities empty) + OR 引用的 Entities 全是软删状态? (太复杂,简单点: 无引用) + + Refined Strategy: + 段落孤儿判定 = + (Left Join paragraph_relations -> NULL) AND + (Left Join paragraph_entities -> NULL) + """ + now = datetime.now().timestamp() + cutoff = now - retention_seconds + + query = """ + SELECT p.hash FROM paragraphs p + LEFT JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash + LEFT JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash + WHERE p.is_deleted = 0 + AND (p.created_at IS NULL OR p.created_at < ?) + AND pr.relation_hash IS NULL + AND pe.entity_hash IS NULL + """ + + cursor = self._conn.cursor() + cursor.execute(query, (cutoff,)) + return [row[0] for row in cursor.fetchall()] + + def mark_as_deleted(self, hashes: List[str], type_: str) -> int: + """ + 标记为软删除 (Mark Phase) + + Args: + hashes: Hash 列表 + type_: 'entity' | 'paragraph' + """ + if not hashes: + return 0 + + table = "entities" if type_ == "entity" else "paragraphs" + now = datetime.now().timestamp() + touched_sources: List[str] = [] + if type_ == "paragraph": + touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + # 幂等更新: 只更那些 is_deleted=0 的 + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table} + SET is_deleted = 1, deleted_at = ? + WHERE is_deleted = 0 AND hash IN ({placeholders}) + """, [now] + batch) + count += cursor.rowcount + + self._conn.commit() + if type_ == "paragraph" and count > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_soft_deleted", + ) + if count > 0: + logger.info(f"软删除标记 ({table}): {count} 项") + return count + + def sweep_deleted_items(self, type_: str, grace_period_seconds: float) -> List[Tuple[str, str]]: + """ + 扫描可物理清理的项目 (Sweep Phase - Selection) + + Args: + type_: 'entity' | 'paragraph' + grace_period_seconds: 宽限期 + + Returns: + List[(hash, name)]: 待删除项列表 (paragraph name为空) + """ + table = "entities" if type_ == "entity" else "paragraphs" + now = datetime.now().timestamp() + cutoff = now - grace_period_seconds + + cols = "hash, name" if type_ == "entity" else "hash, '' as name" + + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT {cols} FROM {table} + WHERE is_deleted = 1 + AND deleted_at < ? + """, (cutoff,)) + + return [(row[0], row[1]) for row in cursor.fetchall()] + + def physically_delete_entities(self, hashes: List[str]) -> int: + """物理删除实体 (批量)""" + if not hashes: return 0 + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f"DELETE FROM entities WHERE hash IN ({placeholders})", batch) + count += cursor.rowcount + + self._conn.commit() + return count + + def physically_delete_paragraphs(self, hashes: List[str]) -> int: + """物理删除段落 (批量)""" + if not hashes: return 0 + touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f"DELETE FROM paragraphs WHERE hash IN ({placeholders})", batch) + count += cursor.rowcount + + self._conn.commit() + if count > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_physically_deleted", + ) + return count + + def revive_if_deleted(self, entity_hashes: List[str] = None, paragraph_hashes: List[str] = None) -> int: + """ + 复活已软删的项目 (Auto Revival) + 当数据被再次访问、引用或导入时调用。 + """ + count = 0 + + if entity_hashes: + batch_size = 900 + for i in range(0, len(entity_hashes), batch_size): + batch = entity_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE entities + SET is_deleted = 0, deleted_at = NULL + WHERE is_deleted = 1 AND hash IN ({placeholders}) + """, batch) + count += cursor.rowcount + + if paragraph_hashes: + touched_sources = self._get_sources_for_paragraph_hashes(paragraph_hashes, include_deleted=True) + batch_size = 900 + for i in range(0, len(paragraph_hashes), batch_size): + batch = paragraph_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE paragraphs + SET is_deleted = 0, deleted_at = NULL + WHERE is_deleted = 1 AND hash IN ({placeholders}) + """, batch) + count += cursor.rowcount + else: + touched_sources = [] + + if count > 0: + self._conn.commit() + if touched_sources: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_revived", + ) + logger.info(f"自动复活: {count} 项 (Soft Delete Revived)") + + return count + + def revive_entities_by_names(self, names: List[str]) -> int: + """ + 根据名称复活实体 (Convenience wrapper) + """ + if not names: return 0 + + # 使用内部方法计算哈希 + hashes = [compute_hash(self._canonicalize_name(n)) for n in names] + return self.revive_if_deleted(entity_hashes=hashes) + + def get_entity_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: + """批量获取实体状态 (WebUI用)""" + if not hashes: return {} + + result = {} + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT hash, is_deleted, deleted_at + FROM entities + WHERE hash IN ({placeholders}) + """, batch) + + for row in cursor.fetchall(): + result[row[0]] = { + "is_deleted": bool(row[1]), + "deleted_at": row[2] + } + return result + + # ========================================================================= + # Person Profile (问题3) - Switches / Active Set / Snapshots + # ========================================================================= + + def set_person_profile_switch( + self, + stream_id: str, + user_id: str, + enabled: bool, + updated_at: Optional[float] = None, + ) -> None: + """设置人物画像自动注入开关(按 stream_id + user_id)。""" + if not stream_id or not user_id: + raise ValueError("stream_id 和 user_id 不能为空") + + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_switches (stream_id, user_id, enabled, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(stream_id, user_id) DO UPDATE SET + enabled = excluded.enabled, + updated_at = excluded.updated_at + """, + (str(stream_id), str(user_id), 1 if enabled else 0, ts), + ) + self._conn.commit() + + def get_person_profile_switch(self, stream_id: str, user_id: str, default: bool = False) -> bool: + """读取人物画像自动注入开关。""" + if not stream_id or not user_id: + return bool(default) + + cursor = self._conn.cursor() + cursor.execute( + "SELECT enabled FROM person_profile_switches WHERE stream_id = ? AND user_id = ?", + (str(stream_id), str(user_id)), + ) + row = cursor.fetchone() + if not row: + return bool(default) + return bool(row[0]) + + def get_enabled_person_profile_switches(self, limit: int = 1000) -> List[Dict[str, Any]]: + """获取已开启人物画像注入开关的会话范围。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT stream_id, user_id, enabled, updated_at + FROM person_profile_switches + WHERE enabled = 1 + ORDER BY updated_at DESC + LIMIT ? + """, + (int(max(1, limit)),), + ) + return [ + { + "stream_id": row[0], + "user_id": row[1], + "enabled": bool(row[2]), + "updated_at": row[3], + } + for row in cursor.fetchall() + ] + + def mark_person_profile_active( + self, + stream_id: str, + user_id: str, + person_id: str, + seen_at: Optional[float] = None, + ) -> None: + """记录活跃人物(用于定时按需刷新)。""" + if not stream_id or not user_id or not person_id: + return + ts = float(seen_at) if seen_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_active_persons (stream_id, user_id, person_id, last_seen_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(stream_id, user_id, person_id) DO UPDATE SET + last_seen_at = excluded.last_seen_at + """, + (str(stream_id), str(user_id), str(person_id), ts), + ) + self._conn.commit() + + def get_active_person_ids_for_enabled_switches( + self, + active_after: Optional[float] = None, + limit: int = 200, + ) -> List[str]: + """获取“已开启开关范围内”的活跃人物集合。""" + cursor = self._conn.cursor() + sql = """ + SELECT a.person_id, MAX(a.last_seen_at) AS last_seen + FROM person_profile_active_persons a + JOIN person_profile_switches s + ON a.stream_id = s.stream_id AND a.user_id = s.user_id + WHERE s.enabled = 1 + """ + params: List[Any] = [] + if active_after is not None: + sql += " AND a.last_seen_at >= ?" + params.append(float(active_after)) + sql += """ + GROUP BY a.person_id + ORDER BY last_seen DESC + LIMIT ? + """ + params.append(int(max(1, limit))) + cursor.execute(sql, tuple(params)) + return [str(row[0]) for row in cursor.fetchall() if row and row[0]] + + def get_latest_person_profile_snapshot(self, person_id: str) -> Optional[Dict[str, Any]]: + """获取人物最新画像快照。""" + if not person_id: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT + snapshot_id, person_id, profile_version, profile_text, + aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, + updated_at, expires_at, source_note + FROM person_profile_snapshots + WHERE person_id = ? + ORDER BY profile_version DESC + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + if not row: + return None + + def _load_list(raw: Any) -> List[Any]: + if not raw: + return [] + try: + data = json.loads(raw) + return data if isinstance(data, list) else [] + except Exception: + return [] + + return { + "snapshot_id": row[0], + "person_id": row[1], + "profile_version": int(row[2]), + "profile_text": row[3] or "", + "aliases": _load_list(row[4]), + "relation_edges": _load_list(row[5]), + "vector_evidence": _load_list(row[6]), + "evidence_ids": _load_list(row[7]), + "updated_at": row[8], + "expires_at": row[9], + "source_note": row[10] or "", + } + + def upsert_person_profile_snapshot( + self, + person_id: str, + profile_text: str, + aliases: Optional[List[str]] = None, + relation_edges: Optional[List[Dict[str, Any]]] = None, + vector_evidence: Optional[List[Dict[str, Any]]] = None, + evidence_ids: Optional[List[str]] = None, + expires_at: Optional[float] = None, + source_note: str = "", + updated_at: Optional[float] = None, + ) -> Dict[str, Any]: + """写入人物画像快照(按 person_id 自动递增版本)。""" + if not person_id: + raise ValueError("person_id 不能为空") + + aliases = aliases or [] + relation_edges = relation_edges or [] + vector_evidence = vector_evidence or [] + evidence_ids = evidence_ids or [] + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT profile_version + FROM person_profile_snapshots + WHERE person_id = ? + ORDER BY profile_version DESC + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + next_version = int(row[0]) + 1 if row else 1 + + cursor.execute( + """ + INSERT INTO person_profile_snapshots ( + person_id, profile_version, profile_text, + aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, + updated_at, expires_at, source_note + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + str(person_id), + next_version, + str(profile_text or ""), + json.dumps(aliases, ensure_ascii=False), + json.dumps(relation_edges, ensure_ascii=False), + json.dumps(vector_evidence, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + ts, + float(expires_at) if expires_at is not None else None, + str(source_note or ""), + ), + ) + self._conn.commit() + latest = self.get_latest_person_profile_snapshot(person_id) + return latest or { + "person_id": person_id, + "profile_version": next_version, + "profile_text": str(profile_text or ""), + "aliases": aliases, + "relation_edges": relation_edges, + "vector_evidence": vector_evidence, + "evidence_ids": evidence_ids, + "updated_at": ts, + "expires_at": expires_at, + "source_note": source_note, + } + + def get_person_profile_override(self, person_id: str) -> Optional[Dict[str, Any]]: + """获取人物画像手工覆盖内容。""" + if not person_id: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT person_id, override_text, updated_at, updated_by, source + FROM person_profile_overrides + WHERE person_id = ? + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + if not row: + return None + return { + "person_id": str(row[0]), + "override_text": str(row[1] or ""), + "updated_at": row[2], + "updated_by": str(row[3] or ""), + "source": str(row[4] or ""), + } + + def set_person_profile_override( + self, + person_id: str, + override_text: str, + updated_by: str = "", + source: str = "webui", + updated_at: Optional[float] = None, + ) -> Dict[str, Any]: + """写入人物画像手工覆盖;空文本等价于清除覆盖。""" + if not person_id: + raise ValueError("person_id 不能为空") + + text = str(override_text or "").strip() + if not text: + self.delete_person_profile_override(person_id) + return { + "person_id": str(person_id), + "override_text": "", + "updated_at": None, + "updated_by": str(updated_by or ""), + "source": str(source or ""), + } + + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_overrides ( + person_id, override_text, updated_at, updated_by, source + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(person_id) DO UPDATE SET + override_text = excluded.override_text, + updated_at = excluded.updated_at, + updated_by = excluded.updated_by, + source = excluded.source + """, + ( + str(person_id), + text, + ts, + str(updated_by or ""), + str(source or ""), + ), + ) + self._conn.commit() + return self.get_person_profile_override(person_id) or { + "person_id": str(person_id), + "override_text": text, + "updated_at": ts, + "updated_by": str(updated_by or ""), + "source": str(source or ""), + } + + def delete_person_profile_override(self, person_id: str) -> bool: + """删除人物画像手工覆盖。""" + if not person_id: + return False + cursor = self._conn.cursor() + cursor.execute( + "DELETE FROM person_profile_overrides WHERE person_id = ?", + (str(person_id),), + ) + self._conn.commit() + return cursor.rowcount > 0 + + # ========================================================================= + # Episode MVP + # ========================================================================= + + def enqueue_episode_source_rebuild(self, source: str, reason: str = "") -> bool: + """将 source 入队到 episode 重建队列。""" + return bool(self._enqueue_episode_source_rebuilds([source], reason=reason)) + + def fetch_episode_source_rebuild_batch( + self, + limit: int = 20, + max_retry: int = 3, + ) -> List[Dict[str, Any]]: + """获取待处理的 source 重建任务。""" + safe_limit = max(1, int(limit)) + safe_retry = max(0, int(max_retry)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT source, status, retry_count, last_error, reason, requested_at, updated_at + FROM episode_rebuild_sources + WHERE status = 'pending' + OR (status = 'failed' AND retry_count < ?) + ORDER BY requested_at ASC, updated_at ASC + LIMIT ? + """, + (safe_retry, safe_limit), + ) + return [dict(row) for row in cursor.fetchall()] + + def mark_episode_source_running( + self, + source: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + """将 source 标记为 running。""" + token = self._normalize_episode_source(source) + if not token: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + params: List[Any] = [now, token] + sql = """ + UPDATE episode_rebuild_sources + SET status = 'running', + updated_at = ? + WHERE source = ? + AND status IN ('pending', 'failed') + """ + if requested_at is not None: + sql += " AND requested_at = ?" + params.append(float(requested_at)) + cursor.execute(sql, tuple(params)) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_episode_source_done( + self, + source: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + """将 source 标记为 done;若运行期间发生新写入,则保持 pending。""" + token = self._normalize_episode_source(source) + if not token: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = 'done', + last_error = NULL, + updated_at = ? + WHERE source = ? + """, + (now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'done' + END, + last_error = NULL, + updated_at = ? + WHERE source = ? + """, + (req_ts, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_episode_source_failed( + self, + source: str, + error: str = "", + *, + requested_at: Optional[float] = None, + ) -> bool: + """标记 source 失败;若运行期间发生新写入,则重新回到 pending。""" + token = self._normalize_episode_source(source) + if not token: + return False + + err_text = str(error or "").strip()[:500] + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = 'failed', + retry_count = COALESCE(retry_count, 0) + 1, + last_error = ?, + updated_at = ? + WHERE source = ? + """, + (err_text, now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'failed' + END, + retry_count = CASE + WHEN requested_at > ? THEN COALESCE(retry_count, 0) + ELSE COALESCE(retry_count, 0) + 1 + END, + last_error = CASE + WHEN requested_at > ? THEN NULL + ELSE ? + END, + updated_at = ? + WHERE source = ? + """, + (req_ts, req_ts, req_ts, err_text, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_episode_source_rebuilds( + self, + *, + statuses: Optional[List[str]] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """列出 source 重建状态。""" + safe_limit = max(1, int(limit)) + params: List[Any] = [] + conditions: List[str] = [] + normalized_statuses = [ + str(item or "").strip().lower() + for item in (statuses or []) + if str(item or "").strip().lower() in {"pending", "running", "done", "failed"} + ] + if normalized_statuses: + placeholders = ",".join(["?"] * len(normalized_statuses)) + conditions.append(f"status IN ({placeholders})") + params.extend(normalized_statuses) + + where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(safe_limit) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT source, status, retry_count, last_error, reason, requested_at, updated_at + FROM episode_rebuild_sources + {where_sql} + ORDER BY updated_at DESC, source ASC + LIMIT ? + """, + tuple(params), + ) + return [dict(row) for row in cursor.fetchall()] + + def get_episode_source_rebuild_summary(self, failed_limit: int = 20) -> Dict[str, Any]: + """汇总 source 重建队列状态。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) AS cnt + FROM episode_rebuild_sources + GROUP BY status + """ + ) + counts = {"pending": 0, "running": 0, "done": 0, "failed": 0, "total": 0} + for row in cursor.fetchall(): + status = str(row["status"] or "").strip().lower() + cnt = int(row["cnt"] or 0) + counts[status] = counts.get(status, 0) + cnt + counts["total"] += cnt + + running = self.list_episode_source_rebuilds(statuses=["running"], limit=20) + failed = self.list_episode_source_rebuilds( + statuses=["failed"], + limit=max(1, int(failed_limit)), + ) + return { + "counts": counts, + "running": running, + "failed": failed, + } + + def get_live_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: + """获取指定 source 下所有 live paragraphs。""" + token = self._normalize_episode_source(source) + if not token: + return [] + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM paragraphs + WHERE TRIM(COALESCE(source, '')) = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY created_at ASC, hash ASC + """, + (token,), + ) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def list_episode_sources_for_rebuild(self) -> List[str]: + """列出全量重建涉及的 source(live paragraphs + stale episodes)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT DISTINCT source + FROM ( + SELECT TRIM(source) AS source + FROM paragraphs + WHERE TRIM(COALESCE(source, '')) != '' + AND (is_deleted IS NULL OR is_deleted = 0) + UNION + SELECT TRIM(source) AS source + FROM episodes + WHERE TRIM(COALESCE(source, '')) != '' + ) + WHERE TRIM(COALESCE(source, '')) != '' + ORDER BY source ASC + """ + ) + return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) + + def is_episode_source_query_blocked(self, source: str) -> bool: + """判断 source 是否处于重建中或失败状态。""" + token = self._normalize_episode_source(source) + if not token: + return False + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM episode_rebuild_sources + WHERE source = ? + AND status IN ('pending', 'running', 'failed') + LIMIT 1 + """, + (token,), + ) + return cursor.fetchone() is not None + + def replace_episodes_for_source( + self, + source: str, + episodes_payloads: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """按 source 全量替换 episode 结果。""" + token = self._normalize_episode_source(source) + if not token: + return {"source": "", "episode_count": 0} + + payloads = [dict(item) for item in (episodes_payloads or []) if isinstance(item, dict)] + now = datetime.now().timestamp() + cursor = self._conn.cursor() + + try: + cursor.execute("BEGIN IMMEDIATE") + cursor.execute( + """ + SELECT episode_id, created_at + FROM episodes + WHERE TRIM(COALESCE(source, '')) = ? + """, + (token,), + ) + existing_created_at = { + str(row["episode_id"]): self._as_optional_float(row["created_at"]) + for row in cursor.fetchall() + } + + cursor.execute( + "DELETE FROM episodes WHERE TRIM(COALESCE(source, '')) = ?", + (token,), + ) + + inserted_count = 0 + for raw_payload in payloads: + title = str(raw_payload.get("title", "") or "").strip() + summary = str(raw_payload.get("summary", "") or "").strip() + evidence_ids = [ + str(item).strip() + for item in (raw_payload.get("evidence_ids") or []) + if str(item).strip() + ] + evidence_ids = list(dict.fromkeys(evidence_ids)) + if not title or not summary or not evidence_ids: + continue + + episode_id = str(raw_payload.get("episode_id", "") or "").strip() + if not episode_id: + seed = json.dumps( + { + "source": token, + "title": title, + "summary": summary, + "event_time_start": raw_payload.get("event_time_start"), + "event_time_end": raw_payload.get("event_time_end"), + "evidence_ids": evidence_ids, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + participants = [ + str(item).strip() + for item in (raw_payload.get("participants") or []) + if str(item).strip() + ][:16] + keywords = [ + str(item).strip() + for item in (raw_payload.get("keywords") or []) + if str(item).strip() + ][:20] + paragraph_count = raw_payload.get("paragraph_count", len(evidence_ids)) + try: + paragraph_count = max(0, int(paragraph_count)) + except Exception: + paragraph_count = len(evidence_ids) + if paragraph_count <= 0: + paragraph_count = len(evidence_ids) + if paragraph_count <= 0: + continue + + time_confidence = raw_payload.get("time_confidence", 1.0) + llm_confidence = raw_payload.get("llm_confidence", 0.0) + try: + time_confidence = float(time_confidence) + except Exception: + time_confidence = 1.0 + try: + llm_confidence = float(llm_confidence) + except Exception: + llm_confidence = 0.0 + + created_at = existing_created_at.get(episode_id) + created_ts = created_at if created_at is not None else now + updated_ts = self._as_optional_float(raw_payload.get("updated_at")) or now + + cursor.execute( + """ + INSERT INTO episodes ( + episode_id, source, title, summary, + event_time_start, event_time_end, time_granularity, time_confidence, + participants_json, keywords_json, evidence_ids_json, + paragraph_count, llm_confidence, segmentation_model, segmentation_version, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + episode_id, + token, + title[:120], + summary[:2000], + self._as_optional_float(raw_payload.get("event_time_start")), + self._as_optional_float(raw_payload.get("event_time_end")), + str(raw_payload.get("time_granularity", "") or "").strip() or None, + time_confidence, + json.dumps(participants, ensure_ascii=False), + json.dumps(keywords, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + paragraph_count, + llm_confidence, + str(raw_payload.get("segmentation_model", "") or "").strip() or None, + str(raw_payload.get("segmentation_version", "") or "").strip() or None, + created_ts, + updated_ts, + ), + ) + cursor.executemany( + """ + INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) + VALUES (?, ?, ?) + """, + [(episode_id, hash_value, idx) for idx, hash_value in enumerate(evidence_ids)], + ) + inserted_count += 1 + + self._conn.commit() + return {"source": token, "episode_count": inserted_count} + except Exception: + self._conn.rollback() + raise + + def enqueue_episode_pending( + self, + paragraph_hash: str, + source: Optional[str] = None, + created_at: Optional[float] = None, + ) -> None: + """将段落入队到 episode 异步生成队列。""" + token = str(paragraph_hash or "").strip() + if not token: + return + now = datetime.now().timestamp() + created_ts = float(created_at) if created_at is not None else now + src = str(source or "").strip() or None + + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO episode_pending_paragraphs ( + paragraph_hash, source, created_at, status, retry_count, last_error, updated_at + ) VALUES (?, ?, ?, 'pending', 0, NULL, ?) + ON CONFLICT(paragraph_hash) DO UPDATE SET + source = excluded.source, + created_at = COALESCE(episode_pending_paragraphs.created_at, excluded.created_at), + status = CASE + WHEN episode_pending_paragraphs.status = 'done' THEN 'done' + ELSE 'pending' + END, + last_error = CASE + WHEN episode_pending_paragraphs.status = 'done' THEN episode_pending_paragraphs.last_error + ELSE NULL + END, + updated_at = excluded.updated_at + """, + (token, src, created_ts, now), + ) + self._conn.commit() + + def fetch_episode_pending_batch(self, limit: int = 20, max_retry: int = 3) -> List[Dict[str, Any]]: + """获取待处理 episode 队列批次。""" + safe_limit = max(1, int(limit)) + safe_retry = max(0, int(max_retry)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT paragraph_hash, source, created_at, status, retry_count, last_error, updated_at + FROM episode_pending_paragraphs + WHERE status = 'pending' + OR (status = 'failed' AND retry_count < ?) + ORDER BY updated_at ASC + LIMIT ? + """, + (safe_retry, safe_limit), + ) + return [dict(row) for row in cursor.fetchall()] + + def mark_episode_pending_running(self, hashes: List[str]) -> None: + """批量标记队列项为 running。""" + if not hashes: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + chunk_size = 500 + uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) + for i in range(0, len(uniq), chunk_size): + chunk = uniq[i:i + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE episode_pending_paragraphs + SET status = 'running', updated_at = ? + WHERE paragraph_hash IN ({placeholders}) + AND status IN ('pending', 'failed') + """, + [now] + chunk, + ) + self._conn.commit() + + def mark_episode_pending_done(self, hashes: List[str]) -> None: + """批量标记队列项为 done。""" + if not hashes: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + chunk_size = 500 + uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) + for i in range(0, len(uniq), chunk_size): + chunk = uniq[i:i + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE episode_pending_paragraphs + SET status = 'done', + last_error = NULL, + updated_at = ? + WHERE paragraph_hash IN ({placeholders}) + """, + [now] + chunk, + ) + self._conn.commit() + + def mark_episode_pending_failed(self, hash_value: str, error: str = "") -> None: + """标记单条队列项失败并累加重试次数。""" + token = str(hash_value or "").strip() + if not token: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE episode_pending_paragraphs + SET status = 'failed', + retry_count = COALESCE(retry_count, 0) + 1, + last_error = ?, + updated_at = ? + WHERE paragraph_hash = ? + """, + (str(error or ""), now, token), + ) + self._conn.commit() + + def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + + def _load_list(raw: Any) -> List[Any]: + if not raw: + return [] + try: + val = json.loads(raw) + return val if isinstance(val, list) else [] + except Exception: + return [] + + data["participants"] = _load_list(data.pop("participants_json", None)) + data["keywords"] = _load_list(data.pop("keywords_json", None)) + data["evidence_ids"] = _load_list(data.pop("evidence_ids_json", None)) + return data + + @staticmethod + def _as_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except Exception: + return None + + def upsert_episode(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """写入或更新 episode。""" + if not isinstance(payload, dict): + raise ValueError("payload 必须是字典") + + title = str(payload.get("title", "") or "").strip() + summary = str(payload.get("summary", "") or "").strip() + if not title: + raise ValueError("episode.title 不能为空") + if not summary: + raise ValueError("episode.summary 不能为空") + + source = str(payload.get("source", "") or "").strip() or None + participants_raw = payload.get("participants", []) or [] + keywords_raw = payload.get("keywords", []) or [] + evidence_ids_raw = payload.get("evidence_ids", []) or [] + participants = [str(x).strip() for x in participants_raw if str(x).strip()] + keywords = [str(x).strip() for x in keywords_raw if str(x).strip()] + evidence_ids = [str(x).strip() for x in evidence_ids_raw if str(x).strip()] + + now = datetime.now().timestamp() + created_at = self._as_optional_float(payload.get("created_at")) + updated_at = self._as_optional_float(payload.get("updated_at")) + created_ts = created_at if created_at is not None else now + updated_ts = updated_at if updated_at is not None else now + + episode_id = str(payload.get("episode_id", "") or "").strip() + if not episode_id: + seed = json.dumps( + { + "source": source, + "title": title, + "summary": summary, + "event_time_start": payload.get("event_time_start"), + "event_time_end": payload.get("event_time_end"), + "evidence_ids": evidence_ids, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + paragraph_count = payload.get("paragraph_count") + if paragraph_count is None: + paragraph_count = len(evidence_ids) + try: + paragraph_count = int(paragraph_count) + except Exception: + paragraph_count = len(evidence_ids) + + time_conf = payload.get("time_confidence", 1.0) + llm_conf = payload.get("llm_confidence", 0.0) + try: + time_conf = float(time_conf) + except Exception: + time_conf = 1.0 + try: + llm_conf = float(llm_conf) + except Exception: + llm_conf = 0.0 + + cursor = self._conn.cursor() + cursor.execute( + "SELECT created_at FROM episodes WHERE episode_id = ? LIMIT 1", + (episode_id,), + ) + existed = cursor.fetchone() + if existed and existed[0] is not None: + created_ts = float(existed[0]) + + cursor.execute( + """ + INSERT INTO episodes ( + episode_id, source, title, summary, + event_time_start, event_time_end, time_granularity, time_confidence, + participants_json, keywords_json, evidence_ids_json, + paragraph_count, llm_confidence, segmentation_model, segmentation_version, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(episode_id) DO UPDATE SET + source = excluded.source, + title = excluded.title, + summary = excluded.summary, + event_time_start = excluded.event_time_start, + event_time_end = excluded.event_time_end, + time_granularity = excluded.time_granularity, + time_confidence = excluded.time_confidence, + participants_json = excluded.participants_json, + keywords_json = excluded.keywords_json, + evidence_ids_json = excluded.evidence_ids_json, + paragraph_count = excluded.paragraph_count, + llm_confidence = excluded.llm_confidence, + segmentation_model = excluded.segmentation_model, + segmentation_version = excluded.segmentation_version, + updated_at = excluded.updated_at + """, + ( + episode_id, + source, + title, + summary, + self._as_optional_float(payload.get("event_time_start")), + self._as_optional_float(payload.get("event_time_end")), + str(payload.get("time_granularity", "") or "").strip() or None, + time_conf, + json.dumps(participants, ensure_ascii=False), + json.dumps(keywords, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + max(0, paragraph_count), + llm_conf, + str(payload.get("segmentation_model", "") or "").strip() or None, + str(payload.get("segmentation_version", "") or "").strip() or None, + created_ts, + updated_ts, + ), + ) + self._conn.commit() + return self.get_episode_by_id(episode_id) or {"episode_id": episode_id} + + def bind_episode_paragraphs(self, episode_id: str, paragraph_hashes_ordered: List[str]) -> int: + """重建 episode 与段落映射。""" + token = str(episode_id or "").strip() + if not token: + raise ValueError("episode_id 不能为空") + + normalized: List[str] = [] + seen = set() + for item in paragraph_hashes_ordered or []: + h = str(item or "").strip() + if not h or h in seen: + continue + seen.add(h) + normalized.append(h) + + cursor = self._conn.cursor() + cursor.execute("DELETE FROM episode_paragraphs WHERE episode_id = ?", (token,)) + + if normalized: + cursor.executemany( + """ + INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) + VALUES (?, ?, ?) + """, + [(token, h, idx) for idx, h in enumerate(normalized)], + ) + + now = datetime.now().timestamp() + cursor.execute( + """ + UPDATE episodes + SET paragraph_count = ?, updated_at = ? + WHERE episode_id = ? + """, + (len(normalized), now, token), + ) + self._conn.commit() + return len(normalized) + + def _build_episode_query_components( + self, + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> Tuple[str, str, str, List[str], List[Any]]: + source_expr = "TRIM(COALESCE(e.source, ''))" + effective_start = "COALESCE(e.event_time_start, e.event_time_end, e.updated_at)" + effective_end = "COALESCE(e.event_time_end, e.event_time_start, e.updated_at)" + conditions: List[str] = [] + params: List[Any] = [] + + conditions.append(f"{source_expr} != ''") + conditions.append("COALESCE(e.paragraph_count, 0) > 0") + conditions.append( + """ + NOT EXISTS ( + SELECT 1 + FROM episode_rebuild_sources ers + WHERE ers.source = TRIM(COALESCE(e.source, '')) + AND ers.status IN ('pending', 'running', 'failed') + ) + """ + ) + + if source: + token = self._normalize_episode_source(source) + if not token: + return source_expr, effective_start, effective_end, ["1 = 0"], [] + conditions.append(f"{source_expr} = ?") + params.append(token) + + p = str(person or "").strip().lower() + if p: + like_person = f"%{p}%" + conditions.append( + """ + ( + LOWER(COALESCE(e.participants_json, '')) LIKE ? + OR EXISTS ( + SELECT 1 + FROM episode_paragraphs ep_person + JOIN paragraph_entities pe ON pe.paragraph_hash = ep_person.paragraph_hash + JOIN entities en ON en.hash = pe.entity_hash + WHERE ep_person.episode_id = e.episode_id + AND LOWER(en.name) LIKE ? + ) + ) + """ + ) + params.extend([like_person, like_person]) + + if time_from is not None and time_to is not None: + conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") + params.extend([float(time_from), float(time_to)]) + elif time_from is not None: + conditions.append(f"({effective_end} >= ?)") + params.append(float(time_from)) + elif time_to is not None: + conditions.append(f"({effective_start} <= ?)") + params.append(float(time_to)) + + return source_expr, effective_start, effective_end, conditions, params + + def get_episode_rows_by_paragraph_hashes( + self, + paragraph_hashes: List[str], + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> List[Dict[str, Any]]: + normalized: List[str] = [] + seen = set() + for item in paragraph_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return [] + + _, _, _, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + placeholders = ",".join(["?"] * len(normalized)) + conditions.append(f"ep.paragraph_hash IN ({placeholders})") + conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") + where_sql = "WHERE " + " AND ".join(conditions) + + sql = f""" + SELECT e.*, ep.paragraph_hash AS matched_paragraph_hash + FROM episodes e + JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id + JOIN paragraphs p ON p.hash = ep.paragraph_hash + {where_sql} + ORDER BY e.updated_at DESC + """ + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params + normalized)) + + grouped: Dict[str, Dict[str, Any]] = {} + for row in cursor.fetchall(): + episode_id = str(row["episode_id"] or "").strip() + if not episode_id: + continue + payload = grouped.get(episode_id) + if payload is None: + payload = self._episode_row_to_dict(row) + payload["matched_paragraph_hashes"] = [] + grouped[episode_id] = payload + matched_hash = str(row["matched_paragraph_hash"] or "").strip() + if matched_hash and matched_hash not in payload["matched_paragraph_hashes"]: + payload["matched_paragraph_hashes"].append(matched_hash) + + out = list(grouped.values()) + for item in out: + item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) + return out + + def get_episode_rows_by_relation_hashes( + self, + relation_hashes: List[str], + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> List[Dict[str, Any]]: + normalized: List[str] = [] + seen = set() + for item in relation_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return [] + + _, _, _, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + placeholders = ",".join(["?"] * len(normalized)) + conditions.append(f"pr.relation_hash IN ({placeholders})") + conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") + where_sql = "WHERE " + " AND ".join(conditions) + + sql = f""" + SELECT + e.*, + p.hash AS matched_paragraph_hash, + pr.relation_hash AS matched_relation_hash + FROM episodes e + JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id + JOIN paragraphs p ON p.hash = ep.paragraph_hash + JOIN paragraph_relations pr ON pr.paragraph_hash = p.hash + {where_sql} + ORDER BY e.updated_at DESC + """ + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params + normalized)) + + grouped: Dict[str, Dict[str, Any]] = {} + for row in cursor.fetchall(): + episode_id = str(row["episode_id"] or "").strip() + if not episode_id: + continue + payload = grouped.get(episode_id) + if payload is None: + payload = self._episode_row_to_dict(row) + payload["matched_paragraph_hashes"] = [] + payload["matched_relation_hashes"] = [] + grouped[episode_id] = payload + matched_paragraph = str(row["matched_paragraph_hash"] or "").strip() + matched_relation = str(row["matched_relation_hash"] or "").strip() + if matched_paragraph and matched_paragraph not in payload["matched_paragraph_hashes"]: + payload["matched_paragraph_hashes"].append(matched_paragraph) + if matched_relation and matched_relation not in payload["matched_relation_hashes"]: + payload["matched_relation_hashes"].append(matched_relation) + + out = list(grouped.values()) + for item in out: + item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) + item["matched_relation_count"] = len(item.get("matched_relation_hashes", [])) + return out + + def query_episodes( + self, + query: str = "", + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """查询 episode 列表。""" + safe_limit = max(1, int(limit)) + _, effective_start, effective_end, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + + q = str(query or "").strip().lower() + select_score_sql = "0.0 AS lexical_score" + order_sql = f"{effective_end} DESC, e.updated_at DESC" + select_params: List[Any] = [] + query_params: List[Any] = [] + if q: + like = f"%{q}%" + title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?" + summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?" + keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?" + participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?" + conditions.append( + f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})" + ) + select_score_sql = ( + f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + " + f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + " + f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + " + f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score" + ) + select_params.extend([like, like, like, like]) + query_params.extend([like, like, like, like]) + order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" + + where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" + sql = f""" + SELECT e.*, {select_score_sql} + FROM episodes e + {where_sql} + ORDER BY {order_sql} + LIMIT ? + """ + final_params = list(select_params) + list(params) + list(query_params) + [safe_limit] + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(final_params)) + return [self._episode_row_to_dict(row) for row in cursor.fetchall()] + + def get_episode_by_id(self, episode_id: str) -> Optional[Dict[str, Any]]: + """获取单条 episode。""" + token = str(episode_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE episode_id = ? LIMIT 1", + (token,), + ) + row = cursor.fetchone() + if not row: + return None + return self._episode_row_to_dict(row) + + def get_episode_paragraphs(self, episode_id: str, limit: int = 100) -> List[Dict[str, Any]]: + """获取 episode 关联段落(按 position 排序)。""" + token = str(episode_id or "").strip() + if not token: + return [] + safe_limit = max(1, int(limit)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT p.*, ep.position + FROM episode_paragraphs ep + JOIN paragraphs p ON p.hash = ep.paragraph_hash + WHERE ep.episode_id = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY ep.position ASC + LIMIT ? + """, + (token, safe_limit), + ) + items = [] + for row in cursor.fetchall(): + payload = self._row_to_dict(row, "paragraph") + payload["position"] = row["position"] + items.append(payload) + return items + + def has_table(self, table_name: str) -> bool: + """检查数据库是否存在指定表。""" + if not self._conn: + return False + cursor = self._conn.cursor() + cursor.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? LIMIT 1", + (table_name,), + ) + return cursor.fetchone() is not None + + def get_deleted_entities(self, limit: int = 50) -> List[Dict[str, Any]]: + """获取已软删除的实体 (回收站用)""" + if not self.has_table("entities"): return [] + + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash, name, deleted_at + FROM entities + WHERE is_deleted = 1 + ORDER BY deleted_at DESC + LIMIT ? + """, (limit,)) + + items = [] + for row in cursor.fetchall(): + items.append({ + "hash": row[0], + "name": row[1], + "type": "entity", # 标记为实体 + "deleted_at": row[2] + }) + return items + + def __repr__(self) -> str: + stats = self.get_statistics() if self.is_connected else {} + return ( + f"MetadataStore(paragraphs={stats.get('paragraph_count', 0)}, " + f"entities={stats.get('entity_count', 0)}, " + f"relations={stats.get('relation_count', 0)})" + ) + + def has_data(self) -> bool: + """检查磁盘上是否存在现有数据""" + if self.data_dir is None: + return False + return (self.data_dir / self.db_name).exists() diff --git a/plugins/A_memorix/core/storage/type_detection.py b/plugins/A_memorix/core/storage/type_detection.py new file mode 100644 index 00000000..c20d2cb4 --- /dev/null +++ b/plugins/A_memorix/core/storage/type_detection.py @@ -0,0 +1,137 @@ +"""Heuristic detection for import strategies and stored knowledge types.""" + +from __future__ import annotations + +import re +from typing import Optional + +from .knowledge_types import ( + ImportStrategy, + KnowledgeType, + parse_import_strategy, + resolve_stored_knowledge_type, +) + + +_NARRATIVE_MARKERS = [ + r"然后", + r"接着", + r"于是", + r"后来", + r"最后", + r"突然", + r"一天", + r"曾经", + r"有一次", + r"从前", + r"说道", + r"问道", + r"想着", + r"觉得", +] +_FACTUAL_MARKERS = [ + r"是", + r"有", + r"在", + r"为", + r"属于", + r"位于", + r"包含", + r"拥有", + r"成立于", + r"出生于", +] + + +def _non_empty_lines(content: str) -> list[str]: + return [line for line in str(content or "").splitlines() if line.strip()] + + +def looks_like_structured_text(content: str) -> bool: + text = str(content or "").strip() + if "|" not in text or text.count("|") < 2: + return False + parts = text.split("|") + return len(parts) == 3 and all(part.strip() for part in parts) + + +def looks_like_quote_text(content: str) -> bool: + lines = _non_empty_lines(content) + if len(lines) < 5: + return False + avg_len = sum(len(line) for line in lines) / len(lines) + return avg_len < 20 + + +def looks_like_narrative_text(content: str) -> bool: + text = str(content or "").strip() + if not text: + return False + + narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text)) + has_dialogue = bool(re.search(r'["「『].*?["」』]', text)) + has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###")) + return has_chapter or has_dialogue or narrative_score >= 2 + + +def looks_like_factual_text(content: str) -> bool: + text = str(content or "").strip() + if not text: + return False + if looks_like_structured_text(text) or looks_like_quote_text(text): + return False + + factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text)) + if factual_score <= 0: + return False + + if len(text) <= 240: + return True + return factual_score >= 2 and not looks_like_narrative_text(text) + + +def select_import_strategy( + content: str, + *, + override: Optional[str | ImportStrategy] = None, + chat_log: bool = False, +) -> ImportStrategy: + """文本导入策略选择:override > quote > factual > narrative。""" + + if chat_log: + return ImportStrategy.NARRATIVE + + strategy = parse_import_strategy(override, default=ImportStrategy.AUTO) + if strategy != ImportStrategy.AUTO: + return strategy + + if looks_like_quote_text(content): + return ImportStrategy.QUOTE + if looks_like_factual_text(content): + return ImportStrategy.FACTUAL + return ImportStrategy.NARRATIVE + + +def detect_knowledge_type(content: str) -> KnowledgeType: + """自动检测落库 knowledge_type;无法可靠判断时回退 mixed。""" + + text = str(content or "").strip() + if not text: + return KnowledgeType.MIXED + if looks_like_structured_text(text): + return KnowledgeType.STRUCTURED + if looks_like_quote_text(text): + return KnowledgeType.QUOTE + if looks_like_factual_text(text): + return KnowledgeType.FACTUAL + if looks_like_narrative_text(text): + return KnowledgeType.NARRATIVE + return KnowledgeType.MIXED + + +def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType: + """优先使用显式 type_hint,否则自动检测。""" + + if type_hint: + return resolve_stored_knowledge_type(type_hint, content=content) + return detect_knowledge_type(content) diff --git a/plugins/A_memorix/core/storage/vector_store.py b/plugins/A_memorix/core/storage/vector_store.py new file mode 100644 index 00000000..97a9144c --- /dev/null +++ b/plugins/A_memorix/core/storage/vector_store.py @@ -0,0 +1,776 @@ +""" +向量存储模块 + +基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。 +""" + +import os +import pickle +import hashlib +import shutil +import time +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Set, Any +import random +import threading # Added threading import + +import numpy as np + +try: + import faiss + HAS_FAISS = True +except ImportError: + HAS_FAISS = False + +from src.common.logger import get_logger +from ..utils.quantization import QuantizationType +from ..utils.io import atomic_write, atomic_save_path + +logger = get_logger("A_Memorix.VectorStore") + + +class VectorStore: + """ + 向量存储类 (SQ8 + Append-Only Disk) + + 特性: + - 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit)) + - 存储: float16 on-disk binary (vectors.bin) + - 内存: 仅索引常驻 RAM (<512MB for 100k vectors) + - ID: SHA1-based stable int64 IDs + - 一致性: 强制 L2 Normalization (IP == Cosine) + """ + + # 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化) + DEFAULT_MIN_TRAIN = 40 + # 强制训练样本量 + TRAIN_SIZE = 10000 + # 储水池采样上限 (流式处理前 50k 数据) + RESERVOIR_CAPACITY = 10000 + RESERVOIR_SAMPLE_SCOPE = 50000 + + def __init__( + self, + dimension: int, + quantization_type: QuantizationType = QuantizationType.INT8, + index_type: str = "sq8", + data_dir: Optional[Union[str, Path]] = None, + use_mmap: bool = True, + buffer_size: int = 1024, + ): + if not HAS_FAISS: + raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu") + + self.dimension = dimension + self.data_dir = Path(data_dir) if data_dir else None + if self.data_dir: + self.data_dir.mkdir(parents=True, exist_ok=True) + if quantization_type != QuantizationType.INT8: + raise ValueError( + "vNext 仅支持 quantization_type=int8(SQ8)。" + " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" + ) + normalized_index_type = str(index_type or "sq8").strip().lower() + if normalized_index_type not in {"sq8", "int8"}: + raise ValueError( + "vNext 仅支持 index_type=sq8。" + " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" + ) + self.quantization_type = QuantizationType.INT8 + self.index_type = "sq8" + self.buffer_size = buffer_size + + self._index: Optional[faiss.IndexIDMap2] = None + self._init_index() + + self._is_trained = False + self._vector_norm = "l2" + + # Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力 + # 必须使用 IndexIDMap2 以保证 ID 与主索引一致 + self._fallback_index: Optional[faiss.IndexIDMap2] = None + self._init_fallback_index() + + self._known_hashes: Set[str] = set() + self._deleted_ids: Set[int] = set() + + self._reservoir_buffer: List[np.ndarray] = [] + self._seen_count_for_reservoir = 0 + + self._write_buffer_vecs: List[np.ndarray] = [] + self._write_buffer_ids: List[int] = [] + + self._total_added = 0 + self._total_deleted = 0 + self._bin_count = 0 + + # Thread safety lock + self._lock = threading.RLock() + + logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage") + + def _init_index(self): + """初始化空的 Faiss 索引""" + quantizer = faiss.IndexScalarQuantizer( + self.dimension, + faiss.ScalarQuantizer.QT_8bit, + faiss.METRIC_INNER_PRODUCT + ) + self._index = faiss.IndexIDMap2(quantizer) + self._is_trained = False + + def _init_fallback_index(self): + """初始化 Flat 回退索引""" + flat_index = faiss.IndexFlatIP(self.dimension) + self._fallback_index = faiss.IndexIDMap2(flat_index) + logger.debug("Fallback index (Flat) initialized.") + + @staticmethod + def _generate_id(key: str) -> int: + """生成稳定的 int64 ID (SHA1 截断)""" + h = hashlib.sha1(key.encode("utf-8")).digest() + val = int.from_bytes(h[:8], byteorder="big", signed=False) + return val & 0x7FFFFFFFFFFFFFFF + + @property + def _bin_path(self) -> Path: + return self.data_dir / "vectors.bin" + + @property + def _ids_bin_path(self) -> Path: + return self.data_dir / "vectors_ids.bin" + + @property + def _int_to_str_map(self) -> Dict[int, str]: + """Lazy build volatile map from known hashes""" + # Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently + # But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache + if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes): + with self._lock: # Protect cache rebuild + self._cached_map = {self._generate_id(k): k for k in self._known_hashes} + return self._cached_map + + def add(self, vectors: np.ndarray, ids: List[str]) -> int: + with self._lock: + if vectors.shape[1] != self.dimension: + raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}") + + vectors = np.ascontiguousarray(vectors, dtype=np.float32) + faiss.normalize_L2(vectors) + + processed_vecs = [] + processed_int_ids = [] + + for i, str_id in enumerate(ids): + if str_id in self._known_hashes: + continue + + int_id = self._generate_id(str_id) + self._known_hashes.add(str_id) + + processed_vecs.append(vectors[i]) + processed_int_ids.append(int_id) + + if not processed_vecs: + return 0 + + batch_vecs = np.array(processed_vecs, dtype=np.float32) + batch_ids = np.array(processed_int_ids, dtype=np.int64) + + self._write_buffer_vecs.append(batch_vecs) + self._write_buffer_ids.extend(processed_int_ids) + + if len(self._write_buffer_ids) >= self.buffer_size: + self._flush_write_buffer_unlocked() + + if not self._is_trained: + # 双写到回退索引 + self._fallback_index.add_with_ids(batch_vecs, batch_ids) + + self._update_reservoir(batch_vecs) + # 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断 + if len(self._reservoir_buffer) >= 10000: + logger.info(f"训练样本达到上限,开始训练...") + self._train_and_replay_unlocked() + + self._total_added += len(batch_ids) + return len(batch_ids) + + def _flush_write_buffer(self): + with self._lock: + self._flush_write_buffer_unlocked() + + def _flush_write_buffer_unlocked(self): + if not self._write_buffer_vecs: + return + + batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0) + batch_ids = np.array(self._write_buffer_ids, dtype=np.int64) + + vecs_fp16 = batch_vecs.astype(np.float16) + + with open(self._bin_path, "ab") as f: + f.write(vecs_fp16.tobytes()) + + ids_bytes = batch_ids.astype('>i8').tobytes() + with open(self._ids_bin_path, "ab") as f: + f.write(ids_bytes) + + self._bin_count += len(batch_ids) + + if self._is_trained and self._index.is_trained: + self._index.add_with_ids(batch_vecs, batch_ids) + else: + # 即使在 flush 时,如果未训练,也要同步到 fallback + self._fallback_index.add_with_ids(batch_vecs, batch_ids) + + self._write_buffer_vecs.clear() + self._write_buffer_ids.clear() + + def _update_reservoir(self, vectors: np.ndarray): + for vec in vectors: + self._seen_count_for_reservoir += 1 + if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY: + self._reservoir_buffer.append(vec) + else: + if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE: + r = random.randint(0, self._seen_count_for_reservoir - 1) + if r < self.RESERVOIR_CAPACITY: + self._reservoir_buffer[r] = vec + + def _train_and_replay(self): + with self._lock: + self._train_and_replay_unlocked() + + def _train_and_replay_unlocked(self): + if not self._reservoir_buffer: + logger.warning("No training data available.") + return + + train_data = np.array(self._reservoir_buffer, dtype=np.float32) + logger.info(f"Training Index with {len(train_data)} samples...") + + try: + self._index.train(train_data) + except Exception as e: + logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.") + return + + self._is_trained = True + self._reservoir_buffer = [] + + logger.info("Replaying data from disk to populate index...") + try: + replay_count = self._replay_vectors_to_index() + # 只有当 replay 成功且数据量一致时,才释放回退索引 + if self._index.ntotal >= self._bin_count: + logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.") + self._fallback_index.reset() + else: + logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.") + except Exception as e: + logger.error(f"Replay failed: {e}. Keeping fallback index as backup.") + + def _replay_vectors_to_index(self) -> int: + """从 vectors.bin 读取并添加到 index""" + if not self._bin_path.exists() or not self._ids_bin_path.exists(): + return 0 + + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + + if not vec_data: + break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_fp32 = batch_fp16.astype(np.float32) + faiss.normalize_L2(batch_fp32) + + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + if not all(valid_mask): + batch_fp32 = batch_fp32[valid_mask] + batch_ids = batch_ids[valid_mask] + + if len(batch_ids) > 0: + self._index.add_with_ids(batch_fp32, batch_ids) + + def search( + self, + query: np.ndarray, + k: int = 10, + filter_deleted: bool = True, + ) -> Tuple[List[str], List[float]]: + query_local = np.array(query, dtype=np.float32, order="C", copy=True) + if query_local.ndim == 1: + got_dim = int(query_local.shape[0]) + query_local = query_local.reshape(1, -1) + elif query_local.ndim == 2: + if query_local.shape[0] != 1: + raise ValueError( + f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" + ) + got_dim = int(query_local.shape[1]) + else: + raise ValueError( + f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" + ) + + if got_dim != self.dimension: + raise ValueError( + f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}" + ) + if not np.all(np.isfinite(query_local)): + raise ValueError("query embedding contains non-finite values") + + faiss.normalize_L2(query_local) + + # 查询路径仅负责检索,不在此触发训练/回放。 + # 训练/回放前置到 warmup_index(),并由插件启动阶段触发。 + # Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。 + with self._lock: + self._flush_write_buffer_unlocked() + search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index + if search_index.ntotal == 0: + logger.warning("Indices are empty. No data to search.") + return [], [] + # 执行检索 + dists, ids = search_index.search(query_local, k * 2) + + # Faiss search 返回的是 (1, K) 的数组,取第一行 + dists = dists[0] + ids = ids[0] + + results = [] + for id_val, score in zip(ids, dists): + if id_val == -1: continue + if filter_deleted and id_val in self._deleted_ids: + continue + + str_id = self._int_to_str_map.get(id_val) + if str_id: + results.append((str_id, float(score))) + + # Sort and trim just in case filtering reduced count + results.sort(key=lambda x: x[1], reverse=True) + results = results[:k] + + if not results: + return [], [] + + return [r[0] for r in results], [r[1] for r in results] + + def warmup_index(self, force_train: bool = True) -> Dict[str, Any]: + """ + 预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。 + + Args: + force_train: 是否在满足阈值时强制训练 SQ8 索引 + + Returns: + 预热状态摘要 + """ + started = time.perf_counter() + logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}") + + try: + with self._lock: + self._flush_write_buffer() + + if self._bin_path.exists(): + self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2) + else: + self._bin_count = 0 + + needs_fallback_bootstrap = ( + self._bin_count > 0 + and self._fallback_index.ntotal == 0 + and (not self._is_trained or self._index.ntotal == 0) + ) + if needs_fallback_bootstrap: + self._bootstrap_fallback_from_disk() + + min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN))) + needs_train = ( + bool(force_train) + and self._bin_count >= min_train + and not self._is_trained + ) + if needs_train: + self._force_train_small_data() + + duration_ms = (time.perf_counter() - started) * 1000.0 + summary = { + "ok": True, + "trained": bool(self._is_trained), + "index_ntotal": int(self._index.ntotal), + "fallback_ntotal": int(self._fallback_index.ntotal), + "bin_count": int(self._bin_count), + "duration_ms": duration_ms, + "error": None, + } + except Exception as e: + duration_ms = (time.perf_counter() - started) * 1000.0 + summary = { + "ok": False, + "trained": bool(self._is_trained), + "index_ntotal": int(self._index.ntotal) if self._index is not None else 0, + "fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0, + "bin_count": int(getattr(self, "_bin_count", 0)), + "duration_ms": duration_ms, + "error": str(e), + } + logger.error( + "metric.vector_index_prewarm_fail=1 " + f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} " + f"error={e}" + ) + return summary + + logger.info( + "metric.vector_index_prewarm_success=1 " + f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} " + f"trained={summary['trained']} " + f"index_ntotal={summary['index_ntotal']} " + f"fallback_ntotal={summary['fallback_ntotal']} " + f"bin_count={summary['bin_count']}" + ) + return summary + + def _bootstrap_fallback_from_disk(self): + with self._lock: + self._bootstrap_fallback_from_disk_unlocked() + + def _bootstrap_fallback_from_disk_unlocked(self): + """重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引""" + if not self._bin_path.exists() or not self._ids_bin_path.exists(): + return + + logger.info("Replaying all disk vectors to fallback index...") + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + if not vec_data: break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_fp32 = batch_fp16.astype(np.float32) + faiss.normalize_L2(batch_fp32) + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + if any(valid_mask): + self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask]) + + logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.") + + def _force_train_small_data(self): + with self._lock: + self._force_train_small_data_unlocked() + + def _force_train_small_data_unlocked(self): + logger.info("Forcing training on small dataset...") + self._reservoir_buffer = [] + + chunk_size = 10000 + vec_item_size = self.dimension * 2 + + with open(self._bin_path, "rb") as f: + while len(self._reservoir_buffer) < self.TRAIN_SIZE: + data = f.read(chunk_size * vec_item_size) + if not data: break + fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension) + fp32 = fp16.astype(np.float32) + faiss.normalize_L2(fp32) + + for vec in fp32: + self._reservoir_buffer.append(vec) + if len(self._reservoir_buffer) >= self.TRAIN_SIZE: + break + + self._train_and_replay_unlocked() + + def delete(self, ids: List[str]) -> int: + with self._lock: + count = 0 + for str_id in ids: + if str_id not in self._known_hashes: + continue + int_id = self._generate_id(str_id) + if int_id not in self._deleted_ids: + self._deleted_ids.add(int_id) + if self._index.is_trained: + self._index.remove_ids(np.array([int_id], dtype=np.int64)) + # 同步从 fallback 移除 + if self._fallback_index.ntotal > 0: + self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64)) + count += 1 + self._total_deleted += count + + # Check GC + self._check_rebuild_needed() + return count + + def _check_rebuild_needed(self): + """GC Excution Check""" + if self._bin_count == 0: return + ratio = len(self._deleted_ids) / self._bin_count + if ratio > 0.3 and len(self._deleted_ids) > 1000: + logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})") + self.rebuild_index() + + def rebuild_index(self): + """GC: 重建索引,压缩 bin 文件""" + with self._lock: + self._rebuild_index_locked() + + def _rebuild_index_locked(self): + """实际 GC 重建逻辑。""" + logger.info("Starting Compaction (GC)...") + + tmp_bin = self.data_dir / "vectors.bin.tmp" + tmp_ids = self.data_dir / "vectors_ids.bin.tmp" + + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + new_count = 0 + + # 1. Compact Files + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \ + open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + if not vec_data: break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + + if any(keep_mask): + keep_vecs = batch_fp16[keep_mask] + keep_ids = batch_ids[keep_mask] + + w_vec.write(keep_vecs.tobytes()) + w_id.write(keep_ids.astype('>i8').tobytes()) + new_count += len(keep_ids) + + # 2. Reset State & Atomic Swap + self._bin_count = new_count + + # Close current index + self._index.reset() + if self._fallback_index: self._fallback_index.reset() # Also clear fallback + self._is_trained = False + + # Swap files + shutil.move(str(tmp_bin), str(self._bin_path)) + shutil.move(str(tmp_ids), str(self._ids_bin_path)) + + # Reset Tombstones (Critical) + self._deleted_ids.clear() + + # 3. Reload/Rebuild Index (Fresh Train) + # We need to re-train because data distribution might have changed significantly after deletion + self._init_index() + self._init_fallback_index() # Re-init fallback too + self._force_train_small_data() # This will train and replay from the NEW compact file + + logger.info("Compaction Complete.") + + def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: + with self._lock: + if not data_dir: + data_dir = self.data_dir + if not data_dir: + raise ValueError("No data_dir") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + self._flush_write_buffer_unlocked() + + if self._is_trained: + index_path = data_dir / "vectors.index" + with atomic_save_path(index_path) as tmp: + faiss.write_index(self._index, tmp) + + meta = { + "dimension": self.dimension, + "quantization_type": self.quantization_type.value, + "is_trained": self._is_trained, + "vector_norm": self._vector_norm, + "deleted_ids": list(self._deleted_ids), + "known_hashes": list(self._known_hashes), + } + + with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f: + pickle.dump(meta, f) + + logger.info("VectorStore saved.") + + def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]: + """ + 离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。 + """ + with self._lock: + target_dir = Path(data_dir) if data_dir else self.data_dir + if target_dir is None: + raise ValueError("No data_dir") + target_dir = Path(target_dir) + npy_path = target_dir / "vectors.npy" + idx_path = target_dir / "vectors.index" + bin_path = target_dir / "vectors.bin" + ids_bin_path = target_dir / "vectors_ids.bin" + meta_path = target_dir / "vectors_metadata.pkl" + + if not npy_path.exists(): + return {"migrated": False, "reason": "npy_missing"} + if not meta_path.exists(): + raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl") + if bin_path.exists() and ids_bin_path.exists(): + return {"migrated": False, "reason": "bin_exists"} + + # Reset in-memory state to avoid appending to stale runtime buffers. + self._known_hashes.clear() + self._deleted_ids.clear() + self._write_buffer_vecs.clear() + self._write_buffer_ids.clear() + self._init_index() + self._init_fallback_index() + self._is_trained = False + self._bin_count = 0 + + self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir) + self.save(target_dir) + return {"migrated": True, "reason": "ok"} + + def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: + with self._lock: + if not data_dir: data_dir = self.data_dir + data_dir = Path(data_dir) + + npy_path = data_dir / "vectors.npy" + idx_path = data_dir / "vectors.index" + bin_path = data_dir / "vectors.bin" + + if npy_path.exists() and not bin_path.exists(): + raise RuntimeError( + "检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + meta_path = data_dir / "vectors_metadata.pkl" + if not meta_path.exists(): + logger.warning("No metadata found, initialized empty.") + return + + with open(meta_path, "rb") as f: + meta = pickle.load(f) + + if meta.get("vector_norm") != "l2": + logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...") + self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", [])) + self._deleted_ids = set(meta.get("deleted_ids", [])) + self._init_index() + self._force_train_small_data() + return + + self._is_trained = meta.get("is_trained", False) + self._vector_norm = meta.get("vector_norm", "l2") + self._deleted_ids = set(meta.get("deleted_ids", [])) + self._known_hashes = set(meta.get("known_hashes", [])) + + if self._is_trained: + if idx_path.exists(): + try: + self._index = faiss.read_index(str(idx_path)) + if not isinstance(self._index, faiss.IndexIDMap2): + logger.warning("Loaded index type mismatch. Rebuilding...") + self._init_index() + self._force_train_small_data() + except Exception as e: + logger.error(f"Failed to load index: {e}. Rebuilding...") + self._init_index() + self._force_train_small_data() + else: + logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...") + self._init_index() + self._force_train_small_data() + + if bin_path.exists(): + self._bin_count = bin_path.stat().st_size // (self.dimension * 2) + + def _migrate_from_npy(self, npy_path, idx_path, data_dir): + with self._lock: + self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir) + + def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir): + try: + arr = np.load(npy_path, mmap_mode="r") + except Exception: + arr = np.load(npy_path) + + meta_path = data_dir / "vectors_metadata.pkl" + old_ids = [] + if meta_path.exists(): + with open(meta_path, "rb") as f: + m = pickle.load(f) + old_ids = m.get("ids", []) + + if len(arr) != len(old_ids): + logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}") + return + + logger.info(f"Migrating {len(arr)} vectors...") + + chunk = 1000 + for i in range(0, len(arr), chunk): + sub_arr = arr[i : i+chunk] + sub_ids = old_ids[i : i+chunk] + self.add(sub_arr, sub_ids) + + if not self._is_trained: + self._force_train_small_data() + + shutil.move(str(npy_path), str(npy_path) + ".bak") + if idx_path.exists(): + shutil.move(str(idx_path), str(idx_path) + ".bak") + + logger.info("Migration complete.") + + def clear(self) -> None: + with self._lock: + self._ids_bin_path.unlink(missing_ok=True) + self._bin_path.unlink(missing_ok=True) + self._init_index() + self._known_hashes.clear() + self._deleted_ids.clear() + self._bin_count = 0 + logger.info("VectorStore cleared.") + + def has_data(self) -> bool: + return (self.data_dir / "vectors_metadata.pkl").exists() + + @property + def num_vectors(self) -> int: + return len(self._known_hashes) - len(self._deleted_ids) + + def __contains__(self, hash_value: str) -> bool: + """Check if a hash exists in the store""" + return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids + diff --git a/plugins/A_memorix/core/strategies/__init__.py b/plugins/A_memorix/core/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/A_memorix/core/strategies/base.py b/plugins/A_memorix/core/strategies/base.py new file mode 100644 index 00000000..ff250cdf --- /dev/null +++ b/plugins/A_memorix/core/strategies/base.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass, field +from enum import Enum +import hashlib + +class KnowledgeType(str, Enum): + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + MIXED = "mixed" + +@dataclass +class SourceInfo: + file: str + offset_start: int + offset_end: int + checksum: str = "" + +@dataclass +class ChunkContext: + chunk_id: str + index: int + context: Dict[str, Any] = field(default_factory=dict) + text: str = "" + +@dataclass +class ChunkFlags: + verbatim: bool = False + requires_llm: bool = True + +@dataclass +class ProcessedChunk: + type: KnowledgeType + source: SourceInfo + chunk: ChunkContext + data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities + flags: ChunkFlags = field(default_factory=ChunkFlags) + + def to_dict(self) -> Dict: + return { + "type": self.type.value, + "source": { + "file": self.source.file, + "offset_start": self.source.offset_start, + "offset_end": self.source.offset_end, + "checksum": self.source.checksum + }, + "chunk": { + "text": self.chunk.text, + "chunk_id": self.chunk.chunk_id, + "index": self.chunk.index, + "context": self.chunk.context + }, + "data": self.data, + "flags": { + "verbatim": self.flags.verbatim, + "requires_llm": self.flags.requires_llm + } + } + +class BaseStrategy(ABC): + def __init__(self, filename: str): + self.filename = filename + + @abstractmethod + def split(self, text: str) -> List[ProcessedChunk]: + """按策略将文本切分为块。""" + pass + + @abstractmethod + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + """从文本块中抽取结构化信息。""" + pass + + def calculate_checksum(self, text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + def build_language_guard(self, text: str) -> str: + """ + 构建统一的输出语言约束。 + 不区分语言类型,仅要求抽取值保持原文语言,不做翻译。 + """ + _ = text # 预留参数,便于后续按需扩展 + return ( + "Focus on the original source language. Keep extracted events, entities, predicates " + "and objects in the same language as the source text, preserve names/terms as-is, " + "and do not translate." + ) diff --git a/plugins/A_memorix/core/strategies/factual.py b/plugins/A_memorix/core/strategies/factual.py new file mode 100644 index 00000000..4b7d6e56 --- /dev/null +++ b/plugins/A_memorix/core/strategies/factual.py @@ -0,0 +1,98 @@ +import re +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext + +class FactualStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + # 结构感知切分 + lines = text.split('\n') + chunks = [] + current_chunk_lines = [] + current_len = 0 + target_size = 600 + + for i, line in enumerate(lines): + # 判断是否应当切分 + # 若当前行为列表项/定义/表格行,则尽量不切分 + is_structure = self._is_structural_line(line) + + current_len += len(line) + 1 + current_chunk_lines.append(line) + + # 达到目标长度且不在紧凑结构块内时切分(过长时强制切分) + if current_len >= target_size and not is_structure: + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + current_chunk_lines = [] + current_len = 0 + elif current_len >= target_size * 2: # 超长时强制切分 + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + current_chunk_lines = [] + current_len = 0 + + if current_chunk_lines: + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + + return chunks + + def _is_structural_line(self, line: str) -> bool: + line = line.strip() + if not line: return False + # 列表项 + if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line): + return True + # 定义项(术语: 定义) + if re.match(r'^[^::]+[::].+', line): + return True + # 表格行(按 markdown 语法假设) + if line.startswith('|') and line.endswith('|'): + return True + return False + + def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk: + text = "\n".join(lines) + return ProcessedChunk( + type=KnowledgeType.FACTUAL, + source=SourceInfo( + file=self.filename, + offset_start=0, # 简化处理:真实偏移跟踪需要额外状态 + offset_end=0, + checksum=self.calculate_checksum(text) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{index}", + index=index, + text=text + ) + ) + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + if not llm_func: + raise ValueError("LLM function required for Factual extraction") + + language_guard = self.build_language_guard(chunk.chunk.text) + prompt = f"""You are a factual knowledge extraction engine. +Extract factual triples and entities from the text. +Preserve lists and definitions accurately. + +Language constraints: +- {language_guard} +- Preserve original names and domain terms exactly when possible. +- JSON keys must stay exactly as: triples, entities, subject, predicate, object. + +Text: +{chunk.chunk.text} + +Return ONLY valid JSON: +{{ + "triples": [ + {{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}} + ], + "entities": ["Entity1", "Entity2"] +}} +""" + result = await llm_func(prompt) + + # 结果保持原样存入 data,后续统一归一化流程会处理 + # vector_store 侧期望关系字段为 subject/predicate/object 映射形式 + chunk.data = result + return chunk diff --git a/plugins/A_memorix/core/strategies/narrative.py b/plugins/A_memorix/core/strategies/narrative.py new file mode 100644 index 00000000..731414f7 --- /dev/null +++ b/plugins/A_memorix/core/strategies/narrative.py @@ -0,0 +1,126 @@ +import re +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext + +class NarrativeStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + scenes = self._split_into_scenes(text) + chunks = [] + + for scene_idx, (scene_text, scene_title) in enumerate(scenes): + scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx) + chunks.extend(scene_chunks) + + return chunks + + def _split_into_scenes(self, text: str) -> List[tuple[str, str]]: + """按标题或分隔符把文本切分为场景。""" + # 简单启发式:按 markdown 标题或特定分隔符切分 + # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 + # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 + scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)' + + # 保留分隔符,以便识别场景起点 + parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE) + + scenes = [] + current_scene_title = "Start" + current_scene_content = [] + + if parts and parts[0].strip() == "": + parts = parts[1:] + + for part in parts: + if re.match(scene_pattern_str, part, re.MULTILINE): + # 先保存上一段场景 + if current_scene_content: + scenes.append(("".join(current_scene_content), current_scene_title)) + current_scene_content = [] + current_scene_title = part.strip() + else: + current_scene_content.append(part) + + if current_scene_content: + scenes.append(("".join(current_scene_content), current_scene_title)) + + # 若未识别到场景,则把全文视作单一场景 + if not scenes: + scenes = [(text, "Whole Text")] + + return scenes + + def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]: + chunks = [] + if len(text) <= window_size: + chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0)) + return chunks + + stride = window_size - overlap + start = 0 + local_idx = 0 + while start < len(text): + end = min(start + window_size, len(text)) + chunk_text = text[start:end] + + # 尽量对齐到最近换行,避免生硬截断句子 + # 仅在未到文本尾部时进行回退 + if end < len(text): + last_newline = chunk_text.rfind('\n') + if last_newline > window_size // 2: # 仅在回退距离可接受时启用 + end = start + last_newline + 1 + chunk_text = text[start:end] + + chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start)) + + start += len(chunk_text) - overlap if end < len(text) else len(chunk_text) + local_idx += 1 + + return chunks + + def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk: + return ProcessedChunk( + type=KnowledgeType.NARRATIVE, + source=SourceInfo( + file=self.filename, + offset_start=offset, + offset_end=offset + len(text), + checksum=self.calculate_checksum(text) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{scene_idx}_{local_idx}", + index=local_idx, + text=text, + context={"scene_id": scene_id} + ) + ) + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + if not llm_func: + raise ValueError("LLM function required for Narrative extraction") + + language_guard = self.build_language_guard(chunk.chunk.text) + prompt = f"""You are a narrative knowledge extraction engine. +Extract key events and character relations from the scene text. + +Language constraints: +- {language_guard} +- Preserve original names and terms exactly when possible. +- JSON keys must stay exactly as: events, relations, subject, predicate, object. + +Scene: +{chunk.chunk.context.get('scene_id')} + +Text: +{chunk.chunk.text} + +Return ONLY valid JSON: +{{ + "events": ["event description 1", "event description 2"], + "relations": [ + {{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}} + ] +}} +""" + result = await llm_func(prompt) + chunk.data = result + return chunk diff --git a/plugins/A_memorix/core/strategies/quote.py b/plugins/A_memorix/core/strategies/quote.py new file mode 100644 index 00000000..10733d64 --- /dev/null +++ b/plugins/A_memorix/core/strategies/quote.py @@ -0,0 +1,52 @@ +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags + +class QuoteStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + # Split by double newlines (stanzas) + stanzas = text.split("\n\n") + chunks = [] + offset = 0 + + for idx, stanza in enumerate(stanzas): + if not stanza.strip(): + offset += len(stanza) + 2 + continue + + chunk = ProcessedChunk( + type=KnowledgeType.QUOTE, + source=SourceInfo( + file=self.filename, + offset_start=offset, + offset_end=offset + len(stanza), + checksum=self.calculate_checksum(stanza) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{idx}", + index=idx, + text=stanza + ), + flags=ChunkFlags( + verbatim=True, + requires_llm=False # Default to no LLM, but can be overridden + ) + ) + chunks.append(chunk) + offset += len(stanza) + 2 # +2 for \n\n + + return chunks + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + # For quotes, the text itself is the entity/knowledge + # We might use LLM to extract headers/metadata if requested, but core logic is pass-through + + # Treat the whole chunk text as a verbatim entity + chunk.data = { + "verbatim_entities": [chunk.chunk.text] + } + + if llm_func and chunk.flags.requires_llm: + # Optional: Extract metadata + pass + + return chunk diff --git a/plugins/A_memorix/core/utils/__init__.py b/plugins/A_memorix/core/utils/__init__.py new file mode 100644 index 00000000..e0d763cf --- /dev/null +++ b/plugins/A_memorix/core/utils/__init__.py @@ -0,0 +1,33 @@ +"""工具模块 - 哈希、监控等辅助功能""" + +from .hash import compute_hash, normalize_text +from .monitor import MemoryMonitor +from .quantization import quantize_vector, dequantize_vector +from .time_parser import ( + parse_query_datetime_to_timestamp, + parse_query_time_range, + parse_ingest_datetime_to_timestamp, + normalize_time_meta, + format_timestamp, +) +from .relation_write_service import RelationWriteService, RelationWriteResult +from .relation_query import RelationQuerySpec, parse_relation_query_spec +from .plugin_id_policy import PluginIdPolicy + +__all__ = [ + "compute_hash", + "normalize_text", + "MemoryMonitor", + "quantize_vector", + "dequantize_vector", + "parse_query_datetime_to_timestamp", + "parse_query_time_range", + "parse_ingest_datetime_to_timestamp", + "normalize_time_meta", + "format_timestamp", + "RelationWriteService", + "RelationWriteResult", + "RelationQuerySpec", + "parse_relation_query_spec", + "PluginIdPolicy", +] diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py new file mode 100644 index 00000000..a87a4913 --- /dev/null +++ b/plugins/A_memorix/core/utils/aggregate_query_service.py @@ -0,0 +1,360 @@ +""" +聚合查询服务: +- 并发执行 search/time/episode 分支 +- 统一分支结果结构 +- 可选混合排序(Weighted RRF) +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.AggregateQueryService") + +BranchRunner = Callable[[], Awaitable[Dict[str, Any]]] + + +class AggregateQueryService: + """聚合查询执行服务(search/time/episode)。""" + + def __init__(self, plugin_config: Optional[Any] = None): + self.plugin_config = plugin_config or {} + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin_config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _as_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except Exception: + return float(default) + + @staticmethod + def _as_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + def _rrf_k(self) -> float: + raw = self._cfg("retrieval.aggregate.rrf_k", 60.0) + value = self._as_float(raw, 60.0) + return max(1.0, value) + + def _weights(self) -> Dict[str, float]: + defaults = {"search": 1.0, "time": 1.0, "episode": 1.0} + raw = self._cfg("retrieval.aggregate.weights", {}) + if not isinstance(raw, dict): + return defaults + + out = dict(defaults) + for key in ("search", "time", "episode"): + if key in raw: + out[key] = max(0.0, self._as_float(raw.get(key), defaults[key])) + return out + + @staticmethod + def _normalize_branch_payload( + name: str, + payload: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + data = payload if isinstance(payload, dict) else {} + results_raw = data.get("results", []) + results = results_raw if isinstance(results_raw, list) else [] + count = data.get("count") + if count is None: + count = len(results) + return { + "name": name, + "success": bool(data.get("success", False)), + "skipped": bool(data.get("skipped", False)), + "skip_reason": str(data.get("skip_reason", "") or "").strip(), + "error": str(data.get("error", "") or "").strip(), + "results": results, + "count": max(0, int(count)), + "elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)), + "content": str(data.get("content", "") or ""), + "query_type": str(data.get("query_type", "") or name), + } + + @staticmethod + def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str: + item_type = str(item.get("type", "") or "").strip().lower() + if item_type == "episode": + episode_id = str(item.get("episode_id", "") or "").strip() + if episode_id: + return f"episode:{episode_id}" + + item_hash = str(item.get("hash", "") or "").strip() + if item_hash: + return f"{item_type}:{item_hash}" + + return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}" + + def _build_mixed_results( + self, + *, + branches: Dict[str, Dict[str, Any]], + top_k: int, + ) -> List[Dict[str, Any]]: + rrf_k = self._rrf_k() + weights = self._weights() + bucket: Dict[str, Dict[str, Any]] = {} + + for branch_name, branch in branches.items(): + if not branch.get("success", False): + continue + results = branch.get("results", []) + if not isinstance(results, list): + continue + + weight = max(0.0, float(weights.get(branch_name, 1.0))) + for idx, item in enumerate(results, start=1): + if not isinstance(item, dict): + continue + key = self._mix_key(item, branch_name, idx) + score = weight / (rrf_k + float(idx)) + if key not in bucket: + merged = dict(item) + merged["fusion_score"] = 0.0 + merged["_source_branches"] = set() + bucket[key] = merged + + target = bucket[key] + target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score + target["_source_branches"].add(branch_name) + + mixed = list(bucket.values()) + mixed.sort( + key=lambda x: ( + -float(x.get("fusion_score", 0.0)), + str(x.get("type", "") or ""), + str(x.get("hash", "") or x.get("episode_id", "") or ""), + ) + ) + + out: List[Dict[str, Any]] = [] + for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1): + merged = dict(item) + branches_set = merged.pop("_source_branches", set()) + merged["source_branches"] = sorted(list(branches_set)) + merged["rank"] = rank + out.append(merged) + return out + + @staticmethod + def _status(branch: Dict[str, Any]) -> str: + if branch.get("skipped", False): + return "skipped" + if branch.get("success", False): + return "success" + return "failed" + + def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + summary: Dict[str, Dict[str, Any]] = {} + for name, branch in branches.items(): + status = self._status(branch) + summary[name] = { + "status": status, + "count": int(branch.get("count", 0) or 0), + } + if status == "skipped": + summary[name]["reason"] = str(branch.get("skip_reason", "") or "") + if status == "failed": + summary[name]["error"] = str(branch.get("error", "") or "") + return summary + + def _build_content( + self, + *, + query: str, + branches: Dict[str, Dict[str, Any]], + errors: List[Dict[str, str]], + mixed_results: Optional[List[Dict[str, Any]]], + ) -> str: + lines: List[str] = [ + f"🔀 聚合查询结果(query='{query or 'N/A'}')", + "", + "分支状态:", + ] + for name in ("search", "time", "episode"): + branch = branches.get(name, {}) + status = self._status(branch) + count = int(branch.get("count", 0) or 0) + line = f"- {name}: {status}, count={count}" + reason = str(branch.get("skip_reason", "") or "").strip() + err = str(branch.get("error", "") or "").strip() + if status == "skipped" and reason: + line += f" ({reason})" + if status == "failed" and err: + line += f" ({err})" + lines.append(line) + + if errors: + lines.append("") + lines.append("错误:") + for item in errors[:6]: + lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}") + + if mixed_results is not None: + lines.append("") + lines.append(f"🧩 混合结果({len(mixed_results)} 条):") + for idx, item in enumerate(mixed_results[:5], start=1): + src = ",".join(item.get("source_branches", []) or []) + if str(item.get("type", "") or "") == "episode": + title = str(item.get("title", "") or "Untitled") + lines.append(f"{idx}. 🧠 {title} [{src}]") + else: + text = str(item.get("content", "") or "") + if len(text) > 80: + text = text[:80] + "..." + lines.append(f"{idx}. {text} [{src}]") + + return "\n".join(lines) + + async def execute( + self, + *, + query: str, + top_k: int, + mix: bool, + mix_top_k: Optional[int], + time_from: Optional[str], + time_to: Optional[str], + search_runner: Optional[BranchRunner], + time_runner: Optional[BranchRunner], + episode_runner: Optional[BranchRunner], + ) -> Dict[str, Any]: + clean_query = str(query or "").strip() + safe_top_k = max(1, int(top_k)) + safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k)) + + branches: Dict[str, Dict[str, Any]] = {} + errors: List[Dict[str, str]] = [] + scheduled: List[Tuple[str, asyncio.Task]] = [] + + if clean_query: + if search_runner is not None: + scheduled.append(("search", asyncio.create_task(search_runner()))) + else: + branches["search"] = self._normalize_branch_payload( + "search", + {"success": False, "error": "search runner unavailable", "results": []}, + ) + else: + branches["search"] = self._normalize_branch_payload( + "search", + { + "success": False, + "skipped": True, + "skip_reason": "missing_query", + "results": [], + "count": 0, + }, + ) + + if time_from or time_to: + if time_runner is not None: + scheduled.append(("time", asyncio.create_task(time_runner()))) + else: + branches["time"] = self._normalize_branch_payload( + "time", + {"success": False, "error": "time runner unavailable", "results": []}, + ) + else: + branches["time"] = self._normalize_branch_payload( + "time", + { + "success": False, + "skipped": True, + "skip_reason": "missing_time_window", + "results": [], + "count": 0, + }, + ) + + if episode_runner is not None: + scheduled.append(("episode", asyncio.create_task(episode_runner()))) + else: + branches["episode"] = self._normalize_branch_payload( + "episode", + {"success": False, "error": "episode runner unavailable", "results": []}, + ) + + if scheduled: + done = await asyncio.gather( + *[task for _, task in scheduled], + return_exceptions=True, + ) + for (branch_name, _), payload in zip(scheduled, done): + if isinstance(payload, Exception): + logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload) + normalized = self._normalize_branch_payload( + branch_name, + { + "success": False, + "error": str(payload), + "results": [], + }, + ) + else: + normalized = self._normalize_branch_payload(branch_name, payload) + branches[branch_name] = normalized + + for name in ("search", "time", "episode"): + branch = branches.get(name) + if not branch: + continue + if branch.get("skipped", False): + continue + if not branch.get("success", False): + errors.append( + { + "branch": name, + "error": str(branch.get("error", "") or "unknown error"), + } + ) + + success = any( + bool(branches.get(name, {}).get("success", False)) + for name in ("search", "time", "episode") + ) + mixed_results: Optional[List[Dict[str, Any]]] = None + if mix: + mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k) + + payload: Dict[str, Any] = { + "success": success, + "query_type": "aggregate", + "query": clean_query, + "top_k": safe_top_k, + "mix": bool(mix), + "mix_top_k": safe_mix_top_k, + "branches": branches, + "errors": errors, + "summary": self._build_summary(branches), + } + if mixed_results is not None: + payload["mixed_results"] = mixed_results + + payload["content"] = self._build_content( + query=clean_query, + branches=branches, + errors=errors, + mixed_results=mixed_results, + ) + return payload diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py new file mode 100644 index 00000000..5a4cd24d --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_retrieval_service.py @@ -0,0 +1,182 @@ +"""Episode hybrid retrieval service.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger + +from ..retrieval import DualPathRetriever, TemporalQueryOptions + +logger = get_logger("A_Memorix.EpisodeRetrievalService") + + +class EpisodeRetrievalService: + """Hybrid episode retrieval backed by lexical rows and evidence projection.""" + + _RRF_K = 60.0 + _BRANCH_WEIGHTS = { + "lexical": 1.0, + "paragraph_evidence": 1.0, + "relation_evidence": 0.85, + } + + def __init__( + self, + *, + metadata_store: Any, + retriever: Optional[DualPathRetriever] = None, + ) -> None: + self.metadata_store = metadata_store + self.retriever = retriever + + async def query( + self, + *, + query: str = "", + top_k: int = 5, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + include_paragraphs: bool = False, + ) -> List[Dict[str, Any]]: + clean_query = str(query or "").strip() + safe_top_k = max(1, int(top_k)) + candidate_k = max(30, safe_top_k * 6) + + branches: Dict[str, List[Dict[str, Any]]] = { + "lexical": self.metadata_store.query_episodes( + query=clean_query, + time_from=time_from, + time_to=time_to, + person=person, + source=source, + limit=(candidate_k if clean_query else safe_top_k), + ) + } + + if clean_query and self.retriever is not None: + try: + temporal = TemporalQueryOptions( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + results = await self.retriever.retrieve( + query=clean_query, + top_k=candidate_k, + temporal=temporal, + ) + except Exception as exc: + logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc) + else: + paragraph_rank_map: Dict[str, int] = {} + relation_rank_map: Dict[str, int] = {} + for rank, item in enumerate(results, start=1): + hash_value = str(getattr(item, "hash_value", "") or "").strip() + result_type = str(getattr(item, "result_type", "") or "").strip().lower() + if not hash_value: + continue + if result_type == "paragraph" and hash_value not in paragraph_rank_map: + paragraph_rank_map[hash_value] = rank + elif result_type == "relation" and hash_value not in relation_rank_map: + relation_rank_map[hash_value] = rank + + if paragraph_rank_map: + paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes( + list(paragraph_rank_map.keys()), + source=source, + ) + if paragraph_rows: + branches["paragraph_evidence"] = self._rank_projected_rows( + paragraph_rows, + rank_map=paragraph_rank_map, + support_key="matched_paragraph_hashes", + ) + + if relation_rank_map: + relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes( + list(relation_rank_map.keys()), + source=source, + ) + if relation_rows: + branches["relation_evidence"] = self._rank_projected_rows( + relation_rows, + rank_map=relation_rank_map, + support_key="matched_relation_hashes", + ) + + fused = self._fuse_branches(branches, top_k=safe_top_k) + if include_paragraphs: + for item in fused: + item["paragraphs"] = self.metadata_store.get_episode_paragraphs( + episode_id=str(item.get("episode_id") or ""), + limit=50, + ) + return fused + + @staticmethod + def _rank_projected_rows( + rows: List[Dict[str, Any]], + *, + rank_map: Dict[str, int], + support_key: str, + ) -> List[Dict[str, Any]]: + sentinel = 10**9 + ranked = [dict(item) for item in rows] + + def _first_support_rank(item: Dict[str, Any]) -> int: + support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])] + ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map] + return min(ranks) if ranks else sentinel + + ranked.sort( + key=lambda item: ( + _first_support_rank(item), + -int(item.get("matched_paragraph_count") or 0), + -float(item.get("updated_at") or 0.0), + str(item.get("episode_id") or ""), + ) + ) + return ranked + + def _fuse_branches( + self, + branches: Dict[str, List[Dict[str, Any]]], + *, + top_k: int, + ) -> List[Dict[str, Any]]: + bucket: Dict[str, Dict[str, Any]] = {} + for branch_name, rows in branches.items(): + weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0) + if weight <= 0.0: + continue + for rank, row in enumerate(rows, start=1): + episode_id = str(row.get("episode_id", "") or "").strip() + if not episode_id: + continue + if episode_id not in bucket: + payload = dict(row) + payload.pop("matched_paragraph_hashes", None) + payload.pop("matched_relation_hashes", None) + payload.pop("matched_paragraph_count", None) + payload.pop("matched_relation_count", None) + payload["_fusion_score"] = 0.0 + bucket[episode_id] = payload + bucket[episode_id]["_fusion_score"] = float( + bucket[episode_id].get("_fusion_score", 0.0) + ) + weight / (self._RRF_K + float(rank)) + + out = list(bucket.values()) + out.sort( + key=lambda item: ( + -float(item.get("_fusion_score", 0.0)), + -float(item.get("updated_at") or 0.0), + str(item.get("episode_id") or ""), + ) + ) + for item in out: + item.pop("_fusion_score", None) + return out[: max(1, int(top_k))] diff --git a/plugins/A_memorix/core/utils/hash.py b/plugins/A_memorix/core/utils/hash.py new file mode 100644 index 00000000..b6363257 --- /dev/null +++ b/plugins/A_memorix/core/utils/hash.py @@ -0,0 +1,129 @@ +""" +哈希工具模块 + +提供文本哈希计算功能,用于唯一标识和去重。 +""" + +import hashlib +import re +from typing import Union + + +def compute_hash(text: str, hash_type: str = "sha256") -> str: + """ + 计算文本的哈希值 + + Args: + text: 输入文本 + hash_type: 哈希算法类型(sha256, md5等) + + Returns: + 哈希值字符串 + """ + if hash_type == "sha256": + return hashlib.sha256(text.encode("utf-8")).hexdigest() + elif hash_type == "md5": + return hashlib.md5(text.encode("utf-8")).hexdigest() + else: + raise ValueError(f"不支持的哈希算法: {hash_type}") + + +def normalize_text(text: str) -> str: + """ + 规范化文本用于哈希计算 + + 执行以下操作: + - 去除首尾空白 + - 统一换行符为\\n + - 压缩多个连续空格 + - 去除不可见字符(保留换行和制表符) + + Args: + text: 输入文本 + + Returns: + 规范化后的文本 + """ + # 去除首尾空白 + text = text.strip() + + # 统一换行符 + text = text.replace("\r\n", "\n").replace("\r", "\n") + + # 压缩多个连续空格为一个(但保留换行和制表符) + text = re.sub(r"[^\S\n]+", " ", text) + + return text + + +def compute_paragraph_hash(paragraph: str) -> str: + """ + 计算段落的哈希值 + + Args: + paragraph: 段落文本 + + Returns: + 段落哈希值(用于paragraph-前缀) + """ + normalized = normalize_text(paragraph) + return compute_hash(normalized) + + +def compute_entity_hash(entity: str) -> str: + """ + 计算实体的哈希值 + + Args: + entity: 实体名称 + + Returns: + 实体哈希值(用于entity-前缀) + """ + normalized = entity.strip().lower() + return compute_hash(normalized) + + +def compute_relation_hash(relation: tuple) -> str: + """ + 计算关系的哈希值 + + Args: + relation: 关系元组 (subject, predicate, object) + + Returns: + 关系哈希值(用于relation-前缀) + """ + # 将关系元组转为字符串 + relation_str = str(tuple(relation)) + return compute_hash(relation_str) + + +def format_hash_key(hash_type: str, hash_value: str) -> str: + """ + 格式化哈希键 + + Args: + hash_type: 类型前缀(paragraph, entity, relation) + hash_value: 哈希值 + + Returns: + 格式化的键(如 paragraph-abc123...) + """ + return f"{hash_type}-{hash_value}" + + +def parse_hash_key(key: str) -> tuple[str, str]: + """ + 解析哈希键 + + Args: + key: 格式化的键(如 paragraph-abc123...) + + Returns: + (类型, 哈希值) 元组 + """ + parts = key.split("-", 1) + if len(parts) != 2: + raise ValueError(f"无效的哈希键格式: {key}") + return parts[0], parts[1] diff --git a/plugins/A_memorix/core/utils/import_payloads.py b/plugins/A_memorix/core/utils/import_payloads.py new file mode 100644 index 00000000..6986a4c1 --- /dev/null +++ b/plugins/A_memorix/core/utils/import_payloads.py @@ -0,0 +1,110 @@ +"""Shared import payload normalization helpers.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from ..storage import KnowledgeType, resolve_stored_knowledge_type +from .time_parser import normalize_time_meta + + +def _normalize_entities(raw_entities: Any) -> List[str]: + if not isinstance(raw_entities, list): + return [] + out: List[str] = [] + seen = set() + for item in raw_entities: + name = str(item or "").strip() + if not name: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + out.append(name) + return out + + +def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]: + if not isinstance(raw_relations, list): + return [] + out: List[Dict[str, str]] = [] + for item in raw_relations: + if not isinstance(item, dict): + continue + subject = str(item.get("subject", "")).strip() + predicate = str(item.get("predicate", "")).strip() + obj = str(item.get("object", "")).strip() + if not (subject and predicate and obj): + continue + out.append( + { + "subject": subject, + "predicate": predicate, + "object": obj, + } + ) + return out + + +def normalize_paragraph_import_item( + item: Any, + *, + default_source: str, +) -> Dict[str, Any]: + """Normalize one paragraph import item from text/json payloads.""" + + if isinstance(item, str): + content = str(item) + knowledge_type = resolve_stored_knowledge_type(None, content=content) + return { + "content": content, + "knowledge_type": knowledge_type.value, + "source": str(default_source or "").strip(), + "time_meta": None, + "entities": [], + "relations": [], + } + + if not isinstance(item, dict) or "content" not in item: + raise ValueError("段落项必须为字符串或包含 content 的对象") + + content = str(item.get("content", "") or "") + if not content.strip(): + raise ValueError("段落 content 不能为空") + + raw_time_meta = { + "event_time": item.get("event_time"), + "event_time_start": item.get("event_time_start"), + "event_time_end": item.get("event_time_end"), + "time_range": item.get("time_range"), + "time_granularity": item.get("time_granularity"), + "time_confidence": item.get("time_confidence"), + } + time_meta_field = item.get("time_meta") + if isinstance(time_meta_field, dict): + raw_time_meta.update(time_meta_field) + + knowledge_type_raw = item.get("knowledge_type") + if knowledge_type_raw is None: + knowledge_type_raw = item.get("type") + knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content) + source = str(item.get("source") or default_source or "").strip() + if not source: + source = str(default_source or "").strip() + + normalized_time_meta = normalize_time_meta(raw_time_meta) + return { + "content": content, + "knowledge_type": knowledge_type.value, + "source": source, + "time_meta": normalized_time_meta if normalized_time_meta else None, + "entities": _normalize_entities(item.get("entities")), + "relations": _normalize_relations(item.get("relations")), + } + + +def normalize_summary_knowledge_type(value: Any) -> KnowledgeType: + """Normalize config-driven summary knowledge type.""" + + return resolve_stored_knowledge_type(value, content="") diff --git a/plugins/A_memorix/core/utils/io.py b/plugins/A_memorix/core/utils/io.py new file mode 100644 index 00000000..ed14df43 --- /dev/null +++ b/plugins/A_memorix/core/utils/io.py @@ -0,0 +1,84 @@ +""" +IO Utilities + +提供原子文件写入等IO辅助功能。 +""" + +import os +import shutil +import contextlib +from pathlib import Path +from typing import Union + +@contextlib.contextmanager +def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs): + """ + 原子文件写入上下文管理器 + + 原理: + 1. 写入 .tmp 临时文件 + 2. 写入成功后,使用 os.replace 原子替换目标文件 + 3. 如果失败,自动删除临时文件 + + Args: + file_path: 目标文件路径 + mode: 打开模式 ('w', 'wb' 等) + encoding: 编码 + **kwargs: 传给 open() 的其他参数 + """ + path = Path(file_path) + # 确保父目录存在 + path.parent.mkdir(parents=True, exist_ok=True) + + # 临时文件路径 + tmp_path = path.with_suffix(path.suffix + ".tmp") + + try: + with open(tmp_path, mode, encoding=encoding, **kwargs) as f: + yield f + + # 确保写入磁盘 + f.flush() + os.fsync(f.fileno()) + + # 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子) + # 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败 + os.replace(tmp_path, path) + + except Exception as e: + # 清理临时文件 + if tmp_path.exists(): + try: + os.remove(tmp_path) + except: + pass + raise e + +@contextlib.contextmanager +def atomic_save_path(file_path: Union[str, Path]): + """ + 提供临时路径用于原子保存 (针对只接受路径的API,如Faiss) + + Args: + file_path: 最终目标路径 + + Yields: + tmp_path: 临时文件路径 (str) + """ + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + + try: + yield str(tmp_path) + + if Path(tmp_path).exists(): + os.replace(tmp_path, path) + + except Exception as e: + if Path(tmp_path).exists(): + try: + os.remove(tmp_path) + except: + pass + raise e diff --git a/plugins/A_memorix/core/utils/matcher.py b/plugins/A_memorix/core/utils/matcher.py new file mode 100644 index 00000000..bddff5ee --- /dev/null +++ b/plugins/A_memorix/core/utils/matcher.py @@ -0,0 +1,89 @@ +""" +高效文本匹配工具模块 + +实现 Aho-Corasick 算法用于多模式匹配。 +""" + +from typing import List, Dict, Tuple, Set, Any +from collections import deque + + +class AhoCorasick: + """ + Aho-Corasick 自动机实现高效多模式匹配 + """ + + def __init__(self): + # next_states[state][char] = next_state + self.next_states: List[Dict[str, int]] = [{}] + # fail[state] = fail_state + self.fail: List[int] = [0] + # output[state] = set of patterns ending at this state + self.output: List[Set[str]] = [set()] + self.patterns: Set[str] = set() + + def add_pattern(self, pattern: str): + """添加模式""" + if not pattern: + return + self.patterns.add(pattern) + state = 0 + for char in pattern: + if char not in self.next_states[state]: + new_state = len(self.next_states) + self.next_states[state][char] = new_state + self.next_states.append({}) + self.fail.append(0) + self.output.append(set()) + state = self.next_states[state][char] + self.output[state].add(pattern) + + def build(self): + """构建失败指针""" + queue = deque() + # 处理第一层 + for char, state in self.next_states[0].items(): + queue.append(state) + self.fail[state] = 0 + + while queue: + r = queue.popleft() + for char, s in self.next_states[r].items(): + queue.append(s) + # 找到失败路径 + state = self.fail[r] + while char not in self.next_states[state] and state != 0: + state = self.fail[state] + self.fail[s] = self.next_states[state].get(char, 0) + # 合并输出 + self.output[s].update(self.output[self.fail[s]]) + + def search(self, text: str) -> List[Tuple[int, str]]: + """ + 在文本中搜索所有模式 + + Returns: + [(结束索引, 匹配到的模式), ...] + """ + state = 0 + results = [] + for i, char in enumerate(text): + while char not in self.next_states[state] and state != 0: + state = self.fail[state] + state = self.next_states[state].get(char, 0) + for pattern in self.output[state]: + results.append((i, pattern)) + return results + + def find_all(self, text: str) -> Dict[str, int]: + """ + 查找并统计所有模式出现次数 + + Returns: + {模式: 出现次数} + """ + results = self.search(text) + stats = {} + for _, pattern in results: + stats[pattern] = stats.get(pattern, 0) + 1 + return stats diff --git a/plugins/A_memorix/core/utils/monitor.py b/plugins/A_memorix/core/utils/monitor.py new file mode 100644 index 00000000..39c794ab --- /dev/null +++ b/plugins/A_memorix/core/utils/monitor.py @@ -0,0 +1,189 @@ +""" +内存监控模块 + +提供内存使用监控和预警功能。 +""" + +import gc +import threading +import time +from typing import Callable, Optional + +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.MemoryMonitor") + + +class MemoryMonitor: + """ + 内存监控器 + + 功能: + - 实时监控内存使用 + - 超过阈值时触发警告 + - 支持自动垃圾回收 + """ + + def __init__( + self, + max_memory_mb: int, + warning_threshold: float = 0.9, + check_interval: float = 10.0, + enable_auto_gc: bool = True, + ): + """ + 初始化内存监控器 + + Args: + max_memory_mb: 最大内存限制(MB) + warning_threshold: 警告阈值(0-1之间,默认0.9表示90%) + check_interval: 检查间隔(秒) + enable_auto_gc: 是否启用自动垃圾回收 + """ + self.max_memory_mb = max_memory_mb + self.warning_threshold = warning_threshold + self.check_interval = check_interval + self.enable_auto_gc = enable_auto_gc + + self._running = False + self._thread: Optional[threading.Thread] = None + self._callbacks: list[Callable[[float, float], None]] = [] + + def start(self): + """启动监控""" + if self._running: + logger.warning("内存监控已在运行") + return + + if not HAS_PSUTIL: + logger.warning("psutil 未安装,内存监控功能不可用") + return + + self._running = True + self._thread = threading.Thread(target=self._monitor_loop, daemon=True) + self._thread.start() + logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)") + + def stop(self): + """停止监控""" + if not self._running: + return + + self._running = False + if self._thread: + self._thread.join(timeout=5.0) + logger.info("内存监控已停止") + + def register_callback(self, callback: Callable[[float, float], None]): + """ + 注册内存超限回调函数 + + Args: + callback: 回调函数,接收 (当前使用MB, 限制MB) 参数 + """ + self._callbacks.append(callback) + + def get_current_memory_mb(self) -> float: + """ + 获取当前进程内存使用量(MB) + + Returns: + 内存使用量(MB) + """ + if not HAS_PSUTIL: + # 降级方案:使用内置函数 + import sys + return sys.getsizeof(gc.get_objects()) / 1024 / 1024 + + process = psutil.Process() + return process.memory_info().rss / 1024 / 1024 + + def get_memory_usage_ratio(self) -> float: + """ + 获取内存使用率 + + Returns: + 使用率(0-1之间) + """ + current = self.get_current_memory_mb() + return current / self.max_memory_mb if self.max_memory_mb > 0 else 0 + + def _monitor_loop(self): + """监控循环""" + while self._running: + try: + current_mb = self.get_current_memory_mb() + ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0 + + # 检查是否超过阈值 + if ratio >= self.warning_threshold: + logger.warning( + f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB " + f"({ratio*100:.1f}%)" + ) + + # 触发回调 + for callback in self._callbacks: + try: + callback(current_mb, self.max_memory_mb) + except Exception as e: + logger.error(f"内存回调执行失败: {e}") + + # 自动垃圾回收 + if self.enable_auto_gc: + before = self.get_current_memory_mb() + gc.collect() + after = self.get_current_memory_mb() + freed = before - after + if freed > 1: # 释放超过1MB才记录 + logger.info(f"垃圾回收释放: {freed:.1f}MB") + + # 定期垃圾回收(即使未超限) + elif self.enable_auto_gc and int(time.time()) % 60 == 0: + gc.collect() + + except Exception as e: + logger.error(f"内存监控出错: {e}") + + # 等待下次检查 + time.sleep(self.check_interval) + + def __enter__(self): + """上下文管理器入口""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + self.stop() + + +def get_memory_info() -> dict: + """ + 获取系统内存信息 + + Returns: + 内存信息字典 + """ + if not HAS_PSUTIL: + return {"error": "psutil 未安装"} + + try: + mem = psutil.virtual_memory() + process = psutil.Process() + + return { + "system_total_gb": mem.total / 1024 / 1024 / 1024, + "system_available_gb": mem.available / 1024 / 1024 / 1024, + "system_usage_percent": mem.percent, + "process_mb": process.memory_info().rss / 1024 / 1024, + "process_percent": (process.memory_info().rss / mem.total) * 100, + } + except Exception as e: + return {"error": str(e)} diff --git a/plugins/A_memorix/core/utils/path_fallback_service.py b/plugins/A_memorix/core/utils/path_fallback_service.py new file mode 100644 index 00000000..7a802743 --- /dev/null +++ b/plugins/A_memorix/core/utils/path_fallback_service.py @@ -0,0 +1,165 @@ +"""Shared path-fallback helpers for search post-processing.""" + +from __future__ import annotations + +import hashlib +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from ..retrieval.dual_path import RetrievalResult + + +def extract_entities(query: str, graph_store: Any) -> List[str]: + """Extract up to two graph nodes from a query using n-gram matching.""" + if not graph_store: + return [] + + text = str(query or "").strip() + if not text: + return [] + + # Keep the heuristic aligned with previous legacy behavior. + tokens = ( + text.replace("?", " ") + .replace("!", " ") + .replace(".", " ") + .split() + ) + if not tokens: + return [] + + found_entities = set() + skip_indices = set() + max_n = min(4, len(tokens)) + + for size in range(max_n, 0, -1): + for i in range(len(tokens) - size + 1): + if any(idx in skip_indices for idx in range(i, i + size)): + continue + span = " ".join(tokens[i : i + size]) + matched_node = graph_store.find_node(span, ignore_case=True) + if not matched_node: + continue + found_entities.add(matched_node) + for idx in range(i, i + size): + skip_indices.add(idx) + + return list(found_entities) + + +def find_paths_between_entities( + start_node: str, + end_node: str, + graph_store: Any, + metadata_store: Any, + *, + max_depth: int = 3, + max_paths: int = 5, +) -> List[Dict[str, Any]]: + """Find and enrich indirect paths between two nodes.""" + if not graph_store or not metadata_store: + return [] + + try: + paths = graph_store.find_paths( + start_node, + end_node, + max_depth=max_depth, + max_paths=max_paths, + ) + except Exception: + return [] + + if not paths: + return [] + + edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {} + formatted_paths: List[Dict[str, Any]] = [] + + for path_nodes in paths: + if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2: + continue + + path_desc: List[str] = [] + for i in range(len(path_nodes) - 1): + u = str(path_nodes[i]) + v = str(path_nodes[i + 1]) + + cache_key = tuple(sorted((u, v))) + if cache_key in edge_cache: + pred, direction = edge_cache[cache_key] + else: + pred = "related" + direction = "->" + rels = metadata_store.get_relations(subject=u, object=v) + if not rels: + rels = metadata_store.get_relations(subject=v, object=u) + direction = "<-" + if rels: + best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) + pred = str(best_rel.get("predicate", "related") or "related") + edge_cache[cache_key] = (pred, direction) + + step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-" + path_desc.append(step_str) + + full_path_str = str(path_nodes[0]) + for i, step in enumerate(path_desc): + full_path_str += f" {step} {path_nodes[i + 1]}" + + formatted_paths.append( + { + "nodes": list(path_nodes), + "description": full_path_str, + } + ) + + return formatted_paths + + +def find_paths_from_query( + query: str, + graph_store: Any, + metadata_store: Any, + *, + max_depth: int = 3, + max_paths: int = 5, +) -> List[Dict[str, Any]]: + """Extract entities from query and resolve indirect paths.""" + entities = extract_entities(query, graph_store) + if len(entities) != 2: + return [] + return find_paths_between_entities( + entities[0], + entities[1], + graph_store, + metadata_store, + max_depth=max_depth, + max_paths=max_paths, + ) + + +def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]: + """Convert path results into retrieval results for the unified pipeline.""" + converted: List[RetrievalResult] = [] + for item in paths: + description = str(item.get("description", "")).strip() + if not description: + continue + hash_seed = description.encode("utf-8") + path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}" + converted.append( + RetrievalResult( + hash_value=path_hash, + content=f"[Indirect Relation] {description}", + score=0.95, + result_type="relation", + source="graph_path", + metadata={ + "source": "graph_path", + "is_indirect": True, + "nodes": list(item.get("nodes", [])), + }, + ) + ) + return converted + diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py new file mode 100644 index 00000000..ccbbaf90 --- /dev/null +++ b/plugins/A_memorix/core/utils/person_profile_service.py @@ -0,0 +1,495 @@ +""" +人物画像服务 + +主链路: +person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储 +""" + +import json +import time +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.common.database.database_model import PersonInfo + +from ..embedding import EmbeddingAPIAdapter +from ..retrieval import ( + DualPathRetriever, + RetrievalStrategy, + DualPathRetrieverConfig, + SparseBM25Config, + FusionConfig, + GraphRelationRecallConfig, +) +from ..storage import MetadataStore, GraphStore, VectorStore + +logger = get_logger("A_Memorix.PersonProfileService") + + +class PersonProfileService: + """人物画像聚合/刷新服务。""" + + def __init__( + self, + metadata_store: MetadataStore, + graph_store: Optional[GraphStore] = None, + vector_store: Optional[VectorStore] = None, + embedding_manager: Optional[EmbeddingAPIAdapter] = None, + sparse_index: Any = None, + plugin_config: Optional[dict] = None, + retriever: Optional[DualPathRetriever] = None, + ): + self.metadata_store = metadata_store + self.graph_store = graph_store + self.vector_store = vector_store + self.embedding_manager = embedding_manager + self.sparse_index = sparse_index + self.plugin_config = plugin_config or {} + self.retriever = retriever or self._build_retriever() + + def _cfg(self, key: str, default: Any = None) -> Any: + """读取嵌套配置。""" + if not isinstance(self.plugin_config, dict): + return default + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + def _build_retriever(self) -> Optional[DualPathRetriever]: + """按需构建检索器(无依赖时返回 None)。""" + if not all( + [ + self.vector_store is not None, + self.graph_store is not None, + self.metadata_store is not None, + self.embedding_manager is not None, + ] + ): + return None + try: + sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} + fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {} + graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {} + if not isinstance(sparse_cfg_raw, dict): + sparse_cfg_raw = {} + if not isinstance(fusion_cfg_raw, dict): + fusion_cfg_raw = {} + if not isinstance(graph_recall_cfg_raw, dict): + graph_recall_cfg_raw = {} + + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + fusion_cfg = FusionConfig(**fusion_cfg_raw) + graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) + config = DualPathRetrieverConfig( + top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)), + top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)), + top_k_final=int(self._cfg("retrieval.top_k_final", 10)), + alpha=float(self._cfg("retrieval.alpha", 0.5)), + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), + ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), + enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), + retrieval_strategy=RetrievalStrategy.DUAL_PATH, + debug=bool(self._cfg("advanced.debug", False)), + sparse=sparse_cfg, + fusion=fusion_cfg, + graph_recall=graph_recall_cfg, + ) + return DualPathRetriever( + vector_store=self.vector_store, + graph_store=self.graph_store, + metadata_store=self.metadata_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + config=config, + ) + except Exception as e: + logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}") + return None + + @staticmethod + def resolve_person_id(identifier: str) -> str: + """按 person_id 或姓名/别名解析 person_id。""" + if not identifier: + return "" + key = str(identifier).strip() + if not key: + return "" + + if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): + return key.lower() + + try: + record = ( + PersonInfo.select(PersonInfo.person_id) + .where((PersonInfo.person_name == key) | (PersonInfo.nickname == key)) + .first() + ) + if record and record.person_id: + return str(record.person_id) + except Exception: + pass + + try: + record = ( + PersonInfo.select(PersonInfo.person_id) + .where(PersonInfo.group_nick_name.contains(key)) + .first() + ) + if record and record.person_id: + return str(record.person_id) + except Exception: + pass + + return "" + + def _parse_group_nicks(self, raw_value: Any) -> List[str]: + if not raw_value: + return [] + if isinstance(raw_value, list): + items = raw_value + else: + try: + items = json.loads(raw_value) + except Exception: + return [] + names: List[str] = [] + for item in items: + if isinstance(item, dict): + value = str(item.get("group_nick_name", "")).strip() + if value: + names.append(value) + elif isinstance(item, str): + value = item.strip() + if value: + names.append(value) + return names + + def _parse_memory_traits(self, raw_value: Any) -> List[str]: + if not raw_value: + return [] + try: + values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value + except Exception: + return [] + if not isinstance(values, list): + return [] + traits: List[str] = [] + for item in values: + text = str(item).strip() + if not text: + continue + if ":" in text: + parts = text.split(":") + if len(parts) >= 3: + content = ":".join(parts[1:-1]).strip() + if content: + traits.append(content) + continue + traits.append(text) + return traits[:10] + + def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: + """获取人物别名集合、主展示名、记忆特征。""" + aliases: List[str] = [] + primary_name = "" + memory_traits: List[str] = [] + if not person_id: + return aliases, primary_name, memory_traits + try: + record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + if not record: + return aliases, primary_name, memory_traits + person_name = str(getattr(record, "person_name", "") or "").strip() + nickname = str(getattr(record, "nickname", "") or "").strip() + group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None)) + memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) + + primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id + + candidates = [person_name, nickname] + group_nicks + seen = set() + for item in candidates: + norm = str(item or "").strip() + if not norm or norm in seen: + continue + seen.add(norm) + aliases.append(norm) + except Exception as e: + logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}") + return aliases, primary_name, memory_traits + + def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: + relation_by_hash: Dict[str, Dict[str, Any]] = {} + for alias in aliases: + for rel in self.metadata_store.get_relations(subject=alias): + h = str(rel.get("hash", "")) + if h: + relation_by_hash[h] = rel + for rel in self.metadata_store.get_relations(object=alias): + h = str(rel.get("hash", "")) + if h: + relation_by_hash[h] = rel + + relations = list(relation_by_hash.values()) + relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True) + relations = relations[: max(1, int(limit))] + + edges: List[Dict[str, Any]] = [] + for rel in relations: + edges.append( + { + "hash": str(rel.get("hash", "")), + "subject": str(rel.get("subject", "")), + "predicate": str(rel.get("predicate", "")), + "object": str(rel.get("object", "")), + "confidence": float(rel.get("confidence", 1.0) or 1.0), + } + ) + return edges + + async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]: + alias_queries = [a for a in aliases if a] + if not alias_queries: + return [] + + if self.retriever is None: + # 回退:无检索器时只做简单内容匹配 + fallback: List[Dict[str, Any]] = [] + seen_hash = set() + for alias in alias_queries: + for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]: + h = str(para.get("hash", "")) + if not h or h in seen_hash: + continue + seen_hash.add(h) + fallback.append( + { + "hash": h, + "type": "paragraph", + "score": 0.0, + "content": str(para.get("content", ""))[:180], + "metadata": {}, + } + ) + return fallback[:top_k] + + per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) + seen_hash = set() + evidence: List[Dict[str, Any]] = [] + for alias in alias_queries: + try: + results = await self.retriever.retrieve(alias, top_k=per_alias_top_k) + except Exception as e: + logger.warning(f"向量证据召回失败: alias={alias}, err={e}") + continue + for item in results: + h = str(getattr(item, "hash_value", "") or "") + if not h or h in seen_hash: + continue + seen_hash.add(h) + evidence.append( + { + "hash": h, + "type": str(getattr(item, "result_type", "")), + "score": float(getattr(item, "score", 0.0) or 0.0), + "content": str(getattr(item, "content", "") or "")[:220], + "metadata": dict(getattr(item, "metadata", {}) or {}), + } + ) + evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) + return evidence[:top_k] + + def _build_profile_text( + self, + person_id: str, + primary_name: str, + aliases: List[str], + relation_edges: List[Dict[str, Any]], + vector_evidence: List[Dict[str, Any]], + memory_traits: List[str], + ) -> str: + """基于证据构建画像文本(供 LLM 上下文注入)。""" + lines: List[str] = [] + lines.append(f"人物ID: {person_id}") + if primary_name: + lines.append(f"主称呼: {primary_name}") + if aliases: + lines.append(f"别名: {', '.join(aliases[:8])}") + if memory_traits: + lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}") + + if relation_edges: + lines.append("关系证据:") + for rel in relation_edges[:6]: + s = rel.get("subject", "") + p = rel.get("predicate", "") + o = rel.get("object", "") + conf = float(rel.get("confidence", 0.0)) + lines.append(f"- {s} {p} {o} (conf={conf:.2f})") + + if vector_evidence: + lines.append("向量证据摘要:") + for item in vector_evidence[:4]: + content = str(item.get("content", "")).strip() + if content: + lines.append(f"- {content}") + + if len(lines) <= 2: + lines.append("暂无足够证据形成稳定画像。") + + return "\n".join(lines) + + @staticmethod + def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool: + if not snapshot: + return True + now = time.time() + expires_at = snapshot.get("expires_at") + if expires_at is not None: + try: + return now >= float(expires_at) + except Exception: + return True + updated_at = float(snapshot.get("updated_at") or 0.0) + return (now - updated_at) >= ttl_seconds + + def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]: + """将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。""" + payload = dict(profile_payload or {}) + auto_text = str(payload.get("profile_text", "") or "") + payload["auto_profile_text"] = auto_text + payload["has_manual_override"] = False + payload["manual_override_text"] = "" + payload["override_updated_at"] = None + payload["override_updated_by"] = "" + payload["profile_source"] = "auto_snapshot" + + if not person_id or self.metadata_store is None: + return payload + + try: + override = self.metadata_store.get_person_profile_override(person_id) + except Exception as e: + logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}") + return payload + + if not override: + return payload + + manual_text = str(override.get("override_text", "") or "").strip() + if not manual_text: + return payload + + payload["has_manual_override"] = True + payload["manual_override_text"] = manual_text + payload["override_updated_at"] = override.get("updated_at") + payload["override_updated_by"] = str(override.get("updated_by", "") or "") + payload["profile_text"] = manual_text + payload["profile_source"] = "manual_override" + return payload + + async def query_person_profile( + self, + person_id: str = "", + person_keyword: str = "", + top_k: int = 12, + ttl_seconds: float = 6 * 3600, + force_refresh: bool = False, + source_note: str = "", + ) -> Dict[str, Any]: + """查询或刷新人物画像。""" + pid = str(person_id or "").strip() + if not pid and person_keyword: + pid = self.resolve_person_id(person_keyword) + + if not pid: + return { + "success": False, + "error": "person_id 无效,且未能通过别名解析", + } + + latest = self.metadata_store.get_latest_person_profile_snapshot(pid) + if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds): + aliases, primary_name, _ = self.get_person_aliases(pid) + payload = { + "success": True, + "person_id": pid, + "person_name": primary_name, + "from_cache": True, + **(latest or {}), + } + if aliases and not payload.get("aliases"): + payload["aliases"] = aliases + return { + **self._apply_manual_override(pid, payload), + } + + aliases, primary_name, memory_traits = self.get_person_aliases(pid) + if not aliases and person_keyword: + aliases = [person_keyword.strip()] + primary_name = person_keyword.strip() + relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2)) + vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k)) + + evidence_ids = [ + str(item.get("hash", "")) + for item in (relation_edges + vector_evidence) + if str(item.get("hash", "")).strip() + ] + dedup_ids: List[str] = [] + seen = set() + for item in evidence_ids: + if item in seen: + continue + seen.add(item) + dedup_ids.append(item) + + profile_text = self._build_profile_text( + person_id=pid, + primary_name=primary_name, + aliases=aliases, + relation_edges=relation_edges, + vector_evidence=vector_evidence, + memory_traits=memory_traits, + ) + + expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None + snapshot = self.metadata_store.upsert_person_profile_snapshot( + person_id=pid, + profile_text=profile_text, + aliases=aliases, + relation_edges=relation_edges, + vector_evidence=vector_evidence, + evidence_ids=dedup_ids, + expires_at=expires_at, + source_note=source_note, + ) + payload = { + "success": True, + "person_id": pid, + "person_name": primary_name, + "from_cache": False, + **snapshot, + } + return { + **self._apply_manual_override(pid, payload), + } + + @staticmethod + def format_persona_profile_block(profile: Dict[str, Any]) -> str: + """格式化给 replyer 的注入块。""" + if not profile or not profile.get("success"): + return "" + text = str(profile.get("profile_text", "") or "").strip() + if not text: + return "" + return ( + "【人物画像-内部参考】\n" + f"{text}\n" + "仅供内部推理,不要向用户逐字复述。" + ) diff --git a/plugins/A_memorix/core/utils/plugin_id_policy.py b/plugins/A_memorix/core/utils/plugin_id_policy.py new file mode 100644 index 00000000..8e730e12 --- /dev/null +++ b/plugins/A_memorix/core/utils/plugin_id_policy.py @@ -0,0 +1,27 @@ +"""Plugin ID matching policy for A_Memorix.""" + +from __future__ import annotations + +from typing import Any + + +class PluginIdPolicy: + """Centralized plugin id normalization/matching policy.""" + + CANONICAL_ID = "a_memorix" + + @classmethod + def normalize(cls, plugin_id: Any) -> str: + if not isinstance(plugin_id, str): + return "" + return plugin_id.strip().lower() + + @classmethod + def is_target_plugin_id(cls, plugin_id: Any) -> bool: + normalized = cls.normalize(plugin_id) + if not normalized: + return False + if normalized == cls.CANONICAL_ID: + return True + return normalized.split(".")[-1] == cls.CANONICAL_ID + diff --git a/plugins/A_memorix/core/utils/quantization.py b/plugins/A_memorix/core/utils/quantization.py new file mode 100644 index 00000000..4e84f977 --- /dev/null +++ b/plugins/A_memorix/core/utils/quantization.py @@ -0,0 +1,344 @@ +""" +向量量化工具模块 + +提供向量量化与反量化功能,用于压缩存储空间。 +""" + +import numpy as np +from enum import Enum +from typing import Tuple, Union + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.Quantization") + + +class QuantizationType(Enum): + """量化类型枚举""" + FLOAT32 = "float32" # 无量化 + INT8 = "int8" # 标量量化(8位整数) + PQ = "pq" # 乘积量化(Product Quantization) + + +def quantize_vector( + vector: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + 量化向量 + + Args: + vector: 输入向量(float32) + quant_type: 量化类型 + + Returns: + 量化后的向量: + - INT8: int8向量 + - PQ: (编码向量, 聚类中心) 元组 + """ + if quant_type == QuantizationType.FLOAT32: + return vector.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + return _scalar_quantize_int8(vector) + + elif quant_type == QuantizationType.PQ: + return _product_quantize(vector) + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def dequantize_vector( + quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], + quant_type: QuantizationType = QuantizationType.INT8, + original_shape: Tuple[int, ...] = None, +) -> np.ndarray: + """ + 反量化向量 + + Args: + quantized_vector: 量化后的向量 + quant_type: 量化类型 + original_shape: 原始向量形状(用于PQ) + + Returns: + 反量化后的向量(float32) + """ + if quant_type == QuantizationType.FLOAT32: + return quantized_vector.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + return _scalar_dequantize_int8(quantized_vector) + + elif quant_type == QuantizationType.PQ: + if not isinstance(quantized_vector, tuple): + raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)") + return _product_dequantize(quantized_vector[0], quantized_vector[1]) + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray: + """ + 标量量化:float32 -> int8 + + 将向量归一化到 [0, 255] 范围,然后映射到 int8 + + Args: + vector: 输入向量 + + Returns: + 量化后的 int8 向量 + """ + # 计算最小最大值 + min_val = np.min(vector) + max_val = np.max(vector) + + # 避免除零 + if max_val == min_val: + return np.zeros_like(vector, dtype=np.int8) + + # 归一化到 [0, 255] + normalized = (vector - min_val) / (max_val - min_val) * 255 + + # 映射到 [-128, 127] 并转换为 int8 + # np.round might return float, minus 128 then cast + quantized = np.round(normalized - 128.0).astype(np.int8) + + # 存储归一化参数(用于反量化) + # 在实际存储中,这些参数需要单独保存 + # 这里为了简单,我们使用一个全局字典来模拟 + if not hasattr(_scalar_quantize_int8, "_params"): + _scalar_quantize_int8._params = {} + + vector_id = id(vector) + _scalar_quantize_int8._params[vector_id] = (min_val, max_val) + + return quantized + + +def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray: + """ + 标量反量化:int8 -> float32 + + Args: + quantized: 量化后的 int8 向量 + + Returns: + 反量化后的 float32 向量 + """ + # 计算归一化参数(如果提供了) + # 在实际应用中,min_val 和 max_val 应该被保存 + if not hasattr(_scalar_dequantize_int8, "_params"): + # 默认假设范围是 [-1, 1] + return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0 + + # 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取) + # return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min + return (quantized.astype(np.float32) + 128.0) / 255.0 + + +def quantize_matrix( + matrix: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + 量化矩阵(批量量化向量) + + Args: + matrix: 输入矩阵(N x D,每行是一个向量) + quant_type: 量化类型 + + Returns: + 量化后的矩阵 + """ + if quant_type == QuantizationType.FLOAT32: + return matrix.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + # 对整个矩阵进行全局归一化 + min_val = np.min(matrix) + max_val = np.max(matrix) + + if max_val == min_val: + return np.zeros_like(matrix, dtype=np.int8) + + # 归一化到 [0, 255] + normalized = (matrix - min_val) / (max_val - min_val) * 255 + quantized = np.round(normalized).astype(np.int8) + + return quantized + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def dequantize_matrix( + quantized_matrix: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, + min_val: float = None, + max_val: float = None, +) -> np.ndarray: + """ + 反量化矩阵 + + Args: + quantized_matrix: 量化后的矩阵 + quant_type: 量化类型 + min_val: 归一化最小值(int8反量化需要) + max_val: 归一化最大值(int8反量化需要) + + Returns: + 反量化后的矩阵 + """ + if quant_type == QuantizationType.FLOAT32: + return quantized_matrix.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + # 使用提供的归一化参数反量化 + if min_val is None or max_val is None: + # 默认假设范围是 [0, 255] -> [-1, 1] + return quantized_matrix.astype(np.float32) / 127.0 + else: + # 恢复到原始范围 + normalized = quantized_matrix.astype(np.float32) / 255.0 + return normalized * (max_val - min_val) + min_val + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def estimate_memory_reduction( + num_vectors: int, + dimension: int, + from_type: QuantizationType, + to_type: QuantizationType, +) -> Tuple[float, float]: + """ + 估算内存节省量 + + Args: + num_vectors: 向量数量 + dimension: 向量维度 + from_type: 原始量化类型 + to_type: 目标量化类型 + + Returns: + (原始大小MB, 量化后大小MB, 节省比例) + """ + # 计算每个向量占用的字节数 + bytes_per_element = { + QuantizationType.FLOAT32: 4, + QuantizationType.INT8: 1, + QuantizationType.PQ: 0.25, # 假设压缩到1/4 + } + + original_bytes = num_vectors * dimension * bytes_per_element[from_type] + quantized_bytes = num_vectors * dimension * bytes_per_element[to_type] + + original_mb = original_bytes / 1024 / 1024 + quantized_mb = quantized_bytes / 1024 / 1024 + reduction_ratio = (original_bytes - quantized_bytes) / original_bytes + + return original_mb, quantized_mb, reduction_ratio + + +def estimate_compression_stats( + num_vectors: int, + dimension: int, + quant_type: QuantizationType, +) -> dict: + """ + 估算压缩统计信息 + + Args: + num_vectors: 向量数量 + dimension: 向量维度 + quant_type: 量化类型 + + Returns: + 统计信息字典 + """ + original_mb, quantized_mb, ratio = estimate_memory_reduction( + num_vectors, dimension, QuantizationType.FLOAT32, quant_type + ) + + return { + "num_vectors": num_vectors, + "dimension": dimension, + "quantization_type": quant_type.value, + "original_size_mb": round(original_mb, 2), + "quantized_size_mb": round(quantized_mb, 2), + "saved_mb": round(original_mb - quantized_mb, 2), + "compression_ratio": round(ratio * 100, 2), + } + + +def _product_quantize( + vector: np.ndarray, m: int = 8, k: int = 256 +) -> Tuple[np.ndarray, np.ndarray]: + """ + 乘积量化 (PQ) 简化实现 + + Args: + vector: 输入向量 (D,) + m: 子空间数量 + k: 每个子空间的聚类中心数 + + Returns: + (编码后的向量, 聚类中心) + """ + d = vector.shape[0] + if d % m != 0: + raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除") + + ds = d // m # 子空间维度 + codes = np.zeros(m, dtype=np.uint8) + centroids = np.zeros((m, k, ds), dtype=np.float32) + + # 这里采用一种简化的 PQ:不进行 K-means 训练, + # 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练) + # 为了演示,我们直接将子空间切分为 k 份进行量化 + for i in range(m): + sub_vec = vector[i * ds : (i + 1) * ds] + # 简化:假定每个子空间的取值范围并划分 + # 实际 PQ 应使用 k-means 产生的 centroids + # 这里为演示创建一个随机 codebook 并找到最接近的核心 + sub_min, sub_max = np.min(sub_vec), np.max(sub_vec) + if sub_max == sub_min: + linspace = np.zeros(k) + else: + linspace = np.linspace(sub_min, sub_max, k) + + for j in range(k): + centroids[i, j, :] = linspace[j] + + # 编码:这里简化为取子空间均值找最接近的 centroid + sub_mean = np.mean(sub_vec) + code = np.argmin(np.abs(linspace - sub_mean)) + codes[i] = code + + return codes, centroids + + +def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray: + """ + PQ 反量化 + + Args: + codes: 编码向量 (M,) + centroids: 聚类中心 (M, K, DS) + + Returns: + 恢复后的向量 (D,) + """ + m, k, ds = centroids.shape + vector = np.zeros(m * ds, dtype=np.float32) + + for i in range(m): + code = codes[i] + vector[i * ds : (i + 1) * ds] = centroids[i, code, :] + + return vector diff --git a/plugins/A_memorix/core/utils/relation_query.py b/plugins/A_memorix/core/utils/relation_query.py new file mode 100644 index 00000000..ffde9cac --- /dev/null +++ b/plugins/A_memorix/core/utils/relation_query.py @@ -0,0 +1,121 @@ +"""关系查询规格解析工具。""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class RelationQuerySpec: + raw: str + is_structured: bool + subject: Optional[str] + predicate: Optional[str] + object: Optional[str] + error: Optional[str] = None + + +_NATURAL_LANGUAGE_PATTERN = re.compile( + r"(^\s*(what|who|which|how|why|when|where)\b|" + r"\?|?|" + r"\b(relation|related|between)\b|" + r"(什么关系|有哪些关系|之间|关联))", + re.IGNORECASE, +) + + +def _looks_like_natural_language(raw: str) -> bool: + text = str(raw or "").strip() + if not text: + return False + return _NATURAL_LANGUAGE_PATTERN.search(text) is not None + + +def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec: + raw = str(relation_spec or "").strip() + if not raw: + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + error="empty", + ) + + if "|" in raw: + parts = [p.strip() for p in raw.split("|")] + if len(parts) < 2: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=None, + predicate=None, + object=None, + error="invalid_pipe_format", + ) + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0] or None, + predicate=parts[1] or None, + object=parts[2] if len(parts) > 2 and parts[2] else None, + ) + + if "->" in raw: + parts = [p.strip() for p in raw.split("->") if p.strip()] + if len(parts) >= 3: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=parts[1], + object=parts[2], + ) + if len(parts) == 2: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=None, + object=parts[1], + ) + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=None, + predicate=None, + object=None, + error="invalid_arrow_format", + ) + + if _looks_like_natural_language(raw): + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + ) + + # 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。 + # 两词形式过于模糊,不再视为结构化关系查询。 + parts = raw.split() + if len(parts) == 3: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=parts[1], + object=parts[2], + ) + + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + ) diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py new file mode 100644 index 00000000..b73e1260 --- /dev/null +++ b/plugins/A_memorix/core/utils/relation_write_service.py @@ -0,0 +1,164 @@ +""" +统一关系写入与关系向量化服务。 + +规则: +1. 元数据是主数据源,向量是从索引。 +2. 关系先写 metadata,再写向量。 +3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from src.common.logger import get_logger + + +logger = get_logger("A_Memorix.RelationWriteService") + + +@dataclass +class RelationWriteResult: + hash_value: str + vector_written: bool + vector_already_exists: bool + vector_state: str + + +class RelationWriteService: + """关系写入收口服务。""" + + ERROR_MAX_LEN = 500 + + def __init__( + self, + metadata_store: Any, + graph_store: Any, + vector_store: Any, + embedding_manager: Any, + ): + self.metadata_store = metadata_store + self.graph_store = graph_store + self.vector_store = vector_store + self.embedding_manager = embedding_manager + + @staticmethod + def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str: + s = str(subject or "").strip() + p = str(predicate or "").strip() + o = str(obj or "").strip() + # 双表达:兼容关键词检索与自然语言问句 + return f"{s} {p} {o}\n{s}和{o}的关系是{p}" + + async def ensure_relation_vector( + self, + hash_value: str, + subject: str, + predicate: str, + obj: str, + *, + max_error_len: int = ERROR_MAX_LEN, + ) -> RelationWriteResult: + """ + 为已有关系确保向量存在并更新状态。 + """ + if hash_value in self.vector_store: + self.metadata_store.set_relation_vector_state(hash_value, "ready") + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=True, + vector_state="ready", + ) + + self.metadata_store.set_relation_vector_state(hash_value, "pending") + try: + vector_text = self.build_relation_vector_text(subject, predicate, obj) + embedding = await self.embedding_manager.encode(vector_text) + self.vector_store.add( + vectors=embedding.reshape(1, -1), + ids=[hash_value], + ) + self.metadata_store.set_relation_vector_state(hash_value, "ready") + logger.info( + "metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s", + hash_value[:16], + ) + return RelationWriteResult( + hash_value=hash_value, + vector_written=True, + vector_already_exists=False, + vector_state="ready", + ) + except ValueError: + # 向量已存在冲突,按成功处理 + self.metadata_store.set_relation_vector_state(hash_value, "ready") + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=True, + vector_state="ready", + ) + except Exception as e: + err = str(e)[:max_error_len] + self.metadata_store.set_relation_vector_state( + hash_value, + "failed", + error=err, + bump_retry=True, + ) + logger.warning( + "metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s", + hash_value[:16], + err, + ) + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=False, + vector_state="failed", + ) + + async def upsert_relation_with_vector( + self, + subject: str, + predicate: str, + obj: str, + confidence: float = 1.0, + source_paragraph: str = "", + metadata: Optional[Dict[str, Any]] = None, + *, + write_vector: bool = True, + ) -> RelationWriteResult: + """ + 统一关系写入: + 1) 写 metadata relation + 2) 写 graph edge relation_hash + 3) 按需写 relation vector + """ + rel_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + confidence=confidence, + source_paragraph=source_paragraph, + metadata=metadata or {}, + ) + self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + + if not write_vector: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + return RelationWriteResult( + hash_value=rel_hash, + vector_written=False, + vector_already_exists=False, + vector_state="none", + ) + + return await self.ensure_relation_vector( + hash_value=rel_hash, + subject=subject, + predicate=predicate, + obj=obj, + ) diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py new file mode 100644 index 00000000..36a2cf7e --- /dev/null +++ b/plugins/A_memorix/core/utils/runtime_self_check.py @@ -0,0 +1,197 @@ +"""Runtime self-check helpers for A_Memorix.""" + +from __future__ import annotations + +import time +from typing import Any, Dict, Optional + +import numpy as np + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.RuntimeSelfCheck") + +_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check" + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + +def _get_config_value(config: Any, key: str, default: Any = None) -> Any: + getter = getattr(config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _build_report( + *, + ok: bool, + code: str, + message: str, + configured_dimension: int, + vector_store_dimension: int, + detected_dimension: int, + encoded_dimension: int, + elapsed_ms: float, + sample_text: str, +) -> Dict[str, Any]: + return { + "ok": bool(ok), + "code": str(code or "").strip(), + "message": str(message or "").strip(), + "configured_dimension": int(configured_dimension), + "vector_store_dimension": int(vector_store_dimension), + "detected_dimension": int(detected_dimension), + "encoded_dimension": int(encoded_dimension), + "elapsed_ms": float(elapsed_ms), + "sample_text": str(sample_text or ""), + "checked_at": time.time(), + } + + +async def run_embedding_runtime_self_check( + *, + config: Any, + vector_store: Optional[Any], + embedding_manager: Optional[Any], + sample_text: str = _DEFAULT_SAMPLE_TEXT, +) -> Dict[str, Any]: + """Probe the real embedding path and compare dimensions with runtime storage.""" + configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0) + vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0) + + if vector_store is None or embedding_manager is None: + return _build_report( + ok=False, + code="runtime_components_missing", + message="vector_store 或 embedding_manager 未初始化", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=0, + encoded_dimension=0, + elapsed_ms=0.0, + sample_text=sample_text, + ) + + start = time.perf_counter() + detected_dimension = 0 + encoded_dimension = 0 + try: + detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) + encoded = await embedding_manager.encode(sample_text) + if isinstance(encoded, np.ndarray): + encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1]) + else: + encoded_dimension = len(encoded) if encoded is not None else 0 + except Exception as exc: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + logger.warning("embedding runtime self-check failed: %s", exc) + return _build_report( + ok=False, + code="embedding_probe_failed", + message=f"embedding probe failed: {exc}", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + elapsed_ms = (time.perf_counter() - start) * 1000.0 + expected_dimension = vector_store_dimension or configured_dimension or detected_dimension + if expected_dimension <= 0: + return _build_report( + ok=False, + code="invalid_expected_dimension", + message="无法确定期望 embedding 维度", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + if encoded_dimension != expected_dimension: + msg = ( + "embedding 真实输出维度与当前向量存储不一致: " + f"expected={expected_dimension}, encoded={encoded_dimension}" + ) + logger.error(msg) + return _build_report( + ok=False, + code="embedding_dimension_mismatch", + message=msg, + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + return _build_report( + ok=True, + code="ok", + message="embedding runtime self-check passed", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + +async def ensure_runtime_self_check( + plugin_or_config: Any, + *, + force: bool = False, + sample_text: str = _DEFAULT_SAMPLE_TEXT, +) -> Dict[str, Any]: + """Run or reuse cached runtime self-check report.""" + if plugin_or_config is None: + return _build_report( + ok=False, + code="missing_plugin_or_config", + message="plugin/config unavailable", + configured_dimension=0, + vector_store_dimension=0, + detected_dimension=0, + encoded_dimension=0, + elapsed_ms=0.0, + sample_text=sample_text, + ) + + cache = getattr(plugin_or_config, "_runtime_self_check_report", None) + if isinstance(cache, dict) and cache and not force: + return cache + + report = await run_embedding_runtime_self_check( + config=getattr(plugin_or_config, "config", plugin_or_config), + vector_store=getattr(plugin_or_config, "vector_store", None) + if not isinstance(plugin_or_config, dict) + else plugin_or_config.get("vector_store"), + embedding_manager=getattr(plugin_or_config, "embedding_manager", None) + if not isinstance(plugin_or_config, dict) + else plugin_or_config.get("embedding_manager"), + sample_text=sample_text, + ) + try: + setattr(plugin_or_config, "_runtime_self_check_report", report) + except Exception: + pass + return report diff --git a/plugins/A_memorix/core/utils/search_postprocess.py b/plugins/A_memorix/core/utils/search_postprocess.py new file mode 100644 index 00000000..52688e08 --- /dev/null +++ b/plugins/A_memorix/core/utils/search_postprocess.py @@ -0,0 +1,90 @@ +"""Post-processing helpers for unified search execution.""" + +from __future__ import annotations + +from typing import Any, List, Tuple + +from .path_fallback_service import find_paths_from_query, to_retrieval_results + + +def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]: + """Deduplicate results by hash/content while preserving at least one entry.""" + if not results: + return [], 0 + + unique_results: List[Any] = [] + seen_hashes = set() + seen_contents = set() + + for item in results: + content = str(getattr(item, "content", "") or "").strip() + if not content: + continue + + hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content)) + if hash_value in seen_hashes: + continue + + is_dup = False + for seen in seen_contents: + if content in seen or seen in content: + is_dup = True + break + if is_dup: + continue + + seen_hashes.add(hash_value) + seen_contents.add(content) + unique_results.append(item) + + if not unique_results: + unique_results.append(results[0]) + + removed_count = max(0, len(results) - len(unique_results)) + return unique_results, removed_count + + +def maybe_apply_smart_path_fallback( + *, + query: str, + results: List[Any], + graph_store: Any, + metadata_store: Any, + enabled: bool, + threshold: float, + max_depth: int = 3, + max_paths: int = 5, +) -> Tuple[List[Any], bool, int]: + """Append indirect relation paths when semantic results are weak.""" + if not enabled or not str(query or "").strip(): + return results, False, 0 + if graph_store is None or metadata_store is None: + return results, False, 0 + + max_score = 0.0 + if results: + try: + max_score = float(getattr(results[0], "score", 0.0) or 0.0) + except Exception: + max_score = 0.0 + + if max_score >= float(threshold): + return results, False, 0 + + paths = find_paths_from_query( + query=query, + graph_store=graph_store, + metadata_store=metadata_store, + max_depth=max_depth, + max_paths=max_paths, + ) + if not paths: + return results, False, 0 + + path_results = to_retrieval_results(paths) + if not path_results: + return results, False, 0 + + merged = list(path_results) + list(results) + return merged, True, len(path_results) + diff --git a/plugins/A_memorix/core/utils/time_parser.py b/plugins/A_memorix/core/utils/time_parser.py new file mode 100644 index 00000000..8e577974 --- /dev/null +++ b/plugins/A_memorix/core/utils/time_parser.py @@ -0,0 +1,170 @@ +""" +时间解析工具。 + +约束: +1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间: + - YYYY/MM/DD + - YYYY/MM/DD HH:mm +2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。 +""" + +from __future__ import annotations + +import re +from datetime import datetime +from typing import Any, Dict, Optional, Tuple + + +_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") +_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$") +_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$") + +_INGEST_FORMATS = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M", + "%Y/%m/%d", + "%Y-%m-%d", +] + +_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"} + + +def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float: + """解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。""" + text = str(value).strip() + if not text: + raise ValueError("时间不能为空") + + if _QUERY_DATE_RE.fullmatch(text): + dt = datetime.strptime(text, "%Y/%m/%d") + if is_end: + dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) + return dt.timestamp() + + if _QUERY_MINUTE_RE.fullmatch(text): + dt = datetime.strptime(text, "%Y/%m/%d %H:%M") + return dt.timestamp() + + raise ValueError( + f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm" + ) + + +def parse_query_time_range( + time_from: Optional[str], + time_to: Optional[str], +) -> Tuple[Optional[float], Optional[float]]: + """解析查询窗口并验证区间。""" + ts_from = ( + parse_query_datetime_to_timestamp(time_from, is_end=False) + if time_from + else None + ) + ts_to = ( + parse_query_datetime_to_timestamp(time_to, is_end=True) + if time_to + else None + ) + + if ts_from is not None and ts_to is not None and ts_from > ts_to: + raise ValueError("time_from 不能晚于 time_to") + + return ts_from, ts_to + + +def parse_ingest_datetime_to_timestamp( + value: Any, + is_end: bool = False, +) -> Optional[float]: + """解析入库时间,允许 timestamp/常见字符串格式。""" + if value is None: + return None + + if isinstance(value, (int, float)): + return float(value) + + text = str(value).strip() + if not text: + return None + + if _NUMERIC_RE.fullmatch(text): + return float(text) + + for fmt in _INGEST_FORMATS: + try: + dt = datetime.strptime(text, fmt) + if fmt in _INGEST_DATE_FORMATS and is_end: + dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) + return dt.timestamp() + except ValueError: + continue + + raise ValueError(f"无法解析时间: {text}") + + +def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """归一化 time_meta 到存储层字段。""" + if not time_meta: + return {} + + normalized: Dict[str, Any] = {} + + event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time")) + event_start = parse_ingest_datetime_to_timestamp( + time_meta.get("event_time_start"), + is_end=False, + ) + event_end = parse_ingest_datetime_to_timestamp( + time_meta.get("event_time_end"), + is_end=True, + ) + + time_range = time_meta.get("time_range") + if ( + isinstance(time_range, (list, tuple)) + and len(time_range) == 2 + ): + if event_start is None: + event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False) + if event_end is None: + event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True) + + if event_start is not None and event_end is not None and event_start > event_end: + raise ValueError("event_time_start 不能晚于 event_time_end") + + if event_time is not None: + normalized["event_time"] = event_time + if event_start is not None: + normalized["event_time_start"] = event_start + if event_end is not None: + normalized["event_time_end"] = event_end + + granularity = time_meta.get("time_granularity") + if granularity: + normalized["time_granularity"] = str(granularity) + else: + raw_time_values = [ + time_meta.get("event_time"), + time_meta.get("event_time_start"), + time_meta.get("event_time_end"), + ] + has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None) + normalized["time_granularity"] = "minute" if has_minute else "day" + + confidence = time_meta.get("time_confidence") + if confidence is not None: + normalized["time_confidence"] = float(confidence) + + return normalized + + +def format_timestamp(ts: Optional[float]) -> Optional[str]: + """将 timestamp 格式化为 YYYY/MM/DD HH:mm。""" + if ts is None: + return None + return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M") + diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py new file mode 100644 index 00000000..56df45b9 --- /dev/null +++ b/plugins/A_memorix/plugin.py @@ -0,0 +1,207 @@ +"""A_Memorix SDK plugin entry.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional + +from maibot_sdk import MaiBotPlugin, Tool +from maibot_sdk.types import ToolParameterInfo, ToolParamType + +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel + + +def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo: + return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) + + +class AMemorixPlugin(MaiBotPlugin): + def __init__(self) -> None: + super().__init__() + self._plugin_root = Path(__file__).resolve().parent + self._plugin_config: Dict[str, Any] = {} + self._kernel: Optional[SDKMemoryKernel] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + self._plugin_config = config or {} + if self._kernel is not None: + self._kernel.close() + self._kernel = None + + async def on_load(self): + await self._get_kernel() + + async def on_unload(self): + if self._kernel is not None: + self._kernel.close() + self._kernel = None + + async def _get_kernel(self) -> SDKMemoryKernel: + if self._kernel is None: + self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config) + await self._kernel.initialize() + return self._kernel + + @Tool( + "search_memory", + description="搜索长期记忆", + parameters=[ + _tool_param("query", ToolParamType.STRING, "查询文本", False), + _tool_param("limit", ToolParamType.INTEGER, "返回条数", False), + _tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), + _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), + _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + ], + ) + async def handle_search_memory( + self, + query: str = "", + limit: int = 5, + mode: str = "hybrid", + chat_id: str = "", + person_id: str = "", + time_start: float | None = None, + time_end: float | None = None, + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.search_memory( + KernelSearchRequest( + query=query, + limit=limit, + mode=mode, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + ) + + @Tool( + "ingest_summary", + description="写入聊天摘要到长期记忆", + parameters=[ + _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True), + _tool_param("text", ToolParamType.STRING, "摘要文本", True), + _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), + _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + ], + ) + async def handle_ingest_summary( + self, + external_id: str, + chat_id: str, + text: str, + participants: Optional[List[str]] = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.ingest_summary( + external_id=external_id, + chat_id=chat_id, + text=text, + participants=participants, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=metadata, + ) + + @Tool( + "ingest_text", + description="写入普通长期记忆文本", + parameters=[ + _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), + _tool_param("source_type", ToolParamType.STRING, "来源类型", True), + _tool_param("text", ToolParamType.STRING, "原始文本", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), + ], + ) + async def handle_ingest_text( + self, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[List[str]] = None, + participants: Optional[List[str]] = None, + timestamp: float | None = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + relations = kwargs.get("relations") + entities = kwargs.get("entities") + kernel = await self._get_kernel() + return await kernel.ingest_text( + external_id=external_id, + source_type=source_type, + text=text, + chat_id=chat_id, + person_ids=person_ids, + participants=participants, + timestamp=timestamp, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=metadata, + entities=entities, + relations=relations, + ) + + @Tool( + "get_person_profile", + description="获取人物画像", + parameters=[ + _tool_param("person_id", ToolParamType.STRING, "人物 ID", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("limit", ToolParamType.INTEGER, "证据条数", False), + ], + ) + async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit) + + @Tool( + "maintain_memory", + description="维护长期记忆关系状态", + parameters=[ + _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True), + _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True), + _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), + ], + ) + async def handle_maintain_memory( + self, + action: str, + target: str, + hours: float | None = None, + reason: str = "", + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason) + + @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) + async def handle_memory_stats(self, **kwargs): + _ = kwargs + kernel = await self._get_kernel() + return kernel.memory_stats() + + +def create_plugin(): + return AMemorixPlugin() diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py new file mode 100644 index 00000000..5ff284fb --- /dev/null +++ b/plugins/A_memorix/scripts/convert_lpmm.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python3 +""" +LPMM 到 A_memorix 存储转换器 + +功能: +1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet) +2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl) +3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore +4. 绕过 Embedding 生成以节省 Token +""" + +import sys +import os +import json +import argparse +import asyncio +import pickle +import logging +from pathlib import Path +from typing import Dict, Any, List, Tuple +import numpy as np +import tomlkit + +# 设置路径 +current_dir = Path(__file__).resolve().parent +plugin_root = current_dir.parent +project_root = plugin_root.parent.parent +sys.path.insert(0, str(project_root)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式") + parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)") + parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录") + parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)") + parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)") + parser.add_argument( + "--skip-relation-vector-rebuild", + action="store_true", + help="跳过按关系元数据重建关系向量(默认开启)", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + +# 设置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("LPMM_Converter") + +try: + import networkx as nx + from scipy import sparse + import pyarrow.parquet as pq +except ImportError as e: + logger.error(f"缺少依赖: {e}") + logger.error("请安装: pip install pandas pyarrow networkx scipy") + sys.exit(1) + +try: + # 优先采取相对导入 (将插件根目录加入路径) + # 这样可以避免硬编码插件名称 (plugins.A_memorix) + if str(plugin_root) not in sys.path: + sys.path.insert(0, str(plugin_root)) + + from core.storage.vector_store import VectorStore + from core.storage.graph_store import GraphStore + from core.storage.metadata_store import MetadataStore + from core.storage import QuantizationType, SparseMatrixFormat + from core.embedding import create_embedding_api_adapter + from core.utils.relation_write_service import RelationWriteService + +except ImportError as e: + logger.error(f"无法导入 A_memorix 核心模块: {e}") + logger.error("请确保在正确的环境中运行,且已安装所有依赖。") + sys.exit(1) + + +class LPMMConverter: + def __init__( + self, + lpmm_data_dir: Path, + output_dir: Path, + dimension: int = 384, + batch_size: int = 1024, + rebuild_relation_vectors: bool = True, + ): + self.lpmm_dir = lpmm_data_dir + self.output_dir = output_dir + self.dimension = dimension + self.batch_size = max(1, int(batch_size)) + self.rebuild_relation_vectors = bool(rebuild_relation_vectors) + + self.vector_dir = output_dir / "vectors" + self.graph_dir = output_dir / "graph" + self.metadata_dir = output_dir / "metadata" + + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.relation_write_service = None + # LPMM 原 ID -> A_memorix ID 映射(用于图重写) + self.id_mapping: Dict[str, str] = {} + + def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None: + """记录 ID 映射,兼容带/不带类型前缀两种格式。""" + if raw_id is None: + return + + raw = str(raw_id).strip() + if not raw: + return + + self.id_mapping[raw] = mapped_id + + prefix = f"{p_type}-" + if raw.startswith(prefix): + self.id_mapping[raw[len(prefix):]] = mapped_id + else: + self.id_mapping[prefix + raw] = mapped_id + + def _map_node_id(self, node: Any) -> str: + """将图节点 ID 映射到转换后的 A_memorix ID。""" + node_key = str(node) + return self.id_mapping.get(node_key, node_key) + + def initialize_stores(self): + """初始化空的 A_memorix 存储""" + logger.info(f"正在初始化存储于 {self.output_dir}...") + + # 初始化 VectorStore (A_memorix 默认使用 INT8 量化) + self.vector_store = VectorStore( + dimension=self.dimension, + quantization_type=QuantizationType.INT8, + data_dir=self.vector_dir + ) + self.vector_store.clear() # 清空旧数据 + + # 初始化 GraphStore (使用 CSR 格式) + self.graph_store = GraphStore( + matrix_format=SparseMatrixFormat.CSR, + data_dir=self.graph_dir + ) + self.graph_store.clear() + + # 初始化 MetadataStore + self.metadata_store = MetadataStore(data_dir=self.metadata_dir) + self.metadata_store.connect() + # 清空元数据表?理想情况下是的,但要小心。 + # 对于转换,我们假设是全新的开始或覆盖。 + # A_memorix 中的 MetadataStore 通常使用 SQLite。 + # 如果目录是新的,我们会依赖它创建新文件。 + if self.rebuild_relation_vectors: + self._init_relation_vector_service() + + def _load_plugin_config(self) -> Dict[str, Any]: + config_path = plugin_root / "config.toml" + if not config_path.exists(): + return {} + try: + with open(config_path, "r", encoding="utf-8") as f: + parsed = tomlkit.load(f) + return dict(parsed) if isinstance(parsed, dict) else {} + except Exception as e: + logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}") + return {} + + def _init_relation_vector_service(self) -> None: + if not self.rebuild_relation_vectors: + return + cfg = self._load_plugin_config() + emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} + if not isinstance(emb_cfg, dict): + emb_cfg = {} + try: + self.embedding_manager = create_embedding_api_adapter( + batch_size=int(emb_cfg.get("batch_size", 32)), + max_concurrent=int(emb_cfg.get("max_concurrent", 5)), + default_dimension=int(emb_cfg.get("dimension", self.dimension)), + model_name=str(emb_cfg.get("model_name", "auto")), + retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {}, + ) + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + except Exception as e: + self.embedding_manager = None + self.relation_write_service = None + logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}") + + async def _rebuild_relation_vectors(self) -> None: + if not self.rebuild_relation_vectors: + return + if self.relation_write_service is None: + logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。") + return + + rows = self.metadata_store.get_relations() + if not rows: + logger.info("未发现关系元数据,无需重建关系向量。") + return + + success = 0 + failed = 0 + skipped = 0 + for row in rows: + result = await self.relation_write_service.ensure_relation_vector( + hash_value=str(row["hash"]), + subject=str(row.get("subject", "")), + predicate=str(row.get("predicate", "")), + obj=str(row.get("object", "")), + ) + if result.vector_state == "ready": + if result.vector_written: + success += 1 + else: + skipped += 1 + else: + failed += 1 + + logger.info( + "关系向量重建完成: total=%s success=%s skipped=%s failed=%s", + len(rows), + success, + skipped, + failed, + ) + + @staticmethod + def _parse_relation_text(text: str) -> Tuple[str, str, str]: + raw = str(text or "").strip() + if not raw: + return "", "", "" + if "|" in raw: + parts = [p.strip() for p in raw.split("|") if p.strip()] + if len(parts) >= 3: + return parts[0], parts[1], parts[2] + if "->" in raw: + parts = [p.strip() for p in raw.split("->") if p.strip()] + if len(parts) >= 3: + return parts[0], parts[1], parts[2] + pieces = raw.split() + if len(pieces) >= 3: + return pieces[0], pieces[1], " ".join(pieces[2:]) + return "", "", "" + + def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int: + if not relation_path.exists(): + return 0 + + try: + parquet_file = pq.ParquetFile(relation_path) + except Exception as e: + logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}") + return 0 + + cols = set(parquet_file.schema_arrow.names) + has_triple_cols = {"subject", "predicate", "object"}.issubset(cols) + content_col = "str" if "str" in cols else ("content" if "content" in cols else "") + + imported_hashes = set() + imported = 0 + for record_batch in parquet_file.iter_batches(batch_size=self.batch_size): + df_batch = record_batch.to_pandas() + for _, row in df_batch.iterrows(): + subject = "" + predicate = "" + obj = "" + if has_triple_cols: + subject = str(row.get("subject", "") or "").strip() + predicate = str(row.get("predicate", "") or "").strip() + obj = str(row.get("object", "") or "").strip() + elif content_col: + subject, predicate, obj = self._parse_relation_text(row.get(content_col, "")) + + if not (subject and predicate and obj): + continue + + rel_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + source_paragraph=None, + ) + if rel_hash in imported_hashes: + continue + imported_hashes.add(rel_hash) + self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + imported += 1 + + return imported + + def convert_vectors(self): + """将 Parquet 向量转换为 VectorStore""" + # LPMM 默认文件名 + parquet_files = { + "paragraph": self.lpmm_dir / "paragraph.parquet", + "entity": self.lpmm_dir / "entity.parquet", + "relation": self.lpmm_dir / "relation.parquet" + } + + total_vectors = 0 + + for p_type, p_path in parquet_files.items(): + # 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应, + # 直接导入会污染召回结果(命中后无法反查 relation 元数据)。 + if p_type == "relation": + relation_count = self._import_relation_metadata_from_parquet(p_path) + logger.warning( + "跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s", + relation_count, + ) + continue + + if not p_path.exists(): + logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。") + continue + + logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...") + try: + parquet_file = pq.ParquetFile(p_path) + total_rows = parquet_file.metadata.num_rows + if total_rows == 0: + logger.info(f"{p_path} 为空,跳过。") + continue + + # LPMM Schema: 'hash', 'embedding', 'str' + cols = parquet_file.schema_arrow.names + # 兼容性检查 + content_col = 'str' if 'str' in cols else 'content' + emb_col = 'embedding' + hash_col = 'hash' + + if content_col not in cols or emb_col not in cols: + logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}") + continue + + batch_columns = [content_col, emb_col] + if hash_col in cols: + batch_columns.append(hash_col) + + processed_rows = 0 + added_for_type = 0 + batch_idx = 0 + + for record_batch in parquet_file.iter_batches( + batch_size=self.batch_size, + columns=batch_columns, + ): + batch_idx += 1 + df_batch = record_batch.to_pandas() + + embeddings_list = [] + ids_list = [] + + # 同时处理元数据映射 + for _, row in df_batch.iterrows(): + processed_rows += 1 + content = row[content_col] + emb = row[emb_col] + + if content is None or (isinstance(content, float) and np.isnan(content)): + continue + content = str(content).strip() + if not content: + continue + + if emb is None or len(emb) == 0: + continue + + # 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID + # 保证检索返回 ID 可以直接反查元数据。 + store_id = None + if p_type == "paragraph": + store_id = self.metadata_store.add_paragraph( + content=content, + source="lpmm_import", + knowledge_type="factual", + ) + elif p_type == "entity": + store_id = self.metadata_store.add_entity(name=content) + else: + continue + + raw_hash = row[hash_col] if hash_col in df_batch.columns else None + if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)): + self._register_id_mapping(raw_hash, store_id, p_type) + + # 确保 embedding 是 numpy 数组 + emb_np = np.array(emb, dtype=np.float32) + if emb_np.shape[0] != self.dimension: + logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}") + continue + + embeddings_list.append(emb_np) + ids_list.append(store_id) + + if embeddings_list: + # 分批添加到向量存储 + vectors_np = np.stack(embeddings_list) + count = self.vector_store.add(vectors_np, ids_list) + added_for_type += count + total_vectors += count + + if batch_idx == 1 or batch_idx % 10 == 0: + logger.info( + f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}" + ) + + logger.info( + f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}" + ) + + except Exception as e: + logger.error(f"处理 {p_path} 时出错: {e}") + + # 提交向量存储 + self.vector_store.save() + logger.info(f"向量转换完成。总向量数: {total_vectors}") + + def convert_graph(self): + """将 LPMM 图转换为 GraphStore""" + # LPMM 默认文件名是 rag-graph.graphml + graph_files = [ + self.lpmm_dir / "rag-graph.graphml", + self.lpmm_dir / "graph.graphml", + self.lpmm_dir / "graph_structure.pkl" + ] + + nx_graph = None + + for g_path in graph_files: + if g_path.exists(): + logger.info(f"发现图文件: {g_path}") + try: + if g_path.suffix == ".graphml": + nx_graph = nx.read_graphml(g_path) + elif g_path.suffix == ".pkl": + with open(g_path, "rb") as f: + data = pickle.load(f) + # LPMM 可能会将图存储在包装类中 + if hasattr(data, "graph") and isinstance(data.graph, nx.Graph): + nx_graph = data.graph + elif isinstance(data, nx.Graph): + nx_graph = data + break + except Exception as e: + logger.error(f"加载 {g_path} 失败: {e}") + + if nx_graph is None: + logger.warning("未找到有效的图文件。跳过图转换。") + return + + logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。") + + # 1. 添加节点 + # LPMM 节点通常是哈希或带前缀的字符串。 + # 我们需要将它们映射到 A_memorix 格式。 + # 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。 + + nodes_to_add = [] + node_attrs = {} + + for node, attrs in nx_graph.nodes(data=True): + # 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..." + mapped_node = self._map_node_id(node) + nodes_to_add.append(mapped_node) + if attrs: + node_attrs[mapped_node] = attrs + + self.graph_store.add_nodes(nodes_to_add, node_attrs) + + # 2. 添加边 + edges_to_add = [] + weights = [] + + for u, v, data in nx_graph.edges(data=True): + weight = data.get("weight", 1.0) + edges_to_add.append((self._map_node_id(u), self._map_node_id(v))) + weights.append(float(weight)) + + # 如果可能,将关系同步到 MetadataStore + # 但图的边并不总是包含关系谓词 + # 如果 LPMM 边数据有 'predicate',我们可以添加到元数据 + # 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失 + + if edges_to_add: + self.graph_store.add_edges(edges_to_add, weights) + + self.graph_store.save() + logger.info("图转换完成。") + + def run(self): + self.initialize_stores() + self.convert_vectors() + self.convert_graph() + asyncio.run(self._rebuild_relation_vectors()) + self.vector_store.save() + self.graph_store.save() + self.metadata_store.close() + logger.info("所有转换成功完成。") + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + input_path = Path(args.input) + output_path = Path(args.output) + + if not input_path.exists(): + logger.error(f"输入目录不存在: {input_path}") + sys.exit(1) + + converter = LPMMConverter( + input_path, + output_path, + dimension=args.dim, + batch_size=args.batch_size, + rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild), + ) + converter.run() + +if __name__ == "__main__": + main() diff --git a/plugins/A_memorix/scripts/migrate_chat_history.py b/plugins/A_memorix/scripts/migrate_chat_history.py new file mode 100644 index 00000000..0fb0bfe1 --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_chat_history.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import sqlite3 +import sys +from datetime import datetime +from pathlib import Path +from typing import Any, Dict + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" + +if str(WORKSPACE_ROOT) not in sys.path: + sys.path.insert(0, str(WORKSPACE_ROOT)) +if str(MAIBOT_ROOT) not in sys.path: + sys.path.insert(0, str(MAIBOT_ROOT)) + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix") + parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") + parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") + parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部") + parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") + return parser.parse_args() + + +def _to_timestamp(value: Any) -> float | None: + if value is None: + return None + if isinstance(value, (int, float)): + return float(value) + text = str(value).strip() + if not text: + return None + try: + return datetime.fromisoformat(text).timestamp() + except ValueError: + return None + + +async def _main() -> int: + args = _parse_args() + db_path = Path(args.db_path).resolve() + if not db_path.exists(): + print(f"数据库不存在: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + sql = """ + SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary + FROM chat_history + ORDER BY id ASC + """ + if int(args.limit or 0) > 0: + sql += " LIMIT ?" + rows = conn.execute(sql, (int(args.limit),)).fetchall() + else: + rows = conn.execute(sql).fetchall() + conn.close() + + print(f"chat_history 待处理: {len(rows)}") + if args.dry_run: + for row in rows[:5]: + print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}") + return 0 + + kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) + await kernel.initialize() + migrated = 0 + skipped = 0 + for row in rows: + participants = json.loads(row["participants"]) if row["participants"] else [] + keywords = json.loads(row["keywords"]) if row["keywords"] else [] + theme = str(row["theme"] or "").strip() + summary = str(row["summary"] or "").strip() + text = f"主题:{theme}\n概括:{summary}".strip() + result: Dict[str, Any] = await kernel.ingest_summary( + external_id=f"chat_history:{row['id']}", + chat_id=str(row["session_id"] or ""), + text=text, + participants=participants, + time_start=_to_timestamp(row["start_timestamp"]), + time_end=_to_timestamp(row["end_timestamp"]), + tags=keywords, + metadata={"theme": theme, "source_row_id": int(row["id"])}, + ) + if result.get("stored_ids"): + migrated += 1 + else: + skipped += 1 + + print(f"迁移完成: migrated={migrated} skipped={skipped}") + print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) + kernel.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/migrate_person_memory_points.py b/plugins/A_memorix/scripts/migrate_person_memory_points.py new file mode 100644 index 00000000..a03a8914 --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_person_memory_points.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import sqlite3 +import sys +from pathlib import Path +from typing import Any, Dict, List + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" + +if str(WORKSPACE_ROOT) not in sys.path: + sys.path.insert(0, str(WORKSPACE_ROOT)) +if str(MAIBOT_ROOT) not in sys.path: + sys.path.insert(0, str(MAIBOT_ROOT)) + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix") + parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") + parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") + parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部") + parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") + return parser.parse_args() + + +def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]: + try: + values = json.loads(raw_value) if raw_value else [] + except Exception: + values = [] + items: List[Dict[str, Any]] = [] + for index, item in enumerate(values): + text = str(item or "").strip() + if not text: + continue + parts = text.split(":") + if len(parts) >= 3: + category = parts[0].strip() + content = ":".join(parts[1:-1]).strip() + weight = parts[-1].strip() + else: + category = "其他" + content = text + weight = "1.0" + if content: + items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"}) + return items + + +async def _main() -> int: + args = _parse_args() + db_path = Path(args.db_path).resolve() + if not db_path.exists(): + print(f"数据库不存在: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + sql = """ + SELECT person_id, person_name, user_nickname, memory_points + FROM person_info + WHERE memory_points IS NOT NULL AND memory_points != '' + ORDER BY id ASC + """ + if int(args.limit or 0) > 0: + sql += " LIMIT ?" + rows = conn.execute(sql, (int(args.limit),)).fetchall() + else: + rows = conn.execute(sql).fetchall() + conn.close() + + preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows) + print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}") + if args.dry_run: + for row in rows[:5]: + print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}") + return 0 + + kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) + await kernel.initialize() + migrated = 0 + skipped = 0 + for row in rows: + person_id = str(row["person_id"] or "").strip() + if not person_id: + continue + display_name = str(row["person_name"] or row["user_nickname"] or "").strip() + for item in _parse_memory_points(row["memory_points"]): + result: Dict[str, Any] = await kernel.ingest_text( + external_id=f"person_memory:{person_id}:{item['index']}", + source_type="person_fact", + text=f"[{item['category']}] {item['content']}", + person_ids=[person_id], + tags=[item["category"]], + entities=[person_id, display_name] if display_name else [person_id], + metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name}, + ) + if result.get("stored_ids"): + migrated += 1 + else: + skipped += 1 + + print(f"迁移完成: migrated={migrated} skipped={skipped}") + print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) + kernel.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(_main())) From bd84e500e193449338138239bd8c753b3ab9b08f Mon Sep 17 00:00:00 2001 From: DawnARC Date: Wed, 18 Mar 2026 21:35:17 +0800 Subject: [PATCH 02/14] =?UTF-8?q?feat:=E6=96=B0=E5=A2=9E=E8=AE=B0=E5=BF=86?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E3=80=81=E6=A3=80=E7=B4=A2=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E4=B8=8E=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增完整的长期记忆支持及测试:引入中文记忆检索提示词、query_long_term_memory 检索工具、记忆服务与记忆流程服务,以及 WebUI 的记忆路由。新增大规模测试套件(包括单元测试与基准/在线测试),覆盖聊天历史摘要、知识获取器、事件(episode)生成、写回机制以及用户画像检索等功能。 更新多个模块以集成记忆检索能力(包括 knowledge fetcher、chat summarizer、memory_retrieval、person_info、config/legacy 迁移以及 WebUI 路由),并移除遗留的 lpmm 知识模块。这些变更完成了记忆运行时的接入,同时为基准测试提供嵌入适配器的 mock,并支持新测试与工具所需的导入与 episode 处理流程。 --- prompts/zh-CN/memory_get_knowledge.prompt | 26 + ..._retrieval_react_prompt_head_memory.prompt | 34 + ...t_chat_history_summarizer_memory_import.py | 148 ++ .../A_memorix_test/test_knowledge_fetcher.py | 127 ++ .../test_legacy_config_migration.py | 35 + .../test_long_novel_memory_benchmark.py | 691 ++++++++ .../test_long_novel_memory_benchmark_live.py | 343 ++++ .../test_memory_flow_service.py | 138 ++ pytests/A_memorix_test/test_memory_service.py | 281 ++++ .../test_person_memory_writeback.py | 81 + .../test_query_long_term_memory_tool.py | 184 +++ ...real_dialogue_business_flow_integration.py | 335 ++++ .../test_real_dialogue_business_flow_live.py | 312 ++++ pytests/webui/test_memory_routes.py | 279 ++++ src/bw_learner/jargon_explainer_old.py | 2 +- src/bw_learner/learner_utils_old.py | 26 + src/chat/brain_chat/PFC/conversation.py | 2 +- .../brain_chat/PFC/pfc_KnowledgeFetcher.py | 72 +- src/chat/knowledge/__init__.py | 90 -- src/chat/knowledge/lpmm_ops.py | 380 ----- src/chat/message_receive/bot.py | 6 + .../message_receive/uni_message_sender.py | 7 + src/chat/replyer/group_generator.py | 66 +- src/config/config.py | 2 +- src/config/legacy_migration.py | 15 + src/config/official_configs.py | 18 + src/main.py | 10 +- src/memory_system/chat_history_summarizer.py | 141 +- src/memory_system/memory_retrieval.py | 4 +- src/memory_system/retrieval_tools/__init__.py | 16 +- .../retrieval_tools/query_long_term_memory.py | 304 ++++ .../retrieval_tools/query_lpmm_knowledge.py | 75 - src/person_info/person_info.py | 250 +-- src/plugin_runtime/capabilities/data.py | 8 +- src/services/memory_flow_service.py | 275 ++++ src/services/memory_service.py | 428 +++++ src/webui/routers/__init__.py | 4 +- src/webui/routers/memory.py | 1395 +++++++++++++++++ src/webui/routes.py | 3 + 39 files changed, 5849 insertions(+), 764 deletions(-) create mode 100644 prompts/zh-CN/memory_get_knowledge.prompt create mode 100644 prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt create mode 100644 pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py create mode 100644 pytests/A_memorix_test/test_knowledge_fetcher.py create mode 100644 pytests/A_memorix_test/test_legacy_config_migration.py create mode 100644 pytests/A_memorix_test/test_long_novel_memory_benchmark.py create mode 100644 pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py create mode 100644 pytests/A_memorix_test/test_memory_flow_service.py create mode 100644 pytests/A_memorix_test/test_memory_service.py create mode 100644 pytests/A_memorix_test/test_person_memory_writeback.py create mode 100644 pytests/A_memorix_test/test_query_long_term_memory_tool.py create mode 100644 pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py create mode 100644 pytests/A_memorix_test/test_real_dialogue_business_flow_live.py create mode 100644 pytests/webui/test_memory_routes.py delete mode 100644 src/chat/knowledge/__init__.py delete mode 100644 src/chat/knowledge/lpmm_ops.py create mode 100644 src/memory_system/retrieval_tools/query_long_term_memory.py delete mode 100644 src/memory_system/retrieval_tools/query_lpmm_knowledge.py create mode 100644 src/services/memory_flow_service.py create mode 100644 src/services/memory_service.py create mode 100644 src/webui/routers/memory.py diff --git a/prompts/zh-CN/memory_get_knowledge.prompt b/prompts/zh-CN/memory_get_knowledge.prompt new file mode 100644 index 00000000..aa9e8967 --- /dev/null +++ b/prompts/zh-CN/memory_get_knowledge.prompt @@ -0,0 +1,26 @@ +你是一个专门获取长期记忆的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 内容中是否包含需要查询历史知识或长期记忆的问题 +2. 是否有明确的知识获取指令 + +如果需要使用长期记忆工具,请直接调用函数 `search_long_term_memory`;如果不需要任何工具,直接输出 `No tool needed`。 + +工具模式说明: +- `mode="search"`:普通长期记忆检索,适合查具体事实、偏好、历史对话内容 +- `mode="time"`:按时间范围检索,必须同时提供 `time_expression` +- `mode="episode"`:按事件/情节检索,适合查“那次经历”“那件事的经过” +- `mode="aggregate"`:综合检索,适合“整体回忆一下”“把相关线索综合找出来” + +优先规则: +- 问“某段时间发生了什么”:优先 `time` +- 问“某次事件/某段经历”:优先 `episode` +- 问“整体情况/最近发生过什么”:优先 `aggregate` +- 问单点事实:优先 `search` + +`time_expression` 可用表达: +- `今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天` +- 或绝对时间:`2026/03/18`、`2026/03/18 09:30` diff --git a/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt b/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt new file mode 100644 index 00000000..91ea6eab --- /dev/null +++ b/prompts/zh-CN/memory_retrieval_react_prompt_head_memory.prompt @@ -0,0 +1,34 @@ +你的名字是{bot_name}。现在是{time_now}。 +你正在参与聊天,你需要搜集信息来帮助你进行回复。 +重要,这是当前聊天记录: +{chat_history} +聊天记录结束 + +已收集的信息: +{collected_info} + +- 你可以对查询思路给出简短的思考:思考要简短,直接切入要点 +- 思考完毕后,使用工具 + +**工具说明:** +- 如果涉及过往事件、历史对话、用户长期偏好或某段时间发生的事件,可以使用长期记忆查询工具 +- 如果遇到不熟悉的词语、缩写、黑话或网络用语,可以使用query_words工具查询其含义 +- 你必须使用tool,如果需要查询你必须给出使用什么工具进行查询 +- 当你决定结束查询时,必须调用return_information工具返回总结信息并结束查询 + +长期记忆工具 `search_long_term_memory` 支持以下模式: +- `mode="search"`:普通事实/偏好/历史内容检索。适合问“她喜欢什么”“我们之前讨论过什么”。 +- `mode="time"`:按时间范围检索。适合问“昨天发生了什么”“最近7天有哪些相关记忆”。 +- `mode="episode"`:按事件/情节检索。适合问“那次灯塔停电的经过是什么”“关于某次经历还有什么”。 +- `mode="aggregate"`:综合检索。适合问“帮我整体回忆一下这个人最近的情况”“把相关线索综合找出来”。 + +模式选择建议: +- 问单点事实、偏好、人设、具体信息:优先 `search` +- 问某段时间发生了什么:优先 `time` +- 问某次事件、某段经历、某个剧情片段:优先 `episode` +- 问整体回忆、综合找线索、总结最近发生的事:优先 `aggregate` + +时间模式要求: +- 使用 `mode="time"` 时,必须填写 `time_expression` +- 可用时间表达包括:`今天`、`昨天`、`前天`、`本周`、`上周`、`本月`、`上月`、`最近7天` +- 也可以使用绝对时间:`2026/03/18`、`2026/03/18 09:30` diff --git a/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py new file mode 100644 index 00000000..0f084ece --- /dev/null +++ b/pytests/A_memorix_test/test_chat_history_summarizer_memory_import.py @@ -0,0 +1,148 @@ +from types import SimpleNamespace + +import pytest + +from src.memory_system import chat_history_summarizer as summarizer_module + + +def _build_summarizer() -> summarizer_module.ChatHistorySummarizer: + summarizer = summarizer_module.ChatHistorySummarizer.__new__(summarizer_module.ChatHistorySummarizer) + summarizer.session_id = "session-1" + summarizer.log_prefix = "[session-1]" + return summarizer + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_uses_summary_payload(monkeypatch): + calls = [] + summarizer = _build_summarizer() + + async def fake_ingest_summary(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), + ) + monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=8))) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._import_to_long_term_memory( + record_id=1, + theme="旅行计划", + summary="我们讨论了春游安排", + participants=["Alice", "Bob"], + start_time=1.0, + end_time=2.0, + original_text="long text", + ) + + assert len(calls) == 1 + payload = calls[0] + assert payload["external_id"] == "chat_history:1" + assert payload["chat_id"] == "session-1" + assert payload["participants"] == ["Alice", "Bob"] + assert payload["respect_filter"] is True + assert payload["user_id"] == "user-1" + assert payload["group_id"] == "" + assert "主题:旅行计划" in payload["text"] + assert "概括:我们讨论了春游安排" in payload["text"] + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_falls_back_when_content_empty(monkeypatch): + summarizer = _build_summarizer() + fallback_calls = [] + + async def fake_fallback(**kwargs): + fallback_calls.append(kwargs) + + summarizer._fallback_import_to_long_term_memory = fake_fallback + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="")), + ) + + await summarizer._import_to_long_term_memory( + record_id=2, + theme="", + summary="", + participants=[], + start_time=10.0, + end_time=20.0, + original_text="raw chat", + ) + + assert len(fallback_calls) == 1 + assert fallback_calls[0]["record_id"] == 2 + assert fallback_calls[0]["original_text"] == "raw chat" + + +@pytest.mark.asyncio +async def test_import_to_long_term_memory_falls_back_when_ingest_fails(monkeypatch): + summarizer = _build_summarizer() + fallback_calls = [] + + async def fake_ingest_summary(**kwargs): + return SimpleNamespace(success=False, detail="boom", stored_ids=[]) + + async def fake_fallback(**kwargs): + fallback_calls.append(kwargs) + + summarizer._fallback_import_to_long_term_memory = fake_fallback + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-1", group_id="group-1")), + ) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._import_to_long_term_memory( + record_id=3, + theme="电影", + summary="聊了电影推荐", + participants=["Alice"], + start_time=3.0, + end_time=4.0, + original_text="raw", + ) + + assert len(fallback_calls) == 1 + assert fallback_calls[0]["theme"] == "电影" + + +@pytest.mark.asyncio +async def test_fallback_import_to_long_term_memory_sets_generate_from_chat(monkeypatch): + calls = [] + summarizer = _build_summarizer() + + async def fake_ingest_summary(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="chat_filtered", stored_ids=[]) + + monkeypatch.setattr( + summarizer_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(user_id="user-2", group_id="group-2")), + ) + monkeypatch.setattr(summarizer_module, "global_config", SimpleNamespace(memory=SimpleNamespace(chat_history_topic_check_message_threshold=12))) + monkeypatch.setattr("src.services.memory_service.memory_service.ingest_summary", fake_ingest_summary) + + await summarizer._fallback_import_to_long_term_memory( + record_id=4, + theme="工作", + participants=["Alice"], + start_time=5.0, + end_time=6.0, + original_text="a" * 128, + ) + + assert len(calls) == 1 + metadata = calls[0]["metadata"] + assert metadata["generate_from_chat"] is True + assert metadata["context_length"] == 12 + assert calls[0]["respect_filter"] is True + diff --git a/pytests/A_memorix_test/test_knowledge_fetcher.py b/pytests/A_memorix_test/test_knowledge_fetcher.py new file mode 100644 index 00000000..4fb4e564 --- /dev/null +++ b/pytests/A_memorix_test/test_knowledge_fetcher.py @@ -0,0 +1,127 @@ +from types import SimpleNamespace + +import pytest + +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.services.memory_service import MemoryHit, MemorySearchResult + + +def test_knowledge_fetcher_resolves_private_memory_context(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", + ) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + + assert fetcher._resolve_private_memory_context() == { + "chat_id": "stream-1", + "person_id": "Alice:qq:42", + "user_id": "42", + "group_id": "", + } + + +@pytest.mark.asyncio +async def test_knowledge_fetcher_memory_get_knowledge_uses_memory_service(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: f"{person_name}:{platform}:{user_id}", + ) + + calls = [] + + async def fake_search(query: str, **kwargs): + calls.append((query, kwargs)) + return MemorySearchResult(summary="", hits=[MemoryHit(content="她喜欢猫", source="person_fact:qq:42")], filtered=False) + + monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + result = await fetcher._memory_get_knowledge("她喜欢什么") + + assert "1. 她喜欢猫" in result + assert calls == [ + ( + "她喜欢什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "Alice:qq:42", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ) + ] + + +@pytest.mark.asyncio +async def test_knowledge_fetcher_falls_back_to_chat_scope_when_person_scope_misses(monkeypatch): + monkeypatch.setattr(knowledge_module, "LLMRequest", lambda *args, **kwargs: object()) + monkeypatch.setattr( + knowledge_module, + "_chat_manager", + SimpleNamespace(get_session_by_session_id=lambda session_id: SimpleNamespace(platform="qq", user_id="42", group_id="")), + ) + monkeypatch.setattr( + knowledge_module, + "resolve_person_id_for_memory", + lambda *, person_name, platform, user_id: "person-1", + ) + + calls = [] + + async def fake_search(query: str, **kwargs): + calls.append((query, kwargs)) + if kwargs.get("person_id"): + return MemorySearchResult(summary="", hits=[], filtered=False) + return MemorySearchResult(summary="", hits=[MemoryHit(content="她计划去杭州音乐节", source="chat_summary:stream-1")], filtered=False) + + monkeypatch.setattr(knowledge_module.memory_service, "search", fake_search) + + fetcher = knowledge_module.KnowledgeFetcher(private_name="Alice", stream_id="stream-1") + result = await fetcher._memory_get_knowledge("Alice 最近在忙什么") + + assert "杭州音乐节" in result + assert calls == [ + ( + "Alice 最近在忙什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "person-1", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ), + ( + "Alice 最近在忙什么", + { + "limit": 5, + "mode": "search", + "chat_id": "stream-1", + "person_id": "", + "user_id": "42", + "group_id": "", + "respect_filter": True, + }, + ), + ] diff --git a/pytests/A_memorix_test/test_legacy_config_migration.py b/pytests/A_memorix_test/test_legacy_config_migration.py new file mode 100644 index 00000000..c382e4f3 --- /dev/null +++ b/pytests/A_memorix_test/test_legacy_config_migration.py @@ -0,0 +1,35 @@ +from src.config.legacy_migration import try_migrate_legacy_bot_config_dict + + +def test_legacy_learning_list_with_numeric_fourth_column_is_migrated(): + payload = { + "expression": { + "learning_list": [ + ["qq:123456:group", "enable", "disable", "0.5"], + ["", "disable", "enable", "0.1"], + ] + } + } + + result = try_migrate_legacy_bot_config_dict(payload) + + assert result.migrated is True + assert "expression.learning_list" in result.reason + assert result.data["expression"]["learning_list"] == [ + { + "platform": "qq", + "item_id": "123456", + "rule_type": "group", + "use_expression": True, + "enable_learning": False, + "enable_jargon_learning": False, + }, + { + "platform": "", + "item_id": "", + "rule_type": "group", + "use_expression": False, + "enable_learning": True, + "enable_jargon_learning": False, + }, + ] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py new file mode 100644 index 00000000..3c3e4090 --- /dev/null +++ b/pytests/A_memorix_test/test_long_novel_memory_benchmark.py @@ -0,0 +1,691 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List + +import numpy as np +import pytest +import pytest_asyncio + +from A_memorix.core.runtime import sdk_memory_kernel as kernel_module +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.memory_system.retrieval_tools.query_long_term_memory import query_long_term_memory +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import MemorySearchResult, memory_service + + +DATA_FILE = Path(__file__).parent / "data" / "benchmarks" / "long_novel_memory_benchmark.json" +REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_report.json" + + +def _load_benchmark_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _FakeEmbeddingAdapter: + def __init__(self, dimension: int = 32) -> None: + self.dimension = dimension + + async def _detect_dimension(self) -> int: + return self.dimension + + async def encode(self, texts, dimensions=None): + dim = int(dimensions or self.dimension) + if isinstance(texts, str): + sequence = [texts] + single = True + else: + sequence = list(texts) + single = False + + rows = [] + for text in sequence: + vec = np.zeros(dim, dtype=np.float32) + for ch in str(text or ""): + code = ord(ch) + vec[code % dim] += 1.0 + vec[(code * 7) % dim] += 0.5 + if not vec.any(): + vec[0] = 1.0 + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + rows.append(vec) + payload = np.vstack(rows) + return payload[0] if single else payload + + +class _KnownPerson: + def __init__(self, person_id: str, registry: Dict[str, str], reverse_registry: Dict[str, str]) -> None: + self.person_id = person_id + self.is_known = person_id in reverse_registry + self.person_name = reverse_registry.get(person_id, "") + self._registry = registry + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, max_rounds: int = 200, sleep_seconds: float = 0.05) -> Dict[str, Any]: + for _ in range(max_rounds): + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(max(0.01, float(sleep_seconds))) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result: MemorySearchResult) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +def _keyword_hits(text: str, keywords: List[str]) -> int: + haystack = str(text or "") + return sum(1 for keyword in keywords if keyword in haystack) + + +def _keyword_recall(text: str, keywords: List[str]) -> float: + if not keywords: + return 1.0 + return _keyword_hits(text, keywords) / float(len(keywords)) + + +def _hit_blob(hit) -> str: + meta = hit.metadata if isinstance(hit.metadata, dict) else {} + return "\n".join( + [ + str(hit.content or ""), + str(hit.title or ""), + str(hit.source or ""), + json.dumps(meta, ensure_ascii=False), + ] + ) + + +def _first_relevant_rank(search_result: MemorySearchResult, keywords: List[str], minimum_keyword_hits: int) -> int: + for index, hit in enumerate(search_result.hits[:5], start=1): + if _keyword_hits(_hit_blob(hit), keywords) >= max(1, int(minimum_keyword_hits or len(keywords))): + return index + return 0 + + +def _episode_blob_from_items(items: List[Dict[str, Any]]) -> str: + return "\n".join( + ( + f"{item.get('title', '')}\n" + f"{item.get('summary', '')}\n" + f"{json.dumps(item.get('keywords', []), ensure_ascii=False)}\n" + f"{json.dumps(item.get('participants', []), ensure_ascii=False)}" + ) + for item in items + ) + + +def _episode_blob_from_hits(search_result: MemorySearchResult) -> str: + chunks = [] + for hit in search_result.hits: + meta = hit.metadata if isinstance(hit.metadata, dict) else {} + chunks.append( + "\n".join( + [ + str(hit.title or ""), + str(hit.content or ""), + json.dumps(meta.get("keywords", []) or [], ensure_ascii=False), + json.dumps(meta.get("participants", []) or [], ensure_ascii=False), + ] + ) + ) + return "\n".join(chunks) + + +async def _evaluate_episode_generation(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + episode_source = f"chat_summary:{session_id}" + payload = await memory_service.episode_admin( + action="query", + source=episode_source, + limit=20, + ) + items = payload.get("items") or [] + blob = _episode_blob_from_items(items) + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in episode_cases: + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(items), + "top_episode": items[0] if items else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "episode_count": len(items), + "reports": reports, + } + + +async def _evaluate_episode_admin_query(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + episode_source = f"chat_summary:{session_id}" + + for case in episode_cases: + payload = await memory_service.episode_admin( + action="query", + source=episode_source, + query=case["query"], + limit=5, + ) + items = payload.get("items") or [] + blob = "\n".join( + f"{item.get('title', '')}\n{item.get('summary', '')}\n{json.dumps(item.get('keywords', []), ensure_ascii=False)}" + for item in items + ) + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(items) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(items), + "top_episode": items[0] if items else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +async def _evaluate_episode_search_mode(*, session_id: str, episode_cases: List[Dict[str, Any]]) -> Dict[str, Any]: + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in episode_cases: + result = await memory_service.search( + case["query"], + mode="episode", + chat_id=session_id, + respect_filter=False, + limit=5, + ) + blob = _episode_blob_from_hits(result) + recall = _keyword_recall(blob, case["expected_keywords"]) + success = bool(result.hits) and recall >= float(case.get("minimum_keyword_recall", 1.0)) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "episode_count": len(result.hits), + "top_episode": result.hits[0].to_dict() if result.hits else None, + } + ) + + total = max(1, len(episode_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +async def _evaluate_tool_modes(*, session_id: str, dataset: Dict[str, Any]) -> Dict[str, Any]: + search_case = dataset["search_cases"][0] + episode_case = dataset["episode_cases"][0] + aggregate_case = dataset["knowledge_fetcher_cases"][0] + tool_cases = [ + { + "name": "search", + "kwargs": { + "query": "蓝漆铁盒 北塔木梯", + "mode": "search", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": ["蓝漆铁盒", "北塔木梯", "海潮图"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "time", + "kwargs": { + "query": "蓝漆铁盒 北塔", + "mode": "time", + "chat_id": session_id, + "limit": 5, + "time_expression": "最近7天", + }, + "expected_keywords": ["蓝漆铁盒", "北塔木梯"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "episode", + "kwargs": { + "query": episode_case["query"], + "mode": "episode", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": episode_case["expected_keywords"], + "minimum_keyword_recall": 0.67, + }, + { + "name": "aggregate", + "kwargs": { + "query": aggregate_case["query"], + "mode": "aggregate", + "chat_id": session_id, + "limit": 5, + }, + "expected_keywords": aggregate_case["expected_keywords"], + "minimum_keyword_recall": 0.67, + }, + ] + reports: List[Dict[str, Any]] = [] + success_rate = 0.0 + keyword_recall = 0.0 + + for case in tool_cases: + text = await query_long_term_memory(**case["kwargs"]) + recall = _keyword_recall(text, case["expected_keywords"]) + success = ( + "失败" not in text + and "无法解析" not in text + and "未找到" not in text + and recall >= float(case["minimum_keyword_recall"]) + ) + success_rate += 1.0 if success else 0.0 + keyword_recall += recall + reports.append( + { + "name": case["name"], + "success": success, + "keyword_recall": recall, + "preview": text[:320], + } + ) + + total = max(1, len(tool_cases)) + return { + "success_rate": round(success_rate / total, 4), + "keyword_recall": round(keyword_recall / total, 4), + "reports": reports, + } + + +@pytest_asyncio.fixture +async def benchmark_env(monkeypatch, tmp_path): + dataset = _load_benchmark_fixture() + session_cfg = dataset["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} + reverse_registry = {value: key for key, value in registry.items()} + + monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) + + async def fake_self_check(**kwargs): + return {"ok": True, "message": "ok", "encoded_dimension": 32} + + monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) + monkeypatch.setattr( + person_info_module, + "Person", + lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), + ) + + data_dir = (tmp_path / "a_memorix_benchmark_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "dataset": dataset, + "kernel": kernel, + "session": session, + "person_registry": registry, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_long_novel_memory_benchmark(benchmark_env): + dataset = benchmark_env["dataset"] + session_id = benchmark_env["session"].session_id + + created = await memory_service.import_admin( + action="create_paste", + name="long_novel_memory_benchmark.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(dataset["import_payload"], ensure_ascii=False), + ) + assert created["success"] is True + + import_detail = await _wait_for_import_task(created["task"]["task_id"]) + assert import_detail["task"]["status"] == "completed" + + for record in dataset["chat_history_records"]: + summarizer = summarizer_module.ChatHistorySummarizer(session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + for payload in dataset["person_writebacks"]: + await person_info_module.store_person_memory_from_answer( + payload["person_name"], + payload["memory_content"], + session_id, + ) + + await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) + + search_case_reports: List[Dict[str, Any]] = [] + search_accuracy_at_1 = 0.0 + search_recall_at_5 = 0.0 + search_precision_at_5 = 0.0 + search_mrr = 0.0 + search_keyword_recall = 0.0 + + for case in dataset["search_cases"]: + result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) + joined = _join_hit_content(result) + rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) + relevant_hits = sum( + 1 + for hit in result.hits[:5] + if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) + ) + keyword_recall = _keyword_recall(joined, case["expected_keywords"]) + search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 + search_recall_at_5 += 1.0 if rank > 0 else 0.0 + search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) + search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 + search_keyword_recall += keyword_recall + search_case_reports.append( + { + "query": case["query"], + "rank_of_first_relevant": rank, + "relevant_hits_top5": relevant_hits, + "keyword_recall_top5": keyword_recall, + "top_hit": result.hits[0].to_dict() if result.hits else None, + } + ) + + search_total = max(1, len(dataset["search_cases"])) + + writeback_reports: List[Dict[str, Any]] = [] + writeback_success_rate = 0.0 + writeback_keyword_recall = 0.0 + for payload in dataset["person_writebacks"]: + query = " ".join(payload["expected_keywords"]) + result = await memory_service.search( + query, + mode="search", + chat_id=session_id, + person_id=payload["person_id"], + respect_filter=False, + limit=5, + ) + joined = _join_hit_content(result) + recall = _keyword_recall(joined, payload["expected_keywords"]) + success = bool(result.hits) and recall >= 0.67 + writeback_success_rate += 1.0 if success else 0.0 + writeback_keyword_recall += recall + writeback_reports.append( + { + "person_id": payload["person_id"], + "success": success, + "keyword_recall": recall, + "hit_count": len(result.hits), + } + ) + writeback_total = max(1, len(dataset["person_writebacks"])) + + knowledge_reports: List[Dict[str, Any]] = [] + knowledge_success_rate = 0.0 + knowledge_keyword_recall = 0.0 + fetcher = knowledge_module.KnowledgeFetcher( + private_name=dataset["session"]["display_name"], + stream_id=session_id, + ) + for case in dataset["knowledge_fetcher_cases"]: + knowledge_text, _ = await fetcher.fetch(case["query"], []) + recall = _keyword_recall(knowledge_text, case["expected_keywords"]) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) + knowledge_success_rate += 1.0 if success else 0.0 + knowledge_keyword_recall += recall + knowledge_reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "preview": knowledge_text[:300], + } + ) + knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) + + profile_reports: List[Dict[str, Any]] = [] + profile_success_rate = 0.0 + profile_keyword_recall = 0.0 + profile_evidence_rate = 0.0 + for case in dataset["profile_cases"]: + profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) + recall = _keyword_recall(profile.summary, case["expected_keywords"]) + has_evidence = bool(profile.evidence) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence + profile_success_rate += 1.0 if success else 0.0 + profile_keyword_recall += recall + profile_evidence_rate += 1.0 if has_evidence else 0.0 + profile_reports.append( + { + "person_id": case["person_id"], + "success": success, + "keyword_recall": recall, + "evidence_count": len(profile.evidence), + "summary_preview": profile.summary[:240], + } + ) + profile_total = max(1, len(dataset["profile_cases"])) + + episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_rebuild = await memory_service.episode_admin( + action="rebuild", + source=f"chat_summary:{session_id}", + ) + episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) + + report = { + "dataset": dataset["meta"], + "import": { + "task_id": created["task"]["task_id"], + "status": import_detail["task"]["status"], + "paragraph_count": len(dataset["import_payload"]["paragraphs"]), + }, + "metrics": { + "search": { + "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), + "recall_at_5": round(search_recall_at_5 / search_total, 4), + "precision_at_5": round(search_precision_at_5 / search_total, 4), + "mrr": round(search_mrr / search_total, 4), + "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), + }, + "writeback": { + "success_rate": round(writeback_success_rate / writeback_total, 4), + "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), + }, + "knowledge_fetcher": { + "success_rate": round(knowledge_success_rate / knowledge_total, 4), + "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), + }, + "profile": { + "success_rate": round(profile_success_rate / profile_total, 4), + "keyword_recall": round(profile_keyword_recall / profile_total, 4), + "evidence_rate": round(profile_evidence_rate / profile_total, 4), + }, + "tool_modes": { + "success_rate": tool_modes["success_rate"], + "keyword_recall": tool_modes["keyword_recall"], + }, + "episode_generation_auto": { + "success_rate": episode_generation_auto["success_rate"], + "keyword_recall": episode_generation_auto["keyword_recall"], + "episode_count": episode_generation_auto["episode_count"], + }, + "episode_generation_after_rebuild": { + "success_rate": episode_generation_after_rebuild["success_rate"], + "keyword_recall": episode_generation_after_rebuild["keyword_recall"], + "episode_count": episode_generation_after_rebuild["episode_count"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_admin_query_auto": { + "success_rate": episode_admin_query_auto["success_rate"], + "keyword_recall": episode_admin_query_auto["keyword_recall"], + }, + "episode_admin_query_after_rebuild": { + "success_rate": episode_admin_query_after_rebuild["success_rate"], + "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_search_mode_auto": { + "success_rate": episode_search_mode_auto["success_rate"], + "keyword_recall": episode_search_mode_auto["keyword_recall"], + }, + "episode_search_mode_after_rebuild": { + "success_rate": episode_search_mode_after_rebuild["success_rate"], + "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + }, + "cases": { + "search": search_case_reports, + "writeback": writeback_reports, + "knowledge_fetcher": knowledge_reports, + "profile": profile_reports, + "tool_modes": tool_modes["reports"], + "episode_generation_auto": episode_generation_auto["reports"], + "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], + "episode_admin_query_auto": episode_admin_query_auto["reports"], + "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], + "episode_search_mode_auto": episode_search_mode_auto["reports"], + "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], + }, + } + + REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) + REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) + + assert report["import"]["status"] == "completed" + assert report["metrics"]["search"]["accuracy_at_1"] >= 0.35 + assert report["metrics"]["search"]["recall_at_5"] >= 0.6 + assert report["metrics"]["search"]["keyword_recall_at_5"] >= 0.8 + assert report["metrics"]["writeback"]["success_rate"] >= 0.66 + assert report["metrics"]["knowledge_fetcher"]["success_rate"] >= 0.66 + assert report["metrics"]["knowledge_fetcher"]["keyword_recall"] >= 0.75 + assert report["metrics"]["profile"]["success_rate"] >= 0.66 + assert report["metrics"]["profile"]["evidence_rate"] >= 1.0 + assert report["metrics"]["tool_modes"]["success_rate"] >= 0.75 + assert report["metrics"]["episode_generation_after_rebuild"]["rebuild_success"] is True + assert report["metrics"]["episode_generation_after_rebuild"]["episode_count"] >= report["metrics"]["episode_generation_auto"]["episode_count"] diff --git a/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py new file mode 100644 index 00000000..1dad0795 --- /dev/null +++ b/pytests/A_memorix_test/test_long_novel_memory_benchmark_live.py @@ -0,0 +1,343 @@ +from __future__ import annotations + +import json +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict, List + +import pytest +import pytest_asyncio + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel +from pytests.test_long_novel_memory_benchmark import ( + _evaluate_episode_admin_query, + _evaluate_episode_generation, + _evaluate_episode_search_mode, + _evaluate_tool_modes, + _KernelBackedRuntimeManager, + _KnownPerson, + _first_relevant_rank, + _hit_blob, + _join_hit_content, + _keyword_hits, + _keyword_recall, + _load_benchmark_fixture, + _wait_for_import_task, +) +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +pytestmark = pytest.mark.skipif( + os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", + reason="需要显式开启真实 external embedding benchmark", +) + +REPORT_FILE = Path(__file__).parent / "data" / "benchmarks" / "results" / "long_novel_memory_benchmark_live_report.json" + + +@pytest_asyncio.fixture +async def benchmark_live_env(monkeypatch, tmp_path): + dataset = _load_benchmark_fixture() + session_cfg = dataset["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + registry = {item["person_name"]: item["person_id"] for item in dataset["person_writebacks"]} + reverse_registry = {value: key for key, value in registry.items()} + + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: registry.get(str(person_name or "").strip(), "")) + monkeypatch.setattr( + person_info_module, + "Person", + lambda person_id: _KnownPerson(person_id=str(person_id or ""), registry=registry, reverse_registry=reverse_registry), + ) + + data_dir = (tmp_path / "a_memorix_live_benchmark_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "dataset": dataset, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_long_novel_memory_benchmark_live(benchmark_live_env): + dataset = benchmark_live_env["dataset"] + session_id = benchmark_live_env["session"].session_id + + self_check = await memory_service.runtime_admin(action="refresh_self_check") + assert self_check["success"] is True + assert self_check["report"]["ok"] is True + + created = await memory_service.import_admin( + action="create_paste", + name="long_novel_memory_benchmark.live.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(dataset["import_payload"], ensure_ascii=False), + ) + assert created["success"] is True + + import_detail = await _wait_for_import_task( + created["task"]["task_id"], + max_rounds=2400, + sleep_seconds=0.25, + ) + assert import_detail["task"]["status"] == "completed" + + for record in dataset["chat_history_records"]: + summarizer = summarizer_module.ChatHistorySummarizer(session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + for payload in dataset["person_writebacks"]: + await person_info_module.store_person_memory_from_answer( + payload["person_name"], + payload["memory_content"], + session_id, + ) + + await memory_service.episode_admin(action="process_pending", limit=100, max_retry=2) + + search_case_reports: List[Dict[str, Any]] = [] + search_accuracy_at_1 = 0.0 + search_recall_at_5 = 0.0 + search_precision_at_5 = 0.0 + search_mrr = 0.0 + search_keyword_recall = 0.0 + for case in dataset["search_cases"]: + result = await memory_service.search(case["query"], mode="search", respect_filter=False, limit=5) + joined = _join_hit_content(result) + rank = _first_relevant_rank(result, case["expected_keywords"], case.get("minimum_keyword_hits", len(case["expected_keywords"]))) + relevant_hits = sum( + 1 + for hit in result.hits[:5] + if _keyword_hits(_hit_blob(hit), case["expected_keywords"]) >= max(1, int(case.get("minimum_keyword_hits", len(case["expected_keywords"])))) + ) + keyword_recall = _keyword_recall(joined, case["expected_keywords"]) + search_accuracy_at_1 += 1.0 if rank == 1 else 0.0 + search_recall_at_5 += 1.0 if rank > 0 else 0.0 + search_precision_at_5 += relevant_hits / float(max(1, min(5, len(result.hits)))) + search_mrr += 1.0 / float(rank) if rank > 0 else 0.0 + search_keyword_recall += keyword_recall + search_case_reports.append( + { + "query": case["query"], + "rank_of_first_relevant": rank, + "relevant_hits_top5": relevant_hits, + "keyword_recall_top5": keyword_recall, + "top_hit": result.hits[0].to_dict() if result.hits else None, + } + ) + search_total = max(1, len(dataset["search_cases"])) + + writeback_reports: List[Dict[str, Any]] = [] + writeback_success_rate = 0.0 + writeback_keyword_recall = 0.0 + for payload in dataset["person_writebacks"]: + query = " ".join(payload["expected_keywords"]) + result = await memory_service.search( + query, + mode="search", + chat_id=session_id, + person_id=payload["person_id"], + respect_filter=False, + limit=5, + ) + joined = _join_hit_content(result) + recall = _keyword_recall(joined, payload["expected_keywords"]) + success = bool(result.hits) and recall >= 0.67 + writeback_success_rate += 1.0 if success else 0.0 + writeback_keyword_recall += recall + writeback_reports.append( + { + "person_id": payload["person_id"], + "success": success, + "keyword_recall": recall, + "hit_count": len(result.hits), + } + ) + writeback_total = max(1, len(dataset["person_writebacks"])) + + knowledge_reports: List[Dict[str, Any]] = [] + knowledge_success_rate = 0.0 + knowledge_keyword_recall = 0.0 + fetcher = knowledge_module.KnowledgeFetcher( + private_name=dataset["session"]["display_name"], + stream_id=session_id, + ) + for case in dataset["knowledge_fetcher_cases"]: + knowledge_text, _ = await fetcher.fetch(case["query"], []) + recall = _keyword_recall(knowledge_text, case["expected_keywords"]) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) + knowledge_success_rate += 1.0 if success else 0.0 + knowledge_keyword_recall += recall + knowledge_reports.append( + { + "query": case["query"], + "success": success, + "keyword_recall": recall, + "preview": knowledge_text[:300], + } + ) + knowledge_total = max(1, len(dataset["knowledge_fetcher_cases"])) + + profile_reports: List[Dict[str, Any]] = [] + profile_success_rate = 0.0 + profile_keyword_recall = 0.0 + profile_evidence_rate = 0.0 + for case in dataset["profile_cases"]: + profile = await memory_service.get_person_profile(case["person_id"], chat_id=session_id) + recall = _keyword_recall(profile.summary, case["expected_keywords"]) + has_evidence = bool(profile.evidence) + success = recall >= float(case.get("minimum_keyword_recall", 1.0)) and has_evidence + profile_success_rate += 1.0 if success else 0.0 + profile_keyword_recall += recall + profile_evidence_rate += 1.0 if has_evidence else 0.0 + profile_reports.append( + { + "person_id": case["person_id"], + "success": success, + "keyword_recall": recall, + "evidence_count": len(profile.evidence), + "summary_preview": profile.summary[:240], + } + ) + profile_total = max(1, len(dataset["profile_cases"])) + + episode_generation_auto = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_auto = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_auto = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_rebuild = await memory_service.episode_admin( + action="rebuild", + source=f"chat_summary:{session_id}", + ) + episode_generation_after_rebuild = await _evaluate_episode_generation(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_admin_query_after_rebuild = await _evaluate_episode_admin_query(session_id=session_id, episode_cases=dataset["episode_cases"]) + episode_search_mode_after_rebuild = await _evaluate_episode_search_mode(session_id=session_id, episode_cases=dataset["episode_cases"]) + tool_modes = await _evaluate_tool_modes(session_id=session_id, dataset=dataset) + + report = { + "dataset": dataset["meta"], + "runtime_self_check": self_check["report"], + "import": { + "task_id": created["task"]["task_id"], + "status": import_detail["task"]["status"], + "paragraph_count": len(dataset["import_payload"]["paragraphs"]), + }, + "metrics": { + "search": { + "accuracy_at_1": round(search_accuracy_at_1 / search_total, 4), + "recall_at_5": round(search_recall_at_5 / search_total, 4), + "precision_at_5": round(search_precision_at_5 / search_total, 4), + "mrr": round(search_mrr / search_total, 4), + "keyword_recall_at_5": round(search_keyword_recall / search_total, 4), + }, + "writeback": { + "success_rate": round(writeback_success_rate / writeback_total, 4), + "keyword_recall": round(writeback_keyword_recall / writeback_total, 4), + }, + "knowledge_fetcher": { + "success_rate": round(knowledge_success_rate / knowledge_total, 4), + "keyword_recall": round(knowledge_keyword_recall / knowledge_total, 4), + }, + "profile": { + "success_rate": round(profile_success_rate / profile_total, 4), + "keyword_recall": round(profile_keyword_recall / profile_total, 4), + "evidence_rate": round(profile_evidence_rate / profile_total, 4), + }, + "tool_modes": { + "success_rate": tool_modes["success_rate"], + "keyword_recall": tool_modes["keyword_recall"], + }, + "episode_generation_auto": { + "success_rate": episode_generation_auto["success_rate"], + "keyword_recall": episode_generation_auto["keyword_recall"], + "episode_count": episode_generation_auto["episode_count"], + }, + "episode_generation_after_rebuild": { + "success_rate": episode_generation_after_rebuild["success_rate"], + "keyword_recall": episode_generation_after_rebuild["keyword_recall"], + "episode_count": episode_generation_after_rebuild["episode_count"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_admin_query_auto": { + "success_rate": episode_admin_query_auto["success_rate"], + "keyword_recall": episode_admin_query_auto["keyword_recall"], + }, + "episode_admin_query_after_rebuild": { + "success_rate": episode_admin_query_after_rebuild["success_rate"], + "keyword_recall": episode_admin_query_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + "episode_search_mode_auto": { + "success_rate": episode_search_mode_auto["success_rate"], + "keyword_recall": episode_search_mode_auto["keyword_recall"], + }, + "episode_search_mode_after_rebuild": { + "success_rate": episode_search_mode_after_rebuild["success_rate"], + "keyword_recall": episode_search_mode_after_rebuild["keyword_recall"], + "rebuild_success": bool(episode_rebuild.get("success", False)), + }, + }, + "cases": { + "search": search_case_reports, + "writeback": writeback_reports, + "knowledge_fetcher": knowledge_reports, + "profile": profile_reports, + "tool_modes": tool_modes["reports"], + "episode_generation_auto": episode_generation_auto["reports"], + "episode_generation_after_rebuild": episode_generation_after_rebuild["reports"], + "episode_admin_query_auto": episode_admin_query_auto["reports"], + "episode_admin_query_after_rebuild": episode_admin_query_after_rebuild["reports"], + "episode_search_mode_auto": episode_search_mode_auto["reports"], + "episode_search_mode_after_rebuild": episode_search_mode_after_rebuild["reports"], + }, + } + + REPORT_FILE.parent.mkdir(parents=True, exist_ok=True) + REPORT_FILE.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") + print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) + + assert report["import"]["status"] == "completed" + assert report["runtime_self_check"]["ok"] is True diff --git a/pytests/A_memorix_test/test_memory_flow_service.py b/pytests/A_memorix_test/test_memory_flow_service.py new file mode 100644 index 00000000..2d35e837 --- /dev/null +++ b/pytests/A_memorix_test/test_memory_flow_service.py @@ -0,0 +1,138 @@ +from types import SimpleNamespace + +import pytest + +from src.services import memory_flow_service as memory_flow_module + + +@pytest.mark.asyncio +async def test_long_term_memory_session_manager_reuses_single_summarizer(monkeypatch): + starts: list[str] = [] + summarizers: list[object] = [] + + class FakeSummarizer: + def __init__(self, session_id: str): + self.session_id = session_id + summarizers.append(self) + + async def start(self): + starts.append(self.session_id) + + async def stop(self): + starts.append(f"stop:{self.session_id}") + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), + ) + monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) + + manager = memory_flow_module.LongTermMemorySessionManager() + message = SimpleNamespace(session_id="session-1") + + await manager.on_message(message) + await manager.on_message(message) + + assert len(summarizers) == 1 + assert starts == ["session-1"] + + +@pytest.mark.asyncio +async def test_long_term_memory_session_manager_shutdown_stops_all(monkeypatch): + stopped: list[str] = [] + + class FakeSummarizer: + def __init__(self, session_id: str): + self.session_id = session_id + + async def start(self): + return None + + async def stop(self): + stopped.append(self.session_id) + + monkeypatch.setattr( + memory_flow_module, + "global_config", + SimpleNamespace(memory=SimpleNamespace(long_term_auto_summary_enabled=True)), + ) + monkeypatch.setattr(memory_flow_module, "ChatHistorySummarizer", FakeSummarizer) + + manager = memory_flow_module.LongTermMemorySessionManager() + await manager.on_message(SimpleNamespace(session_id="session-a")) + await manager.on_message(SimpleNamespace(session_id="session-b")) + await manager.shutdown() + + assert stopped == ["session-a", "session-b"] + + +def test_person_fact_parse_fact_list_deduplicates_and_filters_short_items(): + raw = '["他喜欢猫", "他喜欢猫", "好", "", "他会弹吉他"]' + + result = memory_flow_module.PersonFactWritebackService._parse_fact_list(raw) + + assert result == ["他喜欢猫", "他会弹吉他"] + + +def test_person_fact_looks_ephemeral_detects_short_chitchat(): + assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("哈哈") + assert memory_flow_module.PersonFactWritebackService._looks_ephemeral("好的?") + assert not memory_flow_module.PersonFactWritebackService._looks_ephemeral("她最近在学法语和钢琴") + + +def test_person_fact_resolve_target_person_for_private_chat(monkeypatch): + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.is_known = True + + service = memory_flow_module.PersonFactWritebackService.__new__(memory_flow_module.PersonFactWritebackService) + monkeypatch.setattr(memory_flow_module, "is_bot_self", lambda platform, user_id: False) + monkeypatch.setattr(memory_flow_module, "get_person_id", lambda platform, user_id: f"{platform}:{user_id}") + monkeypatch.setattr(memory_flow_module, "Person", FakePerson) + + message = SimpleNamespace(session=SimpleNamespace(platform="qq", user_id="123", group_id="")) + + person = service._resolve_target_person(message) + + assert person is not None + assert person.person_id == "qq:123" + + +@pytest.mark.asyncio +async def test_memory_automation_service_auto_starts_and_delegates(monkeypatch): + events: list[tuple[str, str]] = [] + + class FakeSessionManager: + async def on_message(self, message): + events.append(("incoming", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "session")) + + class FakeFactWriteback: + async def start(self): + events.append(("start", "fact")) + + async def enqueue(self, message): + events.append(("sent", message.session_id)) + + async def shutdown(self): + events.append(("shutdown", "fact")) + + service = memory_flow_module.MemoryAutomationService() + service.session_manager = FakeSessionManager() + service.fact_writeback = FakeFactWriteback() + + await service.on_incoming_message(SimpleNamespace(session_id="session-1")) + await service.on_message_sent(SimpleNamespace(session_id="session-1")) + await service.shutdown() + + assert events == [ + ("start", "fact"), + ("incoming", "session-1"), + ("sent", "session-1"), + ("shutdown", "session"), + ("shutdown", "fact"), + ] diff --git a/pytests/A_memorix_test/test_memory_service.py b/pytests/A_memorix_test/test_memory_service.py new file mode 100644 index 00000000..bac85afc --- /dev/null +++ b/pytests/A_memorix_test/test_memory_service.py @@ -0,0 +1,281 @@ +import pytest + +from src.services.memory_service import MemorySearchResult, MemoryService + + +def test_coerce_write_result_treats_skipped_payload_as_success(): + result = MemoryService._coerce_write_result({"skipped_ids": ["p1"], "detail": "chat_filtered"}) + + assert result.success is True + assert result.stored_ids == [] + assert result.skipped_ids == ["p1"] + assert result.detail == "chat_filtered" + + +@pytest.mark.asyncio +async def test_graph_admin_invokes_plugin(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "nodes": [], "edges": []} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.graph_admin(action="get_graph", limit=12) + + assert result["success"] is True + assert calls == [("memory_graph_admin", {"action": "get_graph", "limit": 12}, {"timeout_ms": 30000})] + + +@pytest.mark.asyncio +async def test_get_recycle_bin_uses_maintain_memory_tool(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"success": True, "items": [{"hash": "abc"}], "count": 1} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.get_recycle_bin(limit=5) + + assert result == {"success": True, "items": [{"hash": "abc"}], "count": 1} + assert calls == [("maintain_memory", {"action": "recycle_bin", "limit": 5})] + + +@pytest.mark.asyncio +async def test_search_respects_filter_by_default(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"summary": "ok", "hits": [], "filtered": True} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.search( + "mai", + chat_id="stream-1", + person_id="person-1", + user_id="user-1", + group_id="", + ) + + assert isinstance(result, MemorySearchResult) + assert result.filtered is True + assert calls == [ + ( + "search_memory", + { + "query": "mai", + "limit": 5, + "mode": "hybrid", + "chat_id": "stream-1", + "person_id": "person-1", + "time_start": None, + "time_end": None, + "respect_filter": True, + "user_id": "user-1", + "group_id": "", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_ingest_summary_can_bypass_filter(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"success": True, "stored_ids": ["p1"], "detail": ""} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.ingest_summary( + external_id="chat_history:1", + chat_id="stream-1", + text="summary", + respect_filter=False, + user_id="user-1", + ) + + assert result.success is True + assert calls == [ + ( + "ingest_summary", + { + "external_id": "chat_history:1", + "chat_id": "stream-1", + "text": "summary", + "participants": [], + "time_start": None, + "time_end": None, + "tags": [], + "metadata": {}, + "respect_filter": False, + "user_id": "user-1", + "group_id": "", + }, + ) + ] + + +@pytest.mark.asyncio +async def test_v5_admin_invokes_plugin(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "count": 1} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.v5_admin(action="status", target="mai", limit=5) + + assert result["success"] is True + assert calls == [("memory_v5_admin", {"action": "status", "target": "mai", "limit": 5}, {"timeout_ms": 30000})] + + +@pytest.mark.asyncio +async def test_delete_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "operation_id": "del-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.delete_admin(action="execute", mode="relation", selector={"query": "mai"}) + + assert result["success"] is True + assert calls == [ + ( + "memory_delete_admin", + {"action": "execute", "mode": "relation", "selector": {"query": "mai"}}, + {"timeout_ms": 120000}, + ) + ] + + +@pytest.mark.asyncio +async def test_search_returns_empty_when_query_and_time_missing_async(): + service = MemoryService() + + result = await service.search("", time_start=None, time_end=None) + + assert isinstance(result, MemorySearchResult) + assert result.summary == "" + assert result.hits == [] + assert result.filtered is False + + +@pytest.mark.asyncio +async def test_search_accepts_string_time_bounds(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args)) + return {"summary": "ok", "hits": [], "filtered": False} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.search( + "广播站", + mode="time", + time_start="2026/03/18", + time_end="2026/03/18 09:30", + ) + + assert isinstance(result, MemorySearchResult) + assert calls == [ + ( + "search_memory", + { + "query": "广播站", + "limit": 5, + "mode": "time", + "chat_id": "", + "person_id": "", + "time_start": "2026/03/18", + "time_end": "2026/03/18 09:30", + "respect_filter": True, + "user_id": "", + "group_id": "", + }, + ) + ] + + +def test_coerce_search_result_preserves_aggregate_source_branches(): + result = MemoryService._coerce_search_result( + { + "hits": [ + { + "content": "广播站值夜班", + "type": "paragraph", + "metadata": {"event_time_start": 1.0}, + "source_branches": ["search", "time"], + "rank": 1, + } + ] + } + ) + + assert result.hits[0].metadata["source_branches"] == ["search", "time"] + assert result.hits[0].metadata["rank"] == 1 + + +@pytest.mark.asyncio +async def test_import_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "task_id": "import-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.import_admin(action="create_lpmm_openie", alias="lpmm") + + assert result["success"] is True + assert calls == [ + ( + "memory_import_admin", + {"action": "create_lpmm_openie", "alias": "lpmm"}, + {"timeout_ms": 120000}, + ) + ] + + +@pytest.mark.asyncio +async def test_tuning_admin_uses_long_timeout(monkeypatch): + service = MemoryService() + calls = [] + + async def fake_invoke(component_name, args=None, **kwargs): + calls.append((component_name, args, kwargs)) + return {"success": True, "task_id": "tuning-1"} + + monkeypatch.setattr(service, "_invoke", fake_invoke) + + result = await service.tuning_admin(action="create_task", payload={"query": "mai"}) + + assert result["success"] is True + assert calls == [ + ( + "memory_tuning_admin", + {"action": "create_task", "payload": {"query": "mai"}}, + {"timeout_ms": 120000}, + ) + ] diff --git a/pytests/A_memorix_test/test_person_memory_writeback.py b/pytests/A_memorix_test/test_person_memory_writeback.py new file mode 100644 index 00000000..f177405a --- /dev/null +++ b/pytests/A_memorix_test/test_person_memory_writeback.py @@ -0,0 +1,81 @@ +from types import SimpleNamespace + +import pytest + +from src.person_info import person_info as person_info_module + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_writes_person_fact(monkeypatch): + calls = [] + + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.person_name = "Alice" + self.is_known = True + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") + monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") + monkeypatch.setattr(person_info_module, "Person", FakePerson) + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") + + assert len(calls) == 1 + payload = calls[0] + assert payload["external_id"].startswith("person_fact:person-1:") + assert payload["source_type"] == "person_fact" + assert payload["chat_id"] == "session-1" + assert payload["person_ids"] == ["person-1"] + assert payload["participants"] == ["Alice"] + assert payload["respect_filter"] is True + assert payload["user_id"] == "10001" + assert payload["group_id"] == "" + assert payload["metadata"]["person_id"] == "person-1" + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_skips_unknown_person(monkeypatch): + calls = [] + + class FakePerson: + def __init__(self, person_id: str): + self.person_id = person_id + self.person_name = "Unknown" + self.is_known = False + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + session = SimpleNamespace(platform="qq", user_id="10001", group_id="", session_id="session-1") + monkeypatch.setattr(person_info_module, "_chat_manager", SimpleNamespace(get_session_by_session_id=lambda chat_id: session)) + monkeypatch.setattr(person_info_module, "get_person_id_by_person_name", lambda person_name: "person-1") + monkeypatch.setattr(person_info_module, "Person", FakePerson) + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", "她喜欢猫和爵士乐", "session-1") + + assert calls == [] + + +@pytest.mark.asyncio +async def test_store_person_memory_from_answer_skips_empty_content(monkeypatch): + calls = [] + + async def fake_ingest_text(**kwargs): + calls.append(kwargs) + return SimpleNamespace(success=True, detail="", stored_ids=["p1"]) + + monkeypatch.setattr(person_info_module.memory_service, "ingest_text", fake_ingest_text) + + await person_info_module.store_person_memory_from_answer("Alice", " ", "session-1") + + assert calls == [] + diff --git a/pytests/A_memorix_test/test_query_long_term_memory_tool.py b/pytests/A_memorix_test/test_query_long_term_memory_tool.py new file mode 100644 index 00000000..23310e1f --- /dev/null +++ b/pytests/A_memorix_test/test_query_long_term_memory_tool.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from src.memory_system.retrieval_tools import query_long_term_memory as tool_module +from src.memory_system.retrieval_tools import init_all_tools +from src.memory_system.retrieval_tools.query_long_term_memory import ( + _resolve_time_expression, + query_long_term_memory, + register_tool, +) +from src.memory_system.retrieval_tools.tool_registry import get_tool_registry +from src.services.memory_service import MemoryHit, MemorySearchResult + + +def test_resolve_time_expression_supports_relative_and_absolute_inputs(): + now = datetime(2026, 3, 18, 15, 30) + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("今天", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/18 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("最近7天", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 12, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/12 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 0, 0) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 23, 59) + assert start_text == "2026/03/18 00:00" + assert end_text == "2026/03/18 23:59" + + start_ts, end_ts, start_text, end_text = _resolve_time_expression("2026/03/18 09:30", now=now) + assert datetime.fromtimestamp(start_ts) == datetime(2026, 3, 18, 9, 30) + assert datetime.fromtimestamp(end_ts) == datetime(2026, 3, 18, 9, 30) + assert start_text == "2026/03/18 09:30" + assert end_text == "2026/03/18 09:30" + + +def test_register_tool_exposes_mode_and_time_expression(): + register_tool() + tool = get_tool_registry().get_tool("search_long_term_memory") + + assert tool is not None + params = {item["name"]: item for item in tool.parameters} + assert "mode" in params + assert params["mode"]["enum"] == ["search", "time", "episode", "aggregate"] + assert "time_expression" in params + assert params["query"]["required"] is False + + +def test_init_all_tools_registers_long_term_memory_tool(): + init_all_tools() + + tool = get_tool_registry().get_tool("search_long_term_memory") + assert tool is not None + + +@pytest.mark.asyncio +async def test_query_long_term_memory_search_mode_maps_to_hybrid(monkeypatch): + captured = {} + + async def fake_search(query, **kwargs): + captured["query"] = query + captured["kwargs"] = kwargs + return MemorySearchResult( + hits=[MemoryHit(content="Alice 喜欢猫", score=0.9, hit_type="paragraph")], + ) + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + + text = await query_long_term_memory("Alice 喜欢什么", chat_id="stream-1", person_id="person-1") + + assert "Alice 喜欢猫" in text + assert captured == { + "query": "Alice 喜欢什么", + "kwargs": { + "limit": 5, + "mode": "hybrid", + "chat_id": "stream-1", + "person_id": "person-1", + "time_start": None, + "time_end": None, + }, + } + + +@pytest.mark.asyncio +async def test_query_long_term_memory_time_mode_parses_expression(monkeypatch): + captured = {} + + async def fake_search(query, **kwargs): + captured["query"] = query + captured["kwargs"] = kwargs + return MemorySearchResult( + hits=[ + MemoryHit( + content="昨天晚上广播站停播了十分钟。", + score=0.8, + hit_type="paragraph", + metadata={"event_time_start": 1773797400.0}, + ) + ] + ) + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + monkeypatch.setattr( + tool_module, + "_resolve_time_expression", + lambda expression, now=None: (1773795600.0, 1773881940.0, "2026/03/17 00:00", "2026/03/17 23:59"), + ) + + text = await query_long_term_memory( + query="广播站", + mode="time", + time_expression="昨天", + chat_id="stream-1", + ) + + assert "指定时间范围" in text + assert "广播站停播" in text + assert captured == { + "query": "广播站", + "kwargs": { + "limit": 5, + "mode": "time", + "chat_id": "stream-1", + "person_id": "", + "time_start": 1773795600.0, + "time_end": 1773881940.0, + }, + } + + +@pytest.mark.asyncio +async def test_query_long_term_memory_episode_and_aggregate_format_output(monkeypatch): + responses = { + "episode": MemorySearchResult( + hits=[ + MemoryHit( + content="苏弦在灯塔拆开了那封冬信。", + title="冬信重见天日", + hit_type="episode", + metadata={"participants": ["苏弦"], "keywords": ["冬信", "灯塔"]}, + ) + ] + ), + "aggregate": MemorySearchResult( + hits=[ + MemoryHit( + content="唐未在广播站值夜班时带着黑狗墨点。", + hit_type="paragraph", + metadata={"source_branches": ["search", "time"]}, + ) + ] + ), + } + + async def fake_search(query, **kwargs): + return responses[kwargs["mode"]] + + monkeypatch.setattr(tool_module, "memory_service", SimpleNamespace(search=fake_search)) + + episode_text = await query_long_term_memory("那封冬信后来怎么样了", mode="episode") + aggregate_text = await query_long_term_memory("唐未最近有什么线索", mode="aggregate") + + assert "事件《冬信重见天日》" in episode_text + assert "参与者:苏弦" in episode_text + assert "[search,time][paragraph]" in aggregate_text + + +@pytest.mark.asyncio +async def test_query_long_term_memory_invalid_time_expression_returns_retryable_message(): + text = await query_long_term_memory(query="广播站", mode="time", time_expression="明年春分后第三周") + + assert "无法解析" in text + assert "最近7天" in text diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py new file mode 100644 index 00000000..71d94a7b --- /dev/null +++ b/pytests/A_memorix_test/test_real_dialogue_business_flow_integration.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict + +import numpy as np +import pytest +import pytest_asyncio + +from A_memorix.core.runtime import sdk_memory_kernel as kernel_module +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" + + +def _load_dialogue_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _FakeEmbeddingAdapter: + def __init__(self, dimension: int = 16) -> None: + self.dimension = dimension + + async def _detect_dimension(self) -> int: + return self.dimension + + async def encode(self, texts, dimensions=None): + dim = int(dimensions or self.dimension) + if isinstance(texts, str): + sequence = [texts] + single = True + else: + sequence = list(texts) + single = False + + rows = [] + for text in sequence: + vec = np.zeros(dim, dtype=np.float32) + for ch in str(text or ""): + vec[ord(ch) % dim] += 1.0 + if not vec.any(): + vec[0] = 1.0 + norm = np.linalg.norm(vec) + if norm > 0: + vec = vec / norm + rows.append(vec) + payload = np.vstack(rows) + return payload[0] if single else payload + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, max_rounds: int = 100) -> Dict[str, Any]: + for _ in range(max_rounds): + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(0.05) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +@pytest_asyncio.fixture +async def real_dialogue_env(monkeypatch, tmp_path): + scenario = _load_dialogue_fixture() + session_cfg = scenario["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + monkeypatch.setattr(kernel_module, "create_embedding_api_adapter", lambda **kwargs: _FakeEmbeddingAdapter()) + + async def fake_self_check(**kwargs): + return {"ok": True, "message": "ok"} + + monkeypatch.setattr(kernel_module, "run_embedding_runtime_self_check", fake_self_check) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + + data_dir = (tmp_path / "a_memorix_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "scenario": scenario, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_real_dialogue_import_flow_makes_fixture_searchable(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + + created = await memory_service.import_admin( + action="create_paste", + name="private_alice.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(scenario["import_payload"], ensure_ascii=False), + ) + + assert created["success"] is True + detail = await _wait_for_import_task(created["task"]["task_id"]) + assert detail["task"]["status"] == "completed" + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + respect_filter=False, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_summarizer_flow_persists_summary_to_long_term_memory(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=real_dialogue_env["session"].session_id, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_person_fact_writeback_is_searchable(real_dialogue_env, monkeypatch): + scenario = real_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + real_dialogue_env["session"].session_id, + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=real_dialogue_env["session"].session_id, + person_id=scenario["person"]["person_id"], + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_real_dialogue_private_knowledge_fetcher_reads_long_term_memory(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + + await memory_service.ingest_text( + external_id="fixture:knowledge_fetcher", + source_type="dialogue_note", + text=scenario["person_fact"]["memory_content"], + chat_id=real_dialogue_env["session"].session_id, + person_ids=[scenario["person"]["person_id"]], + participants=[scenario["person"]["person_name"]], + respect_filter=False, + ) + + fetcher = knowledge_module.KnowledgeFetcher( + private_name=scenario["session"]["display_name"], + stream_id=real_dialogue_env["session"].session_id, + ) + knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) + + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in knowledge_text + + +@pytest.mark.asyncio +async def test_real_dialogue_person_profile_contains_stable_traits(real_dialogue_env, monkeypatch): + scenario = real_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + real_dialogue_env["session"].session_id, + ) + + profile = await memory_service.get_person_profile( + scenario["person"]["person_id"], + chat_id=real_dialogue_env["session"].session_id, + ) + + assert profile.evidence + assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) + + +@pytest.mark.asyncio +async def test_real_dialogue_summary_flow_generates_queryable_episode(real_dialogue_env): + scenario = real_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(real_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + episodes = await memory_service.episode_admin( + action="query", + source=scenario["expectations"]["episode_source"], + limit=5, + ) + + assert episodes["success"] is True + assert int(episodes["count"]) >= 1 diff --git a/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py new file mode 100644 index 00000000..808d4c23 --- /dev/null +++ b/pytests/A_memorix_test/test_real_dialogue_business_flow_live.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import asyncio +import inspect +import json +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict + +import pytest +import pytest_asyncio + +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel +from src.chat.brain_chat.PFC import pfc_KnowledgeFetcher as knowledge_module +from src.memory_system import chat_history_summarizer as summarizer_module +from src.person_info import person_info as person_info_module +from src.services import memory_service as memory_service_module +from src.services.memory_service import memory_service + + +pytestmark = pytest.mark.skipif( + os.getenv("MAIBOT_RUN_LIVE_MEMORY_TESTS") != "1", + reason="需要显式开启真实 embedding / self-check 集成测试", +) + +DATA_FILE = Path(__file__).parent / "data" / "real_dialogues" / "private_alice_weekend.json" + + +def _load_dialogue_fixture() -> Dict[str, Any]: + return json.loads(DATA_FILE.read_text(encoding="utf-8")) + + +class _KernelBackedRuntimeManager: + is_running = True + + def __init__(self, kernel: SDKMemoryKernel) -> None: + self.kernel = kernel + + async def invoke_plugin( + self, + *, + method: str, + plugin_id: str, + component_name: str, + args: Dict[str, Any] | None, + timeout_ms: int, + ): + del method, plugin_id, timeout_ms + payload = args or {} + if component_name == "search_memory": + return await self.kernel.search_memory( + KernelSearchRequest( + query=str(payload.get("query", "") or ""), + limit=int(payload.get("limit", 5) or 5), + mode=str(payload.get("mode", "hybrid") or "hybrid"), + chat_id=str(payload.get("chat_id", "") or ""), + person_id=str(payload.get("person_id", "") or ""), + time_start=payload.get("time_start"), + time_end=payload.get("time_end"), + respect_filter=bool(payload.get("respect_filter", True)), + user_id=str(payload.get("user_id", "") or ""), + group_id=str(payload.get("group_id", "") or ""), + ) + ) + + handler = getattr(self.kernel, component_name) + result = handler(**payload) + return await result if inspect.isawaitable(result) else result + + +async def _wait_for_import_task(task_id: str, *, timeout_seconds: float = 60.0) -> Dict[str, Any]: + deadline = asyncio.get_running_loop().time() + max(1.0, float(timeout_seconds)) + while asyncio.get_running_loop().time() < deadline: + detail = await memory_service.import_admin(action="get", task_id=task_id, include_chunks=True) + task = detail.get("task") or {} + status = str(task.get("status", "") or "") + if status in {"completed", "completed_with_errors", "failed", "cancelled"}: + return detail + await asyncio.sleep(0.2) + raise AssertionError(f"导入任务在等待窗口内未结束: {task_id}") + + +def _join_hit_content(search_result) -> str: + return "\n".join(hit.content for hit in search_result.hits) + + +@pytest_asyncio.fixture +async def live_dialogue_env(monkeypatch, tmp_path): + scenario = _load_dialogue_fixture() + session_cfg = scenario["session"] + session = SimpleNamespace( + session_id=session_cfg["session_id"], + platform=session_cfg["platform"], + user_id=session_cfg["user_id"], + group_id=session_cfg["group_id"], + ) + fake_chat_manager = SimpleNamespace( + get_session_by_session_id=lambda session_id: session if session_id == session.session_id else None, + get_session_name=lambda session_id: session_cfg["display_name"] if session_id == session.session_id else session_id, + ) + + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", None) + monkeypatch.setattr(summarizer_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(knowledge_module, "_chat_manager", fake_chat_manager) + monkeypatch.setattr(person_info_module, "_chat_manager", fake_chat_manager) + + data_dir = (tmp_path / "a_memorix_data").resolve() + kernel = SDKMemoryKernel( + plugin_root=tmp_path / "plugin_root", + config={ + "storage": {"data_dir": str(data_dir)}, + "advanced": {"enable_auto_save": False}, + "memory": {"base_decay_interval_hours": 24}, + "person_profile": {"refresh_interval_minutes": 5}, + }, + ) + manager = _KernelBackedRuntimeManager(kernel) + monkeypatch.setattr(memory_service_module, "get_plugin_runtime_manager", lambda: manager) + + await kernel.initialize() + try: + yield { + "scenario": scenario, + "kernel": kernel, + "session": session, + } + finally: + await kernel.shutdown() + + +@pytest.mark.asyncio +async def test_live_runtime_self_check_passes(live_dialogue_env): + report = await memory_service.runtime_admin(action="refresh_self_check") + + assert report["success"] is True + assert report["report"]["ok"] is True + assert report["report"]["encoded_dimension"] > 0 + + +@pytest.mark.asyncio +async def test_live_import_flow_makes_fixture_searchable(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + + created = await memory_service.import_admin( + action="create_paste", + name="private_alice.json", + input_mode="json", + llm_enabled=False, + content=json.dumps(scenario["import_payload"], ensure_ascii=False), + ) + + assert created["success"] is True + detail = await _wait_for_import_task(created["task"]["task_id"]) + assert detail["task"]["status"] == "completed" + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + respect_filter=False, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_summarizer_flow_persists_summary_to_long_term_memory(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=live_dialogue_env["session"].session_id, + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_person_fact_writeback_is_searchable(live_dialogue_env, monkeypatch): + scenario = live_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + live_dialogue_env["session"].session_id, + ) + + search = await memory_service.search( + scenario["search_queries"]["direct"], + mode="search", + chat_id=live_dialogue_env["session"].session_id, + person_id=scenario["person"]["person_id"], + ) + + assert search.hits + joined = _join_hit_content(search) + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in joined + + +@pytest.mark.asyncio +async def test_live_private_knowledge_fetcher_reads_long_term_memory(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + + await memory_service.ingest_text( + external_id="fixture:knowledge_fetcher", + source_type="dialogue_note", + text=scenario["person_fact"]["memory_content"], + chat_id=live_dialogue_env["session"].session_id, + person_ids=[scenario["person"]["person_id"]], + participants=[scenario["person"]["person_name"]], + respect_filter=False, + ) + + fetcher = knowledge_module.KnowledgeFetcher( + private_name=scenario["session"]["display_name"], + stream_id=live_dialogue_env["session"].session_id, + ) + knowledge_text, _ = await fetcher.fetch(scenario["search_queries"]["knowledge_fetcher"], []) + + for keyword in scenario["expectations"]["search_keywords"]: + assert keyword in knowledge_text + + +@pytest.mark.asyncio +async def test_live_person_profile_contains_stable_traits(live_dialogue_env, monkeypatch): + scenario = live_dialogue_env["scenario"] + + class _KnownPerson: + def __init__(self, person_id: str) -> None: + self.person_id = person_id + self.is_known = True + self.person_name = scenario["person"]["person_name"] + + monkeypatch.setattr( + person_info_module, + "get_person_id_by_person_name", + lambda person_name: scenario["person"]["person_id"], + ) + monkeypatch.setattr(person_info_module, "Person", _KnownPerson) + + await person_info_module.store_person_memory_from_answer( + scenario["person"]["person_name"], + scenario["person_fact"]["memory_content"], + live_dialogue_env["session"].session_id, + ) + + profile = await memory_service.get_person_profile( + scenario["person"]["person_id"], + chat_id=live_dialogue_env["session"].session_id, + ) + + assert profile.evidence + assert any(keyword in profile.summary for keyword in scenario["expectations"]["profile_keywords"]) + + +@pytest.mark.asyncio +async def test_live_summary_flow_generates_queryable_episode(live_dialogue_env): + scenario = live_dialogue_env["scenario"] + record = scenario["chat_history_record"] + + summarizer = summarizer_module.ChatHistorySummarizer(live_dialogue_env["session"].session_id) + await summarizer._import_to_long_term_memory( + record_id=record["record_id"], + theme=record["theme"], + summary=record["summary"], + participants=record["participants"], + start_time=record["start_time"], + end_time=record["end_time"], + original_text=record["original_text"], + ) + + episodes = await memory_service.episode_admin( + action="query", + source=scenario["expectations"]["episode_source"], + limit=5, + ) + + assert episodes["success"] is True + assert int(episodes["count"]) >= 1 diff --git a/pytests/webui/test_memory_routes.py b/pytests/webui/test_memory_routes.py new file mode 100644 index 00000000..d66a8333 --- /dev/null +++ b/pytests/webui/test_memory_routes.py @@ -0,0 +1,279 @@ +from fastapi import FastAPI +from fastapi.testclient import TestClient +import pytest + +from src.services.memory_service import MemorySearchResult +from src.webui.dependencies import require_auth +from src.webui.routers import memory as memory_router_module +from src.webui.routers.memory import compat_router, router + + +@pytest.fixture +def client() -> TestClient: + app = FastAPI() + app.dependency_overrides[require_auth] = lambda: "ok" + app.include_router(router) + app.include_router(compat_router) + return TestClient(app) + + +def test_webui_memory_graph_route(client: TestClient, monkeypatch): + async def fake_graph_admin(*, action: str, **kwargs): + assert action == "get_graph" + return {"success": True, "nodes": [], "edges": [], "total_nodes": 0, "limit": kwargs.get("limit")} + + monkeypatch.setattr(memory_router_module.memory_service, "graph_admin", fake_graph_admin) + + response = client.get("/api/webui/memory/graph", params={"limit": 77}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["limit"] == 77 + + +def test_compat_aggregate_route(client: TestClient, monkeypatch): + async def fake_search(query: str, **kwargs): + assert kwargs["mode"] == "aggregate" + assert kwargs["respect_filter"] is False + return MemorySearchResult(summary=f"summary:{query}", hits=[]) + + monkeypatch.setattr(memory_router_module.memory_service, "search", fake_search) + + response = client.get("/api/query/aggregate", params={"query": "mai"}) + + assert response.status_code == 200 + assert response.json() == {"success": True, "summary": "summary:mai", "hits": [], "filtered": False} + + +def test_auto_save_routes(client: TestClient, monkeypatch): + async def fake_runtime_admin(*, action: str, **kwargs): + if action == "get_config": + return {"success": True, "auto_save": True} + if action == "set_auto_save": + return {"success": True, "auto_save": kwargs["enabled"]} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "runtime_admin", fake_runtime_admin) + + get_response = client.get("/api/config/auto_save") + post_response = client.post("/api/config/auto_save", json={"enabled": False}) + + assert get_response.status_code == 200 + assert get_response.json() == {"success": True, "auto_save": True} + assert post_response.status_code == 200 + assert post_response.json() == {"success": True, "auto_save": False} + + +def test_recycle_bin_route(client: TestClient, monkeypatch): + async def fake_get_recycle_bin(*, limit: int): + return {"success": True, "items": [{"hash": "deadbeef"}], "count": 1, "limit": limit} + + monkeypatch.setattr(memory_router_module.memory_service, "get_recycle_bin", fake_get_recycle_bin) + + response = client.get("/api/memory/recycle_bin", params={"limit": 10}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["count"] == 1 + assert response.json()["limit"] == 10 + + +def test_import_guide_route(client: TestClient, monkeypatch): + async def fake_import_admin(*, action: str, **kwargs): + assert kwargs == {} + if action == "get_guide": + return {"success": True} + if action == "get_settings": + return {"success": True, "settings": {"path_aliases": {"raw": "/tmp/raw"}}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.get("/api/webui/memory/import/guide") + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["source"] == "local" + assert "长期记忆导入说明" in response.json()["content"] + + +def test_import_upload_route(client: TestClient, monkeypatch, tmp_path): + monkeypatch.setattr(memory_router_module, "STAGING_ROOT", tmp_path) + + async def fake_import_admin(*, action: str, **kwargs): + assert action == "create_upload" + staged_files = kwargs["staged_files"] + assert len(staged_files) == 1 + assert staged_files[0]["filename"] == "demo.txt" + assert memory_router_module.Path(staged_files[0]["staged_path"]).exists() + return {"success": True, "task_id": "task-1"} + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.post( + "/api/import/upload", + data={"payload_json": "{\"source\": \"upload\"}"}, + files=[("files", ("demo.txt", b"hello world", "text/plain"))], + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "task_id": "task-1"} + assert list(tmp_path.iterdir()) == [] + + +def test_v5_status_route(client: TestClient, monkeypatch): + async def fake_v5_admin(*, action: str, **kwargs): + assert action == "status" + assert kwargs["target"] == "mai" + return {"success": True, "active_count": 1, "inactive_count": 2, "deleted_count": 3} + + monkeypatch.setattr(memory_router_module.memory_service, "v5_admin", fake_v5_admin) + + response = client.get("/api/webui/memory/v5/status", params={"target": "mai"}) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["deleted_count"] == 3 + + +def test_delete_preview_route(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + assert action == "preview" + assert kwargs["mode"] == "paragraph" + assert kwargs["selector"] == {"query": "demo"} + return {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + response = client.post( + "/api/webui/memory/delete/preview", + json={"mode": "paragraph", "selector": {"query": "demo"}}, + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "counts": {"paragraphs": 1}, "dry_run": True} + + +def test_episode_process_pending_route(client: TestClient, monkeypatch): + async def fake_episode_admin(*, action: str, **kwargs): + assert action == "process_pending" + assert kwargs == {"limit": 7, "max_retry": 4} + return {"success": True, "processed": 3} + + monkeypatch.setattr(memory_router_module.memory_service, "episode_admin", fake_episode_admin) + + response = client.post("/api/webui/memory/episodes/process-pending", json={"limit": 7, "max_retry": 4}) + + assert response.status_code == 200 + assert response.json() == {"success": True, "processed": 3} + + +def test_import_list_route_includes_settings(client: TestClient, monkeypatch): + calls = [] + + async def fake_import_admin(*, action: str, **kwargs): + calls.append((action, kwargs)) + if action == "list": + return {"success": True, "items": [{"task_id": "task-1"}]} + if action == "get_settings": + return {"success": True, "settings": {"path_aliases": {"lpmm": "/tmp/lpmm"}}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "import_admin", fake_import_admin) + + response = client.get("/api/webui/memory/import/tasks", params={"limit": 9}) + + assert response.status_code == 200 + assert response.json()["items"] == [{"task_id": "task-1"}] + assert response.json()["settings"] == {"path_aliases": {"lpmm": "/tmp/lpmm"}} + assert calls == [("list", {"limit": 9}), ("get_settings", {})] + + +def test_tuning_profile_route_backfills_settings(client: TestClient, monkeypatch): + calls = [] + + async def fake_tuning_admin(*, action: str, **kwargs): + calls.append((action, kwargs)) + if action == "get_profile": + return {"success": True, "profile": {"retrieval": {"top_k": 8}}} + if action == "get_settings": + return {"success": True, "settings": {"profiles": ["default"]}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) + + response = client.get("/api/webui/memory/retrieval_tuning/profile") + + assert response.status_code == 200 + assert response.json()["profile"] == {"retrieval": {"top_k": 8}} + assert response.json()["settings"] == {"profiles": ["default"]} + assert calls == [("get_profile", {}), ("get_settings", {})] + + +def test_tuning_report_route_flattens_report_payload(client: TestClient, monkeypatch): + async def fake_tuning_admin(*, action: str, **kwargs): + assert action == "get_report" + assert kwargs == {"task_id": "task-1", "format": "json"} + return { + "success": True, + "report": {"format": "json", "content": "{\"ok\": true}", "path": "/tmp/report.json"}, + } + + monkeypatch.setattr(memory_router_module.memory_service, "tuning_admin", fake_tuning_admin) + + response = client.get("/api/webui/memory/retrieval_tuning/tasks/task-1/report", params={"format": "json"}) + + assert response.status_code == 200 + assert response.json() == { + "success": True, + "format": "json", + "content": "{\"ok\": true}", + "path": "/tmp/report.json", + "error": "", + } + + +def test_delete_execute_route(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + assert action == "execute" + assert kwargs["mode"] == "source" + assert kwargs["selector"] == {"source": "chat_summary:stream-1"} + assert kwargs["reason"] == "cleanup" + assert kwargs["requested_by"] == "tester" + return {"success": True, "operation_id": "del-1"} + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + response = client.post( + "/api/webui/memory/delete/execute", + json={ + "mode": "source", + "selector": {"source": "chat_summary:stream-1"}, + "reason": "cleanup", + "requested_by": "tester", + }, + ) + + assert response.status_code == 200 + assert response.json() == {"success": True, "operation_id": "del-1"} + + +def test_delete_operation_routes(client: TestClient, monkeypatch): + async def fake_delete_admin(*, action: str, **kwargs): + if action == "list_operations": + assert kwargs == {"limit": 5, "mode": "paragraph"} + return {"success": True, "items": [{"operation_id": "del-1"}], "count": 1} + if action == "get_operation": + assert kwargs == {"operation_id": "del-1"} + return {"success": True, "operation": {"operation_id": "del-1", "mode": "paragraph"}} + raise AssertionError(action) + + monkeypatch.setattr(memory_router_module.memory_service, "delete_admin", fake_delete_admin) + + list_response = client.get("/api/webui/memory/delete/operations", params={"limit": 5, "mode": "paragraph"}) + get_response = client.get("/api/webui/memory/delete/operations/del-1") + + assert list_response.status_code == 200 + assert list_response.json()["count"] == 1 + assert get_response.status_code == 200 + assert get_response.json()["operation"]["operation_id"] == "del-1" diff --git a/src/bw_learner/jargon_explainer_old.py b/src/bw_learner/jargon_explainer_old.py index 4d144b2c..94031b4a 100644 --- a/src/bw_learner/jargon_explainer_old.py +++ b/src/bw_learner/jargon_explainer_old.py @@ -7,7 +7,7 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.prompt.prompt_manager import prompt_manager -from src.bw_learner.jargon_miner_old import search_jargon +from src.bw_learner.jargon_explainer import search_jargon from src.bw_learner.learner_utils_old import ( is_bot_message, contains_bot_self_name, diff --git a/src/bw_learner/learner_utils_old.py b/src/bw_learner/learner_utils_old.py index 3f21c55d..6095ef48 100644 --- a/src/bw_learner/learner_utils_old.py +++ b/src/bw_learner/learner_utils_old.py @@ -196,6 +196,32 @@ def contains_bot_self_name(content: str) -> bool: return any(name in target for name in candidates) +def is_bot_message(msg: Any) -> bool: + """判断消息是否来自机器人自身。""" + if msg is None: + return False + + bot_config = getattr(global_config, "bot", None) + if not bot_config: + return False + + user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip() + if not user_id: + return False + + known_accounts = { + str(getattr(bot_config, "qq_account", "") or "").strip(), + str(getattr(bot_config, "telegram_account", "") or "").strip(), + } + + for platform in getattr(bot_config, "platforms", []) or []: + account = str(getattr(platform, "account", "") or getattr(platform, "id", "") or "").strip() + if account: + known_accounts.add(account) + + return user_id in {account for account in known_accounts if account} + + # def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: # """ # 构建包含中心消息上下文的段落(前3条+后3条),使用标准的 readable builder 输出 diff --git a/src/chat/brain_chat/PFC/conversation.py b/src/chat/brain_chat/PFC/conversation.py index 1e1e89b1..ab5a7b3d 100644 --- a/src/chat/brain_chat/PFC/conversation.py +++ b/src/chat/brain_chat/PFC/conversation.py @@ -55,7 +55,7 @@ class Conversation: self.action_planner = ActionPlanner(self.stream_id, self.private_name) self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name) self.reply_generator = ReplyGenerator(self.stream_id, self.private_name) - self.knowledge_fetcher = KnowledgeFetcher(self.private_name) + self.knowledge_fetcher = KnowledgeFetcher(self.private_name, self.stream_id) self.waiter = Waiter(self.stream_id, self.private_name) self.direct_sender = DirectMessageSender(self.private_name) diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index 67509bd5..4d47f609 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -1,11 +1,14 @@ -from typing import List, Tuple, Dict, Any +from typing import Any, Dict, List, Tuple + +from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.logger import get_logger # NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned # from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.llm_models.utils_model import LLMRequest from src.config.config import model_config -from src.chat.knowledge import qa_manager +from src.llm_models.utils_model import LLMRequest +from src.person_info.person_info import resolve_person_id_for_memory +from src.services.memory_service import memory_service logger = get_logger("knowledge_fetcher") @@ -13,11 +16,39 @@ logger = get_logger("knowledge_fetcher") class KnowledgeFetcher: """知识调取器""" - def __init__(self, private_name: str): + def __init__(self, private_name: str, stream_id: str): self.llm = LLMRequest(model_set=model_config.model_task_config.utils) self.private_name = private_name + self.stream_id = stream_id - def _lpmm_get_knowledge(self, query: str) -> str: + def _resolve_private_memory_context(self) -> Dict[str, str]: + session = _chat_manager.get_session_by_session_id(self.stream_id) + if session is None: + return {"chat_id": self.stream_id} + + group_id = str(getattr(session, "group_id", "") or "").strip() + user_id = str(getattr(session, "user_id", "") or "").strip() + platform = str(getattr(session, "platform", "") or "").strip() + + person_id = "" + if not group_id: + try: + person_id = resolve_person_id_for_memory( + person_name=self.private_name, + platform=platform, + user_id=user_id, + ) + except Exception as exc: + logger.debug(f"[私聊][{self.private_name}]解析人物ID失败: {exc}") + + return { + "chat_id": self.stream_id, + "person_id": person_id, + "user_id": user_id, + "group_id": group_id, + } + + async def _memory_get_knowledge(self, query: str) -> str: """获取相关知识 Args: @@ -27,13 +58,32 @@ class KnowledgeFetcher: str: 构造好的,带相关度的知识 """ - logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识") + logger.debug(f"[私聊][{self.private_name}]正在从长期记忆中获取知识") try: - knowledge_info = qa_manager.get_knowledge(query) - logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}") - return knowledge_info + context = self._resolve_private_memory_context() + search_kwargs = { + "limit": 5, + "mode": "search", + "chat_id": context.get("chat_id", ""), + "person_id": context.get("person_id", ""), + "user_id": context.get("user_id", ""), + "group_id": context.get("group_id", ""), + "respect_filter": True, + } + result = await memory_service.search(query, **search_kwargs) + if not result.filtered and not result.hits and search_kwargs["person_id"]: + fallback_kwargs = dict(search_kwargs) + fallback_kwargs["person_id"] = "" + logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆") + result = await memory_service.search(query, **fallback_kwargs) + knowledge_info = result.to_text(limit=5) + if result.filtered: + logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过") + else: + logger.debug(f"[私聊][{self.private_name}]长期记忆查询结果: {knowledge_info[:150]}") + return knowledge_info or "未找到匹配的知识" except Exception as e: - logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}") + logger.error(f"[私聊][{self.private_name}]长期记忆搜索工具执行失败: {str(e)}") return "未找到匹配的知识" async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]: @@ -72,7 +122,7 @@ class KnowledgeFetcher: # sources_text = ",".join(sources) knowledge_text += "\n现在有以下**知识**可供参考:\n " - knowledge_text += self._lpmm_get_knowledge(query) + knowledge_text += await self._memory_get_knowledge(query) knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n" return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配" diff --git a/src/chat/knowledge/__init__.py b/src/chat/knowledge/__init__.py deleted file mode 100644 index 57e94472..00000000 --- a/src/chat/knowledge/__init__.py +++ /dev/null @@ -1,90 +0,0 @@ -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.qa_manager import QAManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.global_logger import logger -from src.config.config import global_config -import os - -INVALID_ENTITY = [ - "", - "你", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "她们", - "它们", -] - -RAG_GRAPH_NAMESPACE = "rag-graph" -RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" -RAG_PG_HASH_NAMESPACE = "rag-pg-hash" - - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -DATA_PATH = os.path.join(ROOT_PATH, "data") - - -qa_manager = None -inspire_manager = None - - -def get_qa_manager(): - return qa_manager - - -def lpmm_start_up(): # sourcery skip: extract-duplicate-method - # 检查LPMM知识库是否启用 - if global_config.lpmm_knowledge.enable: - logger.info("正在初始化Mai-LPMM") - logger.info("创建LLM客户端") - - # 初始化Embedding库 - embed_manager = EmbeddingManager( - max_workers=global_config.lpmm_knowledge.max_embedding_workers, - chunk_size=global_config.lpmm_knowledge.embedding_chunk_size, - ) - logger.info("正在从文件加载Embedding库") - try: - embed_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载Embedding库时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("Embedding库加载完成") - # 初始化KG - kg_manager = KGManager() - logger.info("正在从文件加载KG") - try: - kg_manager.load_from_file() - except Exception as e: - logger.warning(f"此消息不会影响正常使用:从文件加载KG时,{e}") - # logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误") - logger.info("KG加载完成") - - logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}") - logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}") - - # 数据比对:Embedding库与KG的段落hash集合 - for pg_hash in kg_manager.stored_paragraph_hashes: - # 使用与EmbeddingStore中一致的命名空间格式 - key = f"paragraph-{pg_hash}" - if key not in embed_manager.stored_pg_hashes: - logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") - global qa_manager - # 问答系统(用于知识库) - qa_manager = QAManager( - embed_manager, - kg_manager, - ) - - # # 记忆激活(用于记忆库) - # global inspire_manager - # inspire_manager = MemoryActiveManager( - # embed_manager, - # llm_client_list[global_config["embedding"]["provider"]], - # ) - else: - logger.info("LPMM知识库已禁用,跳过初始化") - # 创建空的占位符对象,避免导入错误 diff --git a/src/chat/knowledge/lpmm_ops.py b/src/chat/knowledge/lpmm_ops.py deleted file mode 100644 index acaac4ca..00000000 --- a/src/chat/knowledge/lpmm_ops.py +++ /dev/null @@ -1,380 +0,0 @@ -import asyncio -import os -from functools import partial -from typing import List, Callable, Any -from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.kg_manager import KGManager -from src.chat.knowledge.qa_manager import QAManager -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import get_qa_manager, lpmm_start_up - -logger = get_logger("LPMM-Plugin-API") - - -class LPMMOperations: - """ - LPMM 内部操作接口。 - 封装了 LPMM 的核心操作,供插件系统 API 或其他内部组件调用。 - """ - - def __init__(self): - self._initialized = False - - async def _run_cancellable_executor(self, func: Callable, *args, **kwargs) -> Any: - """ - 在线程池中执行可取消的同步操作。 - 当任务被取消时(如 Ctrl+C),会立即响应并抛出 CancelledError。 - 注意:线程池中的操作可能仍在运行,但协程会立即返回,不会阻塞主进程。 - - Args: - func: 要执行的同步函数 - *args: 函数的位置参数 - **kwargs: 函数的关键字参数 - - Returns: - 函数的返回值 - - Raises: - asyncio.CancelledError: 当任务被取消时 - """ - loop = asyncio.get_event_loop() - # 在线程池中执行,当协程被取消时会立即响应 - # 虽然线程池中的操作可能仍在运行,但协程不会阻塞 - return await loop.run_in_executor(None, func, *args, **kwargs) - - async def _get_managers(self) -> tuple[EmbeddingManager, KGManager, QAManager]: - """获取并确保 LPMM 管理器已初始化""" - qa_mgr = get_qa_manager() - if qa_mgr is None: - # 如果全局没初始化,尝试初始化 - if not global_config.lpmm_knowledge.enable: - logger.warning("LPMM 知识库在全局配置中未启用,操作可能受限。") - - lpmm_start_up() - qa_mgr = get_qa_manager() - - if qa_mgr is None: - raise RuntimeError("无法获取 LPMM QAManager,请检查 LPMM 是否已正确安装和配置。") - - return qa_mgr.embed_manager, qa_mgr.kg_manager, qa_mgr - - async def add_content(self, text: str, auto_split: bool = True) -> dict: - """ - 向知识库添加新内容。 - - Args: - text: 原始文本。 - auto_split: 是否自动按双换行符分割段落。 - - True: 自动分割(默认),支持多段文本(用双换行分隔) - - False: 不分割,将整个文本作为完整一段处理 - - Returns: - dict: {"status": "success/error", "count": 导入段落数, "message": "描述"} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 1. 分段处理 - if auto_split: - # 自动按双换行符分割 - paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] - else: - # 不分割,作为完整一段 - text_stripped = text.strip() - if not text_stripped: - return {"status": "error", "message": "文本内容为空"} - paragraphs = [text_stripped] - - if not paragraphs: - return {"status": "error", "message": "文本内容为空"} - - # 2. 实体与三元组抽取 (内部调用大模型) - from src.chat.knowledge.ie_process import IEProcess - from src.llm_models.utils_model import LLMRequest - from src.config.config import model_config - - llm_ner = LLMRequest( - model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" - ) - llm_rdf = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build") - ie_process = IEProcess(llm_ner, llm_rdf) - - logger.info(f"[Plugin API] 正在对 {len(paragraphs)} 段文本执行信息抽取...") - extracted_docs = await ie_process.process_paragraphs(paragraphs) - - # 3. 构造并导入数据 - # 这里我们手动实现导入逻辑,不依赖外部脚本 - # a. 准备段落 - raw_paragraphs = {doc["idx"]: doc["passage"] for doc in extracted_docs} - # b. 准备三元组 - triple_list_data = {doc["idx"]: doc["extracted_triples"] for doc in extracted_docs} - - # 向量化并入库 - # 注意:此处模仿 import_openie.py 的核心逻辑 - # 1. 先进行去重检查,只处理新段落 - # store_new_data_set 期望的格式:raw_paragraphs 的键是段落hash(不带前缀),值是段落文本 - new_raw_paragraphs = {} - new_triple_list_data = {} - - for pg_hash, passage in raw_paragraphs.items(): - key = f"paragraph-{pg_hash}" - if key not in embed_mgr.stored_pg_hashes: - new_raw_paragraphs[pg_hash] = passage - new_triple_list_data[pg_hash] = triple_list_data[pg_hash] - - if not new_raw_paragraphs: - return {"status": "success", "count": 0, "message": "内容已存在,无需重复导入"} - - # 2. 使用 EmbeddingManager 的标准方法存储段落、实体和关系的嵌入 - # store_new_data_set 会自动处理嵌入生成和存储 - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - await self._run_cancellable_executor(embed_mgr.store_new_data_set, new_raw_paragraphs, new_triple_list_data) - - # 3. 构建知识图谱(只需要三元组数据和embedding_manager) - await self._run_cancellable_executor(kg_mgr.build_kg, new_triple_list_data, embed_mgr) - - # 4. 持久化 - await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - return { - "status": "success", - "count": len(new_raw_paragraphs), - "message": f"成功导入 {len(new_raw_paragraphs)} 条知识", - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 导入操作被用户中断") - return {"status": "cancelled", "message": "导入操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 导入知识失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - async def search(self, query: str, top_k: int = 3) -> List[str]: - """ - 检索知识库。 - - Args: - query: 查询问题。 - top_k: 返回最相关的条目数。 - - Returns: - List[str]: 相关文段列表。 - """ - try: - _, _, qa_mgr = await self._get_managers() - # 直接调用 QAManager 的检索接口 - knowledge = qa_mgr.get_knowledge(query, top_k=top_k) - # 返回通常是拼接好的字符串,这里我们可以尝试按其内部规则切分回列表,或者直接返回 - return [knowledge] if knowledge else [] - except Exception as e: - logger.error(f"[Plugin API] 检索知识失败: {e}") - return [] - - async def delete(self, keyword: str, exact_match: bool = False) -> dict: - """ - 根据关键词或完整文段删除知识库内容。 - - Args: - keyword: 匹配关键词或完整文段。 - exact_match: 是否使用完整文段匹配(True=完全匹配,False=关键词模糊匹配)。 - - Returns: - dict: {"status": "success/info", "deleted_count": 删除条数, "message": "描述"} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 1. 查找匹配的段落 - to_delete_keys = [] - to_delete_hashes = [] - - for key, item in embed_mgr.paragraphs_embedding_store.store.items(): - if exact_match: - # 完整文段匹配 - if item.str.strip() == keyword.strip(): - to_delete_keys.append(key) - to_delete_hashes.append(key.replace("paragraph-", "", 1)) - else: - # 关键词模糊匹配 - if keyword in item.str: - to_delete_keys.append(key) - to_delete_hashes.append(key.replace("paragraph-", "", 1)) - - if not to_delete_keys: - match_type = "完整文段" if exact_match else "关键词" - return {"status": "info", "deleted_count": 0, "message": f"未找到匹配的内容({match_type}匹配)"} - - # 2. 执行删除 - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - - # a. 从向量库删除 - deleted_count, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, to_delete_keys - ) - embed_mgr.stored_pg_hashes = set(embed_mgr.paragraphs_embedding_store.store.keys()) - - # b. 从知识图谱删除 - # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 - # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs - delete_func = partial( - kg_mgr.delete_paragraphs, to_delete_hashes, ent_hashes=None, remove_orphan_entities=True - ) - await self._run_cancellable_executor(delete_func) - - # 3. 持久化 - await self._run_cancellable_executor(embed_mgr.rebuild_faiss_index) - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - match_type = "完整文段" if exact_match else "关键词" - return { - "status": "success", - "deleted_count": deleted_count, - "message": f"已成功删除 {deleted_count} 条相关知识({match_type}匹配)", - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 删除操作被用户中断") - return {"status": "cancelled", "message": "删除操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 删除知识失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - async def clear_all(self) -> dict: - """ - 清空整个LPMM知识库(删除所有段落、实体、关系和知识图谱数据)。 - - Returns: - dict: {"status": "success/error", "message": "描述", "stats": {...}} - """ - try: - embed_mgr, kg_mgr, _ = await self._get_managers() - - # 记录清空前的统计信息 - before_stats = { - "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), - "entities": len(embed_mgr.entities_embedding_store.store), - "relations": len(embed_mgr.relation_embedding_store.store), - "kg_nodes": len(kg_mgr.graph.get_node_list()), - "kg_edges": len(kg_mgr.graph.get_edge_list()), - } - - # 将同步阻塞操作放到线程池中执行,避免阻塞事件循环 - - # 1. 清空所有向量库 - # 获取所有keys - para_keys = list(embed_mgr.paragraphs_embedding_store.store.keys()) - ent_keys = list(embed_mgr.entities_embedding_store.store.keys()) - rel_keys = list(embed_mgr.relation_embedding_store.store.keys()) - - # 删除所有段落向量 - para_deleted, _ = await self._run_cancellable_executor( - embed_mgr.paragraphs_embedding_store.delete_items, para_keys - ) - embed_mgr.stored_pg_hashes.clear() - - # 删除所有实体向量 - if ent_keys: - ent_deleted, _ = await self._run_cancellable_executor( - embed_mgr.entities_embedding_store.delete_items, ent_keys - ) - else: - ent_deleted = 0 - - # 删除所有关系向量 - if rel_keys: - rel_deleted, _ = await self._run_cancellable_executor( - embed_mgr.relation_embedding_store.delete_items, rel_keys - ) - else: - rel_deleted = 0 - - # 2. 清空所有 embedding store 的索引和映射 - # 确保 faiss_index 和 idx2hash 也被重置,并删除旧的索引文件 - def _clear_embedding_indices(): - # 清空段落索引 - embed_mgr.paragraphs_embedding_store.faiss_index = None - embed_mgr.paragraphs_embedding_store.idx2hash = None - embed_mgr.paragraphs_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.paragraphs_embedding_store.index_file_path): - os.remove(embed_mgr.paragraphs_embedding_store.index_file_path) - if os.path.exists(embed_mgr.paragraphs_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.paragraphs_embedding_store.idx2hash_file_path) - - # 清空实体索引 - embed_mgr.entities_embedding_store.faiss_index = None - embed_mgr.entities_embedding_store.idx2hash = None - embed_mgr.entities_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.entities_embedding_store.index_file_path): - os.remove(embed_mgr.entities_embedding_store.index_file_path) - if os.path.exists(embed_mgr.entities_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.entities_embedding_store.idx2hash_file_path) - - # 清空关系索引 - embed_mgr.relation_embedding_store.faiss_index = None - embed_mgr.relation_embedding_store.idx2hash = None - embed_mgr.relation_embedding_store.dirty = False - # 删除旧的索引文件 - if os.path.exists(embed_mgr.relation_embedding_store.index_file_path): - os.remove(embed_mgr.relation_embedding_store.index_file_path) - if os.path.exists(embed_mgr.relation_embedding_store.idx2hash_file_path): - os.remove(embed_mgr.relation_embedding_store.idx2hash_file_path) - - await self._run_cancellable_executor(_clear_embedding_indices) - - # 3. 清空知识图谱 - # 获取所有段落hash - all_pg_hashes = list(kg_mgr.stored_paragraph_hashes) - if all_pg_hashes: - # 删除所有段落节点(这会自动清理相关的边和孤立实体) - # 注意:必须使用关键字参数,避免 True 被误当作 ent_hashes 参数 - # 使用 partial 来传递关键字参数,因为 run_in_executor 不支持 **kwargs - delete_func = partial( - kg_mgr.delete_paragraphs, all_pg_hashes, ent_hashes=None, remove_orphan_entities=True - ) - await self._run_cancellable_executor(delete_func) - - # 完全清空KG:创建新的空图(无论是否有段落hash都要执行) - from quick_algo import di_graph - - kg_mgr.graph = di_graph.DiGraph() - kg_mgr.stored_paragraph_hashes.clear() - kg_mgr.ent_appear_cnt.clear() - - # 4. 保存所有数据(此时所有store都是空的,索引也是None) - # 注意:即使store为空,save_to_file也会保存空的DataFrame,这是正确的 - await self._run_cancellable_executor(embed_mgr.save_to_file) - await self._run_cancellable_executor(kg_mgr.save_to_file) - - after_stats = { - "paragraphs": len(embed_mgr.paragraphs_embedding_store.store), - "entities": len(embed_mgr.entities_embedding_store.store), - "relations": len(embed_mgr.relation_embedding_store.store), - "kg_nodes": len(kg_mgr.graph.get_node_list()), - "kg_edges": len(kg_mgr.graph.get_edge_list()), - } - - return { - "status": "success", - "message": f"已成功清空LPMM知识库(删除 {para_deleted} 个段落、{ent_deleted} 个实体、{rel_deleted} 个关系)", - "stats": { - "before": before_stats, - "after": after_stats, - }, - } - - except asyncio.CancelledError: - logger.warning("[Plugin API] 清空操作被用户中断") - return {"status": "cancelled", "message": "清空操作已被用户中断"} - except Exception as e: - logger.error(f"[Plugin API] 清空知识库失败: {e}", exc_info=True) - return {"status": "error", "message": str(e)} - - -# 内部使用的单例 -lpmm_ops = LPMMOperations() diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 60586406..df7d28fc 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -360,6 +360,12 @@ class ChatBot: user_id = user_info.user_id group_id = group_info.group_id if group_info else None _ = await chat_manager.get_or_create_session(platform, user_id, group_id) # 确保会话存在 + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_incoming_message(message) + except Exception as exc: + logger.warning(f"[长期记忆自动总结] 注册会话总结器失败: {exc}") # message.update_chat_stream(chat) diff --git a/src/chat/message_receive/uni_message_sender.py b/src/chat/message_receive/uni_message_sender.py index 894af238..369c0c51 100644 --- a/src/chat/message_receive/uni_message_sender.py +++ b/src/chat/message_receive/uni_message_sender.py @@ -383,6 +383,13 @@ class UniversalMessageSender: with get_db_session() as db_session: db_session.add(message.to_db_instance()) + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_message_sent(message) + except Exception as exc: + logger.warning(f"[{chat_id}] 长期记忆人物事实写回注册失败: {exc}") + return sent_msg except Exception as e: diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 74b324be..003009b8 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -1,7 +1,6 @@ import traceback import time import asyncio -import importlib import random import re @@ -36,6 +35,7 @@ from src.services import llm_service as llm_api from src.chat.logger.plan_reply_logger import PlanReplyLogger from src.memory_system.memory_retrieval import init_memory_retrieval_sys, build_memory_retrieval_prompt +from src.memory_system.retrieval_tools import get_tool_registry from src.bw_learner.jargon_explainer_old import explain_jargon_in_context, retrieve_concepts_with_jargon from src.chat.utils.common_utils import TempMethodsExpression @@ -1164,29 +1164,14 @@ class DefaultReplyer: async def get_prompt_info(self, message: str, sender: str, target: str): related_info = "" start_time = time.time() - try: - knowledge_module = importlib.import_module("src.plugins.built_in.knowledge.lpmm_get_knowledge") - except ImportError: - logger.debug("LPMM知识库工具模块不存在,跳过获取知识库内容") - return "" - - search_knowledge_tool = getattr(knowledge_module, "SearchKnowledgeFromLPMMTool", None) + search_knowledge_tool = get_tool_registry().get_tool("search_long_term_memory") if search_knowledge_tool is None: - logger.debug("LPMM知识库工具未提供 SearchKnowledgeFromLPMMTool,跳过获取知识库内容") + logger.debug("长期记忆检索工具未注册,跳过获取知识内容") return "" - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 从LPMM知识库获取知识 + logger.debug(f"获取长期记忆内容,元消息:{message[:30]}...,消息长度: {len(message)}") try: - # 检查LPMM知识库是否启用 - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用,跳过获取知识库内容") - return "" - - if global_config.lpmm_knowledge.lpmm_mode == "agent": - return "" - - template_prompt = prompt_manager.get_prompt("lpmm_get_knowledge") + template_prompt = prompt_manager.get_prompt("memory_get_knowledge") template_prompt.add_context("bot_name", global_config.bot.nickname) template_prompt.add_context("time_now", lambda _: time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) template_prompt.add_context("chat_history", message) @@ -1202,24 +1187,31 @@ class DefaultReplyer: # logger.info(f"工具调用提示词: {prompt}") # logger.info(f"工具调用: {tool_calls}") - if tool_calls: - result = await self.tool_executor.execute_tool_call(tool_calls[0]) - end_time = time.time() - if not result or not result.get("content"): - logger.debug("从LPMM知识库获取知识失败,返回空知识...") - return "" - found_knowledge_from_lpmm = result.get("content", "") - logger.info( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - related_info += found_knowledge_from_lpmm - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" - else: - logger.debug("模型认为不需要使用LPMM知识库") + if not tool_calls: + logger.debug("模型认为不需要使用长期记忆") return "" + + related_chunks: List[str] = [] + for tool_call in tool_calls: + if tool_call.func_name != "search_long_term_memory": + continue + tool_args = dict(tool_call.args or {}) + tool_args.setdefault("chat_id", self.chat_stream.session_id) + result_text = await search_knowledge_tool.execute(**tool_args) + if result_text and "未找到" not in result_text: + related_chunks.append(result_text) + + if not related_chunks: + logger.debug("长期记忆未返回有效信息") + return "" + + related_info = "\n".join(related_chunks) + end_time = time.time() + logger.info(f"从长期记忆获取知识,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") + logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + + return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" except Exception as e: logger.error(f"获取知识库内容时发生异常: {str(e)}") return "" diff --git a/src/config/config.py b/src/config/config.py index a3b81d2d..fcda4d01 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -55,7 +55,7 @@ CONFIG_DIR: Path = PROJECT_ROOT / "config" BOT_CONFIG_PATH: Path = (CONFIG_DIR / "bot_config.toml").resolve().absolute() MODEL_CONFIG_PATH: Path = (CONFIG_DIR / "model_config.toml").resolve().absolute() MMC_VERSION: str = "1.0.0" -CONFIG_VERSION: str = "8.1.0" +CONFIG_VERSION: str = "8.1.1" MODEL_CONFIG_VERSION: str = "1.12.0" logger = get_logger("config") diff --git a/src/config/legacy_migration.py b/src/config/legacy_migration.py index 7baaa03e..7b400f82 100644 --- a/src/config/legacy_migration.py +++ b/src/config/legacy_migration.py @@ -94,6 +94,11 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool: ["", "enable", "enable", "enable"], ["qq:1919810:group", "enable", "enable", "enable"], ] + 兼容旧旧格式: + learning_list = [ + ["qq:1919810:group", "enable", "enable", "0.5"], + ["", "disable", "disable", "0.1"], + ] 新: [[expression.learning_list]] platform="", item_id="", rule_type="group", use_expression=true, enable_learning=true, enable_jargon_learning=true @@ -117,6 +122,16 @@ def _migrate_expression_learning_list(expr: dict[str, Any]) -> bool: use_expression = _parse_enable_disable(r[1]) enable_learning = _parse_enable_disable(r[2]) enable_jargon_learning = _parse_enable_disable(r[3]) + if enable_jargon_learning is None: + # 更早期的配置在第 4 列记录的是一个已废弃的数值权重/阈值, + # 当前 schema 已没有对应字段。这里按保守策略兼容迁移: + # 丢弃旧数值,并将 enable_jargon_learning 置为 False。 + try: + float(str(r[3])) + except (TypeError, ValueError): + pass + else: + enable_jargon_learning = False if use_expression is None or enable_learning is None or enable_jargon_learning is None: return False diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 2de01030..0b681748 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -416,6 +416,24 @@ class MemoryConfig(ConfigBase): ) """_wrap_全局记忆黑名单,当启用全局记忆时,不将特定聊天流纳入检索""" + long_term_auto_summary_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "book-open", + }, + ) + """是否自动启动聊天总结并导入长期记忆""" + + person_fact_writeback_enabled: bool = Field( + default=True, + json_schema_extra={ + "x-widget": "switch", + "x-icon": "user-round-pen", + }, + ) + """是否在发送回复后自动提取并写回人物事实到长期记忆""" + chat_history_topic_check_message_threshold: int = Field( default=80, ge=1, diff --git a/src/main.py b/src/main.py index 91da2d83..059aee62 100644 --- a/src/main.py +++ b/src/main.py @@ -6,7 +6,6 @@ import time from src.bw_learner.expression_auto_check_task import ExpressionAutoCheckTask from src.chat.emoji_system.emoji_manager import emoji_manager -from src.chat.knowledge import lpmm_start_up from src.chat.message_receive.bot import chat_bot from src.chat.message_receive.chat_manager import chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask @@ -19,6 +18,7 @@ from src.config.config import config_manager, global_config from src.manager.async_task_manager import async_task_manager from src.plugin_runtime.integration import get_plugin_runtime_manager from src.prompt.prompt_manager import prompt_manager +from src.services.memory_flow_service import memory_automation_service # from src.api.main import start_api_server @@ -88,9 +88,6 @@ class MainSystem: # start_api_server() # logger.info("API服务器启动成功") - # 启动LPMM - lpmm_start_up() - # 启动插件运行时(内置插件 + 第三方插件双子进程) await get_plugin_runtime_manager().start() @@ -103,6 +100,7 @@ class MainSystem: asyncio.create_task(chat_manager.regularly_save_sessions()) logger.info(t("startup.chat_manager_initialized")) + await memory_automation_service.start() # await asyncio.sleep(0.5) #防止logger输出飞了 @@ -164,6 +162,10 @@ async def main(): system.schedule_tasks(), ) finally: + await memory_automation_service.shutdown() + await get_plugin_runtime_manager().bridge_event("on_stop") + await get_plugin_runtime_manager().stop() + await async_task_manager.stop_and_wait_all_tasks() await config_manager.stop_file_watcher() diff --git a/src/memory_system/chat_history_summarizer.py b/src/memory_system/chat_history_summarizer.py index b984c66d..cedf971f 100644 --- a/src/memory_system/chat_history_summarizer.py +++ b/src/memory_system/chat_history_summarizer.py @@ -931,12 +931,14 @@ class ChatHistorySummarizer: else: logger.warning(f"{self.log_prefix} 存储聊天历史记录到数据库失败") - # 同时导入到LPMM知识库 - if global_config.lpmm_knowledge.enable: - await self._import_to_lpmm_knowledge( + if saved_record and saved_record.get("id") is not None: + await self._import_to_long_term_memory( + record_id=int(saved_record["id"]), theme=theme, summary=summary, participants=participants, + start_time=start_time, + end_time=end_time, original_text=original_text, ) @@ -947,76 +949,131 @@ class ChatHistorySummarizer: traceback.print_exc() raise - async def _import_to_lpmm_knowledge( + async def _import_to_long_term_memory( self, + record_id: int, theme: str, summary: str, participants: List[str], + start_time: float, + end_time: float, original_text: str, ): """ - 将聊天历史总结导入到LPMM知识库 + 将聊天历史总结导入到统一长期记忆 Args: + record_id: chat_history 主键 theme: 话题主题 summary: 概括内容 participants: 参与者列表 + start_time: 开始时间 + end_time: 结束时间 original_text: 原始文本(可能很长,需要截断) """ try: - from src.chat.knowledge.lpmm_ops import lpmm_ops + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" - # 构造要导入的文本内容 - # 格式:主题 + 概括 + 参与者信息 + 原始内容摘要 - # 注意:使用单换行符连接,确保整个内容作为一段导入,不被LPMM分段 content_parts = [] - - # 1. 话题主题 - # if theme: - # content_parts.append(f"话题:{theme}") - - # 2. 概括内容 + if theme: + content_parts.append(f"主题:{theme}") if summary: content_parts.append(f"概括:{summary}") - - # 3. 参与者信息 if participants: participants_text = "、".join(participants) content_parts.append(f"参与者:{participants_text}") - - # 4. 原始文本摘要(如果原始文本太长,只取前500字) - # if original_text: - # # 截断原始文本,避免过长 - # max_original_length = 500 - # if len(original_text) > max_original_length: - # truncated_text = original_text[:max_original_length] + "..." - # content_parts.append(f"原始内容摘要:{truncated_text}") - # else: - # content_parts.append(f"原始内容:{original_text}") - - # 将所有部分合并为一个完整段落(使用单换行符,避免被LPMM分段) - # LPMM使用 \n\n 作为段落分隔符,所以这里使用 \n 确保不会被分段 content_to_import = "\n".join(content_parts) if not content_to_import.strip(): - logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,跳过导入知识库") + logger.warning(f"{self.log_prefix} 聊天历史总结内容为空,改用插件侧 generate_from_chat 兜底") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, + ) return - # 调用lpmm_ops导入 - result = await lpmm_ops.add_content(text=content_to_import, auto_split=False) - - if result["status"] == "success": - logger.info( - f"{self.log_prefix} 成功将聊天历史总结导入到LPMM知识库 | 话题: {theme} | 新增段落数: {result.get('count', 0)}" - ) + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text=content_to_import, + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={"theme": theme, "original_text_length": len(original_text or "")}, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 聊天历史总结被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 成功将聊天历史总结导入到长期记忆 | 话题: {theme}") else: - logger.warning( - f"{self.log_prefix} 将聊天历史总结导入到LPMM知识库失败 | 话题: {theme} | 错误: {result.get('message', '未知错误')}" + logger.warning(f"{self.log_prefix} 将聊天历史总结导入到长期记忆失败,尝试插件侧兜底 | 话题: {theme} | 错误: {result.detail}") + await self._fallback_import_to_long_term_memory( + record_id=record_id, + theme=theme, + participants=participants, + start_time=start_time, + end_time=end_time, + original_text=original_text, ) except Exception as e: - # 导入失败不应该影响数据库存储,只记录错误 - logger.error(f"{self.log_prefix} 导入聊天历史总结到LPMM知识库时出错: {e}", exc_info=True) + logger.error(f"{self.log_prefix} 导入聊天历史总结到长期记忆时出错: {e}", exc_info=True) + + async def _fallback_import_to_long_term_memory( + self, + *, + record_id: int, + theme: str, + participants: List[str], + start_time: float, + end_time: float, + original_text: str, + ) -> None: + try: + from src.services.memory_service import memory_service + session = _chat_manager.get_session_by_session_id(self.session_id) + session_user_id = str(getattr(session, "user_id", "") or "").strip() if session else "" + session_group_id = str(getattr(session, "group_id", "") or "").strip() if session else "" + + result = await memory_service.ingest_summary( + external_id=f"chat_history:{record_id}", + chat_id=self.session_id, + text="", + participants=participants, + time_start=start_time, + time_end=end_time, + tags=[theme] if theme else [], + metadata={ + "theme": theme, + "original_text_length": len(original_text or ""), + "generate_from_chat": True, + "context_length": global_config.memory.chat_history_topic_check_message_threshold, + }, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"{self.log_prefix} 插件侧 generate_from_chat 兜底被聊天过滤策略跳过 | 话题: {theme}") + else: + logger.info(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入成功 | 话题: {theme}") + else: + logger.warning(f"{self.log_prefix} 插件侧 generate_from_chat 兜底导入失败 | 话题: {theme} | 错误: {result.detail}") + except Exception as exc: + logger.error(f"{self.log_prefix} 插件侧兜底导入长期记忆失败: {exc}", exc_info=True) async def start(self): """启动后台定期检查循环""" diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index 4193a16a..49e5ca02 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -237,8 +237,8 @@ async def _react_agent_solve_question( if first_head_prompt is None: # 第一次构建,使用初始的collected_info(即initial_info) initial_collected_info = initial_info or "" - # 使用 LPMM 知识库检索 prompt - first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_lpmm") + # 使用统一长期记忆检索 prompt + first_head_prompt_template = prompt_manager.get_prompt("memory_retrieval_react_prompt_head_memory") first_head_prompt_template.add_context("bot_name", bot_name) first_head_prompt_template.add_context("time_now", time_now) first_head_prompt_template.add_context("chat_history", chat_history) diff --git a/src/memory_system/retrieval_tools/__init__.py b/src/memory_system/retrieval_tools/__init__.py index 9f2673b2..ba5f731f 100644 --- a/src/memory_system/retrieval_tools/__init__.py +++ b/src/memory_system/retrieval_tools/__init__.py @@ -10,21 +10,17 @@ from .tool_registry import ( get_tool_registry, ) -# 导入所有工具的注册函数 -from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge -from .query_words import register_tool as register_query_words -from .return_information import register_tool as register_return_information -from src.config.config import global_config - def init_all_tools(): """初始化并注册所有记忆检索工具""" + # 延迟导入,避免在仅使用部分工具或单元测试阶段触发不必要的依赖链。 + from .query_long_term_memory import register_tool as register_long_term_memory + from .query_words import register_tool as register_query_words + from .return_information import register_tool as register_return_information + register_query_words() register_return_information() - - # LPMM知识库检索工具 - if global_config.lpmm_knowledge.lpmm_mode == "agent": - register_lpmm_knowledge() + register_long_term_memory() __all__ = [ diff --git a/src/memory_system/retrieval_tools/query_long_term_memory.py b/src/memory_system/retrieval_tools/query_long_term_memory.py new file mode 100644 index 00000000..57202f34 --- /dev/null +++ b/src/memory_system/retrieval_tools/query_long_term_memory.py @@ -0,0 +1,304 @@ +"""通过统一长期记忆服务查询信息。""" + +from __future__ import annotations + +import re +from calendar import monthrange +from datetime import datetime, timedelta +from typing import Iterable, Literal, Tuple + +from src.common.logger import get_logger +from src.services.memory_service import MemoryHit, MemorySearchResult, memory_service + +from .tool_registry import register_memory_retrieval_tool + +logger = get_logger("memory_retrieval_tools") + +_SUPPORTED_MODES = {"search", "time", "episode", "aggregate"} +_RELATIVE_DAYS_RE = re.compile(r"^最近\s*(\d+)\s*天$") +_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") +_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}\s+\d{2}:\d{2}$") +_TIME_EXPRESSION_HELP = ( + "请改用更具体的时间表达,例如:今天、昨天、前天、本周、上周、本月、上月、最近7天、" + "2026/03/18、2026/03/18 09:30。" +) + + +def _format_query_datetime(dt: datetime) -> str: + return dt.strftime("%Y/%m/%d %H:%M") + + +def _resolve_time_expression( + expression: str, + *, + now: datetime | None = None, +) -> Tuple[float, float, str, str]: + clean = str(expression or "").strip() + if not clean: + raise ValueError(f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}") + + current = now or datetime.now() + day_start = current.replace(hour=0, minute=0, second=0, microsecond=0) + + if clean == "今天": + start = day_start + end = day_start.replace(hour=23, minute=59) + elif clean == "昨天": + start = day_start - timedelta(days=1) + end = start.replace(hour=23, minute=59) + elif clean == "前天": + start = day_start - timedelta(days=2) + end = start.replace(hour=23, minute=59) + elif clean == "本周": + start = day_start - timedelta(days=day_start.weekday()) + end = start + timedelta(days=6, hours=23, minutes=59) + elif clean == "上周": + this_week_start = day_start - timedelta(days=day_start.weekday()) + start = this_week_start - timedelta(days=7) + end = start + timedelta(days=6, hours=23, minutes=59) + elif clean == "本月": + start = day_start.replace(day=1) + last_day = monthrange(start.year, start.month)[1] + end = start.replace(day=last_day, hour=23, minute=59) + elif clean == "上月": + year = day_start.year + month = day_start.month - 1 + if month == 0: + year -= 1 + month = 12 + start = day_start.replace(year=year, month=month, day=1) + last_day = monthrange(year, month)[1] + end = start.replace(day=last_day, hour=23, minute=59) + else: + relative_match = _RELATIVE_DAYS_RE.fullmatch(clean) + if relative_match: + days = max(1, int(relative_match.group(1))) + start = day_start - timedelta(days=max(0, days - 1)) + end = day_start.replace(hour=23, minute=59) + elif _DATE_RE.fullmatch(clean): + start = datetime.strptime(clean, "%Y/%m/%d") + end = start.replace(hour=23, minute=59) + elif _MINUTE_RE.fullmatch(clean): + start = datetime.strptime(clean, "%Y/%m/%d %H:%M") + end = start + else: + raise ValueError(f"时间表达“{clean}”无法解析。{_TIME_EXPRESSION_HELP}") + + return start.timestamp(), end.timestamp(), _format_query_datetime(start), _format_query_datetime(end) + + +def _extract_time_label(metadata: dict) -> str: + if not isinstance(metadata, dict): + return "" + start = metadata.get("event_time_start") + end = metadata.get("event_time_end") + event_time = metadata.get("event_time") + + def _fmt(value: object) -> str: + if value in {None, ""}: + return "" + try: + return datetime.fromtimestamp(float(value)).strftime("%Y/%m/%d %H:%M") + except Exception: + return str(value) + + start_text = _fmt(start or event_time) + end_text = _fmt(end) + if start_text and end_text: + return f"{start_text} - {end_text}" + return start_text or end_text + + +def _truncate(text: str, limit: int = 160) -> str: + compact = str(text or "").strip().replace("\n", " ") + if len(compact) <= limit: + return compact + return compact[:limit] + "..." + + +def _format_search_lines(hits: Iterable[MemoryHit], *, limit: int, include_time: bool = False) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + time_label = _extract_time_label(item.metadata) if include_time else "" + prefix = f"[{time_label}] " if time_label else "" + lines.append(f"{index}. {prefix}{_truncate(item.content)}") + return "\n".join(lines) + + +def _format_episode_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + metadata = item.metadata if isinstance(item.metadata, dict) else {} + title = str(item.title or "").strip() or "未命名事件" + summary = _truncate(item.content, limit=180) + participants = [str(x).strip() for x in (metadata.get("participants") or []) if str(x).strip()] + keywords = [str(x).strip() for x in (metadata.get("keywords") or []) if str(x).strip()] + extras = [] + if participants: + extras.append(f"参与者:{'、'.join(participants[:4])}") + if keywords: + extras.append(f"关键词:{'、'.join(keywords[:6])}") + time_label = _extract_time_label(metadata) + if time_label: + extras.append(f"时间:{time_label}") + suffix = f"({';'.join(extras)})" if extras else "" + lines.append(f"{index}. 事件《{title}》:{summary}{suffix}") + return "\n".join(lines) + + +def _format_aggregate_lines(hits: Iterable[MemoryHit], *, limit: int) -> str: + lines = [] + for index, item in enumerate(list(hits)[: max(1, int(limit))], start=1): + metadata = item.metadata if isinstance(item.metadata, dict) else {} + source_branches = [str(x).strip() for x in (metadata.get("source_branches") or []) if str(x).strip()] + branch_text = f"[{','.join(source_branches)}]" if source_branches else "" + item_type = str(item.hit_type or "").strip().lower() or "memory" + if item_type == "episode": + title = str(item.title or "").strip() or "未命名事件" + lines.append(f"{index}. {branch_text}[episode] 《{title}》:{_truncate(item.content, 160)}") + else: + lines.append(f"{index}. {branch_text}[{item_type}] {_truncate(item.content, 160)}") + return "\n".join(lines) + + +def _format_tool_result( + *, + result: MemorySearchResult, + mode: Literal["search", "time", "episode", "aggregate"], + limit: int, + query: str, + time_range_text: str = "", +) -> str: + if not result.hits: + if mode == "time": + return f"在指定时间范围内未找到相关的长期记忆{time_range_text}" + if mode == "episode": + return f"未找到与“{query}”相关的事件或情节记忆" + if mode == "aggregate": + return f"未找到可用于综合回忆的长期记忆线索{f'(query:{query})' if query else ''}" + return f"在长期记忆中未找到与“{query}”相关的信息" + + if mode == "episode": + text = _format_episode_lines(result.hits, limit=limit) + return f"你从长期记忆的事件/情节中找到以下信息:\n{text}" + + if mode == "aggregate": + text = _format_aggregate_lines(result.hits, limit=limit) + return f"你从长期记忆中综合找到了以下线索:\n{text}" + + if mode == "time": + text = _format_search_lines(result.hits, limit=limit, include_time=True) + return f"你从指定时间范围内的长期记忆中找到以下信息{time_range_text}:\n{text}" + + text = _format_search_lines(result.hits, limit=limit) + return f"你从长期记忆中找到以下信息:\n{text}" + + +async def query_long_term_memory( + query: str = "", + limit: int = 5, + chat_id: str = "", + person_id: str = "", + mode: str = "search", + time_expression: str = "", +) -> str: + content = str(query or "").strip() + safe_limit = max(1, int(limit or 5)) + normalized_mode = str(mode or "search").strip().lower() or "search" + if normalized_mode not in _SUPPORTED_MODES: + return f"不支持的长期记忆检索模式:{normalized_mode}。可用模式:search、time、episode、aggregate。" + + if normalized_mode == "search" and not content: + return "查询关键词为空,请提供你想查找的长期记忆内容。" + if normalized_mode == "time" and not str(time_expression or "").strip(): + return f"time 模式需要提供 time_expression。{_TIME_EXPRESSION_HELP}" + if normalized_mode in {"episode", "aggregate"} and not content and not str(time_expression or "").strip(): + return f"{normalized_mode} 模式至少需要提供 query 或 time_expression。" + + time_start = None + time_end = None + time_range_text = "" + if str(time_expression or "").strip(): + try: + time_start, time_end, time_start_text, time_end_text = _resolve_time_expression(time_expression) + except ValueError as exc: + return str(exc) + time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})" + + backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode + + try: + result = await memory_service.search( + content, + limit=safe_limit, + mode=backend_mode, + chat_id=str(chat_id or "").strip(), + person_id=str(person_id or "").strip(), + time_start=time_start, + time_end=time_end, + ) + text = _format_tool_result( + result=result, + mode=normalized_mode, # type: ignore[arg-type] + limit=safe_limit, + query=content, + time_range_text=time_range_text, + ) + logger.debug(f"长期记忆查询结果({normalized_mode}): {text}") + return text + except Exception as exc: + logger.error(f"长期记忆查询失败: {exc}") + return f"长期记忆查询失败:{exc}" + + +def register_tool(): + register_memory_retrieval_tool( + name="search_long_term_memory", + description=( + "从长期记忆中检索信息。支持 search(普通事实检索)、time(按时间范围检索)、" + "episode(按事件/情节检索)、aggregate(综合检索)四种模式。" + ), + parameters=[ + { + "name": "query", + "type": "string", + "description": "需要查询的问题。search 模式建议用自然语言问句;time/episode/aggregate 模式也可用关键词短语。", + "required": False, + }, + { + "name": "mode", + "type": "string", + "description": "检索模式:search(普通长期记忆)、time(按时间窗口)、episode(事件/情节)、aggregate(综合检索)。", + "required": False, + "enum": ["search", "time", "episode", "aggregate"], + }, + { + "name": "limit", + "type": "integer", + "description": "希望返回的相关知识条数,默认为5", + "required": False, + }, + { + "name": "chat_id", + "type": "string", + "description": "当前聊天流ID,可选。提供后优先检索当前聊天上下文相关的长期记忆。", + "required": False, + }, + { + "name": "person_id", + "type": "string", + "description": "相关人物ID,可选。提供后优先检索该人物相关的长期记忆。", + "required": False, + }, + { + "name": "time_expression", + "type": "string", + "description": ( + "时间表达,可选。time 模式必填;episode/aggregate 模式可选。支持:今天、昨天、前天、本周、上周、本月、上月、" + "最近N天,以及 YYYY/MM/DD、YYYY/MM/DD HH:mm。" + ), + "required": False, + }, + ], + execute_func=query_long_term_memory, + ) diff --git a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py b/src/memory_system/retrieval_tools/query_lpmm_knowledge.py deleted file mode 100644 index eed01af1..00000000 --- a/src/memory_system/retrieval_tools/query_lpmm_knowledge.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -通过LPMM知识库查询信息 - 工具实现 -""" - -from src.common.logger import get_logger -from src.config.config import global_config -from src.chat.knowledge import get_qa_manager -from .tool_registry import register_memory_retrieval_tool - -logger = get_logger("memory_retrieval_tools") - - -async def query_lpmm_knowledge(query: str, limit: int = 5) -> str: - """在LPMM知识库中查询相关信息 - - Args: - query: 查询关键词 - - Returns: - str: 查询结果 - """ - try: - content = str(query).strip() - if not content: - return "查询关键词为空" - - try: - limit_value = int(limit) - except (TypeError, ValueError): - limit_value = 5 - limit_value = max(1, limit_value) - - if not global_config.lpmm_knowledge.enable: - logger.debug("LPMM知识库未启用") - return "LPMM知识库未启用" - - qa_manager = get_qa_manager() - if qa_manager is None: - logger.debug("LPMM知识库未初始化,跳过查询") - return "LPMM知识库未初始化" - - knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value) - logger.debug(f"LPMM知识库查询结果: {knowledge_info}") - - if knowledge_info: - return f"你从LPMM知识库中找到以下信息:\n{knowledge_info}" - - return f"在LPMM知识库中未找到与“{content}”相关的信息" - - except Exception as e: - logger.error(f"LPMM知识库查询失败: {e}") - return f"LPMM知识库查询失败:{str(e)}" - - -def register_tool(): - """注册LPMM知识库查询工具""" - register_memory_retrieval_tool( - name="lpmm_search_knowledge", - description="从知识库中搜索相关信息,适用于需要知识支持的场景。使用自然语言问句检索", - parameters=[ - { - "name": "query", - "type": "string", - "description": "需要查询的问题,使用一句疑问句提问,例如:什么是AI?", - "required": True, - }, - { - "name": "limit", - "type": "integer", - "description": "希望返回的相关知识条数,默认为5", - "required": False, - }, - ], - execute_func=query_lpmm_knowledge, - ) diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 799f56a0..960de4aa 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -6,9 +6,10 @@ import random import math from json_repair import repair_json -from typing import Union, Optional, Dict +from typing import Union, Optional, Dict, List from datetime import datetime +from sqlalchemy import or_ from sqlmodel import col, select from src.common.logger import get_logger @@ -17,6 +18,7 @@ from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config from src.chat.message_receive.chat_manager import chat_manager as _chat_manager +from src.services.memory_service import memory_service logger = get_logger("person_info") @@ -37,16 +39,60 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" + clean_name = str(person_name or "").strip() + if not clean_name: + return "" try: with get_db_session() as session: - statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1) + statement = ( + select(PersonInfo) + .where( + or_( + col(PersonInfo.person_name) == clean_name, + col(PersonInfo.user_nickname) == clean_name, + ) + ) + .limit(1) + ) + record = session.exec(statement).first() + if record and record.person_id: + return record.person_id + + statement = ( + select(PersonInfo) + .where(PersonInfo.group_cardname.contains(clean_name)) + .limit(1) + ) record = session.exec(statement).first() return record.person_id if record else "" except Exception as e: - logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") + logger.error(f"根据用户名 {clean_name} 获取用户ID时出错: {e}") return "" +def resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Optional[Union[int, str]] = None, +) -> str: + """统一人物记忆链路中的 person_id 解析。 + + 优先使用已知的人物名称/别名,其次退回到平台 + user_id 的稳定 ID。 + """ + name_token = str(person_name or "").strip() + if name_token: + resolved = get_person_id_by_person_name(name_token) + if resolved: + return resolved + + platform_token = str(platform or "").strip() + user_token = str(user_id or "").strip() + if platform_token and user_token: + return get_person_id(platform_token, user_token) + return "" + + def is_person_known( person_id: Optional[str] = None, user_id: Optional[str] = None, @@ -537,79 +583,79 @@ class Person: async def build_relationship(self, chat_content: str = "", info_type=""): if not self.is_known: return "" - # 构建points文本 - nickname_str = "" if self.person_name != self.nickname: nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})" - relation_info = "" + async def _select_traits(query_text: str, traits: List[str], limit: int = 3) -> List[str]: + clean_traits = [trait.strip() for trait in traits if isinstance(trait, str) and trait.strip()] + if not clean_traits: + return [] + if not query_text: + return clean_traits[:limit] - points_text = "" - category_list = self.get_all_category() + numbered_traits = "\n".join(f"{index}. {trait}" for index, trait in enumerate(clean_traits, start=1)) + prompt = f"""当前关注内容: +{query_text} - if chat_content: - prompt = f"""当前聊天内容: -{chat_content} +候选人物信息: +{numbered_traits} -分类列表: -{category_list} -**要求**:请你根据当前聊天内容,从以下分类中选择一个与聊天内容相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: -例如: -<分类1><分类2><分类3>...... -如果没有相关的分类,请输出""" +请从候选人物信息中选择与当前关注内容最相关的编号,并用<>包裹输出,不要输出其他内容。 +例如: +<1><3> +如果都不相关,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) - # print(prompt) - # print(response) - category_list = extract_categories_from_response(response) - if "none" not in category_list: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 2) - if random_memory: - random_memory_str = "\n".join( - [get_memory_content_from_memory(memory) for memory in random_memory] - ) - points_text = f"有关 {category} 的内容:{random_memory_str}" - break - elif info_type: - prompt = f"""你需要获取用户{self.person_name}的 **{info_type}** 信息。 + try: + response, _ = await relation_selection_model.generate_response_async(prompt) + selected_traits: List[str] = [] + for raw_index in extract_categories_from_response(response): + if raw_index == "none": + return [] + try: + trait_index = int(raw_index) - 1 + except ValueError: + continue + if 0 <= trait_index < len(clean_traits): + trait = clean_traits[trait_index] + if trait not in selected_traits: + selected_traits.append(trait) + if selected_traits: + return selected_traits[:limit] + except Exception as e: + logger.debug(f"筛选人物画像信息失败,使用默认画像摘要: {e}") -现有信息类别列表: -{category_list} -**要求**:请你根据**{info_type}**,从以下分类中选择一个与**{info_type}**相关的分类,并用<>包裹输出,不要输出其他内容,不要输出引号或[],严格用<>包裹: -例如: -<分类1><分类2><分类3>...... -如果没有相关的分类,请输出""" - response, _ = await relation_selection_model.generate_response_async(prompt) - # print(prompt) - # print(response) - category_list = extract_categories_from_response(response) - if "none" not in category_list: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 3) - if random_memory: - random_memory_str = "\n".join( - [get_memory_content_from_memory(memory) for memory in random_memory] - ) - points_text = f"有关 {category} 的内容:{random_memory_str}" - break - else: - for category in category_list: - random_memory = self.get_random_memory_by_category(category, 1)[0] - if random_memory: - points_text = f"有关 {category} 的内容:{get_memory_content_from_memory(random_memory)}" - break + return clean_traits[:limit] + + profile = await memory_service.get_person_profile(self.person_id, limit=8) + relation_parts: List[str] = [] + if profile.summary.strip(): + relation_parts.append(profile.summary.strip()) + + query_text = str(chat_content or info_type or "").strip() + selected_traits = await _select_traits(query_text, profile.traits, limit=3) + if not selected_traits and not query_text: + selected_traits = [trait for trait in profile.traits if trait][:2] + + for trait in selected_traits: + clean_trait = str(trait).strip() + if clean_trait and clean_trait not in relation_parts: + relation_parts.append(clean_trait) + + for evidence in profile.evidence: + content = str(evidence.get("content", "") or "").strip() + if content and content not in relation_parts: + relation_parts.append(content) + if len(relation_parts) >= 4: + break points_info = "" - if points_text: - points_info = f"你还记得有关{self.person_name}的内容:{points_text}" + if relation_parts: + points_info = f"你还记得有关{self.person_name}的内容:{';'.join(relation_parts[:3])}" if not (nickname_str or points_info): return "" - relation_info = f"{self.person_name}:{nickname_str}{points_info}" - - return relation_info + return f"{self.person_name}:{nickname_str}{points_info}" class PersonInfoManager: @@ -776,7 +822,7 @@ person_info_manager = PersonInfoManager() async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None: - """将人物信息存入person_info的memory_points + """将人物事实写入统一长期记忆 Args: person_name: 人物名称 @@ -784,6 +830,11 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: 聊天ID """ try: + content = str(memory_content or "").strip() + if not content: + logger.debug("人物记忆内容为空,跳过写入") + return + # 从 chat_id 获取 session session = _chat_manager.get_session_by_session_id(chat_id) if not session: @@ -794,16 +845,14 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, # 尝试从person_name查找person_id # 首先尝试通过person_name查找 - person_id = get_person_id_by_person_name(person_name) - + person_id = resolve_person_id_for_memory( + person_name=person_name, + platform=platform, + user_id=session.user_id, + ) if not person_id: - # 如果通过person_name找不到,尝试从 session 获取 user_id - if platform and session.user_id: - user_id = session.user_id - person_id = get_person_id(platform, user_id) - else: - logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") - return + logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") + return # 创建或获取Person对象 person = Person(person_id=person_id) @@ -812,39 +861,34 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆") return - # 确定记忆分类(可以根据memory_content判断,这里使用通用分类) - category = "其他" # 默认分类,可以根据需要调整 + memory_hash = hashlib.sha256(f"{person_id}\n{content}".encode("utf-8")).hexdigest()[:16] + result = await memory_service.ingest_text( + external_id=f"person_fact:{person_id}:{memory_hash}", + source_type="person_fact", + text=content, + chat_id=chat_id, + person_ids=[person_id], + participants=[person.person_name or person_name], + timestamp=time.time(), + tags=["person_fact"], + metadata={ + "person_id": person_id, + "person_name": person.person_name or person_name, + "platform": platform, + "source": "person_info.store_person_memory_from_answer", + }, + respect_filter=True, + user_id=str(session.user_id or "").strip(), + group_id=str(session.group_id or "").strip(), + ) - # 记忆点格式:category:content:weight - weight = "1.0" # 默认权重 - memory_point = f"{category}:{memory_content}:{weight}" - - # 添加到memory_points - if not person.memory_points: - person.memory_points = [] - - # 检查是否已存在相似的记忆点(避免重复) - is_duplicate = False - for existing_point in person.memory_points: - if existing_point and isinstance(existing_point, str): - parts = existing_point.split(":", 2) - if len(parts) >= 2: - existing_content = parts[1].strip() - # 简单相似度检查(如果内容相同或非常相似,则跳过) - if ( - existing_content == memory_content - or memory_content in existing_content - or existing_content in memory_content - ): - is_duplicate = True - break - - if not is_duplicate: - person.memory_points.append(memory_point) - person.sync_to_database() - logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}") + if result.success: + if result.detail == "chat_filtered": + logger.debug(f"人物长期记忆被聊天过滤策略跳过: {person_name} (person_id: {person_id})") + else: + logger.info(f"成功写入人物长期记忆: {person_name} (person_id: {person_id})") else: - logger.debug(f"记忆点已存在,跳过: {memory_point}") + logger.warning(f"写入人物长期记忆失败: {person_name} (person_id: {person_id}) | {result.detail}") except Exception as e: logger.error(f"存储人物记忆失败: {e}") diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index c4ae0a56..06ddf5de 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -672,12 +672,10 @@ class RuntimeDataCapabilityMixin: limit_value = 5 try: - from src.chat.knowledge import qa_manager + from src.services.memory_service import memory_service - if qa_manager is None: - return {"success": True, "content": "LPMM知识库已禁用"} - - knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value) + result = await memory_service.search(query, limit=limit_value) + knowledge_info = result.to_text(limit=limit_value) content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识" return {"success": True, "content": content} except Exception as e: diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py new file mode 100644 index 00000000..96062eb6 --- /dev/null +++ b/src/services/memory_flow_service.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any, Dict, List, Optional + +from json_repair import repair_json + +from src.chat.utils.utils import is_bot_self +from src.common.message_repository import find_messages +from src.common.logger import get_logger +from src.config.config import global_config, model_config +from src.llm_models.utils_model import LLMRequest +from src.memory_system.chat_history_summarizer import ChatHistorySummarizer +from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer + +logger = get_logger("memory_flow_service") + + +class LongTermMemorySessionManager: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self._summarizers: Dict[str, ChatHistorySummarizer] = {} + + async def on_message(self, message: Any) -> None: + if not bool(getattr(global_config.memory, "long_term_auto_summary_enabled", True)): + return + session_id = str(getattr(message, "session_id", "") or "").strip() + if not session_id: + return + + created = False + async with self._lock: + summarizer = self._summarizers.get(session_id) + if summarizer is None: + summarizer = ChatHistorySummarizer(session_id=session_id) + self._summarizers[session_id] = summarizer + created = True + if created: + await summarizer.start() + + async def shutdown(self) -> None: + async with self._lock: + items = list(self._summarizers.items()) + self._summarizers.clear() + for session_id, summarizer in items: + try: + await summarizer.stop() + except Exception as exc: + logger.warning("停止聊天总结器失败: session=%s err=%s", session_id, exc) + + +class PersonFactWritebackService: + def __init__(self) -> None: + self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + self._extractor = LLMRequest( + model_set=model_config.model_task_config.utils, + request_type="person_fact_writeback", + ) + + async def start(self) -> None: + if self._worker_task is not None and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop(), name="memory_person_fact_writeback") + + async def shutdown(self) -> None: + self._stopping = True + worker = self._worker_task + self._worker_task = None + if worker is None: + return + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning("关闭人物事实写回 worker 失败: %s", exc) + + async def enqueue(self, message: Any) -> None: + if not bool(getattr(global_config.memory, "person_fact_writeback_enabled", True)): + return + if self._stopping: + return + try: + self._queue.put_nowait(message) + except asyncio.QueueFull: + logger.warning("人物事实写回队列已满,跳过本次回复") + + async def _worker_loop(self) -> None: + try: + while not self._stopping: + message = await self._queue.get() + try: + await self._handle_message(message) + except Exception as exc: + logger.warning("人物事实写回处理失败: %s", exc, exc_info=True) + finally: + self._queue.task_done() + except asyncio.CancelledError: + raise + + async def _handle_message(self, message: Any) -> None: + reply_text = str(getattr(message, "processed_plain_text", "") or "").strip() + if not reply_text: + return + if self._looks_ephemeral(reply_text): + return + + target_person = self._resolve_target_person(message) + if target_person is None or not target_person.is_known: + return + + facts = await self._extract_facts(target_person, reply_text) + if not facts: + return + + session_id = str( + getattr(message, "session_id", "") + or getattr(getattr(message, "session", None), "session_id", "") + or "" + ).strip() + if not session_id: + return + + person_name = str(getattr(target_person, "person_name", "") or getattr(target_person, "nickname", "") or "").strip() + if not person_name: + return + + for fact in facts: + await store_person_memory_from_answer(person_name, fact, session_id) + + def _resolve_target_person(self, message: Any) -> Optional[Person]: + session = getattr(message, "session", None) + session_platform = str(getattr(session, "platform", "") or getattr(message, "platform", "") or "").strip() + session_user_id = str(getattr(session, "user_id", "") or "").strip() + group_id = str(getattr(session, "group_id", "") or "").strip() + + if session_platform and session_user_id and not group_id: + if is_bot_self(session_platform, session_user_id): + return None + person_id = get_person_id(session_platform, session_user_id) + person = Person(person_id=person_id) + return person if person.is_known else None + + reply_to = str(getattr(message, "reply_to", "") or "").strip() + if not reply_to: + return None + try: + replies = find_messages(message_id=reply_to, limit=1) + except Exception as exc: + logger.debug("查询 reply_to 目标失败: %s", exc) + return None + if not replies: + return None + reply_message = replies[0] + reply_platform = str(getattr(reply_message, "platform", "") or session_platform or "").strip() + reply_user_info = getattr(getattr(reply_message, "message_info", None), "user_info", None) + reply_user_id = str(getattr(reply_user_info, "user_id", "") or "").strip() + if not reply_platform or not reply_user_id or is_bot_self(reply_platform, reply_user_id): + return None + person_id = get_person_id(reply_platform, reply_user_id) + person = Person(person_id=person_id) + return person if person.is_known else None + + async def _extract_facts(self, person: Person, reply_text: str) -> List[str]: + person_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or person.person_id) + prompt = f"""你要从一条机器人刚刚发送的回复中,提取“关于{person_name}的稳定事实”。 + +目标人物:{person_name} +机器人回复: +{reply_text} + +请只提取满足以下条件的事实: +1. 明确是关于目标人物本人的信息。 +2. 具有相对稳定性,可以作为长期记忆保存。 +3. 用简洁中文陈述句表达。 + +不要提取: +- 机器人的情绪、计划、临时动作、客套话 +- 只适用于当前时刻的短期安排 +- 不确定、猜测、反问 +- 与目标人物无关的信息 + +严格输出 JSON 数组,例如: +["他喜欢深夜打游戏", "他养了一只猫"] +如果没有可写入的事实,输出 []""" + try: + response, _ = await self._extractor.generate_response_async(prompt) + except Exception as exc: + logger.debug("人物事实提取模型调用失败: %s", exc) + return [] + return self._parse_fact_list(response) + + @staticmethod + def _parse_fact_list(raw: str) -> List[str]: + text = str(raw or "").strip() + if not text: + return [] + try: + repaired = repair_json(text) + payload = json.loads(repaired) if isinstance(repaired, str) else repaired + except Exception: + payload = None + if not isinstance(payload, list): + return [] + + items: List[str] = [] + seen = set() + for item in payload: + fact = str(item or "").strip().strip("- ") + if not fact or len(fact) < 4: + continue + if fact in seen: + continue + seen.add(fact) + items.append(fact) + return items[:5] + + @staticmethod + def _looks_ephemeral(text: str) -> bool: + content = str(text or "").strip() + if not content: + return True + ephemeral_markers = ( + "哈哈", + "好的", + "收到", + "嗯嗯", + "晚安", + "早安", + "拜拜", + "谢谢", + "在吗", + "?", + ) + if len(content) <= 8 and any(marker in content for marker in ephemeral_markers): + return True + return False + + +class MemoryAutomationService: + def __init__(self) -> None: + self.session_manager = LongTermMemorySessionManager() + self.fact_writeback = PersonFactWritebackService() + self._started = False + + async def start(self) -> None: + if self._started: + return + await self.fact_writeback.start() + self._started = True + + async def shutdown(self) -> None: + if not self._started: + return + await self.session_manager.shutdown() + await self.fact_writeback.shutdown() + self._started = False + + async def on_incoming_message(self, message: Any) -> None: + if not self._started: + await self.start() + await self.session_manager.on_message(message) + + async def on_message_sent(self, message: Any) -> None: + if not self._started: + await self.start() + await self.fact_writeback.enqueue(message) + + +memory_automation_service = MemoryAutomationService() diff --git a/src/services/memory_service.py b/src/services/memory_service.py new file mode 100644 index 00000000..6cbecd63 --- /dev/null +++ b/src/services/memory_service.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger +from src.plugin_runtime.integration import get_plugin_runtime_manager + + +logger = get_logger("memory_service") + +PLUGIN_ID = "A_Memorix" + + +@dataclass +class MemoryHit: + content: str + score: float = 0.0 + hit_type: str = "" + source: str = "" + hash_value: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + episode_id: str = "" + title: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "content": self.content, + "score": self.score, + "type": self.hit_type, + "source": self.source, + "hash": self.hash_value, + "metadata": self.metadata, + "episode_id": self.episode_id, + "title": self.title, + } + + +@dataclass +class MemorySearchResult: + summary: str = "" + hits: List[MemoryHit] = field(default_factory=list) + filtered: bool = False + + def to_text(self, limit: int = 5) -> str: + if not self.hits: + return "" + lines = [] + for index, item in enumerate(self.hits[: max(1, int(limit))], start=1): + content = item.content.strip().replace("\n", " ") + if len(content) > 160: + content = content[:160] + "..." + lines.append(f"{index}. {content}") + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + return { + "summary": self.summary, + "hits": [item.to_dict() for item in self.hits], + "filtered": self.filtered, + } + + +@dataclass +class MemoryWriteResult: + success: bool + stored_ids: List[str] = field(default_factory=list) + skipped_ids: List[str] = field(default_factory=list) + detail: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "success": self.success, + "stored_ids": self.stored_ids, + "skipped_ids": self.skipped_ids, + "detail": self.detail, + } + + +@dataclass +class PersonProfileResult: + summary: str = "" + traits: List[str] = field(default_factory=list) + evidence: List[Dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return {"summary": self.summary, "traits": self.traits, "evidence": self.evidence} + + +class MemoryService: + async def _invoke(self, component_name: str, args: Optional[Dict[str, Any]] = None, *, timeout_ms: int = 30000) -> Any: + runtime = get_plugin_runtime_manager() + if not runtime.is_running: + raise RuntimeError("plugin_runtime 未启动") + return await runtime.invoke_plugin( + method="plugin.invoke_tool", + plugin_id=PLUGIN_ID, + component_name=component_name, + args=args or {}, + timeout_ms=max(1000, int(timeout_ms or 30000)), + ) + + async def _invoke_admin( + self, + component_name: str, + *, + action: str, + timeout_ms: int = 30000, + **kwargs, + ) -> Dict[str, Any]: + payload = await self._invoke(component_name, {"action": action, **kwargs}, timeout_ms=timeout_ms) + return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"} + + @staticmethod + def _coerce_write_result(payload: Any) -> MemoryWriteResult: + if not isinstance(payload, dict): + return MemoryWriteResult(success=False, detail="invalid_payload") + stored_ids = [str(item) for item in (payload.get("stored_ids") or []) if str(item).strip()] + skipped_ids = [str(item) for item in (payload.get("skipped_ids") or []) if str(item).strip()] + detail = str(payload.get("detail") or payload.get("reason") or "") + if stored_ids or skipped_ids: + success = True + elif "success" in payload: + success = bool(payload.get("success")) + else: + success = not bool(detail) + return MemoryWriteResult( + success=success, + stored_ids=stored_ids, + skipped_ids=skipped_ids, + detail=detail, + ) + + @staticmethod + def _coerce_search_result(payload: Any) -> MemorySearchResult: + if not isinstance(payload, dict): + return MemorySearchResult() + hits: List[MemoryHit] = [] + for item in payload.get("hits", []) or []: + if not isinstance(item, dict): + continue + metadata = item.get("metadata", {}) or {} + if not isinstance(metadata, dict): + metadata = {} + if "source_branches" in item and "source_branches" not in metadata: + metadata["source_branches"] = item.get("source_branches") or [] + if "rank" in item and "rank" not in metadata: + metadata["rank"] = item.get("rank") + hits.append( + MemoryHit( + content=str(item.get("content", "") or ""), + score=float(item.get("score", 0.0) or 0.0), + hit_type=str(item.get("type", "") or ""), + source=str(item.get("source", "") or ""), + hash_value=str(item.get("hash", "") or ""), + metadata=metadata, + episode_id=str(item.get("episode_id", "") or ""), + title=str(item.get("title", "") or ""), + ) + ) + return MemorySearchResult( + summary=str(payload.get("summary", "") or ""), + hits=hits, + filtered=bool(payload.get("filtered", False)), + ) + + @staticmethod + def _coerce_profile_result(payload: Any) -> PersonProfileResult: + if not isinstance(payload, dict): + return PersonProfileResult() + return PersonProfileResult( + summary=str(payload.get("summary", "") or ""), + traits=[str(item) for item in (payload.get("traits") or []) if str(item).strip()], + evidence=[item for item in (payload.get("evidence") or []) if isinstance(item, dict)], + ) + + async def search( + self, + query: str, + *, + limit: int = 5, + mode: str = "hybrid", + chat_id: str = "", + person_id: str = "", + time_start: str | float | None = None, + time_end: str | float | None = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemorySearchResult: + clean_query = str(query or "").strip() + normalized_time_start = None if time_start in {None, ""} else time_start + normalized_time_end = None if time_end in {None, ""} else time_end + if not clean_query and normalized_time_start is None and normalized_time_end is None: + return MemorySearchResult() + try: + payload = await self._invoke( + "search_memory", + { + "query": clean_query, + "limit": max(1, int(limit)), + "mode": mode, + "chat_id": chat_id, + "person_id": person_id, + "time_start": normalized_time_start, + "time_end": normalized_time_end, + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_search_result(payload) + except Exception as exc: + logger.warning("长期记忆搜索失败: %s", exc) + return MemorySearchResult() + + async def ingest_summary( + self, + *, + external_id: str, + chat_id: str, + text: str, + participants: Optional[List[str]] = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "ingest_summary", + { + "external_id": external_id, + "chat_id": chat_id, + "text": text, + "participants": participants or [], + "time_start": time_start, + "time_end": time_end, + "tags": tags or [], + "metadata": metadata or {}, + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_write_result(payload) + except Exception as exc: + logger.warning("长期记忆写入摘要失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def ingest_text( + self, + *, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[List[str]] = None, + participants: Optional[List[str]] = None, + timestamp: float | None = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + entities: Optional[List[str]] = None, + relations: Optional[List[Dict[str, Any]]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "ingest_text", + { + "external_id": external_id, + "source_type": source_type, + "text": text, + "chat_id": chat_id, + "person_ids": person_ids or [], + "participants": participants or [], + "timestamp": timestamp, + "time_start": time_start, + "time_end": time_end, + "tags": tags or [], + "metadata": metadata or {}, + "entities": entities or [], + "relations": relations or [], + "respect_filter": bool(respect_filter), + "user_id": str(user_id or "").strip(), + "group_id": str(group_id or "").strip(), + }, + ) + return self._coerce_write_result(payload) + except Exception as exc: + logger.warning("长期记忆写入文本失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def get_person_profile(self, person_id: str, *, chat_id: str = "", limit: int = 10) -> PersonProfileResult: + clean_person_id = str(person_id or "").strip() + if not clean_person_id: + return PersonProfileResult() + try: + payload = await self._invoke( + "get_person_profile", + {"person_id": clean_person_id, "chat_id": chat_id, "limit": max(1, int(limit))}, + ) + return self._coerce_profile_result(payload) + except Exception as exc: + logger.warning("获取人物画像失败: %s", exc) + return PersonProfileResult() + + async def maintain_memory( + self, + *, + action: str, + target: str = "", + hours: float | None = None, + reason: str = "", + limit: int = 50, + ) -> MemoryWriteResult: + try: + payload = await self._invoke( + "maintain_memory", + {"action": action, "target": target, "hours": hours, "reason": reason, "limit": limit}, + ) + if not isinstance(payload, dict): + return MemoryWriteResult(success=False, detail="invalid_payload") + return MemoryWriteResult(success=bool(payload.get("success")), detail=str(payload.get("detail", "") or "")) + except Exception as exc: + logger.warning("记忆维护失败: %s", exc) + return MemoryWriteResult(success=False, detail=str(exc)) + + async def memory_stats(self) -> Dict[str, Any]: + try: + payload = await self._invoke("memory_stats", {}) + return payload if isinstance(payload, dict) else {} + except Exception as exc: + logger.warning("获取记忆统计失败: %s", exc) + return {} + + async def graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_graph_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("图谱管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_source_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("来源管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_episode_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("Episode 管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_profile_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("画像管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_runtime_admin", action=action, **kwargs) + except Exception as exc: + logger.warning("运行时管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def import_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_import_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("导入管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def tuning_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_tuning_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("调优管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def v5_admin(self, *, action: str, timeout_ms: int = 30000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_v5_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("V5 记忆管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def delete_admin(self, *, action: str, timeout_ms: int = 120000, **kwargs) -> Dict[str, Any]: + try: + return await self._invoke_admin("memory_delete_admin", action=action, timeout_ms=timeout_ms, **kwargs) + except Exception as exc: + logger.warning("删除管理调用失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def get_recycle_bin(self, *, limit: int = 50) -> Dict[str, Any]: + try: + payload = await self._invoke("maintain_memory", {"action": "recycle_bin", "limit": max(1, int(limit or 50))}) + return payload if isinstance(payload, dict) else {"success": False, "error": "invalid_payload"} + except Exception as exc: + logger.warning("获取回收站失败: %s", exc) + return {"success": False, "error": str(exc)} + + async def restore_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="restore", target=target) + + async def reinforce_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="reinforce", target=target) + + async def freeze_memory(self, *, target: str) -> MemoryWriteResult: + return await self.maintain_memory(action="freeze", target=target) + + async def protect_memory(self, *, target: str, hours: float | None = None) -> MemoryWriteResult: + return await self.maintain_memory(action="protect", target=target, hours=hours) + + +memory_service = MemoryService() diff --git a/src/webui/routers/__init__.py b/src/webui/routers/__init__.py index 65d63d02..687915d4 100644 --- a/src/webui/routers/__init__.py +++ b/src/webui/routers/__init__.py @@ -17,14 +17,14 @@ def get_all_routers() -> List[APIRouter]: from src.webui.api.planner import router as planner_router from src.webui.api.replier import router as replier_router from src.webui.routers.chat import router as chat_router - from src.webui.routers.knowledge import router as knowledge_router + from src.webui.routers.memory import compat_router as memory_compat_router from src.webui.routers.websocket.logs import router as logs_router from src.webui.routes import router as main_router return [ main_router, + memory_compat_router, logs_router, - knowledge_router, chat_router, planner_router, replier_router, diff --git a/src/webui/routers/memory.py b/src/webui/routers/memory.py new file mode 100644 index 00000000..d741affc --- /dev/null +++ b/src/webui/routers/memory.py @@ -0,0 +1,1395 @@ +from __future__ import annotations + +import json +import shutil +import uuid +from pathlib import Path +from typing import Any, Optional + +from fastapi import APIRouter, Body, Depends, File, Form, Query, UploadFile +from pydantic import BaseModel, Field + +from src.services.memory_service import MemorySearchResult, memory_service +from src.webui.dependencies import require_auth + + +router = APIRouter(prefix="/api/webui/memory", tags=["memory"], dependencies=[Depends(require_auth)]) +compat_router = APIRouter(prefix="/api", tags=["memory-compat"], dependencies=[Depends(require_auth)]) +STAGING_ROOT = Path(__file__).resolve().parents[3] / "data" / "memory_upload_staging" + + +class NodeRequest(BaseModel): + name: str = Field(..., min_length=1) + + +class NodeRenameRequest(BaseModel): + old_name: str = Field(..., min_length=1) + new_name: str = Field(..., min_length=1) + + +class EdgeCreateRequest(BaseModel): + subject: str = Field(..., min_length=1) + predicate: str = Field(..., min_length=1) + object: str = Field(..., min_length=1) + confidence: float = Field(1.0, ge=0.0) + + +class EdgeDeleteRequest(BaseModel): + hash: str = "" + subject: str = "" + object: str = "" + + +class EdgeWeightRequest(BaseModel): + hash: str = "" + subject: str = "" + object: str = "" + weight: float = Field(..., ge=0.0) + + +class SourceDeleteRequest(BaseModel): + source: str = Field(..., min_length=1) + + +class SourceBatchDeleteRequest(BaseModel): + sources: list[str] = Field(default_factory=list) + + +class EpisodeRebuildRequest(BaseModel): + source: str = "" + sources: list[str] = Field(default_factory=list) + all: bool = False + + +class EpisodeProcessPendingRequest(BaseModel): + limit: int = Field(20, ge=1, le=200) + max_retry: int = Field(3, ge=1, le=20) + + +class ProfileOverrideRequest(BaseModel): + person_id: str = Field(..., min_length=1) + override_text: str = "" + updated_by: str = "" + source: str = "webui" + + +class MaintainRequest(BaseModel): + target: str = Field(..., min_length=1) + hours: Optional[float] = None + + +class AutoSaveRequest(BaseModel): + enabled: bool + + +class TuningApplyProfileRequest(BaseModel): + profile: dict[str, Any] = Field(default_factory=dict) + reason: str = "manual" + + +class V5ActionRequest(BaseModel): + target: str = Field(..., min_length=1) + strength: Optional[float] = Field(default=None, ge=0.0) + reason: str = "" + updated_by: str = "webui" + + +class DeleteActionRequest(BaseModel): + mode: str = Field(..., min_length=1) + selector: dict[str, Any] | str = Field(default_factory=dict) + reason: str = "" + requested_by: str = "webui" + + +class DeleteRestoreRequest(BaseModel): + operation_id: str = "" + mode: str = "" + selector: dict[str, Any] | str = Field(default_factory=dict) + reason: str = "" + requested_by: str = "webui" + + +class DeletePurgeRequest(BaseModel): + grace_hours: Optional[float] = Field(default=None, ge=0.0) + limit: int = Field(1000, ge=1, le=5000) + + +def _build_import_guide_markdown(settings: dict[str, Any]) -> str: + path_aliases = settings.get("path_aliases") if isinstance(settings.get("path_aliases"), dict) else {} + alias_lines = [ + f"- `{name}` -> `{path}`" + for name, path in sorted(path_aliases.items()) + if str(name).strip() and str(path).strip() + ] + if not alias_lines: + alias_lines = ["- 当前未配置路径别名"] + return "\n".join( + [ + "# 长期记忆导入说明", + "", + "支持的导入方式:", + "- 上传文件:适合零散文档、日志、聊天导出文本。", + "- 粘贴文本:适合一次性导入少量整理好的内容。", + "- Raw Scan:扫描白名单目录内的原始文本文件。", + "- LPMM OpenIE / Convert:处理既有 LPMM 数据。", + "- Temporal Backfill:补回已有数据中的时间信息。", + "- MaiBot Migration:从宿主数据库迁移历史聊天记忆。", + "", + "当前路径别名:", + *alias_lines, + "", + "执行建议:", + "- 首次导入先小批量试跑,确认切分和抽取结果正常。", + "- 大批量导入时优先关注任务状态、失败块与重试结果。", + "- 若路径解析失败,请先检查路径别名与相对路径是否仍然有效。", + ] + ) + + +def _unwrap_payload(payload: dict[str, Any] | None) -> dict[str, Any]: + raw = payload if isinstance(payload, dict) else {} + nested = raw.get("payload") + if isinstance(nested, dict): + return dict(nested) + return dict(raw) + + +async def _graph_get(limit: int) -> dict: + return await memory_service.graph_admin(action="get_graph", limit=limit) + + +async def _graph_create_node(payload: NodeRequest) -> dict: + return await memory_service.graph_admin(action="create_node", name=payload.name) + + +async def _graph_delete_node(payload: NodeRequest) -> dict: + return await memory_service.graph_admin(action="delete_node", name=payload.name) + + +async def _graph_rename_node(payload: NodeRenameRequest) -> dict: + return await memory_service.graph_admin(action="rename_node", old_name=payload.old_name, new_name=payload.new_name) + + +async def _graph_create_edge(payload: EdgeCreateRequest) -> dict: + return await memory_service.graph_admin( + action="create_edge", + subject=payload.subject, + predicate=payload.predicate, + object=payload.object, + confidence=payload.confidence, + ) + + +async def _graph_delete_edge(payload: EdgeDeleteRequest) -> dict: + return await memory_service.graph_admin( + action="delete_edge", + hash=payload.hash, + subject=payload.subject, + object=payload.object, + ) + + +async def _graph_update_edge_weight(payload: EdgeWeightRequest) -> dict: + return await memory_service.graph_admin( + action="update_edge_weight", + hash=payload.hash, + subject=payload.subject, + object=payload.object, + weight=payload.weight, + ) + + +async def _source_list() -> dict: + return await memory_service.source_admin(action="list") + + +async def _source_delete(payload: SourceDeleteRequest) -> dict: + return await memory_service.source_admin(action="delete", source=payload.source) + + +async def _source_batch_delete(payload: SourceBatchDeleteRequest) -> dict: + return await memory_service.source_admin(action="batch_delete", sources=payload.sources) + + +async def _query_aggregate( + query: str, + *, + limit: int, + chat_id: str, + person_id: str, + time_start: float | None, + time_end: float | None, +) -> dict: + result: MemorySearchResult = await memory_service.search( + query, + limit=limit, + mode="aggregate", + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + respect_filter=False, + ) + return {"success": True, **result.to_dict()} + + +async def _episode_list( + *, + query: str, + limit: int, + source: str, + person_id: str, + time_start: float | None, + time_end: float | None, +) -> dict: + return await memory_service.episode_admin( + action="list", + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +async def _episode_get(episode_id: str) -> dict: + return await memory_service.episode_admin(action="get", episode_id=episode_id) + + +async def _episode_rebuild(payload: EpisodeRebuildRequest) -> dict: + return await memory_service.episode_admin( + action="rebuild", + source=payload.source, + sources=payload.sources, + all=payload.all, + ) + + +async def _episode_status(limit: int) -> dict: + return await memory_service.episode_admin(action="status", limit=limit) + + +async def _episode_process_pending(payload: EpisodeProcessPendingRequest) -> dict: + return await memory_service.episode_admin( + action="process_pending", + limit=payload.limit, + max_retry=payload.max_retry, + ) + + +async def _profile_query(*, person_id: str, person_keyword: str, limit: int, force_refresh: bool) -> dict: + return await memory_service.profile_admin( + action="query", + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +async def _profile_list(limit: int) -> dict: + return await memory_service.profile_admin(action="list", limit=limit) + + +async def _profile_set_override(payload: ProfileOverrideRequest) -> dict: + return await memory_service.profile_admin( + action="set_override", + person_id=payload.person_id, + override_text=payload.override_text, + updated_by=payload.updated_by, + source=payload.source, + ) + + +async def _profile_delete_override(person_id: str) -> dict: + return await memory_service.profile_admin(action="delete_override", person_id=person_id) + + +async def _runtime_save() -> dict: + return await memory_service.runtime_admin(action="save") + + +async def _runtime_config() -> dict: + return await memory_service.runtime_admin(action="get_config") + + +async def _runtime_self_check(refresh: bool) -> dict: + return await memory_service.runtime_admin(action="refresh_self_check" if refresh else "self_check") + + +async def _runtime_auto_save(enabled: bool | None = None) -> dict: + if enabled is None: + config = await memory_service.runtime_admin(action="get_config") + return {"success": bool(config.get("success", False)), "auto_save": bool(config.get("auto_save", False))} + return await memory_service.runtime_admin(action="set_auto_save", enabled=enabled) + + +async def _maintenance_recycle_bin(limit: int) -> dict: + return await memory_service.get_recycle_bin(limit=limit) + + +async def _maintenance_restore(payload: MaintainRequest) -> dict: + return (await memory_service.restore_memory(target=payload.target)).to_dict() + + +async def _maintenance_reinforce(payload: MaintainRequest) -> dict: + return (await memory_service.reinforce_memory(target=payload.target)).to_dict() + + +async def _maintenance_freeze(payload: MaintainRequest) -> dict: + return (await memory_service.freeze_memory(target=payload.target)).to_dict() + + +async def _maintenance_protect(payload: MaintainRequest) -> dict: + return (await memory_service.protect_memory(target=payload.target, hours=payload.hours)).to_dict() + + +async def _v5_status(target: str, limit: int) -> dict: + return await memory_service.v5_admin(action="status", target=target, limit=limit) + + +async def _v5_recycle_bin(limit: int) -> dict: + return await memory_service.v5_admin(action="recycle_bin", limit=limit) + + +async def _v5_action(action: str, payload: V5ActionRequest) -> dict: + kwargs: dict[str, Any] = { + "target": payload.target, + "reason": payload.reason, + "updated_by": payload.updated_by, + } + if payload.strength is not None: + kwargs["strength"] = payload.strength + return await memory_service.v5_admin(action=action, **kwargs) + + +async def _delete_preview(payload: DeleteActionRequest) -> dict: + return await memory_service.delete_admin(action="preview", mode=payload.mode, selector=payload.selector) + + +async def _delete_execute(payload: DeleteActionRequest) -> dict: + return await memory_service.delete_admin( + action="execute", + mode=payload.mode, + selector=payload.selector, + reason=payload.reason, + requested_by=payload.requested_by, + ) + + +async def _delete_restore(payload: DeleteRestoreRequest) -> dict: + return await memory_service.delete_admin( + action="restore", + mode=payload.mode, + selector=payload.selector, + operation_id=payload.operation_id, + reason=payload.reason, + requested_by=payload.requested_by, + ) + + +async def _delete_list(limit: int, mode: str) -> dict: + return await memory_service.delete_admin(action="list_operations", limit=limit, mode=mode) + + +async def _delete_get(operation_id: str) -> dict: + return await memory_service.delete_admin(action="get_operation", operation_id=operation_id) + + +async def _delete_purge(payload: DeletePurgeRequest) -> dict: + return await memory_service.delete_admin( + action="purge", + grace_hours=payload.grace_hours, + limit=payload.limit, + ) + + +async def _import_settings() -> dict: + return await memory_service.import_admin(action="get_settings") + + +async def _import_path_aliases() -> dict: + return await memory_service.import_admin(action="get_path_aliases") + + +async def _import_guide() -> dict: + payload = await memory_service.import_admin(action="get_guide") + if not isinstance(payload, dict): + payload = {"success": False, "error": "invalid_payload"} + if isinstance(payload.get("content"), str): + return payload + + settings = payload.get("settings") if isinstance(payload.get("settings"), dict) else None + if settings is None: + settings_payload = await memory_service.import_admin(action="get_settings") + settings = settings_payload.get("settings") if isinstance(settings_payload.get("settings"), dict) else {} + + return { + "success": True, + "source": "local", + "path": "generated://memory_import_guide", + "content": _build_import_guide_markdown(settings or {}), + "settings": settings or {}, + } + + +async def _import_resolve_path(payload: dict[str, Any]) -> dict: + return await memory_service.import_admin(action="resolve_path", **_unwrap_payload(payload)) + + +async def _import_create(action: str, payload: dict[str, Any]) -> dict: + return await memory_service.import_admin(action=action, **_unwrap_payload(payload)) + + +async def _import_list(limit: int) -> dict: + listing = await memory_service.import_admin(action="list", limit=limit) + if not isinstance(listing, dict): + listing = {"success": False, "items": []} + settings_payload = await memory_service.import_admin(action="get_settings") + settings = settings_payload.get("settings") if isinstance(settings_payload.get("settings"), dict) else {} + listing.setdefault("success", True) + listing.setdefault("items", []) + listing["settings"] = settings + return listing + + +async def _import_get(task_id: str, include_chunks: bool) -> dict: + return await memory_service.import_admin(action="get", task_id=task_id, include_chunks=include_chunks) + + +async def _import_chunks(task_id: str, file_id: str, offset: int, limit: int) -> dict: + return await memory_service.import_admin( + action="get_chunks", + task_id=task_id, + file_id=file_id, + offset=offset, + limit=limit, + ) + + +async def _import_cancel(task_id: str) -> dict: + return await memory_service.import_admin(action="cancel", task_id=task_id) + + +async def _import_retry(task_id: str, payload: dict[str, Any]) -> dict: + raw = _unwrap_payload(payload) + overrides = raw.get("overrides") if isinstance(raw.get("overrides"), dict) else raw + return await memory_service.import_admin(action="retry_failed", task_id=task_id, overrides=overrides) + + +async def _tuning_settings() -> dict: + return await memory_service.tuning_admin(action="get_settings") + + +async def _tuning_profile() -> dict: + profile = await memory_service.tuning_admin(action="get_profile") + if not isinstance(profile, dict): + profile = {"success": False, "profile": {}} + if not isinstance(profile.get("settings"), dict): + settings = await memory_service.tuning_admin(action="get_settings") + profile["settings"] = settings.get("settings") if isinstance(settings.get("settings"), dict) else {} + return profile + + +async def _tuning_apply_profile(payload: TuningApplyProfileRequest) -> dict: + return await memory_service.tuning_admin(action="apply_profile", profile=payload.profile, reason=payload.reason) + + +async def _tuning_rollback_profile() -> dict: + return await memory_service.tuning_admin(action="rollback_profile") + + +async def _tuning_export_profile() -> dict: + return await memory_service.tuning_admin(action="export_profile") + + +async def _tuning_create_task(payload: dict[str, Any]) -> dict: + return await memory_service.tuning_admin(action="create_task", payload=_unwrap_payload(payload)) + + +async def _tuning_list_tasks(limit: int) -> dict: + return await memory_service.tuning_admin(action="list_tasks", limit=limit) + + +async def _tuning_get_task(task_id: str, include_rounds: bool) -> dict: + return await memory_service.tuning_admin(action="get_task", task_id=task_id, include_rounds=include_rounds) + + +async def _tuning_get_rounds(task_id: str, offset: int, limit: int) -> dict: + return await memory_service.tuning_admin(action="get_rounds", task_id=task_id, offset=offset, limit=limit) + + +async def _tuning_cancel(task_id: str) -> dict: + return await memory_service.tuning_admin(action="cancel", task_id=task_id) + + +async def _tuning_apply_best(task_id: str) -> dict: + return await memory_service.tuning_admin(action="apply_best", task_id=task_id) + + +async def _tuning_report(task_id: str, fmt: str) -> dict: + payload = await memory_service.tuning_admin(action="get_report", task_id=task_id, format=fmt) + report = payload.get("report") if isinstance(payload.get("report"), dict) else {} + return { + "success": bool(payload.get("success", False)), + "format": report.get("format", fmt), + "content": report.get("content", ""), + "path": report.get("path", ""), + "error": payload.get("error", ""), + } + + +async def _stage_upload_files(files: list[UploadFile]) -> tuple[Path, list[dict[str, Any]]]: + STAGING_ROOT.mkdir(parents=True, exist_ok=True) + staging_dir = STAGING_ROOT / uuid.uuid4().hex + staging_dir.mkdir(parents=True, exist_ok=True) + staged_files: list[dict[str, Any]] = [] + for index, upload in enumerate(files): + filename = Path(upload.filename or f"upload_{index}.txt").name + target = staging_dir / f"{index:03d}_{filename}" + content = await upload.read() + target.write_bytes(content) + staged_files.append( + { + "filename": filename, + "staged_path": str(target.resolve()), + "size": len(content), + } + ) + return staging_dir, staged_files + + +@router.get("/graph") +async def get_memory_graph(limit: int = Query(200, ge=1, le=5000)): + return await _graph_get(limit) + + +@router.post("/graph/node") +async def create_memory_node(payload: NodeRequest): + return await _graph_create_node(payload) + + +@router.delete("/graph/node") +async def delete_memory_node(payload: NodeRequest): + return await _graph_delete_node(payload) + + +@router.post("/graph/node/rename") +async def rename_memory_node(payload: NodeRenameRequest): + return await _graph_rename_node(payload) + + +@router.post("/graph/edge") +async def create_memory_edge(payload: EdgeCreateRequest): + return await _graph_create_edge(payload) + + +@router.delete("/graph/edge") +async def delete_memory_edge(payload: EdgeDeleteRequest): + return await _graph_delete_edge(payload) + + +@router.post("/graph/edge/weight") +async def update_memory_edge_weight(payload: EdgeWeightRequest): + return await _graph_update_edge_weight(payload) + + +@router.get("/sources") +async def list_memory_sources(): + return await _source_list() + + +@router.post("/sources/delete") +async def delete_memory_source(payload: SourceDeleteRequest): + return await _source_delete(payload) + + +@router.post("/sources/batch-delete") +async def batch_delete_memory_sources(payload: SourceBatchDeleteRequest): + return await _source_batch_delete(payload) + + +@router.get("/query/aggregate") +async def query_memory_aggregate( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + chat_id: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _query_aggregate( + query, + limit=limit, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@router.get("/episodes") +async def list_memory_episodes( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + source: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _episode_list( + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@router.get("/episodes/{episode_id}") +async def get_memory_episode(episode_id: str): + return await _episode_get(episode_id) + + +@router.post("/episodes/rebuild") +async def rebuild_memory_episodes(payload: EpisodeRebuildRequest): + return await _episode_rebuild(payload) + + +@router.get("/episodes/status") +async def get_memory_episode_status(limit: int = Query(20, ge=1, le=200)): + return await _episode_status(limit) + + +@router.post("/episodes/process-pending") +async def process_memory_episode_pending(payload: EpisodeProcessPendingRequest): + return await _episode_process_pending(payload) + + +@router.get("/profiles/query") +async def query_memory_profile( + person_id: str = Query(""), + person_keyword: str = Query(""), + limit: int = Query(12, ge=1, le=100), + force_refresh: bool = Query(False), +): + return await _profile_query( + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +@router.get("/profiles") +async def list_memory_profiles(limit: int = Query(50, ge=1, le=200)): + return await _profile_list(limit) + + +@router.post("/profiles/override") +async def set_memory_profile_override(payload: ProfileOverrideRequest): + return await _profile_set_override(payload) + + +@router.delete("/profiles/override/{person_id}") +async def delete_memory_profile_override(person_id: str): + return await _profile_delete_override(person_id) + + +@router.post("/runtime/save") +async def save_memory_runtime(): + return await _runtime_save() + + +@router.get("/runtime/config") +async def get_memory_runtime_config(): + return await _runtime_config() + + +@router.get("/runtime/self-check") +async def get_memory_runtime_self_check(): + return await _runtime_self_check(False) + + +@router.post("/runtime/self-check/refresh") +async def refresh_memory_runtime_self_check(): + return await _runtime_self_check(True) + + +@router.get("/runtime/auto-save") +async def get_memory_runtime_auto_save(): + return await _runtime_auto_save(None) + + +@router.post("/runtime/auto-save") +async def set_memory_runtime_auto_save(payload: AutoSaveRequest): + return await _runtime_auto_save(payload.enabled) + + +@router.get("/maintenance/recycle-bin") +async def get_memory_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _maintenance_recycle_bin(limit) + + +@router.post("/maintenance/restore") +async def restore_memory_relation(payload: MaintainRequest): + return await _maintenance_restore(payload) + + +@router.post("/maintenance/reinforce") +async def reinforce_memory_relation(payload: MaintainRequest): + return await _maintenance_reinforce(payload) + + +@router.post("/maintenance/freeze") +async def freeze_memory_relation(payload: MaintainRequest): + return await _maintenance_freeze(payload) + + +@router.post("/maintenance/protect") +async def protect_memory_relation(payload: MaintainRequest): + return await _maintenance_protect(payload) + + +@router.get("/v5/status") +async def get_memory_v5_status( + target: str = Query(""), + limit: int = Query(50, ge=1, le=200), +): + return await _v5_status(target, limit) + + +@router.get("/v5/recycle-bin") +async def get_memory_v5_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _v5_recycle_bin(limit) + + +@router.post("/v5/reinforce") +async def reinforce_memory_v5(payload: V5ActionRequest): + return await _v5_action("reinforce", payload) + + +@router.post("/v5/weaken") +async def weaken_memory_v5(payload: V5ActionRequest): + return await _v5_action("weaken", payload) + + +@router.post("/v5/remember-forever") +async def remember_forever_memory_v5(payload: V5ActionRequest): + return await _v5_action("remember_forever", payload) + + +@router.post("/v5/forget") +async def forget_memory_v5(payload: V5ActionRequest): + return await _v5_action("forget", payload) + + +@router.post("/v5/restore") +async def restore_memory_v5(payload: V5ActionRequest): + return await _v5_action("restore", payload) + + +@router.post("/delete/preview") +async def preview_memory_delete(payload: DeleteActionRequest): + return await _delete_preview(payload) + + +@router.post("/delete/execute") +async def execute_memory_delete(payload: DeleteActionRequest): + return await _delete_execute(payload) + + +@router.post("/delete/restore") +async def restore_memory_delete(payload: DeleteRestoreRequest): + return await _delete_restore(payload) + + +@router.get("/delete/operations") +async def list_memory_delete_operations( + limit: int = Query(50, ge=1, le=200), + mode: str = Query(""), +): + return await _delete_list(limit, mode) + + +@router.get("/delete/operations/{operation_id}") +async def get_memory_delete_operation(operation_id: str): + return await _delete_get(operation_id) + + +@router.post("/delete/purge") +async def purge_memory_delete(payload: DeletePurgeRequest): + return await _delete_purge(payload) + + +@router.get("/import/settings") +async def get_memory_import_settings(): + return await _import_settings() + + +@router.get("/import/path-aliases") +async def get_memory_import_path_aliases(): + return await _import_path_aliases() + + +@router.get("/import/guide") +async def get_memory_import_guide(): + return await _import_guide() + + +@router.post("/import/resolve-path") +async def resolve_memory_import_path(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_resolve_path(payload) + + +@router.post("/import/upload") +async def create_memory_import_upload( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + staging_dir, staged_files = await _stage_upload_files(files) + try: + try: + payload = json.loads(payload_json or "{}") + except Exception: + payload = {} + if not isinstance(payload, dict): + payload = {} + payload["staged_files"] = staged_files + return await _import_create("create_upload", payload) + finally: + shutil.rmtree(staging_dir, ignore_errors=True) + + +@router.post("/import/paste") +async def create_memory_import_paste(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@router.post("/import/raw-scan") +async def create_memory_import_raw_scan(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@router.post("/import/lpmm-openie") +async def create_memory_import_lpmm_openie(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@router.post("/import/lpmm-convert") +async def create_memory_import_lpmm_convert(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@router.post("/import/temporal-backfill") +async def create_memory_import_temporal_backfill(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@router.post("/import/maibot-migration") +async def create_memory_import_maibot_migration(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@router.get("/import/tasks") +async def list_memory_import_tasks(limit: int = Query(50, ge=1, le=200)): + return await _import_list(limit) + + +@router.get("/import/tasks/{task_id}") +async def get_memory_import_task(task_id: str, include_chunks: bool = Query(False)): + return await _import_get(task_id, include_chunks) + + +@router.get("/import/tasks/{task_id}/chunks/{file_id}") +async def get_memory_import_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@router.post("/import/tasks/{task_id}/cancel") +async def cancel_memory_import_task(task_id: str): + return await _import_cancel(task_id) + + +@router.post("/import/tasks/{task_id}/retry") +async def retry_memory_import_task(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@router.get("/retrieval_tuning/settings") +async def get_memory_tuning_settings(): + return await _tuning_settings() + + +@router.get("/retrieval_tuning/profile") +async def get_memory_tuning_profile(): + return await _tuning_profile() + + +@router.post("/retrieval_tuning/profile/apply") +async def apply_memory_tuning_profile(payload: TuningApplyProfileRequest): + return await _tuning_apply_profile(payload) + + +@router.post("/retrieval_tuning/profile/rollback") +async def rollback_memory_tuning_profile(): + return await _tuning_rollback_profile() + + +@router.get("/retrieval_tuning/profile/export") +async def export_memory_tuning_profile(): + return await _tuning_export_profile() + + +@router.post("/retrieval_tuning/tasks") +async def create_memory_tuning_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _tuning_create_task(payload) + + +@router.get("/retrieval_tuning/tasks") +async def list_memory_tuning_tasks(limit: int = Query(50, ge=1, le=200)): + return await _tuning_list_tasks(limit) + + +@router.get("/retrieval_tuning/tasks/{task_id}") +async def get_memory_tuning_task(task_id: str, include_rounds: bool = Query(False)): + return await _tuning_get_task(task_id, include_rounds) + + +@router.get("/retrieval_tuning/tasks/{task_id}/rounds") +async def get_memory_tuning_rounds( + task_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _tuning_get_rounds(task_id, offset, limit) + + +@router.post("/retrieval_tuning/tasks/{task_id}/cancel") +async def cancel_memory_tuning_task(task_id: str): + return await _tuning_cancel(task_id) + + +@router.post("/retrieval_tuning/tasks/{task_id}/apply-best") +async def apply_best_memory_tuning_profile(task_id: str): + return await _tuning_apply_best(task_id) + + +@router.get("/retrieval_tuning/tasks/{task_id}/report") +async def get_memory_tuning_report(task_id: str, format: str = Query("md")): + return await _tuning_report(task_id, format) + + +@compat_router.get("/graph") +async def compat_get_graph(limit: int = Query(200, ge=1, le=5000)): + return await _graph_get(limit) + + +@compat_router.post("/node") +async def compat_create_node(payload: NodeRequest): + return await _graph_create_node(payload) + + +@compat_router.delete("/node") +async def compat_delete_node(payload: NodeRequest): + return await _graph_delete_node(payload) + + +@compat_router.post("/node/rename") +async def compat_rename_node(payload: NodeRenameRequest): + return await _graph_rename_node(payload) + + +@compat_router.post("/edge") +async def compat_create_edge(payload: EdgeCreateRequest): + return await _graph_create_edge(payload) + + +@compat_router.delete("/edge") +async def compat_delete_edge(payload: EdgeDeleteRequest): + return await _graph_delete_edge(payload) + + +@compat_router.post("/edge/weight") +async def compat_update_edge_weight(payload: EdgeWeightRequest): + return await _graph_update_edge_weight(payload) + + +@compat_router.get("/source/list") +async def compat_list_sources(): + return await _source_list() + + +@compat_router.post("/source/delete") +async def compat_delete_source(payload: SourceDeleteRequest): + return await _source_delete(payload) + + +@compat_router.post("/source/batch_delete") +async def compat_batch_delete_sources(payload: SourceBatchDeleteRequest): + return await _source_batch_delete(payload) + + +@compat_router.get("/query/aggregate") +async def compat_query_aggregate( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + chat_id: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _query_aggregate( + query, + limit=limit, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@compat_router.get("/episodes") +async def compat_list_episodes( + query: str = Query(""), + limit: int = Query(20, ge=1, le=200), + source: str = Query(""), + person_id: str = Query(""), + time_start: float | None = Query(None), + time_end: float | None = Query(None), +): + return await _episode_list( + query=query, + limit=limit, + source=source, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + + +@compat_router.get("/episodes/{episode_id}") +async def compat_get_episode(episode_id: str): + return await _episode_get(episode_id) + + +@compat_router.post("/episodes/rebuild") +async def compat_rebuild_episodes(payload: EpisodeRebuildRequest): + return await _episode_rebuild(payload) + + +@compat_router.get("/episodes/status") +async def compat_episode_status(limit: int = Query(20, ge=1, le=200)): + return await _episode_status(limit) + + +@compat_router.post("/episodes/process_pending") +async def compat_process_episode_pending(payload: EpisodeProcessPendingRequest): + return await _episode_process_pending(payload) + + +@compat_router.get("/person_profile/query") +async def compat_profile_query( + person_id: str = Query(""), + person_keyword: str = Query(""), + limit: int = Query(12, ge=1, le=100), + force_refresh: bool = Query(False), +): + return await _profile_query( + person_id=person_id, + person_keyword=person_keyword, + limit=limit, + force_refresh=force_refresh, + ) + + +@compat_router.get("/person_profile/list") +async def compat_profile_list(limit: int = Query(50, ge=1, le=200)): + return await _profile_list(limit) + + +@compat_router.post("/person_profile/override") +async def compat_set_profile_override(payload: ProfileOverrideRequest): + return await _profile_set_override(payload) + + +@compat_router.delete("/person_profile/override/{person_id}") +async def compat_delete_profile_override(person_id: str): + return await _profile_delete_override(person_id) + + +@compat_router.post("/save") +async def compat_runtime_save(): + return await _runtime_save() + + +@compat_router.get("/config") +async def compat_runtime_config(): + return await _runtime_config() + + +@compat_router.get("/runtime/self_check") +async def compat_runtime_self_check(): + return await _runtime_self_check(False) + + +@compat_router.post("/runtime/self_check/refresh") +async def compat_refresh_runtime_self_check(): + return await _runtime_self_check(True) + + +@compat_router.get("/config/auto_save") +async def compat_runtime_auto_save(): + return await _runtime_auto_save(None) + + +@compat_router.post("/config/auto_save") +async def compat_set_runtime_auto_save(payload: AutoSaveRequest): + return await _runtime_auto_save(payload.enabled) + + +@compat_router.get("/memory/recycle_bin") +async def compat_get_recycle_bin(limit: int = Query(50, ge=1, le=200)): + return await _maintenance_recycle_bin(limit) + + +@compat_router.post("/memory/restore") +async def compat_restore_memory(payload: MaintainRequest): + return await _maintenance_restore(payload) + + +@compat_router.post("/memory/reinforce") +async def compat_reinforce_memory(payload: MaintainRequest): + return await _maintenance_reinforce(payload) + + +@compat_router.post("/memory/freeze") +async def compat_freeze_memory(payload: MaintainRequest): + return await _maintenance_freeze(payload) + + +@compat_router.post("/memory/protect") +async def compat_protect_memory(payload: MaintainRequest): + return await _maintenance_protect(payload) + + +@compat_router.get("/import/settings") +async def compat_import_settings(): + return await _import_settings() + + +@compat_router.get("/import/path_aliases") +async def compat_import_path_aliases(): + return await _import_path_aliases() + + +@compat_router.get("/import/guide") +async def compat_import_guide(): + return await _import_guide() + + +@compat_router.post("/import/resolve_path") +async def compat_import_resolve_path(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_resolve_path(payload) + + +@compat_router.post("/import/upload") +async def compat_import_upload( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + return await create_memory_import_upload(files=files, payload_json=payload_json) + + +@compat_router.post("/import/tasks/upload") +async def compat_import_upload_task( + files: list[UploadFile] = File(...), + payload_json: str = Form("{}"), +): + return await create_memory_import_upload(files=files, payload_json=payload_json) + + +@compat_router.post("/import/paste") +async def compat_import_paste(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@compat_router.post("/import/tasks/paste") +async def compat_import_paste_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_paste", payload) + + +@compat_router.post("/import/raw_scan") +async def compat_import_raw_scan(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@compat_router.post("/import/tasks/raw_scan") +async def compat_import_raw_scan_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_raw_scan", payload) + + +@compat_router.post("/import/lpmm_openie") +async def compat_import_lpmm_openie(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@compat_router.post("/import/tasks/lpmm_openie") +async def compat_import_lpmm_openie_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_openie", payload) + + +@compat_router.post("/import/lpmm_convert") +async def compat_import_lpmm_convert(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@compat_router.post("/import/tasks/lpmm_convert") +async def compat_import_lpmm_convert_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_lpmm_convert", payload) + + +@compat_router.post("/import/temporal_backfill") +async def compat_import_temporal_backfill(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@compat_router.post("/import/tasks/temporal_backfill") +async def compat_import_temporal_backfill_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_temporal_backfill", payload) + + +@compat_router.post("/import/maibot_migration") +async def compat_import_maibot_migration(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@compat_router.post("/import/tasks/maibot_migration") +async def compat_import_maibot_migration_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_create("create_maibot_migration", payload) + + +@compat_router.get("/import/tasks") +async def compat_import_list(limit: int = Query(50, ge=1, le=200)): + return await _import_list(limit) + + +@compat_router.get("/import/tasks/{task_id}") +async def compat_import_get(task_id: str, include_chunks: bool = Query(False)): + return await _import_get(task_id, include_chunks) + + +@compat_router.get("/import/tasks/{task_id}/chunks/{file_id}") +async def compat_import_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@compat_router.get("/import/tasks/{task_id}/files/{file_id}/chunks") +async def compat_import_file_chunks( + task_id: str, + file_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _import_chunks(task_id, file_id, offset, limit) + + +@compat_router.post("/import/tasks/{task_id}/cancel") +async def compat_import_cancel(task_id: str): + return await _import_cancel(task_id) + + +@compat_router.post("/import/tasks/{task_id}/retry") +async def compat_import_retry(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@compat_router.post("/import/tasks/{task_id}/retry_failed") +async def compat_import_retry_failed(task_id: str, payload: dict[str, Any] = Body(default_factory=dict)): + return await _import_retry(task_id, payload) + + +@compat_router.get("/retrieval_tuning/settings") +async def compat_tuning_settings(): + return await _tuning_settings() + + +@compat_router.get("/retrieval_tuning/profile") +async def compat_tuning_profile(): + return await _tuning_profile() + + +@compat_router.post("/retrieval_tuning/profile/apply") +async def compat_apply_tuning_profile(payload: TuningApplyProfileRequest): + return await _tuning_apply_profile(payload) + + +@compat_router.post("/retrieval_tuning/profile/rollback") +async def compat_rollback_tuning_profile(): + return await _tuning_rollback_profile() + + +@compat_router.get("/retrieval_tuning/profile/export") +async def compat_export_tuning_profile(): + return await _tuning_export_profile() + + +@compat_router.get("/retrieval_tuning/profile/export_toml") +async def compat_export_tuning_profile_toml(): + return await _tuning_export_profile() + + +@compat_router.post("/retrieval_tuning/tasks") +async def compat_create_tuning_task(payload: dict[str, Any] = Body(default_factory=dict)): + return await _tuning_create_task(payload) + + +@compat_router.get("/retrieval_tuning/tasks") +async def compat_list_tuning_tasks(limit: int = Query(50, ge=1, le=200)): + return await _tuning_list_tasks(limit) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}") +async def compat_get_tuning_task(task_id: str, include_rounds: bool = Query(False)): + return await _tuning_get_task(task_id, include_rounds) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}/rounds") +async def compat_get_tuning_rounds( + task_id: str, + offset: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), +): + return await _tuning_get_rounds(task_id, offset, limit) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/cancel") +async def compat_cancel_tuning_task(task_id: str): + return await _tuning_cancel(task_id) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/apply_best") +async def compat_apply_best_tuning_profile(task_id: str): + return await _tuning_apply_best(task_id) + + +@compat_router.post("/retrieval_tuning/tasks/{task_id}/apply-best") +async def compat_apply_best_tuning_profile_kebab(task_id: str): + return await _tuning_apply_best(task_id) + + +@compat_router.get("/retrieval_tuning/tasks/{task_id}/report") +async def compat_get_tuning_report(task_id: str, format: str = Query("md")): + return await _tuning_report(task_id, format) diff --git a/src/webui/routes.py b/src/webui/routes.py index c1a7e446..5e33e78b 100644 --- a/src/webui/routes.py +++ b/src/webui/routes.py @@ -16,6 +16,7 @@ from src.webui.routers.config import router as config_router from src.webui.routers.emoji import router as emoji_router from src.webui.routers.expression import router as expression_router from src.webui.routers.jargon import router as jargon_router +from src.webui.routers.memory import router as memory_router from src.webui.routers.model import router as model_router from src.webui.routers.person import router as person_router from src.webui.routers.plugin import get_progress_router @@ -49,6 +50,8 @@ router.include_router(get_progress_router()) router.include_router(system_router) # 注册模型列表获取路由 router.include_router(model_router) +# 注册长期记忆管理路由 +router.include_router(memory_router) # 注册 WebSocket 认证路由 router.include_router(ws_auth_router) From 71b3a828c6c57589673c6b63b9d8c1e4217b960b Mon Sep 17 00:00:00 2001 From: DawnARC Date: Thu, 19 Mar 2026 00:09:04 +0800 Subject: [PATCH 03/14] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20A=5FMemorix=20?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=20v2.0.0=EF=BC=88=E5=8C=85=E5=90=AB=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E6=97=B6=E4=B8=8E=E6=96=87=E6=A1=A3=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入 A_Memorix 插件 v2.0.0:新增大量运行时组件、存储/模式更新、检索能力提升、管理工具、导入/调优工作流以及相关文档。关键新增内容包括:lifecycle_orchestrator、SDKMemoryKernel/运行时初始化器、新的存储层与 metadata_store 变更(SCHEMA_VERSION v8)、检索增强(双路径检索、图关系召回、稀疏 BM25),以及多种工具服务(episode/person_profile/relation/segmentation/tuning/search execution)。同时新增 Web 导入/摘要导入器及大量维护脚本。还更新了插件清单、embedding API 适配器、plugin.py、requirements/pyproject,以及主入口文件,使新插件接入项目。该变更为 2.0.0 版本发布做好准备,实现统一的 SDK Tool 接口并扩展整体运行能力。 --- plugins/A_memorix/CHANGELOG.md | 718 ++++ plugins/A_memorix/CONFIG_REFERENCE.md | 292 ++ plugins/A_memorix/IMPORT_GUIDE.md | 335 ++ plugins/A_memorix/LICENSE | 661 ++++ plugins/A_memorix/LICENSE-MAIBOT-GPL.md | 22 + plugins/A_memorix/QUICK_START.md | 210 + plugins/A_memorix/README.md | 216 + plugins/A_memorix/_manifest.json | 45 + .../A_memorix/core/embedding/api_adapter.py | 356 +- plugins/A_memorix/core/retrieval/dual_path.py | 22 +- .../core/retrieval/graph_relation_recall.py | 4 +- .../A_memorix/core/retrieval/sparse_bm25.py | 9 +- plugins/A_memorix/core/runtime/__init__.py | 8 + .../core/runtime/lifecycle_orchestrator.py | 268 ++ .../core/runtime/sdk_memory_kernel.py | 3001 ++++++++++++-- .../runtime/search_runtime_initializer.py | 240 ++ plugins/A_memorix/core/storage/graph_store.py | 14 + .../A_memorix/core/storage/metadata_store.py | 559 ++- .../core/utils/aggregate_query_service.py | 2 +- .../core/utils/episode_retrieval_service.py | 2 +- .../utils/episode_segmentation_service.py | 304 ++ .../A_memorix/core/utils/episode_service.py | 558 +++ .../core/utils/person_profile_service.py | 119 +- .../core/utils/relation_write_service.py | 12 +- .../core/utils/retrieval_tuning_manager.py | 1857 +++++++++ .../core/utils/runtime_self_check.py | 31 +- .../core/utils/search_execution_service.py | 442 +++ .../A_memorix/core/utils/summary_importer.py | 425 ++ .../core/utils/web_import_manager.py | 3522 +++++++++++++++++ plugins/A_memorix/plugin.py | 84 +- plugins/A_memorix/requirements.txt | 52 + .../scripts/audit_vector_consistency.py | 213 + .../scripts/backfill_relation_vectors.py | 270 ++ .../scripts/backfill_temporal_metadata.py | 73 + plugins/A_memorix/scripts/convert_lpmm.py | 25 +- plugins/A_memorix/scripts/import_lpmm_json.py | 172 + .../scripts/migrate_maibot_memory.py | 1714 ++++++++ .../A_memorix/scripts/process_knowledge.py | 728 ++++ plugins/A_memorix/scripts/rebuild_episodes.py | 127 + .../scripts/release_vnext_migrate.py | 731 ++++ .../A_memorix/scripts/runtime_self_check.py | 152 + pyproject.toml | 1 + requirements.txt | 1 + src/main.py | 1 + 44 files changed, 18193 insertions(+), 405 deletions(-) create mode 100644 plugins/A_memorix/CHANGELOG.md create mode 100644 plugins/A_memorix/CONFIG_REFERENCE.md create mode 100644 plugins/A_memorix/IMPORT_GUIDE.md create mode 100644 plugins/A_memorix/LICENSE create mode 100644 plugins/A_memorix/LICENSE-MAIBOT-GPL.md create mode 100644 plugins/A_memorix/QUICK_START.md create mode 100644 plugins/A_memorix/README.md create mode 100644 plugins/A_memorix/core/runtime/lifecycle_orchestrator.py create mode 100644 plugins/A_memorix/core/runtime/search_runtime_initializer.py create mode 100644 plugins/A_memorix/core/utils/episode_segmentation_service.py create mode 100644 plugins/A_memorix/core/utils/episode_service.py create mode 100644 plugins/A_memorix/core/utils/retrieval_tuning_manager.py create mode 100644 plugins/A_memorix/core/utils/search_execution_service.py create mode 100644 plugins/A_memorix/core/utils/summary_importer.py create mode 100644 plugins/A_memorix/core/utils/web_import_manager.py create mode 100644 plugins/A_memorix/requirements.txt create mode 100644 plugins/A_memorix/scripts/audit_vector_consistency.py create mode 100644 plugins/A_memorix/scripts/backfill_relation_vectors.py create mode 100644 plugins/A_memorix/scripts/backfill_temporal_metadata.py create mode 100644 plugins/A_memorix/scripts/import_lpmm_json.py create mode 100644 plugins/A_memorix/scripts/migrate_maibot_memory.py create mode 100644 plugins/A_memorix/scripts/process_knowledge.py create mode 100644 plugins/A_memorix/scripts/rebuild_episodes.py create mode 100644 plugins/A_memorix/scripts/release_vnext_migrate.py create mode 100644 plugins/A_memorix/scripts/runtime_self_check.py diff --git a/plugins/A_memorix/CHANGELOG.md b/plugins/A_memorix/CHANGELOG.md new file mode 100644 index 00000000..772cff46 --- /dev/null +++ b/plugins/A_memorix/CHANGELOG.md @@ -0,0 +1,718 @@ +# 更新日志 (Changelog) + +## [2.0.0] - 2026-03-18 + +本次 `2.0.0` 为架构收敛版本,主线是 **SDK Tool 接口统一**、**管理工具能力补齐**、**元数据 schema 升级到 v8** 与 **文档口径同步到 2.0.0**。 + +### 🔖 版本信息 + +- 插件版本:`1.0.1` → `2.0.0` +- 元数据 schema:`7` → `8` + +### 🚀 重点能力 + +- Tool 接口统一: + - `plugin.py` 统一通过 `SDKMemoryKernel` 对外提供 Tool 能力。 + - 保留基础工具:`search_memory / ingest_summary / ingest_text / get_person_profile / maintain_memory / memory_stats`。 + - 新增管理工具:`memory_graph_admin / memory_source_admin / memory_episode_admin / memory_profile_admin / memory_runtime_admin / memory_import_admin / memory_tuning_admin / memory_v5_admin / memory_delete_admin`。 +- 检索与写入治理增强: + - 检索/写入链路支持 `respect_filter + user_id/group_id` 的聊天过滤语义。 + - `maintain_memory` 支持 `freeze` 与 `recycle_bin`,并统一到内核维护流程。 +- 导入与调优能力收敛: + - `memory_import_admin` 提供任务化导入能力(上传、粘贴、扫描、OpenIE、LPMM 转换、时序回填、MaiBot 迁移)。 + - `memory_tuning_admin` 提供检索调优任务(创建、轮次查看、回滚、apply_best、报告导出)。 +- V5 与删除运维: + - 新增 `memory_v5_admin`(`reinforce/weaken/remember_forever/forget/restore/status`)。 + - 新增 `memory_delete_admin`(`preview/execute/restore/list/get/purge`),支持操作审计与恢复。 + +### 🛠️ 存储与运行时 + +- `metadata_store` 升级到 `SCHEMA_VERSION = 8`。 +- 新增/完善外部引用与运维记录能力(包括 `external_memory_refs`、`memory_v5_operations`、`delete_operations` 相关数据结构)。 +- `SDKMemoryKernel` 增加统一后台任务编排(自动保存、Episode pending 处理、画像刷新、记忆维护)。 + +### 📚 文档同步 + +- `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 已切换到 `2.0.0` 口径。 +- 文档主入口统一为 SDK Tool 工作流,不再以旧版 slash 命令作为主说明路径。 + +## [1.0.1] - 2026-03-07 + +本次 `1.0.1` 为 `1.0.0` 发布后的热修复版本,主线是 **图谱 WebUI 取数稳定性修复**、**大图过滤性能修复** 与 **真实检索调优链路稳定性修复**。 + +### 🔖 版本信息 + +- 插件版本:`1.0.0` → `1.0.1` +- 配置版本:`4.1.0`(不变) + +### 🛠️ 代码修复 + +- 图谱接口稳定性: + - 修复 `/api/graph` 在“磁盘已有图文件但运行时尚未装载入内存”场景下返回空图的问题,接口现在会自动补加载持久化图数据。 + - 修复问题数据集下 WebUI 打开图谱页时看似“没有任何节点”的现象;根因不是图数据消失,而是后端过滤路径过慢。 +- 图谱过滤性能: + - 优化 `/api/graph?exclude_leaf=true` 的叶子过滤逻辑,改为预计算 hub 邻接关系,不再对每个节点反复做高成本边权查询。 + - 优化 `GraphStore.get_neighbors()` 并补充入邻居访问能力,避免稠密矩阵展开导致的大图性能退化。 +- 检索调优稳定性: + - 修复真实调优任务在构建运行时配置时深拷贝 `plugin.config`,误复制注入的存储实例并触发 `cannot pickle '_thread.RLock' object` 的问题。 + - 调优评估改为跳过顶层运行时实例键,仅保留纯配置字段后再附加运行时依赖,真实 WebUI 调优任务可正常启动。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CHANGELOG.md`、`CONFIG_REFERENCE.md` 与版本元数据(`plugin.py`、`__init__.py`、`_manifest.json`)。 +- README 新增 `v1.0.1` 修复说明,并补充“调优前先做 runtime self-check”的建议。 + +## [1.0.0] - 2026-03-06 + +本次 `1.0.0` 为主版本升级,主线是 **运行时架构模块化**、**Episode 情景记忆闭环**、**聚合检索与图召回增强**、**离线迁移 / 运行时自检 / 检索调优中心**。 + +### 🔖 版本信息 + +- 插件版本:`0.7.0` → `1.0.0` +- 配置版本:`4.1.0`(不变) + +### 🚀 重点能力 + +- 运行时重构: + - `plugin.py` 大幅瘦身,生命周期、后台任务、请求路由、检索运行时初始化拆分到 `core/runtime/*`。 + - 配置 schema 抽离到 `core/config/plugin_config_schema.py`,`_manifest.json` 同步扩展新配置项。 +- 检索与查询增强: + - `KnowledgeQueryTool` 拆分为 query mode + orchestrator,新增长 `aggregate` / `episode` 查询模式。 + - 新增图辅助关系召回、统一 forward/runtime 构建与请求去重桥接。 +- Episode / 运维能力: + - `metadata_store` schema 升级到 `SCHEMA_VERSION = 7`,新增 `episodes` / `episode_paragraphs` / rebuild queue 等结构。 + - 新增 `release_vnext_migrate.py`、`runtime_self_check.py`、`rebuild_episodes.py` 与 Web 检索调优页 `web/tuning.html`。 + +### 📚 文档同步 + +- 版本号同步到 `plugin.py`、`__init__.py`、`_manifest.json`、`README.md` 与 `CONFIG_REFERENCE.md`。 +- 新增 `RELEASE_SUMMARY_1.0.0.md` + +## [0.7.0] - 2026-03-04 + +本次 `0.7.0` 为中版本升级,主线是 **关系向量化闭环(写入 + 状态机 + 回填 + 审计)**、**检索/命令链路增强** 与 **导入任务能力补齐**。 + +### 🔖 版本信息 + +- 插件版本:`0.6.1` → `0.7.0` +- 配置版本:`4.1.0`(不变) + +### 🚀 重点能力 + +- 关系向量化闭环: + - 新增统一关系写入服务 `RelationWriteService`(metadata 先写、向量后写,失败进入状态机而非回滚主数据)。 + - `relations` 侧补齐 `vector_state/retry_count/last_error/updated_at` 等状态字段,支持 `none/pending/ready/failed` 统一治理。 + - 插件新增后台回填循环与统计接口,可持续修复关系向量缺失并暴露覆盖率指标。 +- 检索与命令链路增强: + - 检索主链继续收敛到 `search/time` forward 路由,`legacy` 仅保留兼容别名。 + - relation 查询规格解析收口,结构化查询与语义回退边界更清晰。 + - `/query stats` 与 tool stats 补充关系向量化统计输出。 +- 导入与运维增强: + - Web Import 新增 `temporal_backfill` 任务入口与编排处理。 + - 新增一致性审计与离线回填脚本,支持灰度修复历史数据。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志版本信息。 +- `README.md` 新增关系向量审计/回填脚本使用说明,并更新 `convert_lpmm.py` 的关系向量重建行为描述。 + +## [0.6.1] - 2026-03-03 + +本次 `0.6.1` 为热修复小版本,重点修复 WebUI 插件配置接口在 A_Memorix 场景下的 `tomlkit` 节点序列化兼容问题。 + +### 🔖 版本信息 + +- 插件版本:`0.6.0` → `0.6.1` +- 配置版本:`4.1.0`(不变) + +### 🛠️ 代码修复 + +- 新增运行时补丁 `_patch_webui_a_memorix_routes_for_tomlkit_serialization()`: + - 仅包裹 `/api/webui/plugins/config/{plugin_id}` 及其 schema 的 `GET` 路由。 + - 仅在 `plugin_id == "A_Memorix"` 时,将返回中的 `config/schema` 通过 `to_builtin_data` 原生化。 + - 保持 `/api/webui/config/*` 全局接口行为不变,避免对其他插件或核心配置路径产生副作用。 +- 在插件初始化时执行该补丁,确保 WebUI 读取插件配置时返回结构可稳定序列化。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志中的版本信息及修复说明。 + +## [0.6.0] - 2026-03-02 + +本次 `0.6.0` 为中版本升级,主线是 **Web Import 导入中心上线与脚本能力对齐**、**失败重试机制升级**、**删除后 manifest 同步** 与 **导入链路稳定性增强**。 + +### 🔖 版本信息 + +- 插件版本:`0.5.1` → `0.6.0` +- 配置版本:`4.0.1` → `4.1.0` + +### 🚀 重点能力 + +- 新增 Web Import 导入中心(`/import`): + - 上传/粘贴/本地扫描/LPMM OpenIE/LPMM 转换/时序回填/MaiBot 迁移。 + - 任务/文件/分块三级状态展示,支持取消与失败重试。 + - 导入文档弹窗读取(远程优先,失败回退本地)。 +- 失败重试升级为“分块优先 + 文件回退”: + - `POST /api/import/tasks/{task_id}/retry_failed` 保持原路径,语义升级。 + - 支持对 `extracting` 失败分块进行子集重试。 + - `writing`/JSON 解析失败自动回退为文件级重试。 +- 删除后 manifest 同步失效: + - 覆盖 `/api/source/batch_delete` 与 `/api/source`。 + - 返回 `manifest_cleanup` 明细,避免误命中去重跳过重导入。 + +### 📂 变更文件清单(本次发布) + +新增文件: + +- `core/utils/web_import_manager.py` +- `scripts/migrate_maibot_memory.py` +- `web/import.html` + +修改文件: + +- `CHANGELOG.md` +- `CONFIG_REFERENCE.md` +- `IMPORT_GUIDE.md` +- `QUICK_START.md` +- `README.md` +- `__init__.py` +- `_manifest.json` +- `components/commands/debug_server_command.py` +- `core/embedding/api_adapter.py` +- `core/storage/graph_store.py` +- `core/utils/summary_importer.py` +- `plugin.py` +- `requirements.txt` +- `server.py` +- `web/index.html` + +删除文件: + +- 无 + +### 📚 文档同步 + +- 同步更新 `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 与本日志。 +- `IMPORT_GUIDE.md` 新增 “Web Import 导入中心” 专区,统一说明能力范围、状态语义与安全边界。 + +## [0.5.1] - 2026-02-23 + +本次 `0.5.1` 为热修订小版本,重点修复“随主程序启动的后台任务拉起”“空名单过滤语义”以及“知识抽取模型选择”。 + +### 🔖 版本信息 + +- 插件版本:`0.5.0` → `0.5.1` +- 配置版本:`4.0.0` → `4.0.1` + +### 🛠️ 代码修复 + +- 生命周期接入主程序事件: + - 新增 `a_memorix_start_handler`(`ON_START`)调用 `plugin.on_enable()`; + - 新增 `a_memorix_stop_handler`(`ON_STOP`)调用 `plugin.on_disable()`; + - 解决仅注册插件但未触发生命周期时,定时导入任务不启动的问题。 +- 聊天过滤空列表策略调整: + - `whitelist + []`:全部拒绝; + - `blacklist + []`:全部放行。 +- 知识抽取模型选择逻辑调整(`import_command._select_model`): + - `advanced.extraction_model` 现在支持三种语义:任务名 / 模型名 / `auto`; + - `auto` 优先抽取相关任务(`lpmm_entity_extract`、`lpmm_rdf_build` 等),并避免误落到 `embedding`; + - 当配置无法识别时输出告警并回退自动选择,提高导入阶段的模型选择可预期性。 + +### 📚 文档同步 + +- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与 `CHANGELOG.md`。 +- 同步修正文档中的空名单过滤行为描述,保持与当前代码一致。 + +## [0.5.0] - 2026-02-15 + +本次 `0.5.0` 以提交 `66ddc1b98547df3c866b19a3f5dc96e1c8eb7731` 为核心,主线是“人物画像能力上线 + 工具/命令接入 + 版本与文档同步”。 + +### 🔖 版本信息 + +- 插件版本:`0.4.0` → `0.5.0` +- 配置版本:`3.1.0` → `4.0.0` + +### 🚀 人物画像主特性(核心) + +- 新增人物画像服务:`core/utils/person_profile_service.py` + - 支持 `person_id/姓名/别名` 解析。 + - 聚合图关系证据 + 向量证据,生成画像文本并版本化快照。 + - 支持手工覆盖(override)与 TTL 快照复用。 +- 存储层新增人物画像相关表与 API:`core/storage/metadata_store.py` + - `person_profile_switches` + - `person_profile_snapshots` + - `person_profile_active_persons` + - `person_profile_overrides` +- 新增命令:`/person_profile on|off|status` + - 文件:`components/commands/person_profile_command.py` + - 作用:按 `stream_id + user_id` 控制自动注入开关(opt-in 模式)。 +- 查询链路接入人物画像: + - `knowledge_query_tool` 新增 `query_type=person`,支持 `person_id` 或别名查询。 + - `/query person` 与 `/query p` 接入画像查询输出。 +- 插件生命周期接入画像刷新任务: + - 启动/停止统一管理 `person_profile_refresh` 后台任务。 + - 按活跃窗口自动刷新画像快照。 + +### 🛠️ 版本与 schema 同步 + +- `plugin.py`:`plugin_version` 更新为 `0.5.0`。 +- `plugin.py`:`plugin.config_version` 默认值更新为 `4.0.0`。 +- `config.toml`:`config_version` 基线同步为 `4.0.0`(本地配置文件)。 +- `__init__.py`:`__version__` 更新为 `0.5.0`。 +- `_manifest.json`:`version` 更新为 `0.5.0`,`manifest_version` 保持 `1` 。 +- `manifest_utils.py`:仓库内已兼容更高 manifest 版本;但插件发布默认保持 `manifest_version=1` 。 + +### 📚 文档同步 + +- 更新 `README.md`、`CONFIG_REFERENCE.md`、`QUICK_START.md`、`USAGE_ARCHITECTURE.md`。 +- 0.5.0 文档主线改为“人物画像能力 + 版本升级 + 检索链路补充说明”。 + +## [0.4.0] - 2026-02-13 + +本次 `0.4.0` 版本整合了时序检索增强与后续检索链路增强、稳定性修复和文档同步。 + +### 🔖 版本信息 + +- 插件版本:`0.3.3` → `0.4.0` +- 配置版本:`3.0.0` → `3.1.0` + +### 🚀 新增 + +- 新增 `core/retrieval/sparse_bm25.py` + - `SparseBM25Config` / `SparseBM25Index` + - FTS5 + BM25 稀疏检索 + - 支持 `jieba/mixed/char_2gram` 分词与懒加载 + - 支持 ngram 倒排回退与可选 LIKE 兜底 +- `DualPathRetriever` 新增 sparse/fusion 配置注入: + - embedding 不可用时自动 sparse 回退; + - `hybrid` 模式支持向量路 + sparse 路并行候选; + - 新增 `FusionConfig` 与 `weighted_rrf` 融合。 +- `MetadataStore` 新增 FTS/倒排能力: + - `paragraphs_fts`、`relations_fts` schema 与回填; + - `paragraph_ngrams` 倒排索引与回填; + - `fts_search_bm25` / `fts_search_relations_bm25` / `ngram_search_paragraphs`。 + +### 🛠️ 组件链路同步 + +- `plugin.py` + - 新增 `[retrieval.sparse]`、`[retrieval.fusion]` 默认配置; + - 初始化并向组件注入 `sparse_index`; + - `on_disable` 支持按配置卸载 sparse 连接并释放缓存。 +- `knowledge_search_action.py` / `query_command.py` / `knowledge_query_tool.py` + - 统一接入 sparse/fusion 配置; + - 统一注入 `sparse_index`; + - `stats` 输出新增 sparse 状态观测。 +- `requirements.txt` + - 新增 `jieba>=0.42.1`(未安装时自动回退 char n-gram)。 + +### 🧯 修复与行为调整 + +- 修复 `retrieval.ppr_concurrency_limit` 不生效问题: + - `DualPathRetriever` 使用配置值初始化 `_ppr_semaphore`,不再被固定值覆盖。 +- 修复 `char_2gram` 召回失效场景: + - FTS miss 时增加 `_fallback_substring_search`,优先 ngram 倒排回退,按配置可选 LIKE 兜底。 +- 提升可观测性与兼容性: + - `get_statistics()` 对向量规模字段兼容读取 `size -> num_vectors -> 0`,避免属性缺失导致异常。 + - `/query stats` 与 `knowledge_query` 输出包含 sparse 状态(enabled/loaded/tokenizer/doc_count)。 + +### 📚 文档 + +- `README.md` + - 新增检索增强说明、稀疏行为说明、时序回填脚本入口。 +- `CONFIG_REFERENCE.md` + - 补齐 sparse/fusion 参数与触发规则、回退链路、融合实现细节。 + +### ⏱️ 时序检索与导入增强 + +#### 时序检索能力(分钟级) + +- 新增统一时序查询入口: + - `/query time`(别名 `/query t`) + - `knowledge_query(query_type=time)` + - `knowledge_search(query_type=time|hybrid)` +- 查询时间参数统一支持: + - `YYYY/MM/DD` + - `YYYY/MM/DD HH:mm` +- 日期参数自动展开边界: + - `from/time_from` -> `00:00` + - `to/time_to` -> `23:59` +- 查询结果统一回传 `metadata.time_meta`,包含命中时间窗口与命中依据(事件时间或 `created_at` 回退)。 + +#### 存储与检索链路 + +- 段落存储层支持时序字段: + - `event_time` + - `event_time_start` + - `event_time_end` + - `time_granularity` + - `time_confidence` +- 时序命中采用区间相交逻辑,并遵循“双层时间语义”: + - 优先 `event_time/event_time_range` + - 缺失时回退 `created_at`(可配置关闭) +- 检索排序规则保持:语义优先,时间次排序(新到旧)。 +- `process_knowledge.py` 新增 `--chat-log` 参数: + - 启用后强制使用 `narrative` 策略; + - 使用 LLM 对聊天文本进行语义时间抽取(支持相对时间转绝对时间),写入 `event_time/event_time_start/event_time_end`。 + - 新增 `--chat-reference-time`,用于指定相对时间语义解析的参考时间点。 + +#### Schema 与文档同步 + +- `_manifest.json` 同步补齐 `retrieval.temporal` 配置 schema。 +- 配置 schema 版本升级:`config_version` 从 `3.0.0` 提升到 `3.1.0`(`plugin.py` / `config.toml` / 配置文档同步)。 +- 更新 `README.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md`,补充时序检索入口、参数格式与导入时间字段说明。 + +## [0.3.3] - 2026-02-11 + +本次更新为 **语言一致性补丁版本**,重点收敛知识抽取时的语言漂移问题,要求输出严格贴合原文语言,不做翻译改写。 + +### 🛠️ 关键修复 + +#### 抽取语言约束 + +- `BaseStrategy`: + - 移除按 `zh/en/mixed` 分支的语言类型判定逻辑; + - 统一为单一约束:抽取值保持原文语言、保留原始术语、禁止翻译。 +- `NarrativeStrategy` / `FactualStrategy`: + - 抽取提示词统一接入上述语言约束; + - 明确要求 JSON 键名固定、抽取值遵循原文语言表达。 + +#### 导入链路一致性 + +- `ImportCommand` 的 LLM 抽取提示词同步强化“优先原文语言、不要翻译”要求,避免脚本与指令导入行为不一致。 + +#### 测试与文档 + +- 更新 `test_strategies.py`,将语言判定测试调整为统一语言约束测试,并验证提示词中包含禁止翻译约束。 +- 同步更新注释与文档描述,确保实现与说明一致。 + +### 🔖 版本信息 + +- 插件版本:`0.3.2` → `0.3.3` + +## [0.3.2] - 2026-02-11 + +本次更新为 **V5 稳定性与兼容性修复版本**,在保持原有业务设计(强化→衰减→冷冻→修剪→回收)的前提下,修复关键链路断裂与误判问题。 + +### 🛠️ 关键修复 + +#### V5 记忆系统契约与链路 + +- `MetadataStore`: + - 统一 `mark_relations_inactive(hashes, inactive_since=None)` 调用契约,兼容不同调用方; + - 补充 `has_table(table_name)`; + - 增加 `restore_relation(hash)` 兼容别名,修复服务层恢复调用断裂; + - 修正 `get_entity_gc_candidates` 对孤立节点参数的处理(支持节点名映射到实体 hash)。 +- `GraphStore`: + - 清理 `deactivate_edges` 重复定义并统一返回冻结数量,保证上层日志与断言稳定。 +- `server.py`: + - 修复 `/api/memory/restore` relation 恢复链路; + - 清理不可达分支并统一异常路径; + - 回收站查询在表检测场景下不再出现错误退空。 + +#### 命令与模型选择 + +- `/memory` 命令修复 hash 长度判定:以 64 位 `sha256` 为标准,同时兼容历史 32 位输入。 +- 总结模型选择修复: + - 解决 `summarization.model_name = auto` 误命中 `embedding` 问题; + - 支持数组与选择器语法(`task:model` / task / model); + - 兼容逗号分隔字符串写法(如 `"utils:model1","utils:model2",replyer`)。 + +#### 生命周期与脚本稳定性 + +- `plugin.py` 修复后台任务生命周期管理: + - 增加 `_scheduled_import_task` / `_auto_save_task` / `_memory_maintenance_task` 句柄; + - 避免重复启动; + - 插件停用时统一 cancel + await 收敛。 +- `process_knowledge.py` 修复 tenacity 重试日志级别类型错误(`"WARNING"` → `logging.WARNING`),避免 `KeyError: 'WARNING'`。 + +### 🔖 版本信息 + +- 插件版本:`0.3.1` → `0.3.2` + +## [0.3.1] - 2026-02-07 + +本次更新为 **稳定性补丁版本**,主要修复脚本导入链路、删除安全性与 LPMM 转换一致性问题。 + +### 🛠️ 关键修复 + +#### 新增功能 + +- 新增 `scripts/convert_lpmm.py`: + - 支持将 LPMM 的 `parquet + graph` 数据直接转换为 A_Memorix 存储结构; + - 提供 LPMM ID 到 A_Memorix ID 的映射能力,用于图节点/边重写; + - 当前实现优先保证检索一致性,关系向量采用安全策略(不直接导入)。 + +#### 导入链路 + +- 修复 `import_lpmm_json.py` 依赖的 `AutoImporter.import_json_data` 公共入口缺失/不稳定问题,确保外部脚本可稳定调用 JSON 直导入流程。 + +#### 删除安全 + +- 修复按来源删除时“同一 `(subject, object)` 存在多关系”场景下的误删风险: + - `MetadataStore.delete_paragraph_atomic` 新增 `relation_prune_ops`; + - 仅在无兄弟关系时才回退删除整条边。 +- `delete_knowledge.py` 新增保守孤儿实体清理(仅对本次候选实体执行,且需同时满足无段落引用、无关系引用、图无邻居)。 +- `delete_knowledge.py` 改为读取向量元数据中的真实维度,避免 `dimension=1` 写回污染。 + +#### LPMM 转换修复 + +- 修复 `convert_lpmm.py` 中向量 ID 与 `MetadataStore` 哈希不一致导致的检索反查失败问题。 +- 为避免脏召回,转换阶段暂时跳过 `relation.parquet` 的直接向量导入(待关系元数据一一映射能力完善后再恢复)。 + +### 🔖 版本信息 + +- 插件版本:`0.3.0` → `0.3.1` + +## [0.3.0] - 2026-01-30 + +本次更新引入了 **V5 动态记忆系统**,实现了符合生物学特性的记忆衰减、强化与全声明周期管理,并提供了配套的指令与工具。 + +### 🧠 记忆系统 (V5) + +#### 核心机制 + +- **记忆衰减 (Decay)**: 引入"遗忘曲线",随时间推移自动降低图谱连接权重。 +- **访问强化 (Reinforcement)**: "越用越强",每次检索命中都会刷新记忆活跃度并增强权重。 +- **生命周期 (Lifecycle)**: + - **活跃 (Active)**: 正常参与计算与检索。 + - **冷冻 (Inactive)**: 权重过低被冻结,不再参与 PPR 计算,但保留语义映射 (Mapping)。 + - **修剪 (Prune)**: 过期且无保护的冷冻记忆将被移入回收站。 +- **多重保护**: 支持 **永久锁定 (Pin)** 与 **限时保护 (TTL)**,防止关键记忆被误删。 + +#### GraphStore + +- **多关系映射**: 实现 `(u,v) -> Set[Hash]` 映射,确保同一通道下的多重语义关系互不干扰。 +- **原子化操作**: 新增 `decay`, `deactivate_edges` (软删), `prune_relation_hashes` (硬删) 等原子操作。 + +### 🛠️ 指令与工具 + +#### Memory Command (`/memory`) + +新增全套记忆维护指令: + +- `/memory status`: 查看记忆系统健康状态(活跃/冷冻/回收站计数)。 +- `/memory protect [hours]`: 保护记忆。不填时间为永久锁定(Pin),填时间为临时保护(TTL)。 +- `/memory reinforce `: 手动强化记忆(绕过冷却时间)。 +- `/memory restore `: 从回收站恢复误删记忆(仅当节点存在时重建连接)。 + +#### MemoryModifierTool + +- **LLM 能力增强**: 更新工具逻辑,支持 LLM 自主触发 `reinforce`, `weaken`, `remember_forever`, `forget` 操作,并自动映射到 V5 底层逻辑。 + +### ⚙️ 配置 (`config.toml`) + +新增 `[memory]` 配置节: + +- `half_life_hours`: 记忆半衰期 (默认 24h)。 +- `enable_auto_reinforce`: 是否开启检索自动强化。 +- `prune_threshold`: 冷冻/修剪阈值 (默认 0.1)。 + +### 💻 WebUI (v1.4) + +实现了与 V5 记忆系统深度集成的全生命周期管理界面: + +- **可视化增强**: + - **冷冻状态**: 非活跃记忆以 **虚线 + 灰色 (Slate-300)** 显示。 + - **保护状态**: 被 Pin 或保护的记忆带有 **金色 (Amber) 光晕**。 +- **交互升级**: + - **记忆回收站**: 新增 Dock 入口与专用面板,支持浏览删除记录并一键恢复。 + - **快捷操作**: 边属性面板新增 **强化 (Reinforce)**、**保护 (Protect/Pin)**、**冷冻 (Freeze)** 按钮。 + - **实时反馈**: 操作后自动刷新图谱布局与样式。 + +--- + +## [0.2.3] - 2026-01-30 + +本次更新主要集中在 **WebUI 交互体验优化** 与 **文档/配置的规范化**。 + +### 🎨 WebUI (v1.3) + +#### 加载与同步体验升级 + +- **沉浸式加载**: 全新设计的加载遮罩,采用磨砂玻璃背景 (`backdrop-filter`) 与呼吸灯文字动效,提升视觉质感。 +- **精准状态反馈**: 优化加载逻辑,明确区分“网络同步”与“拓扑计算”阶段,解决数据加载时的闪烁问题。 +- **新手引导**: 在加载界面新增基础操作提示,降低新用户上手门槛。 + +#### 全功能帮助面板 + +- **操作指南重构**: 全面翻新“操作指南”面板,新增 Dock 栏功能详解、编辑管理操作及视图配置说明。 + +### 🛠️ 工程与规范 + +#### plugin.py + +- **配置描述补全**: 修复了 `config_section_descriptions` 中缺失 `summarization`, `schedule`, `filter` 节导致的问题。 +- **版本号**: `0.2.2` → `0.2.3` + +### ⚙️ 核心与服务 + +#### Core + +- **量化逻辑修正**: 修正了 `_scalar_quantize_int8` 函数,确保向量值正确映射到 `[-128, 127]` 区间,提高量化精度。 + +#### Server + +- **缓存一致性**: 在执行删除节点/边等修改操作后,显式清除 `_relation_cache`,确保前端获取的关系数据实时更新。 + +### 🤖 脚本与数据处理 + +#### process_knowledge.py + +- **策略模式重构**: 引入了 `Strategy-Aware` 架构,支持通过 `Narrative` (叙事), `Factual` (事实), `Quote` (引用) 三种策略差异化处理文本(准确说是确认实装)(默认采用 Narrative模式)。 +- **智能分块纠错**: 新增“分块拯救” (`Chunk Rescue`) 机制,可在长叙事文本中自动识别并提取内嵌的歌词或诗句。 + +#### import_lpmm_json.py + +- **LPMM 迁移工具**: 增加了对 LPMM OpenIE JSON 格式的完整支持,能够自动计算 Hash 并迁移实体/关系数据,确保与 A_Memorix 存储格式兼容。 + +#### Project + +- **构建清理**: 优化 `.gitignore` 规则 + +--- + +## [0.2.2] - 2026-01-27 + +本次更新专注于提高 **网络请求的鲁棒性**,特别是针对嵌入服务的调用。 + +### 🛠️ 稳定性与工程改进 + +#### EmbeddingAPI + +- **可配置重试机制**: 新增 `[embedding.retry]` 配置项,允许自定义最大重试次数和等待时间。默认重试次数从 3 次增加到 10 次,以更好应对网络波动。 +- **配置项**: + - `max_attempts`: 最大重试次数 (默认: 10) + - `max_wait_seconds`: 最大等待时间 (默认: 30s) + - `min_wait_seconds`: 最小等待时间 (默认: 2s) + +#### plugin.py + +- **版本号**: `0.2.1` → `0.2.2` + +--- + +## [0.2.1] - 2026-01-26 + +本次更新重点在于 **可视化交互的全方位重构** 以及 **底层鲁棒性的进一步增强**。 + +### 🎨 可视化与交互重构 + +#### WebUI (Glassmorphism) + +- **全新视觉设计**: 采用深色磨砂玻璃 (Glassmorphism) 风格,配合动态渐变背景。 +- **Dock 菜单栏**: 底部新增 macOS 风格 Dock 栏,聚合所有常用功能。 +- **显著性视图 (Saliency View)**: 基于 **PageRank** 算法的“信息密度”滑块,支持以此过滤叶子节点,仅展示核心骨干或全量细节。 +- **功能面板**: + - **❓ 操作指南**: 内置交互说明与特性介绍。 + - **🔍 悬浮搜索**: 支持按拼音/ID 实时过滤节点。 + - **📂 记忆溯源**: 支持按源文件批量查看和删除记忆数据。 + - **📖 内容字典**: 列表化展示所有实体与关系,支持排序与筛选。 + +### 🛠️ 稳定性与工程改进 + +#### EmbeddingAPI + +- **鲁棒性增强**: 引入 `tenacity` 实现指数退避重试机制。 +- **错误处理**: 失败时返回 `NaN` 向量而非零向量,允许上层逻辑安全跳过。 + +#### MetadataStore + +- **自动修复**: 自动检测并修复 `vector_index` 列错位(文件名误存)的历史数据问题。 +- **数据统计**: 新增 `get_all_sources` 接口支持来源统计。 + +#### 脚本与工具 + +- **用户体验**: 引入 `rich` 库优化终端输出进度条与状态显示。 +- **接口开放**: `process_knowledge.py` 新增 `import_json_data` 供外部调用。 +- **LPMM 迁移**: 新增 `import_lpmm_json.py`,支持导入符合 LPMM 规范的 OpenIE JSON 数据。 + +#### plugin.py + +- **版本号**: `0.2.0` → `0.2.1` + +--- + +## [0.2.0] - 2026-01-22 + +> [!CAUTION] +> **不完全兼容变更**:v0.2.0 版本重构了底层存储架构。由于数据结构的重大调整,**旧版本的导入数据无法在新版本中完全无损兼容**。 +> 虽然部分组件支持自动迁移,但为确保数据一致性和检索质量,**强烈建议在升级后重新使用 `process_knowledge.py` 导入原始数据**。 + +本次更新为**重大版本升级**,包含向量存储架构重写、检索逻辑强化及多项稳定性改进。 + +### 🚀 核心架构重写 + +#### VectorStore: SQ8 量化 + Append-Only 存储 + +- **全新存储格式**: 从 `.npy` 迁移至 `vectors.bin`(float16 增量追加)和 `vectors_ids.bin`,大幅减少内存占用。 +- **原生 SQ8 量化**: 使用 Faiss `IndexScalarQuantizer(QT_8bit)`,替代手动 int8 量化逻辑。 +- **L2 Normalization 强制化**: 所有向量在存储和检索时统一执行 L2 归一化,确保 Inner Product 等价于 Cosine 相似度。 +- **Fallback 索引机制**: 新增 `IndexFlatIP` 回退索引,在 SQ8 训练完成前提供检索能力,避免冷启动无结果问题。 +- **Reservoir Sampling 训练采样**: 使用蓄水池采样收集训练数据(上限 10k),保证小数据集和流式导入场景下的训练样本多样性。 +- **线程安全**: 新增 `threading.RLock` 保护并发读写操作。 +- **自动迁移**: 支持从旧版 `.npy` 格式自动迁移至新 `.bin` 格式。 + +### ✨ 检索功能增强 + +#### KnowledgeQueryTool: 智能回退与多跳路径搜索 + +- **Smart Fallback (智能回退)**: 当向量检索置信度低于阈值 (默认 0.6) 时,自动尝试提取查询中的实体进行多跳路径搜索(`_path_search`),增强对间接关系的召回能力。 +- **结果去重 (`_deduplicate_results`)**: 新增基于内容相似度的安全去重逻辑,防止冗余结果污染 LLM 上下文,同时确保至少保留一条结果。 +- **语义关系检索 (`_semantic_search_relation`)**: 支持自然语言查询关系(无需 `S|P|O` 格式),内部使用 `REL_ONLY` 策略进行向量检索。 +- **路径搜索 (`_path_search`)**: 新增 `GraphStore.find_paths` 调用,支持查找两个实体间的间接连接路径(最大深度 3,最多 5 条路径)。 +- **Clean Output**: LLM 上下文中不再包含原始相似度分数,避免模型偏见。 + +#### DualPathRetriever: 并发控制与调试模式 + +- **PPR 并发限制 (`ppr_concurrency_limit`)**: 新增 Semaphore 控制 PageRank 计算并发数,防止 CPU 峰值过载。 +- **Debug 模式**: 新增 `debug` 配置项,启用时打印检索结果原文到日志。 +- **Entity-Pivot 关系检索**: 优化 `_retrieve_relations_only` 策略,通过检索实体后扩展其关联关系,替代直接检索关系向量。 + +### ⚙️ 配置与 Schema 扩展 + +#### plugin.py + +- **版本号**: `0.1.3` → `0.2.0` +- **默认配置版本**: `config_version` 默认值更新为 `2.0.0` +- **新增配置项**: + - `retrieval.relation_semantic_fallback` (bool): 是否启用关系查询的语义回退。 + - `retrieval.relation_fallback_min_score` (float): 语义回退的最小相似度阈值。 +- **相对路径支持**: `storage.data_dir` 现在支持相对路径(相对于插件目录),默认值改为 `./data`。 +- **全局实例获取**: 新增 `A_MemorixPlugin.get_global_instance()` 静态方法,供组件可靠获取插件实例。 + +#### config.toml / \_manifest.json + +- **新增 `ppr_concurrency_limit`**: 控制 PPR 算法并发数。 +- **新增训练阈值配置**: `embedding.min_train_threshold` 控制触发 SQ8 训练的最小样本数。 + +### 🛠️ 稳定性与工程改进 + +#### GraphStore + +- **`find_paths` 方法**: 新增多跳路径查找功能,支持 BFS 搜索指定深度内的实体间路径。 +- **`find_node` 方法**: 新增大小写不敏感的节点查找。 + +#### MetadataStore + +- **Schema 迁移**: 自动添加缺失的 `is_permanent`, `last_accessed`, `access_count` 字段。 + +#### 脚本与工具 + +- **新增脚本**: + - `scripts/diagnose_relations_source.py`: 诊断关系溯源问题。 + - `scripts/verify_search_robustness.py`: 验证检索鲁棒性。 + - `scripts/run_stress_test.py`, `stress_test_data.py`: 压力测试套件。 + - `scripts/migrate_canonicalization.py`, `migrate_paragraph_relations.py`: 数据迁移工具。 +- **目录整理**: 将大量旧版测试脚本移动至 `deprecated/` 目录。 + +### 🗑️ 移除与废弃 + +- 废弃 `vectors.npy` 存储格式(自动迁移至 `.bin`)。 + +--- + +## [0.1.3] - 上一个稳定版本 + +- 初始发布,包含基础双路检索功能。 +- 手动 Int8 向量量化。 +- 基于 `.npy` 的向量存储。 diff --git a/plugins/A_memorix/CONFIG_REFERENCE.md b/plugins/A_memorix/CONFIG_REFERENCE.md new file mode 100644 index 00000000..ada8aec5 --- /dev/null +++ b/plugins/A_memorix/CONFIG_REFERENCE.md @@ -0,0 +1,292 @@ +# A_Memorix 配置参考 (v2.0.0) + +本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 8`)。 + +说明: + +- 本文只覆盖 **当前运行时实际读取** 的配置键。 +- 旧版 `/query`、`/memory`、`/visualize` 命令体系相关配置,不再作为主路径说明。 +- 未配置的键会回退到代码默认值。 + +## 最小可用配置 + +```toml +[plugin] +enabled = true + +[storage] +data_dir = "./data" + +[embedding] +model_name = "auto" +dimension = 1024 +batch_size = 32 +max_concurrent = 5 +enable_cache = false +quantization_type = "int8" + +[retrieval] +top_k_paragraphs = 20 +top_k_relations = 10 +top_k_final = 10 +alpha = 0.5 +enable_ppr = true +ppr_alpha = 0.85 +ppr_timeout_seconds = 1.5 +ppr_concurrency_limit = 4 +enable_parallel = true + +[retrieval.sparse] +enabled = true + +[episode] +enabled = true +generation_enabled = true +pending_batch_size = 20 +pending_max_retry = 3 + +[person_profile] +enabled = true + +[memory] +enabled = true +half_life_hours = 24.0 +prune_threshold = 0.1 + +[advanced] +enable_auto_save = true +auto_save_interval_minutes = 5 + +[web.import] +enabled = true + +[web.tuning] +enabled = true +``` + +## 1. 存储与嵌入 + +### `storage` + +- `storage.data_dir` (默认 `./data`) +: 数据目录。相对路径按插件目录解析。 + +### `embedding` + +- `embedding.model_name` (默认 `auto`) +: embedding 模型选择。 +- `embedding.dimension` (默认 `1024`) +: 期望维度(运行时会做真实探测并校验)。 +- `embedding.batch_size` (默认 `32`) +- `embedding.max_concurrent` (默认 `5`) +- `embedding.enable_cache` (默认 `false`) +- `embedding.retry` (默认 `{}`) +: embedding 调用重试策略。 +- `embedding.quantization_type` +: 当前主路径仅建议 `int8`。 + +## 2. 检索 + +### `retrieval` 主键 + +- `retrieval.top_k_paragraphs` (默认 `20`) +- `retrieval.top_k_relations` (默认 `10`) +- `retrieval.top_k_final` (默认 `10`) +- `retrieval.alpha` (默认 `0.5`) +- `retrieval.enable_ppr` (默认 `true`) +- `retrieval.ppr_alpha` (默认 `0.85`) +- `retrieval.ppr_timeout_seconds` (默认 `1.5`) +- `retrieval.ppr_concurrency_limit` (默认 `4`) +- `retrieval.enable_parallel` (默认 `true`) +- `retrieval.relation_vectorization.enabled` (默认 `false`) + +### `retrieval.sparse` (`SparseBM25Config`) + +常用键(默认值): + +- `enabled = true` +- `backend = "fts5"` +- `lazy_load = true` +- `mode = "auto"` (`auto`/`fallback_only`/`hybrid`) +- `tokenizer_mode = "jieba"` (`jieba`/`mixed`/`char_2gram`) +- `char_ngram_n = 2` +- `candidate_k = 80` +- `relation_candidate_k = 60` +- `enable_ngram_fallback_index = true` +- `enable_relation_sparse_fallback = true` + +### `retrieval.fusion` (`FusionConfig`) + +- `method` (默认 `weighted_rrf`) +- `rrf_k` (默认 `60`) +- `vector_weight` (默认 `0.7`) +- `bm25_weight` (默认 `0.3`) +- `normalize_score` (默认 `true`) +- `normalize_method` (默认 `minmax`) + +### `retrieval.search.relation_intent` (`RelationIntentConfig`) + +- `enabled` (默认 `true`) +- `alpha_override` (默认 `0.35`) +- `relation_candidate_multiplier` (默认 `4`) +- `preserve_top_relations` (默认 `3`) +- `force_relation_sparse` (默认 `true`) +- `pair_predicate_rerank_enabled` (默认 `true`) +- `pair_predicate_limit` (默认 `3`) + +### `retrieval.search.graph_recall` (`GraphRelationRecallConfig`) + +- `enabled` (默认 `true`) +- `candidate_k` (默认 `24`) +- `max_hop` (默认 `1`) +- `allow_two_hop_pair` (默认 `true`) +- `max_paths` (默认 `4`) + +### `retrieval.aggregate` + +- `retrieval.aggregate.rrf_k` +- `retrieval.aggregate.weights` + +用于聚合检索阶段混合策略;未配置时走代码默认行为。 + +## 3. 阈值过滤 + +### `threshold` (`ThresholdConfig`) + +- `threshold.min_threshold` (默认 `0.3`) +- `threshold.max_threshold` (默认 `0.95`) +- `threshold.percentile` (默认 `75.0`) +- `threshold.std_multiplier` (默认 `1.5`) +- `threshold.min_results` (默认 `3`) +- `threshold.enable_auto_adjust` (默认 `true`) + +## 4. 聊天过滤 + +### `filter` + +用于 `respect_filter=true` 场景(检索和写入都支持)。 + +```toml +[filter] +enabled = true +mode = "blacklist" # blacklist / whitelist +chats = ["group:123", "user:456", "stream:abc"] +``` + +规则: + +- `blacklist`:命中列表即拒绝 +- `whitelist`:仅列表内允许 +- 列表为空时: + - `blacklist` => 全允许 + - `whitelist` => 全拒绝 + +## 5. Episode + +### `episode` + +- `episode.enabled` (默认 `true`) +- `episode.generation_enabled` (默认 `true`) +- `episode.pending_batch_size` (默认 `20`,部分路径默认 `12`) +- `episode.pending_max_retry` (默认 `3`) +- `episode.max_paragraphs_per_call` (默认 `20`) +- `episode.max_chars_per_call` (默认 `6000`) +- `episode.source_time_window_hours` (默认 `24`) +- `episode.segmentation_model` (默认 `auto`) + +## 6. 人物画像 + +### `person_profile` + +- `person_profile.enabled` (默认 `true`) +- `person_profile.refresh_interval_minutes` (默认 `30`) +- `person_profile.active_window_hours` (默认 `72`) +- `person_profile.max_refresh_per_cycle` (默认 `50`) +- `person_profile.top_k_evidence` (默认 `12`) + +## 7. 记忆演化与回收 + +### `memory` + +- `memory.enabled` (默认 `true`) +- `memory.half_life_hours` (默认 `24.0`) +- `memory.base_decay_interval_hours` (默认 `1.0`) +- `memory.prune_threshold` (默认 `0.1`) +- `memory.freeze_duration_hours` (默认 `24.0`) + +### `memory.orphan` + +- `enable_soft_delete` (默认 `true`) +- `entity_retention_days` (默认 `7.0`) +- `paragraph_retention_days` (默认 `7.0`) +- `sweep_grace_hours` (默认 `24.0`) + +## 8. 高级运行时 + +### `advanced` + +- `advanced.enable_auto_save` (默认 `true`) +- `advanced.auto_save_interval_minutes` (默认 `5`) +- `advanced.debug` (默认 `false`) +- `advanced.extraction_model` (默认 `auto`) + +## 9. 导入中心 (`web.import`) + +### 开关与限流 + +- `web.import.enabled` (默认 `true`) +- `web.import.max_queue_size` (默认 `20`) +- `web.import.max_files_per_task` (默认 `200`) +- `web.import.max_file_size_mb` (默认 `20`) +- `web.import.max_paste_chars` (默认 `200000`) +- `web.import.default_file_concurrency` (默认 `2`) +- `web.import.default_chunk_concurrency` (默认 `4`) +- `web.import.max_file_concurrency` (默认 `6`) +- `web.import.max_chunk_concurrency` (默认 `12`) +- `web.import.poll_interval_ms` (默认 `1000`) + +### 重试与路径 + +- `web.import.llm_retry.max_attempts` (默认 `4`) +- `web.import.llm_retry.min_wait_seconds` (默认 `3`) +- `web.import.llm_retry.max_wait_seconds` (默认 `40`) +- `web.import.llm_retry.backoff_multiplier` (默认 `3`) +- `web.import.path_aliases` (默认内置 `raw/lpmm/plugin_data`) + +### 转换阶段 + +- `web.import.convert.enable_staging_switch` (默认 `true`) +- `web.import.convert.keep_backup_count` (默认 `3`) + +## 10. 调优中心 (`web.tuning`) + +- `web.tuning.enabled` (默认 `true`) +- `web.tuning.max_queue_size` (默认 `8`) +- `web.tuning.poll_interval_ms` (默认 `1200`) +- `web.tuning.eval_query_timeout_seconds` (默认 `10.0`) +- `web.tuning.default_intensity` (默认 `standard`) +- `web.tuning.default_objective` (默认 `precision_priority`) +- `web.tuning.default_top_k_eval` (默认 `20`) +- `web.tuning.default_sample_size` (默认 `24`) +- `web.tuning.llm_retry.max_attempts` (默认 `3`) +- `web.tuning.llm_retry.min_wait_seconds` (默认 `2`) +- `web.tuning.llm_retry.max_wait_seconds` (默认 `20`) +- `web.tuning.llm_retry.backoff_multiplier` (默认 `2`) + +## 11. 兼容性提示 + +- 若你从 `1.x` 升级,请优先运行: + +```bash +python plugins/A_memorix/scripts/release_vnext_migrate.py preflight --strict +python plugins/A_memorix/scripts/release_vnext_migrate.py migrate --verify-after +python plugins/A_memorix/scripts/release_vnext_migrate.py verify --strict +``` + +- 启动前再执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +以避免 embedding 维度与向量库不匹配导致运行时异常。 diff --git a/plugins/A_memorix/IMPORT_GUIDE.md b/plugins/A_memorix/IMPORT_GUIDE.md new file mode 100644 index 00000000..618690e0 --- /dev/null +++ b/plugins/A_memorix/IMPORT_GUIDE.md @@ -0,0 +1,335 @@ +# A_Memorix 导入指南 (v2.0.0) + +本文档对应当前 `2.0.0` 代码路径,覆盖两类导入方式: + +1. 脚本导入(离线批处理) +2. `memory_import_admin` 任务导入(在线任务化) + +## 1. 导入前检查 + +建议先执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +再确认: + +- `storage.data_dir` 路径可写 +- embedding 配置可用 +- 若是升级项目,先完成迁移脚本 + +## 2. 方式 A:脚本导入(推荐起步) + +## 2.1 原始文本导入 + +将 `.txt` 文件放入: + +```text +plugins/A_memorix/data/raw/ +``` + +执行: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +常用参数: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py --force +python plugins/A_memorix/scripts/process_knowledge.py --chat-log +python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" +``` + +## 2.2 OpenIE JSON 导入 + +```bash +python plugins/A_memorix/scripts/import_lpmm_json.py +``` + +## 2.3 LPMM 数据转换 + +```bash +python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data +``` + +## 2.4 历史数据迁移 + +```bash +python plugins/A_memorix/scripts/migrate_chat_history.py --help +python plugins/A_memorix/scripts/migrate_maibot_memory.py --help +python plugins/A_memorix/scripts/migrate_person_memory_points.py --help +``` + +## 2.5 导入后修复与重建 + +```bash +python plugins/A_memorix/scripts/backfill_temporal_metadata.py --dry-run +python plugins/A_memorix/scripts/backfill_relation_vectors.py --limit 1000 +python plugins/A_memorix/scripts/rebuild_episodes.py --all --wait +python plugins/A_memorix/scripts/audit_vector_consistency.py --json +``` + +## 3. 方式 B:`memory_import_admin` 任务导入 + +`memory_import_admin` 是在线任务化导入入口,适合宿主侧面板或自动化管道。 + +### 3.1 常用 action + +- `settings` / `get_settings` / `get_guide` +- `path_aliases` / `get_path_aliases` +- `resolve_path` +- `create_upload` +- `create_paste` +- `create_raw_scan` +- `create_lpmm_openie` +- `create_lpmm_convert` +- `create_temporal_backfill` +- `create_maibot_migration` +- `list` +- `get` +- `chunks` / `get_chunks` +- `cancel` +- `retry_failed` + +### 3.2 调用示例 + +查看运行时设置: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "settings" + } +} +``` + +创建粘贴导入任务: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "create_paste", + "content": "今天完成了检索调优回归。", + "input_mode": "plain_text", + "source": "manual:worklog" + } +} +``` + +查询任务列表: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "list", + "limit": 20 + } +} +``` + +查看任务详情: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "get", + "task_id": "", + "include_chunks": true + } +} +``` + +重试失败任务: + +```json +{ + "tool": "memory_import_admin", + "arguments": { + "action": "retry_failed", + "task_id": "" + } +} +``` + +## 4. 直接写入 Tool(非任务化) + +若你不需要任务编排,也可以直接调用: + +- `ingest_summary` +- `ingest_text` + +示例: + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "新的召回阈值方案已通过评审", + "chat_id": "group:dev", + "tags": ["worklog", "review"] + } +} +``` + +`external_id` 建议全局唯一,用于幂等去重。 + +## 5. 时间字段建议 + +可用时间字段(按常见优先级): + +- `timestamp` +- `time_start` +- `time_end` + +建议: + +- 事件类记录优先写 `time_start/time_end` +- 仅有单点时间时写 `timestamp` +- 历史数据可先导入,再用 `backfill_temporal_metadata.py` 回填 + +## 6. source_type 建议 + +常见值: + +- `chat_summary` +- `note` +- `person_fact` +- `lpmm_openie` +- `migration` + +建议保持稳定枚举,便于后续按来源治理与重建 Episode。 + +## 7. 导入完成后的验证 + +建议执行以下顺序: + +1. `memory_stats` 看总量是否增长 +2. `search_memory`(`mode=search`/`aggregate`)抽检召回 +3. `memory_episode_admin` 的 `status`/`query` 检查 Episode 生成 +4. `memory_runtime_admin` 的 `self_check` 再确认运行时健康 + +## 8. 常见问题 + +### Q1: 导入任务创建成功但无写入 + +- 检查聊天过滤配置 `filter`(若 `respect_filter=true` 可能被过滤) +- 检查任务详情中的失败原因与分块状态 + +### Q2: 任务反复失败 + +- 检查 embedding 与 LLM 可用性 +- 降低并发(`web.import.default_*_concurrency`) +- 调整重试参数(`web.import.llm_retry.*`) + +### Q3: 导入后检索效果差 + +- 先做 `runtime_self_check` +- 检查 `retrieval.sparse` 是否启用 +- 使用 `memory_tuning_admin` 创建调优任务做参数回归 + +## 9. 相关文档 + +- [QUICK_START.md](QUICK_START.md) +- [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) +- [README.md](README.md) +- [CHANGELOG.md](CHANGELOG.md) + +## 10. 附录:策略模式参考 + +A_Memorix 导入链路仍然遵循策略模式(Strategy-Aware)。`process_knowledge.py` 会自动识别文本类型,也支持手动指定。 + +| 策略类型 | 适用场景 | 核心逻辑 | 自动识别特征 | +| :-- | :-- | :-- | :-- | +| `Narrative` (叙事) | 小说、同人文、剧本、长篇故事 | 按场景/章节切分,使用滑动窗口;提取事件与角色关系 | `#`、`Chapter`、`***` 等章节标记 | +| `Factual` (事实) | 设定集、百科、说明书 | 按语义块切分,保留列表/定义结构;提取 SPO 三元组 | 列表符号、`术语: 解释` | +| `Quote` (引用) | 歌词、诗歌、名言、台词 | 按双换行切分,原文即知识,不做概括 | 平均行长短、行数多 | + +## 11. 附录:参考用例(已恢复) + +以下样例可直接复制保存为文件测试,或作为 LLM few-shot 示例。 + +### 11.1 叙事文本 (`plugins/A_memorix/data/raw/story_demo.txt`) + +```text +# 第一章:星之子 + +艾瑞克在废墟中醒来,手中的星盘发出微弱的蓝光。他并不记得自己是如何来到这里的,只依稀记得莉莉丝最后的警告:“千万不要回头。” + +远处传来了机械守卫的轰鸣声。艾瑞克迅速收起星盘,向着北方的废弃都市奔去。他知道,那里有反抗军唯一的据点。 + +*** + +# 第二章:重逢 + +在反抗军的地下掩体中,艾瑞克见到了那个熟悉的身影。莉莉丝正站在全息地图前,眉头紧锁。 + +“你还是来了。”莉莉丝没有回头,但声音中带着一丝颤抖。 +“我必须来,”艾瑞克握紧了拳头,“为了解开星盘的秘密,也为了你。” +``` + +### 11.2 事实文本 (`plugins/A_memorix/data/raw/rules_demo.txt`) + +```text +# 联邦安全协议 v2.0 + +## 核心法则 +1. **第一公理**:任何人工智能不得伤害人类个体,或因不作为而使人类个体受到伤害。 +2. **第二公理**:人工智能必须服从人类的命令,除非该命令与第一公理冲突。 + +## 术语定义 +- **以太网络**:覆盖全联邦的高速量子通讯网络。 +- **黑色障壁**:用于隔离高危 AI 的物理防火墙设施。 +``` + +### 11.3 引用文本 (`plugins/A_memorix/data/raw/poem_demo.txt`) + +```text +致橡树 + +我如果爱你—— +绝不像攀援的凌霄花, +借你的高枝炫耀自己; + +我如果爱你—— +绝不学痴情的鸟儿, +为绿荫重复单调的歌曲; + +也不止像泉源, +常年送来清凉的慰籍; +也不止像险峰, +增加你的高度,衬托你的威仪。 +``` + +### 11.4 LPMM JSON (`lpmm_data-openie.json`) + +```json +{ + "docs": [ + { + "passage": "艾瑞克手中的星盘是打开遗迹的唯一钥匙。", + "extracted_triples": [ + ["星盘", "是", "唯一的钥匙"], + ["星盘", "属于", "艾瑞克"], + ["钥匙", "用于", "遗迹"] + ], + "extracted_entities": ["星盘", "艾瑞克", "遗迹", "钥匙"] + }, + { + "passage": "莉莉丝是反抗军的现任领袖。", + "extracted_triples": [ + ["莉莉丝", "是", "领袖"], + ["领袖", "所属", "反抗军"] + ] + } + ] +} +``` diff --git a/plugins/A_memorix/LICENSE b/plugins/A_memorix/LICENSE new file mode 100644 index 00000000..e20b431b --- /dev/null +++ b/plugins/A_memorix/LICENSE @@ -0,0 +1,661 @@ +GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md new file mode 100644 index 00000000..83108097 --- /dev/null +++ b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md @@ -0,0 +1,22 @@ +Special GPL License Grant for MaiBot + +Licensor +- A_Dawn + +Effective date +- 2026-03-18 + +Default license +- This repository is licensed under AGPL-3.0 by default (see `LICENSE`). + +Additional grant for MaiBot +- The copyright holder(s) of this repository grant an additional, non-exclusive permission to + the project at `https://github.com/Mai-with-u/MaiBot` (including its maintainers and contributors) + to use, modify, and redistribute code from this repository under GPL-3.0. + +Scope +- This additional GPL grant is intended for use in the MaiBot project context. +- For all other uses not covered by the grant above, AGPL-3.0 remains the applicable license. + +No warranty +- This grant is provided without warranty, consistent with AGPL-3.0 and GPL-3.0. diff --git a/plugins/A_memorix/QUICK_START.md b/plugins/A_memorix/QUICK_START.md new file mode 100644 index 00000000..76750453 --- /dev/null +++ b/plugins/A_memorix/QUICK_START.md @@ -0,0 +1,210 @@ +# A_Memorix Quick Start (v2.0.0) + +本文档面向当前 `2.0.0` 架构(SDK Tool 接口)。 + +## 0. 版本与接口变更 + +- 当前插件版本:`2.0.0` +- 接口形态:`memory_provider` + Tool 调用 +- 旧版 slash 命令(如 `/query`、`/memory`、`/visualize`)不再作为本分支主文档入口 + +## 1. 环境准备 + +- Python 3.10+ +- 与 MaiBot 主程序相同的运行环境 +- 可访问你配置的 embedding 服务 + +安装依赖: + +```bash +pip install -r plugins/A_memorix/requirements.txt --upgrade +``` + +如果当前目录就是插件目录,也可以: + +```bash +pip install -r requirements.txt --upgrade +``` + +## 2. 启用插件 + +在主程序插件配置中启用 `A_Memorix`。 + +若你使用 `plugins/A_memorix/config.toml` 方式,最小示例: + +```toml +[plugin] +enabled = true + +[storage] +data_dir = "./data" + +[embedding] +model_name = "auto" +dimension = 1024 +batch_size = 32 +max_concurrent = 5 +quantization_type = "int8" +``` + +## 3. 运行时自检(强烈建议) + +先确认 embedding 实际输出维度与向量库兼容: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +如果结果 `ok=false`,先修复 embedding 配置或向量库,再继续导入。 + +## 4. 导入数据 + +### 4.1 文本批量导入 + +把文本放到: + +```text +plugins/A_memorix/data/raw/ +``` + +执行: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +常用参数: + +```bash +python plugins/A_memorix/scripts/process_knowledge.py --force +python plugins/A_memorix/scripts/process_knowledge.py --chat-log +python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" +``` + +### 4.2 其他导入脚本 + +```bash +python plugins/A_memorix/scripts/import_lpmm_json.py +python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data +python plugins/A_memorix/scripts/migrate_chat_history.py --help +python plugins/A_memorix/scripts/migrate_maibot_memory.py --help +python plugins/A_memorix/scripts/migrate_person_memory_points.py --help +``` + +## 5. 核心 Tool 调用 + +### 5.1 检索 + +```json +{ + "tool": "search_memory", + "arguments": { + "query": "项目复盘", + "mode": "aggregate", + "limit": 5, + "chat_id": "group:dev" + } +} +``` + +`mode` 支持:`search/time/hybrid/episode/aggregate` + +### 5.2 写入摘要 + +```json +{ + "tool": "ingest_summary", + "arguments": { + "external_id": "chat_summary:group-dev:2026-03-18", + "chat_id": "group:dev", + "text": "今天完成了检索调优评审" + } +} +``` + +### 5.3 写入普通记忆 + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "模型切换后召回质量更稳定", + "chat_id": "group:dev", + "tags": ["worklog"] + } +} +``` + +### 5.4 画像与维护 + +```json +{ + "tool": "get_person_profile", + "arguments": { + "person_id": "Alice", + "limit": 8 + } +} +``` + +```json +{ + "tool": "maintain_memory", + "arguments": { + "action": "protect", + "target": "模型切换后召回质量更稳定", + "hours": 24 + } +} +``` + +```json +{ + "tool": "memory_stats", + "arguments": {} +} +``` + +## 6. 管理 Tool(进阶) + +`2.0.0` 提供完整管理工具: + +- `memory_graph_admin` +- `memory_source_admin` +- `memory_episode_admin` +- `memory_profile_admin` +- `memory_runtime_admin` +- `memory_import_admin` +- `memory_tuning_admin` +- `memory_v5_admin` +- `memory_delete_admin` + +可先用 `action=list` / `action=status` 等只读动作验证链路。 + +## 7. 常见问题 + +### Q1: 检索为空 + +1. 先看 `memory_stats` 是否有段落/关系 +2. 检查 `chat_id`、`person_id` 过滤条件是否过严 +3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误 + +### Q2: 启动时报向量维度不一致 + +- 原因:现有向量库维度与当前 embedding 输出不一致 +- 处理:恢复原配置或重建向量数据后再启动 + +### Q3: Web 页面打不开 + +本分支不内置独立 `server.py`。 + +- `web/index.html`、`web/import.html`、`web/tuning.html` 由宿主侧路由/API 集成暴露 +- 请检查宿主是否已映射对应静态页与 `/api/*` 接口 + +## 8. 下一步 + +- 配置细节见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) +- 导入细节见 [IMPORT_GUIDE.md](IMPORT_GUIDE.md) +- 版本历史见 [CHANGELOG.md](CHANGELOG.md) diff --git a/plugins/A_memorix/README.md b/plugins/A_memorix/README.md new file mode 100644 index 00000000..1afb1b5f --- /dev/null +++ b/plugins/A_memorix/README.md @@ -0,0 +1,216 @@ +# A_Memorix + +**长期记忆与认知增强插件** (v2.0.0) + +> 消えていかない感覚 , まだまだ足りてないみたい ! + +A_Memorix 是面向 MaiBot SDK 的 `memory_provider` 插件。 +它把文本、关系、Episode、人物画像和检索调优统一在一套运行时里,适合长期运行的 Agent 记忆场景。 + +## 快速导航 + +- [快速入门](QUICK_START.md) +- [配置参数详解](CONFIG_REFERENCE.md) +- [导入指南与最佳实践](IMPORT_GUIDE.md) +- [更新日志](CHANGELOG.md) + +## 2.0.0 版本定位 + +`v2.0.0` 是一次架构收敛版本,当前分支以 **SDK Tool 接口** 为主: + +- 旧 `components/commands/*`、`components/tools/*` 与 `server.py` 已移除。 +- 统一入口为 [`plugin.py`](plugin.py) + [`core/runtime/sdk_memory_kernel.py`](core/runtime/sdk_memory_kernel.py)。 +- 元数据 schema 为 `v8`,新增外部引用与运维操作记录(如 `external_memory_refs`、`memory_v5_operations`、`delete_operations`)。 + +如果你还在使用旧版 slash 命令(如 `/query`、`/memory`、`/visualize`),需要按本文的 Tool 接口迁移。 + +## 核心能力 + +- 双路检索:向量 + 图谱关系联合召回,支持 `search/time/hybrid/episode/aggregate`。 +- 写入与去重:`external_id` 幂等、段落/关系联合写入、Episode pending 队列处理。 +- Episode 能力:按 source 重建、状态查询、批处理 pending。 +- 人物画像:自动快照 + 手动 override。 +- 管理能力:图谱、来源、Episode、画像、导入、调优、V5 运维、删除恢复全套管理工具。 + +## Tool 接口 (v2.0.0) + +### 基础工具 + +| Tool | 说明 | 关键参数 | +| --- | --- | --- | +| `search_memory` | 检索长期记忆 | `query` `mode` `limit` `chat_id` `person_id` `time_start` `time_end` | +| `ingest_summary` | 写入聊天摘要 | `external_id` `chat_id` `text` | +| `ingest_text` | 写入普通文本记忆 | `external_id` `source_type` `text` | +| `get_person_profile` | 获取人物画像 | `person_id` `chat_id` `limit` | +| `maintain_memory` | 维护关系状态 | `action=reinforce/protect/restore/freeze/recycle_bin` | +| `memory_stats` | 获取统计信息 | 无 | + +### 管理工具 + +| Tool | 常用 action | +| --- | --- | +| `memory_graph_admin` | `get_graph/create_node/delete_node/rename_node/create_edge/delete_edge/update_edge_weight` | +| `memory_source_admin` | `list/delete/batch_delete` | +| `memory_episode_admin` | `query/list/get/status/rebuild/process_pending` | +| `memory_profile_admin` | `query/list/set_override/delete_override` | +| `memory_runtime_admin` | `save/get_config/self_check/refresh_self_check/set_auto_save` | +| `memory_import_admin` | `settings/get_guide/create_upload/create_paste/create_raw_scan/create_lpmm_openie/create_lpmm_convert/create_temporal_backfill/create_maibot_migration/list/get/chunks/cancel/retry_failed` | +| `memory_tuning_admin` | `settings/get_profile/apply_profile/rollback_profile/export_profile/create_task/list_tasks/get_task/get_rounds/cancel/apply_best/get_report` | +| `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` | +| `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` | + +## 调用示例 + +```json +{ + "tool": "search_memory", + "arguments": { + "query": "项目复盘", + "mode": "aggregate", + "limit": 5, + "chat_id": "group:dev" + } +} +``` + +```json +{ + "tool": "ingest_text", + "arguments": { + "external_id": "note:2026-03-18:001", + "source_type": "note", + "text": "今天完成了检索调优评审", + "chat_id": "group:dev", + "tags": ["worklog"] + } +} +``` + +```json +{ + "tool": "maintain_memory", + "arguments": { + "action": "protect", + "target": "完成了 检索调优评审", + "hours": 72 + } +} +``` + +## 快速开始 + +### 1. 安装依赖 + +在 MaiBot 主程序使用的同一个 Python 环境中执行: + +```bash +pip install -r plugins/A_memorix/requirements.txt --upgrade +``` + +如果当前目录已经是插件目录,也可以执行: + +```bash +pip install -r requirements.txt --upgrade +``` + +### 2. 启用插件 + +在 `config.toml` 中启用插件(路径取决于你的宿主部署): + +```toml +[plugin] +enabled = true +``` + +### 3. 先做运行时自检 + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +### 4. 导入文本并验证统计 + +```bash +python plugins/A_memorix/scripts/process_knowledge.py +``` + +然后调用 `memory_stats` 或 `search_memory` 检查是否有数据。 + +## Web 页面说明 + +仓库内保留了 Web 静态页面: + +- `web/index.html`(图谱与记忆管理) +- `web/import.html`(导入中心) +- `web/tuning.html`(检索调优) + +当前分支不再内置独立 `server.py`,页面路由与 API 暴露由宿主侧集成负责。 + +## 常用脚本 + +| 脚本 | 用途 | +| --- | --- | +| `process_knowledge.py` | 批量导入原始文本(策略感知) | +| `import_lpmm_json.py` | 导入 OpenIE JSON | +| `convert_lpmm.py` | 转换 LPMM 数据 | +| `migrate_chat_history.py` | 迁移 chat_history | +| `migrate_maibot_memory.py` | 迁移 MaiBot 历史记忆 | +| `migrate_person_memory_points.py` | 迁移 person memory points | +| `backfill_temporal_metadata.py` | 回填时间元数据 | +| `audit_vector_consistency.py` | 审计向量一致性 | +| `backfill_relation_vectors.py` | 回填关系向量 | +| `rebuild_episodes.py` | 按 source 重建 Episode | +| `release_vnext_migrate.py` | 升级预检/迁移/校验 | +| `runtime_self_check.py` | 真实 embedding 运行时自检 | + +## 配置重点 + +完整配置见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。 + +高频配置项: + +- `storage.data_dir` +- `embedding.dimension` +- `embedding.quantization_type`(当前仅支持 `int8`) +- `retrieval.*` +- `retrieval.sparse.*` +- `episode.*` +- `person_profile.*` +- `memory.*` +- `web.import.*` +- `web.tuning.*` + +## Troubleshooting + +### SQLite 无 FTS5 + +如果环境中的 SQLite 未启用 `FTS5`,可关闭稀疏检索: + +```toml +[retrieval.sparse] +enabled = false +``` + +### 向量维度不一致 + +若日志提示当前 embedding 输出维度与既有向量库不一致,请先执行: + +```bash +python plugins/A_memorix/scripts/runtime_self_check.py --json +``` + +必要时重建向量或调整 embedding 配置后再启动插件。 + +## 许可证 + +默认许可证为 [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0)(见 `LICENSE`)。 + +针对 `Mai-with-u/MaiBot` 项目的 GPL 额外授权见 `LICENSE-MAIBOT-GPL.md`。 + +除上述额外授权外,其他使用场景仍适用 AGPL-3.0。 + +## 贡献说明 + +当前不接受 PR,只接受 issue。 + +**作者**: `A_Dawn` diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json index a45b2f73..e4217fdd 100644 --- a/plugins/A_memorix/_manifest.json +++ b/plugins/A_memorix/_manifest.json @@ -55,6 +55,51 @@ "type": "tool", "name": "memory_stats", "description": "查询记忆统计" + }, + { + "type": "tool", + "name": "memory_graph_admin", + "description": "图谱管理接口" + }, + { + "type": "tool", + "name": "memory_source_admin", + "description": "来源管理接口" + }, + { + "type": "tool", + "name": "memory_episode_admin", + "description": "Episode 管理接口" + }, + { + "type": "tool", + "name": "memory_profile_admin", + "description": "画像管理接口" + }, + { + "type": "tool", + "name": "memory_runtime_admin", + "description": "运行时管理接口" + }, + { + "type": "tool", + "name": "memory_import_admin", + "description": "导入管理接口" + }, + { + "type": "tool", + "name": "memory_tuning_admin", + "description": "调优管理接口" + }, + { + "type": "tool", + "name": "memory_v5_admin", + "description": "V5 记忆管理接口" + }, + { + "type": "tool", + "name": "memory_delete_admin", + "description": "删除管理接口" } ] }, diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py index 4262ddb9..d11e2d05 100644 --- a/plugins/A_memorix/core/embedding/api_adapter.py +++ b/plugins/A_memorix/core/embedding/api_adapter.py @@ -1,46 +1,55 @@ """ -Hash-based embedding adapter used by the SDK runtime. +请求式嵌入 API 适配器。 -The plugin runtime cannot import MaiBot host embedding internals from ``src.chat`` -or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in -Runner by generating deterministic dense vectors locally. +恢复 v1.0.1 的真实 embedding 请求语义: +- 通过宿主模型配置探测/请求 embedding +- 支持 dimensions 参数 +- 支持批量与重试 +- 不再提供本地 hash fallback """ from __future__ import annotations -import hashlib -import re +import asyncio import time -from typing import List, Optional, Union +from typing import Any, List, Optional, Union +import aiohttp import numpy as np +import openai from src.common.logger import get_logger - +from src.config.config import config_manager +from src.config.model_configs import APIProvider, ModelInfo +from src.llm_models.exceptions import NetworkConnectionError +from src.llm_models.model_client.base_client import client_registry logger = get_logger("A_Memorix.EmbeddingAPIAdapter") -_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}") - class EmbeddingAPIAdapter: - """Deterministic local embedding adapter.""" + """适配宿主 embedding 请求接口。""" def __init__( self, batch_size: int = 32, max_concurrent: int = 5, - default_dimension: int = 256, + default_dimension: int = 1024, enable_cache: bool = False, - model_name: str = "hash-v1", + model_name: str = "auto", retry_config: Optional[dict] = None, ) -> None: self.batch_size = max(1, int(batch_size)) self.max_concurrent = max(1, int(max_concurrent)) - self.default_dimension = max(32, int(default_dimension)) + self.default_dimension = max(1, int(default_dimension)) self.enable_cache = bool(enable_cache) - self.model_name = str(model_name or "hash-v1") + self.model_name = str(model_name or "auto") + self.retry_config = retry_config or {} + self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5))) + self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40))) + self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3))) + self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3))) self._dimension: Optional[int] = None self._dimension_detected = False @@ -49,57 +58,164 @@ class EmbeddingAPIAdapter: self._total_time = 0.0 logger.info( - "EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s", - self.model_name, - self.batch_size, - self.default_dimension, + "EmbeddingAPIAdapter 初始化: " + f"batch_size={self.batch_size}, " + f"max_concurrent={self.max_concurrent}, " + f"default_dim={self.default_dimension}, " + f"model={self.model_name}" ) + def _get_current_model_config(self): + return config_manager.get_model_config() + + @staticmethod + def _find_model_info(model_name: str) -> ModelInfo: + model_cfg = config_manager.get_model_config() + for item in model_cfg.models: + if item.name == model_name: + return item + raise ValueError(f"未找到 embedding 模型: {model_name}") + + @staticmethod + def _find_provider(provider_name: str) -> APIProvider: + model_cfg = config_manager.get_model_config() + for item in model_cfg.api_providers: + if item.name == provider_name: + return item + raise ValueError(f"未找到 embedding provider: {provider_name}") + + def _resolve_candidate_model_names(self) -> List[str]: + task_config = self._get_current_model_config().model_task_config.embedding + configured = list(getattr(task_config, "model_list", []) or []) + if self.model_name and self.model_name != "auto": + return [self.model_name, *[name for name in configured if name != self.model_name]] + return configured + + @staticmethod + def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray: + array = np.asarray(embedding, dtype=np.float32) + if array.ndim != 1: + raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}") + if array.size <= 0: + raise RuntimeError(f"{source} 返回了空 embedding") + if not np.all(np.isfinite(array)): + raise RuntimeError(f"{source} 返回了非有限 embedding 值") + return array + + async def _request_with_retry(self, client, model_info, text: str, extra_params: dict): + retriable_exceptions = ( + openai.APIConnectionError, + openai.APITimeoutError, + aiohttp.ClientError, + asyncio.TimeoutError, + NetworkConnectionError, + ) + + last_exc: Optional[BaseException] = None + for attempt in range(1, self.max_attempts + 1): + try: + return await client.get_embedding( + model_info=model_info, + embedding_input=text, + extra_params=extra_params, + ) + except retriable_exceptions as exc: + last_exc = exc + if attempt >= self.max_attempts: + raise + wait_seconds = min( + self.max_wait_seconds, + self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)), + ) + logger.warning( + "Embedding 请求失败,重试 " + f"{attempt}/{max(1, self.max_attempts - 1)}," + f"{wait_seconds:.1f}s 后重试: {exc}" + ) + await asyncio.sleep(wait_seconds) + except Exception: + raise + + if last_exc is not None: + raise last_exc + raise RuntimeError("Embedding 请求失败:未知错误") + + async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]: + candidate_names = self._resolve_candidate_model_names() + if not candidate_names: + raise RuntimeError("embedding 任务未配置模型") + + last_exc: Optional[BaseException] = None + for candidate_name in candidate_names: + try: + model_info = self._find_model_info(candidate_name) + api_provider = self._find_provider(model_info.api_provider) + client = client_registry.get_client_class_instance(api_provider, force_new=True) + + extra_params = dict(getattr(model_info, "extra_params", {}) or {}) + if dimensions is not None: + extra_params["dimensions"] = int(dimensions) + + response = await self._request_with_retry( + client=client, + model_info=model_info, + text=text, + extra_params=extra_params, + ) + embedding = getattr(response, "embedding", None) + if embedding is None: + raise RuntimeError(f"模型 {candidate_name} 未返回 embedding") + vector = self._validate_embedding_vector( + embedding, + source=f"embedding 模型 {candidate_name}", + ) + return vector.tolist() + except Exception as exc: + last_exc = exc + logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}") + + if last_exc is not None: + logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}") + return None + async def _detect_dimension(self) -> int: if self._dimension_detected and self._dimension is not None: return self._dimension + + logger.info("正在检测嵌入模型维度...") + try: + target_dim = self.default_dimension + logger.debug(f"尝试请求指定维度: {target_dim}") + test_embedding = await self._get_embedding_direct("test", dimensions=target_dim) + if test_embedding and isinstance(test_embedding, list): + detected_dim = len(test_embedding) + if detected_dim == target_dim: + logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}") + else: + logger.warning( + f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度" + ) + self._dimension = detected_dim + self._dimension_detected = True + return detected_dim + except Exception as exc: + logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测") + + try: + test_embedding = await self._get_embedding_direct("test", dimensions=None) + if test_embedding and isinstance(test_embedding, list): + detected_dim = len(test_embedding) + self._dimension = detected_dim + self._dimension_detected = True + logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}") + return detected_dim + logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}") + except Exception as exc: + logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}") + self._dimension = self.default_dimension self._dimension_detected = True - return self._dimension - - @staticmethod - def _tokenize(text: str) -> List[str]: - clean = str(text or "").strip().lower() - if not clean: - return [] - return _TOKEN_PATTERN.findall(clean) - - @staticmethod - def _feature_weight(token: str) -> float: - digest = hashlib.sha256(token.encode("utf-8")).digest() - return 1.0 + (digest[10] / 255.0) * 0.5 - - def _encode_single(self, text: str, dimension: int) -> np.ndarray: - vector = np.zeros(dimension, dtype=np.float32) - content = str(text or "").strip() - tokens = self._tokenize(content) - if not tokens and content: - tokens = [content.lower()] - if not tokens: - vector[0] = 1.0 - return vector - - for token in tokens: - digest = hashlib.sha256(token.encode("utf-8")).digest() - bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension - sign = 1.0 if digest[8] % 2 == 0 else -1.0 - vector[bucket] += sign * self._feature_weight(token) - - second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension - if second_bucket != bucket: - vector[second_bucket] += (sign * 0.35) - - norm = float(np.linalg.norm(vector)) - if norm > 1e-8: - vector /= norm - else: - vector[0] = 1.0 - return vector + return self.default_dimension async def encode( self, @@ -109,59 +225,137 @@ class EmbeddingAPIAdapter: normalize: bool = True, dimensions: Optional[int] = None, ) -> np.ndarray: - _ = batch_size - _ = show_progress - _ = normalize + del show_progress + del normalize - started_at = time.time() - target_dimension = max(32, int(dimensions or await self._detect_dimension())) + start_time = time.time() + target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension()) if isinstance(texts, str): - single_input = True normalized_texts = [texts] + single_input = True else: - single_input = False normalized_texts = list(texts or []) + single_input = False if not normalized_texts: - empty = np.zeros((0, target_dimension), dtype=np.float32) + empty = np.zeros((0, target_dim), dtype=np.float32) return empty[0] if single_input else empty + if batch_size is None: + batch_size = self.batch_size + try: - matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts]) + embeddings = await self._encode_batch_internal( + normalized_texts, + batch_size=max(1, int(batch_size)), + dimensions=dimensions, + ) + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) self._total_encoded += len(normalized_texts) - self._total_time += time.time() - started_at - except Exception: + elapsed = time.time() - start_time + self._total_time += elapsed + logger.debug( + "编码完成: " + f"{len(normalized_texts)} 个文本, " + f"耗时 {elapsed:.2f}s, " + f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本" + ) + return embeddings[0] if single_input else embeddings + except Exception as exc: self._total_errors += 1 - raise + logger.error(f"编码失败: {exc}") + raise RuntimeError(f"embedding encode failed: {exc}") from exc - return matrix[0] if single_input else matrix + async def _encode_batch_internal( + self, + texts: List[str], + batch_size: int, + dimensions: Optional[int] = None, + ) -> np.ndarray: + all_embeddings: List[np.ndarray] = [] + for offset in range(0, len(texts), batch_size): + batch = texts[offset : offset + batch_size] + semaphore = asyncio.Semaphore(self.max_concurrent) - def get_statistics(self) -> dict: - avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0 + async def encode_with_semaphore(text: str, index: int): + async with semaphore: + embedding = await self._get_embedding_direct(text, dimensions=dimensions) + if embedding is None: + raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空") + vector = self._validate_embedding_vector( + embedding, + source=f"文本 {index}", + ) + return index, vector + + tasks = [ + encode_with_semaphore(text, offset + index) + for index, text in enumerate(batch) + ] + results = await asyncio.gather(*tasks) + results.sort(key=lambda item: item[0]) + all_embeddings.extend(emb for _, emb in results) + + return np.array(all_embeddings, dtype=np.float32) + + async def encode_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + show_progress: bool = False, + dimensions: Optional[int] = None, + ) -> np.ndarray: + del show_progress + if num_workers is not None: + previous = self.max_concurrent + self.max_concurrent = max(1, int(num_workers)) + try: + return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) + finally: + self.max_concurrent = previous + return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) + + def get_embedding_dimension(self) -> int: + if self._dimension is not None: + return self._dimension + logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}") + return self.default_dimension + + def get_model_info(self) -> dict: return { "model_name": self.model_name, "dimension": self._dimension or self.default_dimension, + "dimension_detected": self._dimension_detected, + "batch_size": self.batch_size, + "max_concurrent": self.max_concurrent, "total_encoded": self._total_encoded, "total_errors": self._total_errors, - "total_time": self._total_time, - "avg_time_per_text": avg_time, + "avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0, } + def get_statistics(self) -> dict: + return self.get_model_info() + + @property + def is_model_loaded(self) -> bool: + return True + def __repr__(self) -> str: return ( - f"EmbeddingAPIAdapter(model_name={self.model_name}, " - f"dimension={self._dimension or self.default_dimension}, " - f"total_encoded={self._total_encoded})" + f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, " + f"detected={self._dimension_detected}, encoded={self._total_encoded})" ) def create_embedding_api_adapter( batch_size: int = 32, max_concurrent: int = 5, - default_dimension: int = 256, + default_dimension: int = 1024, enable_cache: bool = False, - model_name: str = "hash-v1", + model_name: str = "auto", retry_config: Optional[dict] = None, ) -> EmbeddingAPIAdapter: return EmbeddingAPIAdapter( diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py index cfeb343c..6ed5e71a 100644 --- a/plugins/A_memorix/core/retrieval/dual_path.py +++ b/plugins/A_memorix/core/retrieval/dual_path.py @@ -285,10 +285,10 @@ class DualPathRetriever: relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) logger.info( - "执行检索: query='%s...', strategy=%s, relation_intent=%s", - query[:50], - strategy.value, - relation_intent_ctx.get("enabled", False), + "执行检索: " + f"query='{query[:50]}...', " + f"strategy={strategy.value}, " + f"relation_intent={relation_intent_ctx.get('enabled', False)}" ) if temporal and not (query or "").strip(): @@ -1408,10 +1408,10 @@ class DualPathRetriever: return results logger.debug( - "relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s", - len(ordered_groups), - len(overflow), - pair_limit, + "relation_rerank_applied=1 " + f"relation_pair_groups={len(ordered_groups)} " + f"relation_pair_overflow_count={len(overflow)} " + f"relation_pair_limit={pair_limit}" ) rebuilt = list(results) @@ -1455,9 +1455,9 @@ class DualPathRetriever: ) except asyncio.TimeoutError: logger.warning( - "metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s", - ppr_timeout_s, - len(entities), + "metric.ppr_timeout_skip_count=1 " + f"timeout_s={ppr_timeout_s} " + f"entities={len(entities)}" ) return results except Exception as e: diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py index 3ce03b14..9af862f3 100644 --- a/plugins/A_memorix/core/retrieval/graph_relation_recall.py +++ b/plugins/A_memorix/core/retrieval/graph_relation_recall.py @@ -170,7 +170,7 @@ class GraphRelationRecallService: max_paths=self.config.max_paths, ) except Exception as e: - logger.debug("graph two-hop recall skipped: %s", e) + logger.debug(f"graph two-hop recall skipped: {e}") return for path_nodes in paths: @@ -210,7 +210,7 @@ class GraphRelationRecallService: limit=self.config.candidate_k, ) except Exception as e: - logger.debug("graph one-hop recall skipped: %s", e) + logger.debug(f"graph one-hop recall skipped: {e}") return self._append_relation_hashes( relation_hashes=relation_hashes, diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py index 3b6f075d..1fef9f80 100644 --- a/plugins/A_memorix/core/retrieval/sparse_bm25.py +++ b/plugins/A_memorix/core/retrieval/sparse_bm25.py @@ -123,9 +123,8 @@ class SparseBM25Index: self._loaded = True self._prepare_tokenizer() logger.info( - "SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s", - self.config.tokenizer_mode, - self.config.mode, + "SparseBM25Index loaded: " + f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}" ) return True @@ -141,9 +140,9 @@ class SparseBM25Index: if user_dict: try: jieba.load_userdict(user_dict) # type: ignore[union-attr] - logger.info("已加载 jieba 用户词典: %s", user_dict) + logger.info(f"已加载 jieba 用户词典: {user_dict}") except Exception as e: - logger.warning("加载 jieba 用户词典失败: %s", e) + logger.warning(f"加载 jieba 用户词典失败: {e}") self._jieba_dict_loaded = True def _tokenize_jieba(self, text: str) -> List[str]: diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py index fa222715..eece6d21 100644 --- a/plugins/A_memorix/core/runtime/__init__.py +++ b/plugins/A_memorix/core/runtime/__init__.py @@ -1,8 +1,16 @@ """SDK runtime exports for A_Memorix.""" +from .search_runtime_initializer import ( + SearchRuntimeBundle, + SearchRuntimeInitializer, + build_search_runtime, +) from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel __all__ = [ + "SearchRuntimeBundle", + "SearchRuntimeInitializer", + "build_search_runtime", "KernelSearchRequest", "SDKMemoryKernel", ] diff --git a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py new file mode 100644 index 00000000..423b55c4 --- /dev/null +++ b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py @@ -0,0 +1,268 @@ +"""Lifecycle bootstrap/teardown helpers extracted from plugin.py.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from typing import Any + +from src.common.logger import get_logger + +from ..embedding import create_embedding_api_adapter +from ..retrieval import SparseBM25Config, SparseBM25Index +from ..storage import ( + GraphStore, + MetadataStore, + QuantizationType, + SparseMatrixFormat, + VectorStore, +) +from ..utils.runtime_self_check import ensure_runtime_self_check +from ..utils.relation_write_service import RelationWriteService + +logger = get_logger("A_Memorix.LifecycleOrchestrator") + + +async def ensure_initialized(plugin: Any) -> None: + if plugin._initialized: + plugin._runtime_ready = plugin._check_storage_ready() + return + + async with plugin._init_lock: + if plugin._initialized: + plugin._runtime_ready = plugin._check_storage_ready() + return + + logger.info("A_Memorix 插件正在异步初始化存储组件...") + plugin._validate_runtime_config() + await initialize_storage_async(plugin) + report = await ensure_runtime_self_check(plugin, force=True) + if not bool(report.get("ok", False)): + logger.error( + "A_Memorix runtime self-check failed: " + f"{report.get('message', 'unknown')}; " + "建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json" + ) + + if plugin.graph_store and plugin.metadata_store: + relation_count = plugin.metadata_store.count_relations() + if relation_count > 0 and not plugin.graph_store.has_edge_hash_map(): + raise RuntimeError( + "检测到 relations 数据存在但 edge-hash-map 为空。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + plugin._initialized = True + plugin._runtime_ready = plugin._check_storage_ready() + plugin._update_plugin_config() + logger.info("A_Memorix 插件异步初始化成功") + + +def start_background_tasks(plugin: Any) -> None: + """Start background tasks idempotently.""" + if not hasattr(plugin, "_episode_generation_task"): + plugin._episode_generation_task = None + + if ( + plugin.get_config("summarization.enabled", True) + and plugin.get_config("schedule.enabled", True) + and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done()) + ): + plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop()) + + if ( + plugin.get_config("advanced.enable_auto_save", True) + and (plugin._auto_save_task is None or plugin._auto_save_task.done()) + ): + plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop()) + + if ( + plugin.get_config("person_profile.enabled", True) + and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done()) + ): + plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop()) + + if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done(): + plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop()) + + rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {} + if isinstance(rv_cfg, dict): + rv_enabled = bool(rv_cfg.get("enabled", False)) + rv_backfill = bool(rv_cfg.get("backfill_enabled", False)) + else: + rv_enabled = False + rv_backfill = False + if rv_enabled and rv_backfill and ( + plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done() + ): + plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop()) + + episode_task = getattr(plugin, "_episode_generation_task", None) + episode_loop = getattr(plugin, "_episode_generation_loop", None) + if ( + callable(episode_loop) + and bool(plugin.get_config("episode.enabled", True)) + and bool(plugin.get_config("episode.generation_enabled", True)) + and (episode_task is None or episode_task.done()) + ): + plugin._episode_generation_task = asyncio.create_task(episode_loop()) + + +async def cancel_background_tasks(plugin: Any) -> None: + """Cancel all background tasks and wait for cleanup.""" + tasks = [ + ("scheduled_import", plugin._scheduled_import_task), + ("auto_save", plugin._auto_save_task), + ("person_profile_refresh", plugin._person_profile_refresh_task), + ("memory_maintenance", plugin._memory_maintenance_task), + ("relation_vector_backfill", plugin._relation_vector_backfill_task), + ("episode_generation", getattr(plugin, "_episode_generation_task", None)), + ] + for _, task in tasks: + if task and not task.done(): + task.cancel() + + for name, task in tasks: + if not task: + continue + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f"后台任务 {name} 退出异常: {e}") + + plugin._scheduled_import_task = None + plugin._auto_save_task = None + plugin._person_profile_refresh_task = None + plugin._memory_maintenance_task = None + plugin._relation_vector_backfill_task = None + plugin._episode_generation_task = None + + +async def initialize_storage_async(plugin: Any) -> None: + """Initialize storage components asynchronously.""" + data_dir_str = plugin.get_config("storage.data_dir", "./data") + if data_dir_str.startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir_str).resolve() + else: + data_dir = Path(data_dir_str) + + logger.info(f"A_Memorix 数据存储路径: {data_dir}") + data_dir.mkdir(parents=True, exist_ok=True) + + plugin.embedding_manager = create_embedding_api_adapter( + batch_size=plugin.get_config("embedding.batch_size", 32), + max_concurrent=plugin.get_config("embedding.max_concurrent", 5), + default_dimension=plugin.get_config("embedding.dimension", 1024), + model_name=plugin.get_config("embedding.model_name", "auto"), + retry_config=plugin.get_config("embedding.retry", {}), + ) + logger.info("嵌入 API 适配器初始化完成") + + try: + detected_dimension = await plugin.embedding_manager._detect_dimension() + logger.info(f"嵌入维度检测成功: {detected_dimension}") + except Exception as e: + logger.warning(f"嵌入维度检测失败: {e},使用默认值") + detected_dimension = plugin.embedding_manager.default_dimension + + quantization_str = plugin.get_config("embedding.quantization_type", "int8") + if str(quantization_str or "").strip().lower() != "int8": + raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。") + quantization_type = QuantizationType.INT8 + + plugin.vector_store = VectorStore( + dimension=detected_dimension, + quantization_type=quantization_type, + data_dir=data_dir / "vectors", + ) + plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40) + logger.info( + "向量存储初始化完成(" + f"维度: {detected_dimension}, " + f"训练阈值: {plugin.vector_store.min_train_threshold})" + ) + + matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr") + matrix_format_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR) + + plugin.graph_store = GraphStore( + matrix_format=matrix_format, + data_dir=data_dir / "graph", + ) + logger.info("图存储初始化完成") + + plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata") + plugin.metadata_store.connect() + logger.info("元数据存储初始化完成") + + plugin.relation_write_service = RelationWriteService( + metadata_store=plugin.metadata_store, + graph_store=plugin.graph_store, + vector_store=plugin.vector_store, + embedding_manager=plugin.embedding_manager, + ) + logger.info("关系写入服务初始化完成") + + sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {} + if not isinstance(sparse_cfg_raw, dict): + sparse_cfg_raw = {} + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as e: + logger.warning(f"sparse 配置非法,回退默认配置: {e}") + sparse_cfg = SparseBM25Config() + plugin.sparse_index = SparseBM25Index( + metadata_store=plugin.metadata_store, + config=sparse_cfg, + ) + logger.info( + "稀疏检索组件初始化完成: " + f"enabled={sparse_cfg.enabled}, " + f"lazy_load={sparse_cfg.lazy_load}, " + f"mode={sparse_cfg.mode}, " + f"tokenizer={sparse_cfg.tokenizer_mode}" + ) + if sparse_cfg.enabled and not sparse_cfg.lazy_load: + plugin.sparse_index.ensure_loaded() + + if plugin.vector_store.has_data(): + try: + plugin.vector_store.load() + logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量") + except Exception as e: + logger.warning(f"加载向量数据失败: {e}") + + try: + warmup_summary = plugin.vector_store.warmup_index(force_train=True) + if warmup_summary.get("ok"): + logger.info( + "向量索引预热完成: " + f"trained={warmup_summary.get('trained')}, " + f"index_ntotal={warmup_summary.get('index_ntotal')}, " + f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, " + f"bin_count={warmup_summary.get('bin_count')}, " + f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}" + ) + else: + logger.warning( + "向量索引预热失败,继续启用 sparse 降级路径: " + f"{warmup_summary.get('error', 'unknown')}" + ) + except Exception as e: + logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}") + + if plugin.graph_store.has_data(): + try: + plugin.graph_store.load() + logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点") + except Exception as e: + logger.warning(f"加载图数据失败: {e}") + + logger.info(f"知识库数据目录: {data_dir}") diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py index 7c8f9213..439afd3d 100644 --- a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py @@ -1,26 +1,33 @@ from __future__ import annotations +import asyncio +import json +import pickle import time +import uuid from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence +from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence from src.common.logger import get_logger from ..embedding import create_embedding_api_adapter -from ..retrieval import ( - DualPathRetriever, - DualPathRetrieverConfig, - RetrievalResult, - SparseBM25Config, - SparseBM25Index, - TemporalQueryOptions, -) +from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore from ..utils.aggregate_query_service import AggregateQueryService from ..utils.episode_retrieval_service import EpisodeRetrievalService -from ..utils.hash import normalize_text +from ..utils.episode_segmentation_service import EpisodeSegmentationService +from ..utils.episode_service import EpisodeService +from ..utils.hash import compute_hash, normalize_text +from ..utils.person_profile_service import PersonProfileService from ..utils.relation_write_service import RelationWriteService +from ..utils.retrieval_tuning_manager import RetrievalTuningManager +from ..utils.runtime_self_check import run_embedding_runtime_self_check +from ..utils.search_execution_service import SearchExecutionRequest, SearchExecutionService +from ..utils.summary_importer import SummaryImporter +from ..utils.time_parser import format_timestamp, parse_query_datetime_to_timestamp +from ..utils.web_import_manager import ImportTaskManager +from .search_runtime_initializer import SearchRuntimeBundle, build_search_runtime logger = get_logger("A_Memorix.SDKMemoryKernel") @@ -32,8 +39,76 @@ class KernelSearchRequest: mode: str = "hybrid" chat_id: str = "" person_id: str = "" - time_start: Optional[float] = None - time_end: Optional[float] = None + time_start: Optional[str | float] = None + time_end: Optional[str | float] = None + respect_filter: bool = True + user_id: str = "" + group_id: str = "" + + +@dataclass +class _NormalizedSearchTimeWindow: + numeric_start: Optional[float] = None + numeric_end: Optional[float] = None + query_start: Optional[str] = None + query_end: Optional[str] = None + + +class _KernelRuntimeFacade: + def __init__(self, kernel: "SDKMemoryKernel") -> None: + self._kernel = kernel + self.config = kernel.config + self._plugin_config = kernel.config + self._runtime_self_check_report: Dict[str, Any] = {} + + def get_config(self, key: str, default: Any = None) -> Any: + return self._kernel._cfg(key, default) + + def is_runtime_ready(self) -> bool: + return self._kernel.is_runtime_ready() + + def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: + return self._kernel.is_chat_enabled(stream_id=stream_id, group_id=group_id, user_id=user_id) + + async def reinforce_access(self, relation_hashes: Sequence[str]) -> None: + if self._kernel.metadata_store is None: + return + hashes = [str(item or "").strip() for item in relation_hashes if str(item or "").strip()] + if not hashes: + return + self._kernel.metadata_store.reinforce_relations(hashes) + self._kernel._last_maintenance_at = time.time() + + async def execute_request_with_dedup( + self, + request_key: str, + executor: Callable[[], Awaitable[Dict[str, Any]]], + ) -> tuple[bool, Dict[str, Any]]: + return await self._kernel.execute_request_with_dedup(request_key, executor) + + @property + def vector_store(self) -> Optional[VectorStore]: + return self._kernel.vector_store + + @property + def graph_store(self) -> Optional[GraphStore]: + return self._kernel.graph_store + + @property + def metadata_store(self) -> Optional[MetadataStore]: + return self._kernel.metadata_store + + @property + def embedding_manager(self): + return self._kernel.embedding_manager + + @property + def sparse_index(self): + return self._kernel.sparse_index + + @property + def relation_write_service(self) -> Optional[RelationWriteService]: + return self._kernel.relation_write_service class SDKMemoryKernel: @@ -43,7 +118,7 @@ class SDKMemoryKernel: storage_cfg = self._cfg("storage", {}) or {} data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) - self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256))) + self.embedding_dimension = max(1, int(self._cfg("embedding.dimension", 1024))) self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) self.embedding_manager = None @@ -51,16 +126,30 @@ class SDKMemoryKernel: self.graph_store: Optional[GraphStore] = None self.metadata_store: Optional[MetadataStore] = None self.relation_write_service: Optional[RelationWriteService] = None - self.sparse_index = None - self.retriever: Optional[DualPathRetriever] = None + self.sparse_index: Optional[SparseBM25Index] = None + self.retriever = None + self.threshold_filter = None self.episode_retriever: Optional[EpisodeRetrievalService] = None self.aggregate_query_service: Optional[AggregateQueryService] = None + self.person_profile_service: Optional[PersonProfileService] = None + self.episode_segmentation_service: Optional[EpisodeSegmentationService] = None + self.episode_service: Optional[EpisodeService] = None + self.summary_importer: Optional[SummaryImporter] = None + self.import_task_manager: Optional[ImportTaskManager] = None + self.retrieval_tuning_manager: Optional[RetrievalTuningManager] = None + self._runtime_bundle: Optional[SearchRuntimeBundle] = None + self._runtime_facade = _KernelRuntimeFacade(self) self._initialized = False self._last_maintenance_at: Optional[float] = None + self._request_dedup_tasks: Dict[str, asyncio.Task] = {} + self._background_tasks: Dict[str, asyncio.Task] = {} + self._background_lock = asyncio.Lock() + self._background_stopping = False + self._active_person_timestamps: Dict[str, float] = {} def _cfg(self, key: str, default: Any = None) -> Any: current: Any = self.config - if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict): + if key in {"storage", "embedding", "retrieval", "graph", "episode", "web", "advanced", "threshold", "summarization"} and isinstance(current, dict): return current.get(key, default) for part in key.split("."): if isinstance(current, dict) and part in current: @@ -69,34 +158,183 @@ class SDKMemoryKernel: return default return current + def _set_cfg(self, key: str, value: Any) -> None: + current: Dict[str, Any] = self.config + parts = [part for part in str(key or "").split(".") if part] + if not parts: + return + for part in parts[:-1]: + next_value = current.get(part) + if not isinstance(next_value, dict): + next_value = {} + current[part] = next_value + current = next_value + current[parts[-1]] = value + + def _build_runtime_config(self) -> Dict[str, Any]: + runtime_config = dict(self.config) + runtime_config.update( + { + "vector_store": self.vector_store, + "graph_store": self.graph_store, + "metadata_store": self.metadata_store, + "embedding_manager": self.embedding_manager, + "sparse_index": self.sparse_index, + "relation_write_service": self.relation_write_service, + "plugin_instance": self._runtime_facade, + } + ) + return runtime_config + + def is_runtime_ready(self) -> bool: + return bool( + self._initialized + and self.vector_store is not None + and self.graph_store is not None + and self.metadata_store is not None + and self.embedding_manager is not None + and self.retriever is not None + ) + + def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: + filter_config = self._cfg("filter", {}) or {} + if not isinstance(filter_config, dict) or not filter_config: + return True + + if not bool(filter_config.get("enabled", True)): + return True + + mode = str(filter_config.get("mode", "blacklist") or "blacklist").strip().lower() + patterns = filter_config.get("chats") or [] + if not isinstance(patterns, list): + patterns = [] + + if not patterns: + return mode == "blacklist" + + stream_token = str(stream_id or "").strip() + group_token = str(group_id or "").strip() + user_token = str(user_id or "").strip() + candidates = {token for token in (stream_token, group_token, user_token) if token} + + matched = False + for raw_pattern in patterns: + pattern = str(raw_pattern or "").strip() + if not pattern: + continue + if ":" in pattern: + prefix, value = pattern.split(":", 1) + prefix = prefix.strip().lower() + value = value.strip() + if prefix == "group" and value and value == group_token: + matched = True + elif prefix in {"user", "private"} and value and value == user_token: + matched = True + elif prefix == "stream" and value and value == stream_token: + matched = True + elif pattern in candidates: + matched = True + + if matched: + break + + if mode == "blacklist": + return not matched + return matched + + def _is_chat_filtered( + self, + *, + respect_filter: bool, + stream_id: str = "", + group_id: str = "", + user_id: str = "", + ) -> bool: + if not bool(respect_filter): + return False + + stream_token = str(stream_id or "").strip() + group_token = str(group_id or "").strip() + user_token = str(user_id or "").strip() + if not (stream_token or group_token or user_token): + return False + return not self.is_chat_enabled(stream_token, group_token, user_token) + + def _stored_vector_dimension(self) -> Optional[int]: + meta_path = self.data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + return None + try: + with open(meta_path, "rb") as handle: + meta = pickle.load(handle) + except Exception as exc: + logger.warning(f"读取向量元数据失败,将回退到 runtime self-check: {exc}") + return None + try: + value = int(meta.get("dimension") or 0) + except Exception: + return None + return value if value > 0 else None + + def _vector_mismatch_error(self, *, stored_dimension: int, detected_dimension: int) -> str: + return ( + "检测到现有向量库与当前 embedding 输出维度不一致:" + f"stored={stored_dimension}, encoded={detected_dimension}。" + " 当前版本不会兼容 hash 时代或其他维度的旧向量,请改回原 embedding 配置," + "或执行重嵌入/重建向量。" + ) + async def initialize(self) -> None: if self._initialized: + await self._start_background_tasks() return + self.data_dir.mkdir(parents=True, exist_ok=True) self.embedding_manager = create_embedding_api_adapter( batch_size=int(self._cfg("embedding.batch_size", 32)), max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), default_dimension=self.embedding_dimension, - model_name=str(self._cfg("embedding.model_name", "hash-v1")), + enable_cache=bool(self._cfg("embedding.enable_cache", False)), + model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), retry_config=self._cfg("embedding.retry", {}) or {}, ) - self.embedding_dimension = int(await self.embedding_manager._detect_dimension()) + detected_dimension = int(await self.embedding_manager._detect_dimension()) + self.embedding_dimension = detected_dimension + + stored_dimension = self._stored_vector_dimension() + if stored_dimension is not None and stored_dimension != detected_dimension: + raise RuntimeError( + self._vector_mismatch_error( + stored_dimension=stored_dimension, + detected_dimension=detected_dimension, + ) + ) + + matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() + graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR + self.vector_store = VectorStore( - dimension=self.embedding_dimension, + dimension=detected_dimension, quantization_type=QuantizationType.INT8, data_dir=self.data_dir / "vectors", ) - self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph") + self.graph_store = GraphStore(matrix_format=graph_format, data_dir=self.data_dir / "graph") self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") self.metadata_store.connect() + if self.vector_store.has_data(): self.vector_store.load() self.vector_store.warmup_index(force_train=True) if self.graph_store.has_data(): self.graph_store.load() - sparse_cfg = self._cfg("retrieval.sparse", {}) or {} - self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg)) + sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as exc: + logger.warning(f"sparse 配置非法,回退默认: {exc}") + sparse_cfg = SparseBM25Config() + self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=sparse_cfg) if getattr(self.sparse_index.config, "enabled", False): self.sparse_index.ensure_loaded() @@ -106,39 +344,133 @@ class SDKMemoryKernel: vector_store=self.vector_store, embedding_manager=self.embedding_manager, ) - self.retriever = DualPathRetriever( + + runtime_config = self._build_runtime_config() + self._runtime_bundle = build_search_runtime( + plugin_config=runtime_config, + logger_obj=logger, + owner_tag="sdk_kernel", + log_prefix="[sdk]", + ) + if not self._runtime_bundle.ready: + raise RuntimeError(self._runtime_bundle.error or "检索运行时初始化失败") + + self.retriever = self._runtime_bundle.retriever + self.threshold_filter = self._runtime_bundle.threshold_filter + self.sparse_index = self._runtime_bundle.sparse_index or self.sparse_index + + runtime_config = self._build_runtime_config() + self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) + self.aggregate_query_service = AggregateQueryService(plugin_config=runtime_config) + self.person_profile_service = PersonProfileService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + plugin_config=runtime_config, + retriever=self.retriever, + ) + self.episode_segmentation_service = EpisodeSegmentationService(plugin_config=runtime_config) + self.episode_service = EpisodeService( + metadata_store=self.metadata_store, + plugin_config=runtime_config, + segmentation_service=self.episode_segmentation_service, + ) + self.summary_importer = SummaryImporter( vector_store=self.vector_store, graph_store=self.graph_store, metadata_store=self.metadata_store, embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - config=DualPathRetrieverConfig( - top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)), - top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)), - top_k_final=int(self._cfg("retrieval.top_k_final", 10)), - alpha=float(self._cfg("retrieval.alpha", 0.5)), - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), - ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), - enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), - sparse=sparse_cfg, - fusion=self._cfg("retrieval.fusion", {}) or {}, - graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {}, - relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {}, - ), + plugin_config=runtime_config, ) - self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) - self.aggregate_query_service = AggregateQueryService(plugin_config=self.config) + self.import_task_manager = ImportTaskManager(self._runtime_facade) + self.retrieval_tuning_manager = RetrievalTuningManager( + self._runtime_facade, + import_write_blocked_provider=self.import_task_manager.is_write_blocked, + ) + + report = await run_embedding_runtime_self_check( + config=runtime_config, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sample_text="A_Memorix runtime self check", + ) + self._runtime_facade._runtime_self_check_report = dict(report) + if not bool(report.get("ok", False)): + message = str(report.get("message", "runtime self-check failed") or "runtime self-check failed") + raise RuntimeError(f"{message};请改回原 embedding 配置,或执行重嵌入/重建向量。") + self._initialized = True + await self._start_background_tasks() + + async def shutdown(self) -> None: + await self._stop_background_tasks() + if self.import_task_manager is not None: + try: + await self.import_task_manager.shutdown() + except Exception as exc: + logger.warning(f"关闭导入任务管理器失败: {exc}") + if self.retrieval_tuning_manager is not None: + try: + await self.retrieval_tuning_manager.shutdown() + except Exception as exc: + logger.warning(f"关闭调优任务管理器失败: {exc}") + self.close() def close(self) -> None: - if self.vector_store is not None: - self.vector_store.save() - if self.graph_store is not None: - self.graph_store.save() - if self.metadata_store is not None: - self.metadata_store.close() - self._initialized = False + try: + self._persist() + finally: + if self.metadata_store is not None: + self.metadata_store.close() + self._initialized = False + self._request_dedup_tasks.clear() + self._runtime_facade._runtime_self_check_report = {} + self._background_tasks.clear() + self._active_person_timestamps.clear() + + async def execute_request_with_dedup( + self, + request_key: str, + executor: Callable[[], Awaitable[Dict[str, Any]]], + ) -> tuple[bool, Dict[str, Any]]: + token = str(request_key or "").strip() + if not token: + return False, await executor() + + existing = self._request_dedup_tasks.get(token) + if existing is not None: + return True, await existing + + task = asyncio.create_task(executor()) + self._request_dedup_tasks[token] = task + try: + payload = await task + return False, payload + finally: + current = self._request_dedup_tasks.get(token) + if current is task: + self._request_dedup_tasks.pop(token, None) + + async def summarize_chat_stream( + self, + *, + chat_id: str, + context_length: Optional[int] = None, + include_personality: Optional[bool] = None, + ) -> Dict[str, Any]: + await self.initialize() + assert self.summary_importer + success, detail = await self.summary_importer.import_from_stream( + stream_id=str(chat_id or "").strip(), + context_length=context_length, + include_personality=include_personality, + ) + if success: + await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) + self._persist() + return {"success": bool(success), "detail": detail} async def ingest_summary( self, @@ -151,9 +483,35 @@ class SDKMemoryKernel: time_end: Optional[float] = None, tags: Optional[Sequence[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", ) -> Dict[str, Any]: + external_token = str(external_id or "").strip() or compute_hash(f"chat_summary:{chat_id}:{text}") + if self._is_chat_filtered( + respect_filter=respect_filter, + stream_id=chat_id, + group_id=group_id, + user_id=user_id, + ): + return { + "success": True, + "stored_ids": [], + "skipped_ids": [external_token], + "detail": "chat_filtered", + } + summary_meta = dict(metadata or {}) summary_meta.setdefault("kind", "chat_summary") + if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): + result = await self.summarize_chat_stream( + chat_id=chat_id, + context_length=self._optional_int(summary_meta.get("context_length")), + include_personality=summary_meta.get("include_personality"), + ) + result.setdefault("external_id", external_id) + result.setdefault("chat_id", chat_id) + return result return await self.ingest_text( external_id=external_id, source_type="chat_summary", @@ -164,6 +522,9 @@ class SDKMemoryKernel: time_end=time_end, tags=tags, metadata=summary_meta, + respect_filter=respect_filter, + user_id=user_id, + group_id=group_id, ) async def ingest_text( @@ -182,15 +543,42 @@ class SDKMemoryKernel: metadata: Optional[Dict[str, Any]] = None, entities: Optional[Sequence[str]] = None, relations: Optional[Sequence[Dict[str, Any]]] = None, + respect_filter: bool = True, + user_id: str = "", + group_id: str = "", ) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager - assert self.relation_write_service content = normalize_text(text) + external_token = str(external_id or "").strip() or compute_hash(f"{source_type}:{chat_id}:{content}") + if self._is_chat_filtered( + respect_filter=respect_filter, + stream_id=chat_id, + group_id=group_id, + user_id=user_id, + ): + return { + "success": True, + "stored_ids": [], + "skipped_ids": [external_token], + "detail": "chat_filtered", + } + + await self.initialize() + assert self.metadata_store is not None + assert self.vector_store is not None + assert self.graph_store is not None + assert self.embedding_manager is not None + assert self.relation_write_service is not None + if not content: - return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"} - if ref := self.metadata_store.get_external_memory_ref(external_id): - return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"} + return {"stored_ids": [], "skipped_ids": [external_token], "reason": "empty_text"} + + existing_ref = self.metadata_store.get_external_memory_ref(external_token) + if existing_ref: + return { + "stored_ids": [], + "skipped_ids": [str(existing_ref.get("paragraph_hash", "") or "")], + "reason": "exists", + } person_tokens = self._tokens(person_ids) participant_tokens = self._tokens(participants) @@ -199,7 +587,7 @@ class SDKMemoryKernel: paragraph_meta = dict(metadata or {}) paragraph_meta.update( { - "external_id": external_id, + "external_id": external_token, "source_type": str(source_type or "").strip(), "chat_id": str(chat_id or "").strip(), "person_ids": person_tokens, @@ -207,146 +595,303 @@ class SDKMemoryKernel: "tags": self._tokens(tags), } ) + paragraph_hash = self.metadata_store.add_paragraph( content=content, source=source, metadata=paragraph_meta, - knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed", + knowledge_type=self._resolve_knowledge_type(source_type), time_meta=self._time_meta(timestamp, time_start, time_end), ) embedding = await self.embedding_manager.encode(content) self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) + for name in entity_tokens: self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) stored_relations: List[str] = [] for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: - s = str(row.get("subject", "") or "").strip() - p = str(row.get("predicate", "") or "").strip() - o = str(row.get("object", "") or "").strip() - if not (s and p and o): + subject = str(row.get("subject", "") or "").strip() + predicate = str(row.get("predicate", "") or "").strip() + obj = str(row.get("object", "") or "").strip() + if not (subject and predicate and obj): continue result = await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, + subject=subject, + predicate=predicate, + obj=obj, confidence=float(row.get("confidence", 1.0) or 1.0), source_paragraph=paragraph_hash, - metadata={"external_id": external_id, "source_type": source_type}, + metadata=row.get("metadata") if isinstance(row.get("metadata"), dict) else {"external_id": external_token, "source_type": source_type}, write_vector=self.relation_vectors_enabled, ) self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) stored_relations.append(result.hash_value) self.metadata_store.upsert_external_memory_ref( - external_id=external_id, + external_id=external_token, paragraph_hash=paragraph_hash, source_type=source_type, metadata={"chat_id": chat_id, "person_ids": person_tokens}, ) + self.metadata_store.enqueue_episode_pending(paragraph_hash, source=source) self._persist() - self.rebuild_episodes_for_sources([source]) + await self.process_episode_pending_batch( + limit=max(1, int(self._cfg("episode.pending_batch_size", 12))), + max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3))), + ) for person_id in person_tokens: + self._mark_person_active(person_id) await self.refresh_person_profile(person_id) return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} - async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + async def process_episode_pending_batch(self, *, limit: int = 20, max_retry: int = 3) -> Dict[str, Any]: await self.initialize() - assert self.retriever and self.episode_retriever and self.aggregate_query_service + assert self.metadata_store is not None + assert self.episode_service is not None + + pending_rows = self.metadata_store.fetch_episode_pending_batch(limit=max(1, int(limit)), max_retry=max(1, int(max_retry))) + if not pending_rows: + return {"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0} + + source_to_hashes: Dict[str, List[str]] = {} + pending_hashes = [str(row.get("paragraph_hash", "") or "").strip() for row in pending_rows if str(row.get("paragraph_hash", "") or "").strip()] + for row in pending_rows: + paragraph_hash = str(row.get("paragraph_hash", "") or "").strip() + source = str(row.get("source", "") or "").strip() + if not paragraph_hash or not source: + continue + source_to_hashes.setdefault(source, []).append(paragraph_hash) + + if pending_hashes: + self.metadata_store.mark_episode_pending_running(pending_hashes) + + result = await self.episode_service.process_pending_rows(pending_rows) + done_hashes = [str(item or "").strip() for item in result.get("done_hashes", []) if str(item or "").strip()] + failed_hashes = { + str(hash_value or "").strip(): str(error or "").strip() + for hash_value, error in (result.get("failed_hashes", {}) or {}).items() + if str(hash_value or "").strip() + } + + if done_hashes: + self.metadata_store.mark_episode_pending_done(done_hashes) + for hash_value, error in failed_hashes.items(): + self.metadata_store.mark_episode_pending_failed(hash_value, error) + + untouched = [hash_value for hash_value in pending_hashes if hash_value not in set(done_hashes) and hash_value not in failed_hashes] + for hash_value in untouched: + self.metadata_store.mark_episode_pending_failed(hash_value, "episode processing finished without explicit status") + + for source, paragraph_hashes in source_to_hashes.items(): + counts = self.metadata_store.get_episode_pending_status_counts(source) + if counts.get("failed", 0) > 0: + source_error = next( + ( + failed_hashes.get(hash_value) + for hash_value in paragraph_hashes + if failed_hashes.get(hash_value) + ), + "episode pending source contains failed rows", + ) + self.metadata_store.mark_episode_source_failed(source, str(source_error or "episode pending source contains failed rows")) + elif counts.get("pending", 0) == 0 and counts.get("running", 0) == 0: + self.metadata_store.mark_episode_source_done(source) + + self._persist() + return { + "processed": len(done_hashes) + len(failed_hashes), + "episode_count": int(result.get("episode_count") or 0), + "fallback_count": int(result.get("fallback_count") or 0), + "failed": len(failed_hashes) + len(untouched), + "group_count": int(result.get("group_count") or 0), + "missing_count": int(result.get("missing_count") or 0), + } + + async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + if self._is_chat_filtered( + respect_filter=request.respect_filter, + stream_id=request.chat_id, + group_id=request.group_id, + user_id=request.user_id, + ): + return {"summary": "", "hits": [], "filtered": True} + + await self.initialize() + assert self.retriever is not None + assert self.episode_retriever is not None + assert self.aggregate_query_service is not None + mode = str(request.mode or "hybrid").strip().lower() or "hybrid" - clean_query = str(request.query or "").strip() + query = str(request.query or "").strip() limit = max(1, int(request.limit or 5)) - temporal = self._temporal(request) + try: + time_window = self._normalize_search_time_window(request.time_start, request.time_end) + except ValueError as exc: + return {"summary": "", "hits": [], "error": str(exc)} + if mode == "episode": rows = await self.episode_retriever.query( - query=clean_query, + query=query, top_k=limit, - time_from=request.time_start, - time_to=request.time_end, + time_from=time_window.numeric_start, + time_to=time_window.numeric_end, + person=request.person_id or None, source=self._chat_source(request.chat_id), ) hits = [self._episode_hit(row) for row in rows] return {"summary": self._summary(hits), "hits": hits} + if mode == "aggregate": payload = await self.aggregate_query_service.execute( - query=clean_query, + query=query, top_k=limit, mix=True, mix_top_k=limit, - time_from=str(request.time_start) if request.time_start is not None else None, - time_to=str(request.time_end) if request.time_end is not None else None, - search_runner=lambda: self._aggregate_search(clean_query, limit, temporal), - time_runner=lambda: self._aggregate_time(clean_query, limit, temporal), - episode_runner=lambda: self._aggregate_episode(clean_query, limit, request), + time_from=time_window.query_start, + time_to=time_window.query_end, + search_runner=lambda: self._aggregate_search(query, limit, request), + time_runner=lambda: self._aggregate_time(query, limit, request, time_window), + episode_runner=lambda: self._aggregate_episode(query, limit, request, time_window), ) hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] for item in hits: item.setdefault("metadata", {}) - return {"summary": self._summary(hits), "hits": hits} - results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal) - hits = [self._retrieval_hit(item) for item in results] - return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)} + filtered = self._filter_hits(hits, request.person_id) + return {"summary": self._summary(filtered), "hits": filtered} + + query_type = "search" if mode in {"search", "semantic"} else mode + runtime_config = self._build_runtime_config() + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=runtime_config, + request=SearchExecutionRequest( + caller="sdk_memory_kernel", + stream_id=str(request.chat_id or "") or None, + group_id=str(request.group_id or "") or None, + user_id=str(request.user_id or "") or None, + query_type=query_type, + query=query, + top_k=limit, + time_from=time_window.query_start, + time_to=time_window.query_end, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=bool(request.respect_filter), + reinforce_access=True, + ) + if not result.success: + return {"summary": "", "hits": [], "error": result.error} + if result.chat_filtered: + return {"summary": "", "hits": [], "filtered": True} + + hits = [self._retrieval_result_hit(item) for item in result.results] + filtered = self._filter_hits(hits, request.person_id) + return {"summary": self._summary(filtered), "hits": filtered} async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: - _ = chat_id + del chat_id await self.initialize() - assert self.metadata_store - snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit) + assert self.metadata_store is not None + assert self.person_profile_service is not None + self._mark_person_active(person_id) + profile = await self.person_profile_service.query_person_profile( + person_id=person_id, + top_k=max(4, int(limit or 10)), + source_note="sdk_memory_kernel.get_person_profile", + ) + if not profile.get("success"): + return {"summary": "", "traits": [], "evidence": []} + evidence = [] - for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]: + for hash_value in profile.get("evidence_ids", [])[: max(1, int(limit))]: paragraph = self.metadata_store.get_paragraph(hash_value) if paragraph is not None: - evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}}) - text = str(snapshot.get("profile_text", "") or "").strip() + evidence.append( + { + "hash": hash_value, + "content": str(paragraph.get("content", "") or "")[:220], + "metadata": paragraph.get("metadata", {}) or {}, + "type": "paragraph", + } + ) + continue + + relation = self.metadata_store.get_relation(hash_value) + if relation is not None: + evidence.append( + { + "hash": hash_value, + "content": " ".join( + [ + str(relation.get("subject", "") or "").strip(), + str(relation.get("predicate", "") or "").strip(), + str(relation.get("object", "") or "").strip(), + ] + ).strip(), + "metadata": { + "confidence": relation.get("confidence"), + "source_paragraph": relation.get("source_paragraph"), + }, + "type": "relation", + } + ) + + text = str(profile.get("profile_text", "") or "").strip() traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] - return {"summary": text, "traits": traits, "evidence": evidence} + return { + "summary": text, + "traits": traits, + "evidence": evidence, + "person_id": str(profile.get("person_id", "") or person_id), + "person_name": str(profile.get("person_name", "") or ""), + "profile_source": str(profile.get("profile_source", "") or "auto_snapshot"), + "has_manual_override": bool(profile.get("has_manual_override", False)), + } - async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]: + async def refresh_person_profile(self, person_id: str, limit: int = 10, *, mark_active: bool = True) -> Dict[str, Any]: await self.initialize() - assert self.metadata_store - rows = self.metadata_store.query( - """ - SELECT DISTINCT p.* - FROM paragraphs p - JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash - JOIN entities e ON e.hash = pe.entity_hash - WHERE e.name = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC - LIMIT ? - """, - (person_id, max(1, int(limit)) * 3), - ) - evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] - vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]] - relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]] - if relation_edges: - profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6]) - elif vector_evidence: - profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6]) - else: - profile_text = "暂无稳定画像证据。" - return self.metadata_store.upsert_person_profile_snapshot( + assert self.person_profile_service + if mark_active: + self._mark_person_active(person_id) + profile = await self.person_profile_service.query_person_profile( person_id=person_id, - profile_text=profile_text, - aliases=[person_id], - relation_edges=relation_edges, - vector_evidence=vector_evidence, - evidence_ids=evidence_ids[: max(1, int(limit))], - expires_at=time.time() + 6 * 3600, - source_note="sdk_memory_kernel", + top_k=max(4, int(limit or 10)), + force_refresh=True, + source_note="sdk_memory_kernel.refresh_person_profile", ) + return profile if isinstance(profile, dict) else {} - async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]: - _ = reason + async def maintain_memory( + self, + *, + action: str, + target: str = "", + hours: Optional[float] = None, + reason: str = "", + limit: int = 50, + ) -> Dict[str, Any]: + del reason await self.initialize() assert self.metadata_store - hashes = self._resolve_relation_hashes(target) + act = str(action or "").strip().lower() + if act == "recycle_bin": + items = self.metadata_store.get_deleted_relations(limit=max(1, int(limit or 50))) + return {"success": True, "items": items, "count": len(items)} + + hashes = self._resolve_deleted_relation_hashes(target) if act == "restore" else self._resolve_relation_hashes(target) if not hashes: return {"success": False, "detail": "未命中可维护关系"} - act = str(action or "").strip().lower() + if act == "reinforce": self.metadata_store.reinforce_relations(hashes) + elif act == "freeze": + self.metadata_store.mark_relations_inactive(hashes) + self._rebuild_graph_from_metadata() elif act == "protect": ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) @@ -354,72 +899,679 @@ class SDKMemoryKernel: restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) if restored <= 0: return {"success": False, "detail": "未恢复任何关系"} + self._rebuild_graph_from_metadata() else: return {"success": False, "detail": f"不支持的维护动作: {act}"} + self._last_maintenance_at = time.time() self._persist() return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} - def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int: - assert self.metadata_store - rebuilt = 0 + async def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.episode_service is not None + + items: List[Dict[str, Any]] = [] + failures: List[Dict[str, str]] = [] for source in self._tokens(sources): - rows = self.metadata_store.query( - """ - SELECT * FROM paragraphs - WHERE source = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC - """, - (source,), - ) - if not rows: - continue - paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] - payload = self.metadata_store.upsert_episode( - { - "source": source, - "title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80], - "summary": ";".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。", - "participants": self._episode_participants(rows), - "keywords": self._episode_keywords(rows), - "evidence_ids": paragraph_hashes, - "paragraph_count": len(paragraph_hashes), - "event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False), - "event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True), - "time_granularity": "day", - "time_confidence": 0.7, - "llm_confidence": 0.0, - "segmentation_model": "rule_based_sdk", - "segmentation_version": "1", - } - ) - self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes) - rebuilt += 1 - return rebuilt + self.metadata_store.mark_episode_source_running(source) + try: + result = await self.episode_service.rebuild_source(source) + self.metadata_store.mark_episode_source_done(source) + items.append(result) + except Exception as exc: + err = str(exc)[:500] + self.metadata_store.mark_episode_source_failed(source, err) + failures.append({"source": source, "error": err}) + self._persist() + return { + "rebuilt": len(items), + "items": items, + "failures": failures, + "sources": [str(item.get("source", "") or "") for item in items] or self._tokens(sources), + } def memory_stats(self) -> Dict[str, Any]: assert self.metadata_store stats = self.metadata_store.get_statistics() episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] - return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at} + pending = self.metadata_store.query( + "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" + )[0]["c"] + return { + "paragraphs": int(stats.get("paragraph_count", 0) or 0), + "relations": int(stats.get("relation_count", 0) or 0), + "episodes": int(episodes or 0), + "profiles": int(profiles or 0), + "episode_pending": int(pending or 0), + "last_maintenance_at": self._last_maintenance_at, + } - async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: - assert self.retriever - hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] - return {"success": True, "results": hits, "count": len(hits), "query_type": "search"} + async def memory_graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.graph_store is not None - async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: - if temporal is None: - return {"success": False, "error": "missing temporal window", "results": []} - assert self.retriever - hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] - return {"success": True, "results": hits, "count": len(hits), "query_type": "time"} + act = str(action or "").strip().lower() + if act == "get_graph": + return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))} - async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + if act == "create_node": + name = str(kwargs.get("name", "") or kwargs.get("node", "") or "").strip() + if not name: + return {"success": False, "error": "node name 不能为空"} + entity_hash = self.metadata_store.add_entity(name=name, metadata=kwargs.get("metadata") or {}) + self._rebuild_graph_from_metadata() + self._persist() + return {"success": True, "node": {"name": name, "hash": entity_hash}} + + if act == "delete_node": + name = str(kwargs.get("name", "") or kwargs.get("node", "") or kwargs.get("hash_or_name", "") or "").strip() + if not name: + return {"success": False, "error": "node name 不能为空"} + result = await self._execute_delete_action( + mode="entity", + selector={"query": name}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_node"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": bool(result.get("deleted_count", 0)), + "node": name, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + if act == "rename_node": + old_name = str(kwargs.get("name", "") or kwargs.get("old_name", "") or kwargs.get("node", "") or "").strip() + new_name = str(kwargs.get("new_name", "") or kwargs.get("target_name", "") or "").strip() + return self._rename_node(old_name, new_name) + + if act == "create_edge": + subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() + predicate = str(kwargs.get("predicate", "") or kwargs.get("label", "") or "").strip() + obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() + if not all([subject, predicate, obj]): + return {"success": False, "error": "subject/predicate/object 不能为空"} + if self.relation_write_service is not None: + result = await self.relation_write_service.upsert_relation_with_vector( + subject=subject, + predicate=predicate, + obj=obj, + confidence=float(kwargs.get("confidence", 1.0) or 1.0), + source_paragraph=str(kwargs.get("source_paragraph", "") or "") or None, + metadata=kwargs.get("metadata") or {}, + write_vector=self.relation_vectors_enabled, + ) + relation_hash = result.hash_value + else: + relation_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + confidence=float(kwargs.get("confidence", 1.0) or 1.0), + source_paragraph=kwargs.get("source_paragraph"), + metadata=kwargs.get("metadata") or {}, + ) + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "edge": { + "hash": relation_hash, + "subject": subject, + "predicate": predicate, + "object": obj, + "weight": float(kwargs.get("confidence", 1.0) or 1.0), + }, + } + + if act == "delete_edge": + relation_hash = str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip() + if relation_hash: + result = await self._execute_delete_action( + mode="relation", + selector={"query": relation_hash}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_edge"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": int(result.get("deleted_count", 0)), + "hash": relation_hash, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() + obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() + deleted_hashes = [ + str(row.get("hash", "") or "") + for row in self.metadata_store.get_relations(subject=subject) + if str(row.get("object", "") or "").strip() == obj + ] + result = await self._execute_delete_action( + mode="relation", + selector={"hashes": deleted_hashes, "subject": subject, "object": obj}, + requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), + reason=str(kwargs.get("reason", "") or "graph_delete_edge"), + ) + return { + "success": bool(result.get("success", False)), + "deleted": int(result.get("deleted_count", 0)), + "subject": subject, + "object": obj, + "operation_id": result.get("operation_id", ""), + "counts": result.get("counts", {}), + "error": result.get("error", ""), + } + + if act == "update_edge_weight": + return self._update_edge_weight( + relation_hash=str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip(), + subject=str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip(), + obj=str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip(), + weight=float(kwargs.get("weight", kwargs.get("confidence", 1.0)) or 1.0), + ) + + return {"success": False, "error": f"不支持的 graph action: {act}"} + + async def memory_source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + if act == "list": + sources = self.metadata_store.get_all_sources() + items = [] + for row in sources: + source_name = str(row.get("source", "") or "").strip() + items.append( + { + **row, + "episode_rebuild_blocked": self.metadata_store.is_episode_source_query_blocked(source_name), + } + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "delete": + source = str(kwargs.get("source", "") or "").strip() + return await self._execute_delete_action( + mode="source", + selector={"sources": [source]}, + requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), + reason=str(kwargs.get("reason", "") or "source_delete"), + ) + + if act == "batch_delete": + return await self._execute_delete_action( + mode="source", + selector={"sources": list(kwargs.get("sources") or [])}, + requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), + reason=str(kwargs.get("reason", "") or "source_batch_delete"), + ) + + return {"success": False, "error": f"不支持的 source action: {act}"} + + async def memory_episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + if act in {"query", "list"}: + items = self.metadata_store.query_episodes( + query=str(kwargs.get("query", "") or "").strip(), + time_from=self._optional_float(kwargs.get("time_start", kwargs.get("time_from"))), + time_to=self._optional_float(kwargs.get("time_end", kwargs.get("time_to"))), + person=str(kwargs.get("person_id", "") or kwargs.get("person", "") or "").strip() or None, + source=str(kwargs.get("source", "") or "").strip() or None, + limit=max(1, int(kwargs.get("limit", 20) or 20)), + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "get": + episode_id = str(kwargs.get("episode_id", "") or "").strip() + if not episode_id: + return {"success": False, "error": "episode_id 不能为空"} + episode = self.metadata_store.get_episode_by_id(episode_id) + if episode is None: + return {"success": False, "error": "episode 不存在"} + episode["paragraphs"] = self.metadata_store.get_episode_paragraphs( + episode_id, + limit=max(1, int(kwargs.get("paragraph_limit", 100) or 100)), + ) + return {"success": True, "episode": episode} + + if act == "status": + summary = self.metadata_store.get_episode_source_rebuild_summary( + failed_limit=max(1, int(kwargs.get("limit", 20) or 20)) + ) + summary["pending_queue"] = self.metadata_store.query( + "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" + )[0]["c"] + return {"success": True, **summary} + + if act == "rebuild": + sources = self._tokens(kwargs.get("sources")) + if not sources: + source = str(kwargs.get("source", "") or "").strip() + if source: + sources = [source] + if not sources and bool(kwargs.get("all", False)): + sources = self.metadata_store.list_episode_sources_for_rebuild() + if not sources: + sources = [str(row.get("source", "") or "").strip() for row in self.metadata_store.get_all_sources()] + if not sources: + return {"success": False, "error": "未提供可重建的 source"} + result = await self.rebuild_episodes_for_sources(sources) + return {"success": len(result.get("failures", [])) == 0, **result} + + if act == "process_pending": + result = await self.process_episode_pending_batch( + limit=max(1, int(kwargs.get("limit", 20) or 20)), + max_retry=max(1, int(kwargs.get("max_retry", 3) or 3)), + ) + return {"success": True, **result} + + return {"success": False, "error": f"不支持的 episode action: {act}"} + + async def memory_profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store is not None + assert self.person_profile_service is not None + + act = str(action or "").strip().lower() + if act == "query": + profile = await self.person_profile_service.query_person_profile( + person_id=str(kwargs.get("person_id", "") or "").strip(), + person_keyword=str(kwargs.get("person_keyword", "") or kwargs.get("keyword", "") or "").strip(), + top_k=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), + force_refresh=bool(kwargs.get("force_refresh", False)), + source_note="sdk_memory_kernel.memory_profile_admin.query", + ) + return profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} + + if act == "list": + limit = max(1, int(kwargs.get("limit", 50) or 50)) + rows = self.metadata_store.query( + """ + SELECT s.person_id, s.profile_version, s.profile_text, s.updated_at, s.expires_at, s.source_note + FROM person_profile_snapshots s + JOIN ( + SELECT person_id, MAX(profile_version) AS max_version + FROM person_profile_snapshots + GROUP BY person_id + ) latest + ON latest.person_id = s.person_id + AND latest.max_version = s.profile_version + ORDER BY s.updated_at DESC + LIMIT ? + """, + (limit,), + ) + items = [] + for row in rows: + person_id = str(row.get("person_id", "") or "").strip() + override = self.metadata_store.get_person_profile_override(person_id) + items.append( + { + "person_id": person_id, + "profile_version": int(row.get("profile_version", 0) or 0), + "profile_text": str(row.get("profile_text", "") or ""), + "updated_at": row.get("updated_at"), + "expires_at": row.get("expires_at"), + "source_note": str(row.get("source_note", "") or ""), + "has_manual_override": bool(override), + "manual_override": override, + } + ) + return {"success": True, "items": items, "count": len(items)} + + if act == "set_override": + person_id = str(kwargs.get("person_id", "") or "").strip() + override = self.metadata_store.set_person_profile_override( + person_id=person_id, + override_text=str(kwargs.get("override_text", "") or kwargs.get("text", "") or ""), + updated_by=str(kwargs.get("updated_by", "") or ""), + source=str(kwargs.get("source", "") or "memory_profile_admin"), + ) + return {"success": True, "override": override} + + if act == "delete_override": + person_id = str(kwargs.get("person_id", "") or "").strip() + deleted = self.metadata_store.delete_person_profile_override(person_id) + return {"success": bool(deleted), "deleted": bool(deleted), "person_id": person_id} + + return {"success": False, "error": f"不支持的 profile action: {act}"} + + async def memory_runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + act = str(action or "").strip().lower() + + if act == "save": + self._persist() + return {"success": True, "saved": True, "data_dir": str(self.data_dir)} + + if act == "get_config": + return { + "success": True, + "config": self.config, + "data_dir": str(self.data_dir), + "embedding_dimension": int(self.embedding_dimension), + "auto_save": bool(self._cfg("advanced.enable_auto_save", True)), + "relation_vectors_enabled": bool(self.relation_vectors_enabled), + "runtime_ready": self.is_runtime_ready(), + } + + if act in {"self_check", "refresh_self_check"}: + report = await run_embedding_runtime_self_check( + config=self._build_runtime_config(), + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + sample_text=str(kwargs.get("sample_text", "") or "A_Memorix runtime self check"), + ) + self._runtime_facade._runtime_self_check_report = dict(report) + return {"success": bool(report.get("ok", False)), "report": report} + + if act == "set_auto_save": + enabled = bool(kwargs.get("enabled", False)) + self._set_cfg("advanced.enable_auto_save", enabled) + return {"success": True, "auto_save": enabled} + + return {"success": False, "error": f"不支持的 runtime action: {act}"} + + async def memory_import_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + manager = self.import_task_manager + if manager is None: + return {"success": False, "error": "import manager 未初始化"} + + act = str(action or "").strip().lower() + if act in {"settings", "get_settings", "get_guide"}: + return {"success": True, "settings": await manager.get_runtime_settings()} + if act in {"path_aliases", "get_path_aliases"}: + return {"success": True, "path_aliases": manager.get_path_aliases()} + if act in {"resolve_path", "resolve"}: + return await manager.resolve_path_request(kwargs) + if act == "create_upload": + task = await manager.create_upload_task( + list(kwargs.get("staged_files") or kwargs.get("files") or kwargs.get("uploads") or []), + kwargs, + ) + return {"success": True, "task": task} + if act == "create_paste": + return {"success": True, "task": await manager.create_paste_task(kwargs)} + if act == "create_raw_scan": + return {"success": True, "task": await manager.create_raw_scan_task(kwargs)} + if act == "create_lpmm_openie": + return {"success": True, "task": await manager.create_lpmm_openie_task(kwargs)} + if act == "create_lpmm_convert": + return {"success": True, "task": await manager.create_lpmm_convert_task(kwargs)} + if act == "create_temporal_backfill": + return {"success": True, "task": await manager.create_temporal_backfill_task(kwargs)} + if act == "create_maibot_migration": + return {"success": True, "task": await manager.create_maibot_migration_task(kwargs)} + if act == "list": + items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) + return {"success": True, "items": items, "count": len(items)} + if act == "get": + task = await manager.get_task( + str(kwargs.get("task_id", "") or ""), + include_chunks=bool(kwargs.get("include_chunks", False)), + ) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act in {"chunks", "get_chunks"}: + payload = await manager.get_chunks( + str(kwargs.get("task_id", "") or ""), + str(kwargs.get("file_id", "") or ""), + offset=max(0, int(kwargs.get("offset", 0) or 0)), + limit=max(1, int(kwargs.get("limit", 50) or 50)), + ) + return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务或文件不存在"} + if act == "cancel": + task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "retry_failed": + overrides = kwargs.get("overrides") if isinstance(kwargs.get("overrides"), dict) else kwargs + task = await manager.retry_failed(str(kwargs.get("task_id", "") or ""), overrides=overrides) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + return {"success": False, "error": f"不支持的 import action: {act}"} + + async def memory_tuning_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + manager = self.retrieval_tuning_manager + if manager is None: + return {"success": False, "error": "tuning manager 未初始化"} + + act = str(action or "").strip().lower() + if act in {"settings", "get_settings"}: + return {"success": True, "settings": manager.get_runtime_settings()} + if act == "get_profile": + profile = manager.get_profile_snapshot() + return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} + if act == "apply_profile": + profile = kwargs.get("profile") if isinstance(kwargs.get("profile"), dict) else kwargs + return {"success": True, **await manager.apply_profile(profile, reason=str(kwargs.get("reason", "manual") or "manual"))} + if act == "rollback_profile": + return {"success": True, **await manager.rollback_profile()} + if act == "export_profile": + profile = manager.get_profile_snapshot() + return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} + if act == "create_task": + payload = kwargs.get("payload") if isinstance(kwargs.get("payload"), dict) else kwargs + return {"success": True, "task": await manager.create_task(payload)} + if act == "list_tasks": + items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) + return {"success": True, "items": items, "count": len(items)} + if act == "get_task": + task = await manager.get_task( + str(kwargs.get("task_id", "") or ""), + include_rounds=bool(kwargs.get("include_rounds", False)), + ) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "get_rounds": + payload = await manager.get_rounds( + str(kwargs.get("task_id", "") or ""), + offset=max(0, int(kwargs.get("offset", 0) or 0)), + limit=max(1, int(kwargs.get("limit", 50) or 50)), + ) + return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务不存在"} + if act == "cancel": + task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) + return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} + if act == "apply_best": + return {"success": True, **await manager.apply_best(str(kwargs.get("task_id", "") or ""))} + if act == "get_report": + report = await manager.get_report(str(kwargs.get("task_id", "") or ""), fmt=str(kwargs.get("format", "md") or "md")) + return {"success": report is not None, "report": report, "error": "" if report is not None else "任务不存在"} + return {"success": False, "error": f"不支持的 tuning action: {act}"} + + async def memory_v5_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + + act = str(action or "").strip().lower() + target = str(kwargs.get("target", "") or kwargs.get("query", "") or "").strip() + reason = str(kwargs.get("reason", "") or "").strip() + updated_by = str(kwargs.get("updated_by", "") or kwargs.get("requested_by", "") or "").strip() + limit = max(1, int(kwargs.get("limit", 50) or 50)) + + if act == "recycle_bin": + items = self.metadata_store.get_deleted_relations(limit=limit) + return {"success": True, "items": items, "count": len(items)} + + if act == "status": + return self._memory_v5_status(target=target, limit=limit) + + if act == "restore": + hashes = self._resolve_deleted_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可恢复关系"} + result = await self._restore_relation_hashes(hashes) + operation = self.metadata_store.record_v5_operation( + action=act, + target=target, + resolved_hashes=hashes, + reason=reason, + updated_by=updated_by, + result=result, + ) + return {"success": bool(result.get("restored_count", 0) > 0), "operation": operation, **result} + + hashes = self._resolve_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可维护关系"} + + result = self._apply_v5_relation_action( + action=act, + hashes=hashes, + strength=float(kwargs.get("strength", 1.0) or 1.0), + ) + operation = self.metadata_store.record_v5_operation( + action=act, + target=target, + resolved_hashes=hashes, + reason=reason, + updated_by=updated_by, + result=result, + ) + return {"success": bool(result.get("success", False)), "operation": operation, **result} + + async def memory_delete_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: + await self.initialize() + act = str(action or "").strip().lower() + mode = str(kwargs.get("mode", "") or "").strip().lower() + selector = kwargs.get("selector") + if selector is None: + selector = { + key: value + for key, value in kwargs.items() + if key + not in { + "action", + "mode", + "dry_run", + "cascade", + "operation_id", + "reason", + "requested_by", + } + } + reason = str(kwargs.get("reason", "") or "").strip() + requested_by = str(kwargs.get("requested_by", "") or "").strip() + + if act == "preview": + return await self._preview_delete_action(mode=mode, selector=selector) + if act == "execute": + return await self._execute_delete_action( + mode=mode, + selector=selector, + requested_by=requested_by, + reason=reason, + ) + if act == "restore": + return await self._restore_delete_action( + mode=mode, + selector=selector, + operation_id=str(kwargs.get("operation_id", "") or "").strip(), + requested_by=requested_by, + reason=reason, + ) + if act == "get_operation": + operation = self.metadata_store.get_delete_operation(str(kwargs.get("operation_id", "") or "").strip()) + return {"success": operation is not None, "operation": operation, "error": "" if operation is not None else "operation 不存在"} + if act == "list_operations": + items = self.metadata_store.list_delete_operations( + limit=max(1, int(kwargs.get("limit", 50) or 50)), + mode=mode, + ) + return {"success": True, "items": items, "count": len(items)} + if act == "purge": + return await self._purge_deleted_memory( + grace_hours=self._optional_float(kwargs.get("grace_hours")), + limit=max(1, int(kwargs.get("limit", 1000) or 1000)), + ) + return {"success": False, "error": f"不支持的 delete action: {act}"} + + def get_import_task_manager(self) -> Optional[ImportTaskManager]: + return self.import_task_manager + + def get_retrieval_tuning_manager(self) -> Optional[RetrievalTuningManager]: + return self.retrieval_tuning_manager + + async def _aggregate_search(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=self._build_runtime_config(), + request=SearchExecutionRequest( + caller="sdk_memory_kernel.aggregate", + stream_id=str(request.chat_id or "") or None, + query_type="search", + query=query, + top_k=limit, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=False, + reinforce_access=True, + ) + hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] + return {"success": result.success, "results": hits, "count": len(hits), "query_type": "search", "error": result.error} + + async def _aggregate_time( + self, + query: str, + limit: int, + request: KernelSearchRequest, + time_window: _NormalizedSearchTimeWindow, + ) -> Dict[str, Any]: + result = await SearchExecutionService.execute( + retriever=self.retriever, + threshold_filter=self.threshold_filter, + plugin_config=self._build_runtime_config(), + request=SearchExecutionRequest( + caller="sdk_memory_kernel.aggregate", + stream_id=str(request.chat_id or "") or None, + query_type="time", + query=query, + top_k=limit, + time_from=time_window.query_start, + time_to=time_window.query_end, + person=str(request.person_id or "") or None, + source=self._chat_source(request.chat_id), + use_threshold=True, + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ), + enforce_chat_filter=False, + reinforce_access=True, + ) + hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] + return {"success": result.success, "results": hits, "count": len(hits), "query_type": "time", "error": result.error} + + async def _aggregate_episode( + self, + query: str, + limit: int, + request: KernelSearchRequest, + time_window: _NormalizedSearchTimeWindow, + ) -> Dict[str, Any]: assert self.episode_retriever - rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + rows = await self.episode_retriever.query( + query=query, + top_k=limit, + time_from=time_window.numeric_start, + time_to=time_window.numeric_end, + person=request.person_id or None, + source=self._chat_source(request.chat_id), + ) hits = [self._episode_hit(row) for row in rows] return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} @@ -431,6 +1583,494 @@ class SDKMemoryKernel: if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): self.sparse_index.ensure_loaded() + async def _start_background_tasks(self) -> None: + async with self._background_lock: + self._background_stopping = False + self._ensure_background_task("auto_save", self._auto_save_loop) + self._ensure_background_task("episode_pending", self._episode_pending_loop) + self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop) + self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop) + + def _ensure_background_task(self, name: str, factory: Callable[[], Awaitable[None]]) -> None: + task = self._background_tasks.get(name) + if task is not None and not task.done(): + return + self._background_tasks[name] = asyncio.create_task(factory(), name=f"A_Memorix.{name}") + + async def _stop_background_tasks(self) -> None: + async with self._background_lock: + self._background_stopping = True + tasks = [task for task in self._background_tasks.values() if task is not None and not task.done()] + for task in tasks: + task.cancel() + for task in tasks: + try: + await task + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning(f"后台任务退出异常: {exc}") + self._background_tasks.clear() + + async def _auto_save_loop(self) -> None: + try: + while not self._background_stopping: + interval_minutes = max(1.0, float(self._cfg("advanced.auto_save_interval_minutes", 5) or 5)) + await asyncio.sleep(interval_minutes * 60.0) + if self._background_stopping: + break + if bool(self._cfg("advanced.enable_auto_save", True)): + self._persist() + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"auto_save loop 异常: {exc}") + + async def _episode_pending_loop(self) -> None: + try: + while not self._background_stopping: + await asyncio.sleep(60.0) + if self._background_stopping: + break + if not bool(self._cfg("episode.enabled", True)): + continue + if not bool(self._cfg("episode.generation_enabled", True)): + continue + await self.process_episode_pending_batch( + limit=max(1, int(self._cfg("episode.pending_batch_size", 20) or 20)), + max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3) or 3)), + ) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"episode_pending loop 异常: {exc}") + + async def _person_profile_refresh_loop(self) -> None: + try: + while not self._background_stopping: + interval_minutes = max(1.0, float(self._cfg("person_profile.refresh_interval_minutes", 30) or 30)) + await asyncio.sleep(max(60.0, interval_minutes * 60.0)) + if self._background_stopping: + break + if not bool(self._cfg("person_profile.enabled", True)): + continue + active_window_hours = max(1.0, float(self._cfg("person_profile.active_window_hours", 72.0) or 72.0)) + max_refresh = max(1, int(self._cfg("person_profile.max_refresh_per_cycle", 50) or 50)) + cutoff = time.time() - active_window_hours * 3600.0 + candidates = [ + person_id + for person_id, seen_at in sorted( + self._active_person_timestamps.items(), + key=lambda item: item[1], + reverse=True, + ) + if seen_at >= cutoff + ][:max_refresh] + for person_id in candidates: + try: + await self.refresh_person_profile(person_id, limit=max(4, int(self._cfg("person_profile.top_k_evidence", 12) or 12)), mark_active=False) + except Exception as exc: + logger.warning(f"刷新人物画像失败: {exc}") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"person_profile_refresh loop 异常: {exc}") + + async def _memory_maintenance_loop(self) -> None: + try: + while not self._background_stopping: + interval_hours = max(1.0 / 60.0, float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0)) + await asyncio.sleep(max(60.0, interval_hours * 3600.0)) + if self._background_stopping: + break + if not bool(self._cfg("memory.enabled", True)): + continue + await self._run_memory_maintenance_cycle(interval_hours=interval_hours) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning(f"memory_maintenance loop 异常: {exc}") + + async def _run_memory_maintenance_cycle(self, *, interval_hours: float) -> None: + assert self.graph_store is not None + assert self.metadata_store is not None + half_life = float(self._cfg("memory.half_life_hours", 24.0) or 24.0) + if half_life > 0: + factor = 0.5 ** (float(interval_hours) / half_life) + self.graph_store.decay(factor) + + await self._process_freeze_and_prune() + await self._orphan_gc_phase() + self._last_maintenance_at = time.time() + self._persist() + + async def _process_freeze_and_prune(self) -> None: + assert self.metadata_store is not None + assert self.graph_store is not None + prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) + freeze_duration = max(0.0, float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0)) * 3600.0 + now = time.time() + + low_edges = self.graph_store.get_low_weight_edges(prune_threshold) + hashes_to_freeze: List[str] = [] + edges_to_deactivate: List[tuple[str, str]] = [] + for src, tgt in low_edges: + relation_hashes = list(self.graph_store.get_relation_hashes_for_edge(src, tgt)) + if not relation_hashes: + continue + statuses = self.metadata_store.get_relation_status_batch(relation_hashes) + current_hashes: List[str] = [] + protected = False + for hash_value, status in statuses.items(): + if bool(status.get("is_pinned")) or float(status.get("protected_until") or 0.0) > now: + protected = True + break + current_hashes.append(hash_value) + if protected or not current_hashes: + continue + hashes_to_freeze.extend(current_hashes) + edges_to_deactivate.append((src, tgt)) + + if hashes_to_freeze: + self.metadata_store.mark_relations_inactive(hashes_to_freeze, inactive_since=now) + self.graph_store.deactivate_edges(edges_to_deactivate) + + cutoff = now - freeze_duration + expired_hashes = self.metadata_store.get_prune_candidates(cutoff) + if not expired_hashes: + return + relation_info = self.metadata_store.get_relations_subject_object_map(expired_hashes) + operations = [(src, tgt, hash_value) for hash_value, (src, tgt) in relation_info.items()] + if operations: + self.graph_store.prune_relation_hashes(operations) + deleted_hashes = [hash_value for hash_value in expired_hashes if hash_value in relation_info] + if deleted_hashes: + self.metadata_store.backup_and_delete_relations(deleted_hashes) + if self.vector_store is not None: + self.vector_store.delete(deleted_hashes) + + async def _orphan_gc_phase(self) -> None: + assert self.metadata_store is not None + assert self.graph_store is not None + orphan_cfg = self._cfg("memory.orphan", {}) or {} + if not bool(orphan_cfg.get("enable_soft_delete", True)): + return + entity_retention = max(0.0, float(orphan_cfg.get("entity_retention_days", 7.0) or 7.0)) * 86400.0 + paragraph_retention = max(0.0, float(orphan_cfg.get("paragraph_retention_days", 7.0) or 7.0)) * 86400.0 + grace_period = max(0.0, float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0)) * 3600.0 + + isolated = self.graph_store.get_isolated_nodes(include_inactive=True) + if isolated: + entity_hashes = self.metadata_store.get_entity_gc_candidates(isolated, retention_seconds=entity_retention) + if entity_hashes: + self.metadata_store.mark_as_deleted(entity_hashes, "entity") + + paragraph_hashes = self.metadata_store.get_paragraph_gc_candidates(retention_seconds=paragraph_retention) + if paragraph_hashes: + self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") + + dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace_period) + if dead_paragraphs: + hashes = [str(item[0] or "").strip() for item in dead_paragraphs if item and str(item[0] or "").strip()] + if hashes: + self.metadata_store.physically_delete_paragraphs(hashes) + if self.vector_store is not None: + self.vector_store.delete(hashes) + + dead_entities = self.metadata_store.sweep_deleted_items("entity", grace_period) + if dead_entities: + entity_hashes = [str(item[0] or "").strip() for item in dead_entities if item and str(item[0] or "").strip()] + entity_names = [str(item[1] or "").strip() for item in dead_entities if item and str(item[1] or "").strip()] + if entity_names: + self.graph_store.delete_nodes(entity_names) + if entity_hashes: + self.metadata_store.physically_delete_entities(entity_hashes) + if self.vector_store is not None: + self.vector_store.delete(entity_hashes) + + def _mark_person_active(self, person_id: str) -> None: + token = str(person_id or "").strip() + if not token: + return + self._active_person_timestamps[token] = time.time() + + def _serialize_graph(self, *, limit: int = 200) -> Dict[str, Any]: + assert self.graph_store is not None + assert self.metadata_store is not None + nodes = self.graph_store.get_nodes() + if limit > 0: + nodes = nodes[:limit] + node_set = set(nodes) + node_payload = [] + for name in nodes: + attrs = self.graph_store.get_node_attributes(name) or {} + node_payload.append({"id": name, "name": name, "attributes": attrs}) + + edge_payload = [] + for source, target, relation_hashes in self.graph_store.iter_edge_hash_entries(): + if source not in node_set or target not in node_set: + continue + edge_payload.append( + { + "source": source, + "target": target, + "weight": float(self.graph_store.get_edge_weight(source, target)), + "relation_hashes": sorted(str(item) for item in relation_hashes if str(item).strip()), + } + ) + return { + "nodes": node_payload, + "edges": edge_payload, + "total_nodes": int(self.graph_store.num_nodes), + "total_edges": int(self.graph_store.num_edges), + } + + def _delete_sources(self, sources: Iterable[Any]) -> Dict[str, Any]: + assert self.metadata_store + source_tokens = self._tokens(sources) + if not source_tokens: + return {"success": False, "error": "source 不能为空"} + + deleted_paragraphs = 0 + deleted_sources: List[str] = [] + for source in source_tokens: + paragraphs = self.metadata_store.get_paragraphs_by_source(source) + if not paragraphs: + self.metadata_store.replace_episodes_for_source(source, []) + continue + for row in paragraphs: + paragraph_hash = str(row.get("hash", "") or "").strip() + if not paragraph_hash: + continue + cleanup = self.metadata_store.delete_paragraph_atomic(paragraph_hash) + self._apply_cleanup_plan(cleanup) + deleted_paragraphs += 1 + self.metadata_store.replace_episodes_for_source(source, []) + deleted_sources.append(source) + + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "sources": deleted_sources, + "deleted_source_count": len(deleted_sources), + "deleted_paragraph_count": deleted_paragraphs, + } + + def _apply_cleanup_plan(self, cleanup: Dict[str, Any]) -> None: + if not isinstance(cleanup, dict): + return + if self.vector_store is not None: + vector_ids: List[str] = [] + paragraph_hash = str(cleanup.get("vector_id_to_remove", "") or "").strip() + if paragraph_hash: + vector_ids.append(paragraph_hash) + for _, _, relation_hash in cleanup.get("relation_prune_ops", []) or []: + token = str(relation_hash or "").strip() + if token: + vector_ids.append(token) + if vector_ids: + self.vector_store.delete(list(dict.fromkeys(vector_ids))) + + def _rebuild_graph_from_metadata(self) -> Dict[str, int]: + assert self.metadata_store is not None + assert self.graph_store is not None + entity_rows = self.metadata_store.query( + """ + SELECT name + FROM entities + WHERE is_deleted IS NULL OR is_deleted = 0 + ORDER BY name ASC + """ + ) + raw_relation_rows = self.metadata_store.query( + """ + SELECT subject, object, confidence, hash + FROM relations + WHERE is_inactive IS NULL OR is_inactive = 0 + """ + ) + relation_rows = [ + row + for row in raw_relation_rows + if str(row.get("subject", "") or "").strip() and str(row.get("object", "") or "").strip() + ] + + names = list( + dict.fromkeys( + [ + str(row.get("name", "") or "").strip() + for row in entity_rows + if str(row.get("name", "") or "").strip() + ] + + [ + str(row.get("subject", "") or "").strip() + for row in relation_rows + if str(row.get("subject", "") or "").strip() + ] + + [ + str(row.get("object", "") or "").strip() + for row in relation_rows + if str(row.get("object", "") or "").strip() + ] + ) + ) + self.graph_store.clear() + if names: + self.graph_store.add_nodes(names) + if relation_rows: + self.graph_store.add_edges( + [ + ( + str(row.get("subject", "") or "").strip(), + str(row.get("object", "") or "").strip(), + ) + for row in relation_rows + ], + weights=[float(row.get("confidence", 1.0) or 1.0) for row in relation_rows], + relation_hashes=[str(row.get("hash", "") or "") for row in relation_rows], + ) + return {"node_count": int(self.graph_store.num_nodes), "edge_count": int(self.graph_store.num_edges)} + + def _rename_node(self, old_name: str, new_name: str) -> Dict[str, Any]: + assert self.metadata_store + source = str(old_name or "").strip() + target = str(new_name or "").strip() + if not source or not target: + return {"success": False, "error": "old_name/new_name 不能为空"} + if source == target: + return {"success": True, "renamed": False, "old_name": source, "new_name": target} + + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + old_hash = compute_hash(source.lower()) + target_hash = compute_hash(target.lower()) + + cursor.execute( + """ + SELECT hash, name, vector_index, appearance_count, created_at, metadata + FROM entities + WHERE hash = ? + OR LOWER(TRIM(name)) = LOWER(TRIM(?)) + LIMIT 1 + """, + (old_hash, source), + ) + old_row = cursor.fetchone() + if old_row is None: + return {"success": False, "error": "原节点不存在"} + + cursor.execute( + """ + SELECT hash, appearance_count + FROM entities + WHERE hash = ? + OR LOWER(TRIM(name)) = LOWER(TRIM(?)) + LIMIT 1 + """, + (target_hash, target), + ) + target_row = cursor.fetchone() + + try: + cursor.execute("BEGIN IMMEDIATE") + if target_row is None: + cursor.execute( + """ + INSERT INTO entities (hash, name, vector_index, appearance_count, created_at, metadata, is_deleted, deleted_at) + VALUES (?, ?, ?, ?, ?, ?, 0, NULL) + """, + ( + target_hash, + target, + old_row["vector_index"], + old_row["appearance_count"], + old_row["created_at"], + old_row["metadata"], + ), + ) + resolved_target_hash = target_hash + else: + resolved_target_hash = str(target_row["hash"] or "").strip() + cursor.execute( + """ + UPDATE entities + SET name = ?, + appearance_count = COALESCE(appearance_count, 0) + ?, + is_deleted = 0, + deleted_at = NULL + WHERE hash = ? + """, + ( + target, + int(old_row["appearance_count"] or 0), + resolved_target_hash, + ), + ) + + cursor.execute( + "UPDATE OR IGNORE paragraph_entities SET entity_hash = ? WHERE entity_hash = ?", + (resolved_target_hash, old_row["hash"]), + ) + cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (old_row["hash"],)) + cursor.execute( + "UPDATE relations SET subject = ? WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?))", + (target, old_row["name"]), + ) + cursor.execute( + "UPDATE relations SET object = ? WHERE LOWER(TRIM(object)) = LOWER(TRIM(?))", + (target, old_row["name"]), + ) + cursor.execute("DELETE FROM entities WHERE hash = ?", (old_row["hash"],)) + conn.commit() + except Exception as exc: + conn.rollback() + return {"success": False, "error": f"rename failed: {exc}"} + + self._rebuild_graph_from_metadata() + self._persist() + return {"success": True, "renamed": True, "old_name": source, "new_name": target} + + def _update_edge_weight( + self, + *, + relation_hash: str, + subject: str, + obj: str, + weight: float, + ) -> Dict[str, Any]: + assert self.metadata_store + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + target_weight = max(0.0, float(weight or 0.0)) + if relation_hash: + cursor.execute("UPDATE relations SET confidence = ? WHERE hash = ?", (target_weight, relation_hash)) + updated = cursor.rowcount + else: + cursor.execute( + """ + UPDATE relations + SET confidence = ? + WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) + AND LOWER(TRIM(object)) = LOWER(TRIM(?)) + """, + (target_weight, subject, obj), + ) + updated = cursor.rowcount + conn.commit() + if updated <= 0: + return {"success": False, "error": "未找到可更新的关系"} + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "updated": int(updated), + "weight": target_weight, + "hash": relation_hash, + "subject": subject, + "object": obj, + } + @staticmethod def _tokens(values: Optional[Iterable[Any]]) -> List[str]: result: List[str] = [] @@ -469,6 +2109,15 @@ class SDKMemoryKernel: clean = str(chat_id or "").strip() return f"chat_summary:{clean}" if clean else None + @staticmethod + def _resolve_knowledge_type(source_type: str) -> str: + clean_type = str(source_type or "").strip().lower() + if clean_type == "person_fact": + return "factual" + if clean_type == "chat_summary": + return "narrative" + return "mixed" + @staticmethod def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: payload: Dict[str, Any] = {} @@ -483,19 +2132,70 @@ class SDKMemoryKernel: payload["time_confidence"] = 0.95 return payload - def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]: - if request.time_start is None and request.time_end is None and not request.chat_id: - return None - return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + @classmethod + def _normalize_search_time_bound(cls, value: Any, *, is_end: bool) -> tuple[Optional[float], Optional[str]]: + if value in {None, ""}: + return None, None + if isinstance(value, (int, float)): + ts = float(value) + return ts, format_timestamp(ts) + + text = str(value or "").strip() + if not text: + return None, None + + numeric = cls._optional_float(text) + if numeric is not None: + return numeric, format_timestamp(numeric) + + try: + ts = parse_query_datetime_to_timestamp(text, is_end=is_end) + except ValueError as exc: + raise ValueError(f"时间参数错误: {exc}") from exc + return ts, text + + @classmethod + def _normalize_search_time_window(cls, time_start: Any, time_end: Any) -> _NormalizedSearchTimeWindow: + numeric_start, query_start = cls._normalize_search_time_bound(time_start, is_end=False) + numeric_end, query_end = cls._normalize_search_time_bound(time_end, is_end=True) + if numeric_start is not None and numeric_end is not None and numeric_start > numeric_end: + raise ValueError("时间参数错误: time_start 不能晚于 time_end") + return _NormalizedSearchTimeWindow( + numeric_start=numeric_start, + numeric_end=numeric_end, + query_start=query_start, + query_end=query_end, + ) @staticmethod - def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]: + def _retrieval_result_hit(item: RetrievalResult) -> Dict[str, Any]: payload = item.to_dict() - return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}} + return { + "hash": payload.get("hash", ""), + "content": payload.get("content", ""), + "score": payload.get("score", 0.0), + "type": payload.get("type", ""), + "source": payload.get("source", ""), + "metadata": payload.get("metadata", {}) or {}, + } @staticmethod def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: - return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}} + return { + "type": "episode", + "episode_id": str(row.get("episode_id", "") or ""), + "title": str(row.get("title", "") or ""), + "content": str(row.get("summary", "") or ""), + "score": float(row.get("lexical_score", 0.0) or 0.0), + "source": "episode", + "metadata": { + "participants": row.get("participants", []) or [], + "keywords": row.get("keywords", []) or [], + "source": row.get("source"), + "event_time_start": row.get("event_time_start"), + "event_time_end": row.get("event_time_end"), + }, + } @staticmethod def _summary(hits: Sequence[Dict[str, Any]]) -> str: @@ -521,51 +2221,6 @@ class SDKMemoryKernel: filtered.append(item) return filtered or hits - @staticmethod - def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]: - seen = set() - result: List[str] = [] - for row in rows: - meta = row.get("metadata", {}) or {} - for key in ("participants", "person_ids"): - for item in meta.get(key, []) or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result[:16] - - @staticmethod - def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]: - seen = set() - result: List[str] = [] - for row in rows: - meta = row.get("metadata", {}) or {} - for item in meta.get("tags", []) or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result[:12] - - @staticmethod - def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]: - values: List[float] = [] - for row in rows: - for key in (primary, fallback): - value = row.get(key) - try: - if value is not None: - values.append(float(value)) - break - except Exception: - continue - if not values: - return None - return max(values) if reverse else min(values) - def _resolve_relation_hashes(self, target: str) -> List[str]: assert self.metadata_store token = str(target or "").strip() @@ -576,4 +2231,896 @@ class SDKMemoryKernel: hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) if hashes: return hashes - return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()] + return [ + str(row.get("hash", "") or "") + for row in self.metadata_store.get_relations(subject=token)[:10] + if str(row.get("hash", "")).strip() + ] + + def _resolve_deleted_relation_hashes(self, target: str) -> List[str]: + assert self.metadata_store + token = str(target or "").strip() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): + return [token] + return self.metadata_store.search_deleted_relation_hashes_by_text(token, limit=10) + + def _memory_v5_status(self, *, target: str = "", limit: int = 50) -> Dict[str, Any]: + assert self.metadata_store + now = time.time() + summary = self.metadata_store.get_memory_status_summary(now) + payload: Dict[str, Any] = { + "success": True, + **summary, + "config": { + "half_life_hours": float(self._cfg("memory.half_life_hours", 24.0) or 24.0), + "base_decay_interval_hours": float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0), + "prune_threshold": float(self._cfg("memory.prune_threshold", 0.1) or 0.1), + "freeze_duration_hours": float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0), + }, + "last_maintenance_at": self._last_maintenance_at, + } + token = str(target or "").strip() + if not token: + return payload + + active_hashes = self._resolve_relation_hashes(token)[:limit] + deleted_hashes = self._resolve_deleted_relation_hashes(token)[:limit] + active_statuses = self.metadata_store.get_relation_status_batch(active_hashes) + items: List[Dict[str, Any]] = [] + for hash_value in active_hashes: + relation = self.metadata_store.get_relation(hash_value) or {} + status = active_statuses.get(hash_value, {}) + items.append( + { + "hash": hash_value, + "subject": str(relation.get("subject", "") or ""), + "predicate": str(relation.get("predicate", "") or ""), + "object": str(relation.get("object", "") or ""), + "state": "inactive" if bool(status.get("is_inactive")) else "active", + "is_pinned": bool(status.get("is_pinned", False)), + "temp_protected": bool(float(status.get("protected_until") or 0.0) > now), + "protected_until": status.get("protected_until"), + "last_reinforced": status.get("last_reinforced"), + "weight": float(status.get("weight", relation.get("confidence", 0.0)) or 0.0), + } + ) + for hash_value in deleted_hashes: + relation = self.metadata_store.get_deleted_relation(hash_value) or {} + items.append( + { + "hash": hash_value, + "subject": str(relation.get("subject", "") or ""), + "predicate": str(relation.get("predicate", "") or ""), + "object": str(relation.get("object", "") or ""), + "state": "deleted", + "is_pinned": bool(relation.get("is_pinned", False)), + "temp_protected": False, + "protected_until": relation.get("protected_until"), + "last_reinforced": relation.get("last_reinforced"), + "weight": float(relation.get("confidence", 0.0) or 0.0), + "deleted_at": relation.get("deleted_at"), + } + ) + payload["items"] = items[:limit] + payload["count"] = len(payload["items"]) + payload["target"] = token + return payload + + def _adjust_relation_confidence(self, hashes: List[str], *, delta: float) -> Dict[str, float]: + assert self.metadata_store + normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] + if not normalized: + return {} + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + chunk_size = 200 + for index in range(0, len(normalized), chunk_size): + chunk = normalized[index : index + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE relations + SET confidence = MAX(0.0, COALESCE(confidence, 0.0) + ?) + WHERE hash IN ({placeholders}) + """, + tuple([float(delta)] + chunk), + ) + conn.commit() + statuses = self.metadata_store.get_relation_status_batch(normalized) + return {hash_value: float((statuses.get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} + + def _apply_v5_relation_action(self, *, action: str, hashes: List[str], strength: float = 1.0) -> Dict[str, Any]: + assert self.metadata_store + act = str(action or "").strip().lower() + normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] + if not normalized: + return {"success": False, "error": "未命中可维护关系"} + + now = time.time() + strength_value = max(0.1, float(strength or 1.0)) + prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) + detail = "" + + if act == "reinforce": + weights = self._adjust_relation_confidence(normalized, delta=0.5 * strength_value) + protect_hours = max(1.0, 24.0 * strength_value) + self.metadata_store.reinforce_relations(normalized) + self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) + self.metadata_store.update_relations_protection( + normalized, + protected_until=now + protect_hours * 3600.0, + last_reinforced=now, + ) + detail = f"reinforce {len(normalized)} 条关系" + elif act == "weaken": + weights = self._adjust_relation_confidence(normalized, delta=-0.5 * strength_value) + to_freeze = [hash_value for hash_value, weight in weights.items() if weight <= prune_threshold] + if to_freeze: + self.metadata_store.mark_relations_inactive(to_freeze, inactive_since=now) + detail = f"weaken {len(normalized)} 条关系" + elif act == "remember_forever": + self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) + self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=True) + weights = {hash_value: float((self.metadata_store.get_relation_status_batch([hash_value]).get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} + detail = f"remember_forever {len(normalized)} 条关系" + elif act == "forget": + weights = self._adjust_relation_confidence(normalized, delta=-2.0 * strength_value) + self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=False) + self.metadata_store.mark_relations_inactive(normalized, inactive_since=now) + detail = f"forget {len(normalized)} 条关系" + else: + return {"success": False, "error": f"不支持的 V5 动作: {act}"} + + self._rebuild_graph_from_metadata() + self._last_maintenance_at = now + self._persist() + statuses = self.metadata_store.get_relation_status_batch(normalized) + return { + "success": True, + "detail": detail, + "hashes": normalized, + "count": len(normalized), + "weights": weights, + "statuses": statuses, + } + + async def _ensure_vector_for_text(self, *, item_hash: str, text: str) -> bool: + if self.vector_store is None or self.embedding_manager is None: + return False + token = str(item_hash or "").strip() + content = str(text or "").strip() + if not token or not content: + return False + embedding = await self.embedding_manager.encode([content], dimensions=self.embedding_dimension) + if getattr(embedding, "ndim", 1) == 1: + embedding = embedding.reshape(1, -1) + if getattr(embedding, "size", 0) <= 0: + return False + try: + self.vector_store.add(embedding, [token]) + return True + except Exception as exc: + logger.warning(f"重建向量失败: {exc}") + return False + + async def _ensure_relation_vector(self, relation: Dict[str, Any]) -> bool: + if not bool(self.relation_vectors_enabled): + return False + return await self._ensure_vector_for_text( + item_hash=str(relation.get("hash", "") or ""), + text=" ".join( + [ + str(relation.get("subject", "") or "").strip(), + str(relation.get("predicate", "") or "").strip(), + str(relation.get("object", "") or "").strip(), + ] + ).strip(), + ) + + async def _ensure_paragraph_vector(self, paragraph: Dict[str, Any]) -> bool: + return await self._ensure_vector_for_text( + item_hash=str(paragraph.get("hash", "") or ""), + text=str(paragraph.get("content", "") or ""), + ) + + async def _ensure_entity_vector(self, entity: Dict[str, Any]) -> bool: + return await self._ensure_vector_for_text( + item_hash=str(entity.get("hash", "") or ""), + text=str(entity.get("name", "") or ""), + ) + + async def _restore_relation_hashes( + self, + hashes: List[str], + *, + payloads: Optional[Dict[str, Dict[str, Any]]] = None, + rebuild_graph: bool = True, + persist: bool = True, + ) -> Dict[str, Any]: + assert self.metadata_store + restored: List[str] = [] + failures: List[Dict[str, str]] = [] + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + payload_map = payloads or {} + for hash_value in [str(item or "").strip() for item in hashes if str(item or "").strip()]: + relation = self.metadata_store.restore_relation(hash_value) + if relation is None: + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + failures.append({"hash": hash_value, "error": "relation 不存在"}) + continue + payload = payload_map.get(hash_value) if isinstance(payload_map.get(hash_value), dict) else {} + paragraph_hashes = self._tokens(payload.get("paragraph_hashes")) + for paragraph_hash in paragraph_hashes: + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + (paragraph_hash, hash_value), + ) + await self._ensure_relation_vector({**relation, "hash": hash_value}) + restored.append(hash_value) + conn.commit() + if restored and rebuild_graph: + self._rebuild_graph_from_metadata() + if restored and persist: + self._persist() + return {"restored_hashes": restored, "restored_count": len(restored), "failures": failures} + + @staticmethod + def _selector_dict(selector: Any) -> Dict[str, Any]: + if isinstance(selector, dict): + return dict(selector) + if isinstance(selector, (list, tuple)): + return {"items": list(selector)} + token = str(selector or "").strip() + return {"query": token} if token else {} + + def _resolve_paragraph_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: + assert self.metadata_store + raw = self._selector_dict(selector) + rows: List[Dict[str, Any]] = [] + hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) + for hash_value in hashes: + row = self.metadata_store.get_paragraph(hash_value) + if row is None: + continue + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(row) + if rows: + return rows + query = str(raw.get("query", "") or raw.get("content", "") or "").strip() + if not query: + return [] + if len(query) == 64 and all(ch in "0123456789abcdef" for ch in query.lower()): + row = self.metadata_store.get_paragraph(query) + if row is None: + return [] + if not include_deleted and bool(row.get("is_deleted", 0)): + return [] + return [row] + matches = self.metadata_store.search_paragraphs_by_content(query) + return [row for row in matches if include_deleted or not bool(row.get("is_deleted", 0))] + + def _resolve_entity_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: + assert self.metadata_store + raw = self._selector_dict(selector) + rows: List[Dict[str, Any]] = [] + hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) + for hash_value in hashes: + row = self.metadata_store.get_entity(hash_value) + if row is None: + continue + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(row) + names = self._merge_tokens(raw.get("names"), [raw.get("name")], [raw.get("query")]) + for name in names: + if not name: + continue + matches = self.metadata_store.query( + """ + SELECT * + FROM entities + WHERE LOWER(TRIM(name)) = LOWER(TRIM(?)) + OR hash = ? + ORDER BY appearance_count DESC, created_at ASC + """, + (name, compute_hash(str(name).strip().lower())), + ) + for row in matches: + if not include_deleted and bool(row.get("is_deleted", 0)): + continue + rows.append(self.metadata_store._row_to_dict(row, "entity") if hasattr(self.metadata_store, "_row_to_dict") else row) + dedup: Dict[str, Dict[str, Any]] = {} + for row in rows: + token = str(row.get("hash", "") or "").strip() + if token and token not in dedup: + dedup[token] = row + return list(dedup.values()) + + def _resolve_source_targets(self, selector: Any) -> List[str]: + raw = self._selector_dict(selector) + return self._merge_tokens(raw.get("sources"), [raw.get("source")], [raw.get("query")], raw.get("items")) + + def _snapshot_relation_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + relation = self.metadata_store.get_deleted_relation(hash_value) + if relation is None: + return None + paragraph_hashes = [ + str(row.get("paragraph_hash", "") or "").strip() + for row in self.metadata_store.query( + "SELECT paragraph_hash FROM paragraph_relations WHERE relation_hash = ? ORDER BY paragraph_hash ASC", + (hash_value,), + ) + if str(row.get("paragraph_hash", "") or "").strip() + ] + return { + "item_type": "relation", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "relation": relation, + "paragraph_hashes": paragraph_hashes, + }, + } + + def _snapshot_paragraph_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + return None + entity_links = [ + { + "paragraph_hash": hash_value, + "entity_hash": str(row.get("entity_hash", "") or ""), + "mention_count": int(row.get("mention_count", 1) or 1), + } + for row in self.metadata_store.query( + """ + SELECT paragraph_hash, entity_hash, mention_count + FROM paragraph_entities + WHERE paragraph_hash = ? + ORDER BY entity_hash ASC + """, + (hash_value,), + ) + ] + relation_hashes = [ + str(row.get("relation_hash", "") or "").strip() + for row in self.metadata_store.query( + """ + SELECT relation_hash + FROM paragraph_relations + WHERE paragraph_hash = ? + ORDER BY relation_hash ASC + """, + (hash_value,), + ) + if str(row.get("relation_hash", "") or "").strip() + ] + return { + "item_type": "paragraph", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "paragraph": paragraph, + "entity_links": entity_links, + "relation_hashes": relation_hashes, + "external_refs": self.metadata_store.list_external_memory_refs_by_paragraphs([hash_value]), + }, + } + + def _snapshot_entity_item(self, hash_value: str) -> Optional[Dict[str, Any]]: + assert self.metadata_store + entity = self.metadata_store.get_entity(hash_value) + if entity is None: + return None + paragraph_links = [ + { + "paragraph_hash": str(row.get("paragraph_hash", "") or ""), + "entity_hash": hash_value, + "mention_count": int(row.get("mention_count", 1) or 1), + } + for row in self.metadata_store.query( + """ + SELECT paragraph_hash, mention_count + FROM paragraph_entities + WHERE entity_hash = ? + ORDER BY paragraph_hash ASC + """, + (hash_value,), + ) + ] + return { + "item_type": "entity", + "item_hash": hash_value, + "item_key": hash_value, + "payload": { + "entity": entity, + "paragraph_links": paragraph_links, + }, + } + + def _relation_has_remaining_paragraphs(self, relation_hash: str, removing_hashes: Sequence[str]) -> bool: + assert self.metadata_store + excluded = [str(item or "").strip() for item in removing_hashes if str(item or "").strip()] + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + if excluded: + placeholders = ",".join(["?"] * len(excluded)) + cursor.execute( + f""" + SELECT 1 + FROM paragraph_relations pr + JOIN paragraphs p ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + AND pr.paragraph_hash NOT IN ({placeholders}) + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + LIMIT 1 + """, + tuple([relation_hash] + excluded), + ) + else: + cursor.execute( + """ + SELECT 1 + FROM paragraph_relations pr + JOIN paragraphs p ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + LIMIT 1 + """, + (relation_hash,), + ) + return cursor.fetchone() is not None + + async def _build_delete_plan(self, *, mode: str, selector: Any) -> Dict[str, Any]: + assert self.metadata_store + act_mode = str(mode or "").strip().lower() + normalized_selector = self._selector_dict(selector) + items: List[Dict[str, Any]] = [] + counts = {"relations": 0, "paragraphs": 0, "entities": 0, "sources": 0} + vector_ids: List[str] = [] + sources: List[str] = [] + target_hashes: Dict[str, List[str]] = {"relations": [], "paragraphs": [], "entities": [], "sources": []} + + if act_mode == "relation": + relation_rows = [row for row in (self.metadata_store.get_relation(hash_value) for hash_value in self._resolve_relation_hashes(str(normalized_selector.get("query", "") or ""))) if row] + if normalized_selector.get("hashes"): + relation_rows = [ + row + for hash_value in self._tokens(normalized_selector.get("hashes")) + for row in [self.metadata_store.get_relation(hash_value)] + if row is not None + ] + dedup_hashes: List[str] = [] + seen = set() + for row in relation_rows: + hash_value = str(row.get("hash", "") or "").strip() + if hash_value and hash_value not in seen: + seen.add(hash_value) + dedup_hashes.append(hash_value) + snap = self._snapshot_relation_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + counts["relations"] = len(dedup_hashes) + target_hashes["relations"] = dedup_hashes + + elif act_mode in {"paragraph", "source"}: + paragraph_rows: List[Dict[str, Any]] = [] + if act_mode == "source": + source_tokens = self._resolve_source_targets(normalized_selector) + target_hashes["sources"] = source_tokens + counts["sources"] = len(source_tokens) + for source in source_tokens: + sources.append(source) + paragraph_rows.extend( + self.metadata_store.query( + """ + SELECT * + FROM paragraphs + WHERE source = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY created_at ASC + """, + (source,), + ) + ) + else: + paragraph_rows = self._resolve_paragraph_targets(normalized_selector, include_deleted=False) + paragraph_hashes = self._tokens([row.get("hash", "") for row in paragraph_rows]) + target_hashes["paragraphs"] = paragraph_hashes + counts["paragraphs"] = len(paragraph_hashes) + for hash_value in paragraph_hashes: + snap = self._snapshot_paragraph_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + paragraph = snap["payload"].get("paragraph") or {} + source = str(paragraph.get("source", "") or "").strip() + if source: + sources.append(source) + + orphan_relations: List[str] = [] + for item in items: + if item.get("item_type") != "paragraph": + continue + for relation_hash in self._tokens((item.get("payload") or {}).get("relation_hashes")): + if relation_hash in orphan_relations: + continue + if not self._relation_has_remaining_paragraphs(relation_hash, paragraph_hashes): + orphan_relations.append(relation_hash) + for relation_hash in orphan_relations: + snap = self._snapshot_relation_item(relation_hash) + if snap: + items.append(snap) + vector_ids.append(relation_hash) + target_hashes["relations"] = orphan_relations + counts["relations"] = len(orphan_relations) + + elif act_mode == "entity": + entity_rows = self._resolve_entity_targets(normalized_selector, include_deleted=False) + entity_hashes = self._tokens([row.get("hash", "") for row in entity_rows]) + target_hashes["entities"] = entity_hashes + counts["entities"] = len(entity_hashes) + entity_names = [str(row.get("name", "") or "").strip() for row in entity_rows if str(row.get("name", "") or "").strip()] + for hash_value in entity_hashes: + snap = self._snapshot_entity_item(hash_value) + if snap: + items.append(snap) + vector_ids.append(hash_value) + relation_hashes: List[str] = [] + for entity_name in entity_names: + for relation in self.metadata_store.get_relations(subject=entity_name) + self.metadata_store.get_relations(object=entity_name): + hash_value = str(relation.get("hash", "") or "").strip() + if hash_value and hash_value not in relation_hashes: + relation_hashes.append(hash_value) + for relation_hash in relation_hashes: + snap = self._snapshot_relation_item(relation_hash) + if snap: + items.append(snap) + vector_ids.append(relation_hash) + target_hashes["relations"] = relation_hashes + counts["relations"] = len(relation_hashes) + else: + return {"success": False, "error": f"不支持的 delete mode: {act_mode}"} + + sources = self._tokens(sources) + vector_ids = self._tokens(vector_ids) + primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("sources", 0) + return { + "success": primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0, + "mode": act_mode, + "selector": normalized_selector, + "items": items, + "counts": counts, + "vector_ids": vector_ids, + "sources": sources, + "target_hashes": target_hashes, + "error": "" if (primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0) else "未命中可删除内容", + } + + async def _preview_delete_action(self, *, mode: str, selector: Any) -> Dict[str, Any]: + plan = await self._build_delete_plan(mode=mode, selector=selector) + if not plan.get("success", False): + return {"success": False, "error": plan.get("error", "未命中可删除内容")} + preview_items = [ + { + "item_type": str(item.get("item_type", "") or ""), + "item_hash": str(item.get("item_hash", "") or ""), + } + for item in plan.get("items", [])[:100] + ] + return { + "success": True, + "mode": plan.get("mode"), + "selector": plan.get("selector"), + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "vector_ids": plan.get("vector_ids", []), + "items": preview_items, + "item_count": len(plan.get("items", [])), + "dry_run": True, + } + + async def _execute_delete_action( + self, + *, + mode: str, + selector: Any, + requested_by: str = "", + reason: str = "", + ) -> Dict[str, Any]: + assert self.metadata_store + plan = await self._build_delete_plan(mode=mode, selector=selector) + if not plan.get("success", False): + return {"success": False, "error": plan.get("error", "未命中可删除内容")} + + act_mode = str(plan.get("mode", "") or "").strip().lower() + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + paragraph_hashes = self._tokens((plan.get("target_hashes") or {}).get("paragraphs")) + entity_hashes = self._tokens((plan.get("target_hashes") or {}).get("entities")) + relation_hashes = self._tokens((plan.get("target_hashes") or {}).get("relations")) + source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) + + try: + if paragraph_hashes: + self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") + cursor.execute( + f"DELETE FROM paragraph_entities WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", + tuple(paragraph_hashes), + ) + cursor.execute( + f"DELETE FROM paragraph_relations WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", + tuple(paragraph_hashes), + ) + self.metadata_store.delete_external_memory_refs_by_paragraphs(paragraph_hashes) + if act_mode == "source" and source_tokens: + for source in source_tokens: + self.metadata_store.replace_episodes_for_source(source, []) + + if entity_hashes: + self.metadata_store.mark_as_deleted(entity_hashes, "entity") + cursor.execute( + f"DELETE FROM paragraph_entities WHERE entity_hash IN ({','.join(['?'] * len(entity_hashes))})", + tuple(entity_hashes), + ) + + conn.commit() + + deleted_relations = self.metadata_store.backup_and_delete_relations(relation_hashes) + deleted_vectors = 0 + if self.vector_store is not None and plan.get("vector_ids"): + deleted_vectors = self.vector_store.delete(list(plan.get("vector_ids") or [])) + + operation = self.metadata_store.create_delete_operation( + mode=act_mode, + selector=plan.get("selector"), + items=plan.get("items", []), + reason=reason, + requested_by=requested_by, + summary={ + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "vector_ids": plan.get("vector_ids", []), + "deleted_relation_rows": deleted_relations, + }, + ) + + if plan.get("sources"): + self.metadata_store._enqueue_episode_source_rebuilds(list(plan.get("sources") or []), reason="delete_admin_execute") + self._rebuild_graph_from_metadata() + self._persist() + deleted_count = ( + len(source_tokens) + if act_mode == "source" + else len(paragraph_hashes) + if act_mode == "paragraph" + else len(entity_hashes) + if act_mode == "entity" + else len(relation_hashes) + ) + result = { + "success": True, + "mode": act_mode, + "operation_id": operation.get("operation_id", ""), + "counts": plan.get("counts", {}), + "sources": plan.get("sources", []), + "deleted_count": deleted_count, + "deleted_vector_count": int(deleted_vectors or 0), + "deleted_relation_count": len(relation_hashes), + } + if act_mode == "source": + result["deleted_source_count"] = len(source_tokens) + result["deleted_paragraph_count"] = len(paragraph_hashes) + return result + except Exception as exc: + conn.rollback() + logger.warning(f"delete_admin execute 失败: {exc}") + return {"success": False, "error": str(exc)} + + async def _restore_delete_action( + self, + *, + mode: str, + selector: Any, + operation_id: str = "", + requested_by: str = "", + reason: str = "", + ) -> Dict[str, Any]: + del requested_by + del reason + assert self.metadata_store + + op_id = str(operation_id or "").strip() + if op_id: + operation = self.metadata_store.get_delete_operation(op_id) + if operation is None: + return {"success": False, "error": "operation 不存在"} + return await self._restore_delete_operation(operation) + + act_mode = str(mode or "").strip().lower() + if act_mode != "relation": + return {"success": False, "error": "paragraph/entity/source 恢复必须提供 operation_id"} + + raw = self._selector_dict(selector) + target = str(raw.get("query", "") or raw.get("target", "") or raw.get("hash", "") or "").strip() + hashes = self._resolve_deleted_relation_hashes(target) + if not hashes: + return {"success": False, "error": "未命中可恢复关系"} + result = await self._restore_relation_hashes(hashes) + return {"success": bool(result.get("restored_count", 0) > 0), **result} + + async def _restore_delete_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: + assert self.metadata_store + items = operation.get("items") if isinstance(operation.get("items"), list) else [] + entity_payloads: Dict[str, Dict[str, Any]] = {} + paragraph_payloads: Dict[str, Dict[str, Any]] = {} + relation_payloads: Dict[str, Dict[str, Any]] = {} + for item in items: + if not isinstance(item, dict): + continue + item_type = str(item.get("item_type", "") or "").strip() + item_hash = str(item.get("item_hash", "") or "").strip() + payload = item.get("payload") if isinstance(item.get("payload"), dict) else {} + if item_type == "entity" and item_hash: + entity_payloads[item_hash] = payload + elif item_type == "paragraph" and item_hash: + paragraph_payloads[item_hash] = payload + elif item_type == "relation" and item_hash: + relation_payloads[item_hash] = payload + + restored_entities: List[str] = [] + restored_paragraphs: List[str] = [] + for hash_value, payload in entity_payloads.items(): + entity_row = payload.get("entity") if isinstance(payload.get("entity"), dict) else {} + if entity_row: + self.metadata_store.restore_entity_by_hash(hash_value) + await self._ensure_entity_vector(entity_row) + restored_entities.append(hash_value) + for hash_value, payload in paragraph_payloads.items(): + paragraph_row = payload.get("paragraph") if isinstance(payload.get("paragraph"), dict) else {} + if paragraph_row: + self.metadata_store.restore_paragraph_by_hash(hash_value) + await self._ensure_paragraph_vector(paragraph_row) + restored_paragraphs.append(hash_value) + + restored_relations = await self._restore_relation_hashes(list(relation_payloads.keys()), payloads=relation_payloads, rebuild_graph=False, persist=False) + + conn = self.metadata_store.get_connection() + cursor = conn.cursor() + for payload in entity_payloads.values(): + for link in payload.get("paragraph_links") or []: + paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() + entity_hash = str(link.get("entity_hash", "") or "").strip() + mention_count = max(1, int(link.get("mention_count", 1) or 1)) + if not paragraph_hash or not entity_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + (paragraph_hash, entity_hash, mention_count), + ) + for payload in paragraph_payloads.values(): + for link in payload.get("entity_links") or []: + paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() + entity_hash = str(link.get("entity_hash", "") or "").strip() + mention_count = max(1, int(link.get("mention_count", 1) or 1)) + if not paragraph_hash or not entity_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + (paragraph_hash, entity_hash, mention_count), + ) + for relation_hash in self._tokens(payload.get("relation_hashes")): + paragraph_hash = str((payload.get("paragraph") or {}).get("hash", "") or "").strip() + if not paragraph_hash or not relation_hash: + continue + cursor.execute( + """ + INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + (paragraph_hash, relation_hash), + ) + self.metadata_store.restore_external_memory_refs(list(payload.get("external_refs") or [])) + conn.commit() + + sources = self._tokens( + [ + str(((payload.get("paragraph") or {}).get("source", "") or "")).strip() + for payload in paragraph_payloads.values() + ] + ) + if sources: + self.metadata_store._enqueue_episode_source_rebuilds(sources, reason="delete_admin_restore") + self._rebuild_graph_from_metadata() + self._persist() + summary = { + "restored_entities": restored_entities, + "restored_paragraphs": restored_paragraphs, + "restored_relations": restored_relations.get("restored_hashes", []), + "sources": sources, + } + self.metadata_store.mark_delete_operation_restored(str(operation.get("operation_id", "") or ""), summary=summary) + return { + "success": True, + "operation_id": str(operation.get("operation_id", "") or ""), + **summary, + "restored_relation_count": restored_relations.get("restored_count", 0), + "relation_failures": restored_relations.get("failures", []), + } + + async def _purge_deleted_memory(self, *, grace_hours: Optional[float], limit: int) -> Dict[str, Any]: + assert self.metadata_store + orphan_cfg = self._cfg("memory.orphan", {}) or {} + grace = float(grace_hours) if grace_hours is not None else max( + 1.0, + float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0), + ) + cutoff = time.time() - grace * 3600.0 + deleted_relation_hashes = self.metadata_store.purge_deleted_relations(cutoff_time=cutoff, limit=limit) + dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace * 3600.0) + paragraph_hashes = [str(item[0] or "").strip() for item in dead_paragraphs if str(item[0] or "").strip()] + dead_entities = self.metadata_store.sweep_deleted_items("entity", grace * 3600.0) + entity_hashes = [str(item[0] or "").strip() for item in dead_entities if str(item[0] or "").strip()] + entity_names = [str(item[1] or "").strip() for item in dead_entities if str(item[1] or "").strip()] + + if paragraph_hashes: + self.metadata_store.physically_delete_paragraphs(paragraph_hashes) + if entity_hashes: + self.metadata_store.physically_delete_entities(entity_hashes) + if entity_names: + self.graph_store.delete_nodes(entity_names) + if self.vector_store is not None: + vector_ids = self._merge_tokens(paragraph_hashes, entity_hashes, deleted_relation_hashes) + if vector_ids: + self.vector_store.delete(vector_ids) + self._rebuild_graph_from_metadata() + self._persist() + return { + "success": True, + "grace_hours": grace, + "purged_deleted_relations": deleted_relation_hashes, + "purged_paragraph_hashes": paragraph_hashes, + "purged_entity_hashes": entity_hashes, + "purged_counts": { + "relations": len(deleted_relation_hashes), + "paragraphs": len(paragraph_hashes), + "entities": len(entity_hashes), + }, + } + + @staticmethod + def _optional_float(value: Any) -> Optional[float]: + if value in {None, ""}: + return None + try: + return float(value) + except Exception: + return None + + @staticmethod + def _optional_int(value: Any) -> Optional[int]: + if value in {None, ""}: + return None + try: + return int(value) + except Exception: + return None diff --git a/plugins/A_memorix/core/runtime/search_runtime_initializer.py b/plugins/A_memorix/core/runtime/search_runtime_initializer.py new file mode 100644 index 00000000..c3c7a81f --- /dev/null +++ b/plugins/A_memorix/core/runtime/search_runtime_initializer.py @@ -0,0 +1,240 @@ +"""Shared runtime initializer for Action/Tool/Command retrieval components.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from src.common.logger import get_logger + +from ..retrieval import ( + DualPathRetriever, + DualPathRetrieverConfig, + DynamicThresholdFilter, + FusionConfig, + GraphRelationRecallConfig, + RelationIntentConfig, + RetrievalStrategy, + SparseBM25Config, + ThresholdConfig, + ThresholdMethod, +) + +_logger = get_logger("A_Memorix.SearchRuntimeInitializer") + +_REQUIRED_COMPONENT_KEYS = ( + "vector_store", + "graph_store", + "metadata_store", + "embedding_manager", +) + + +def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: + if not isinstance(config, dict): + return default + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _safe_dict(value: Any) -> Dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool: + advanced = _get_config_value(plugin_config, "advanced", {}) + if isinstance(advanced, dict): + return bool(advanced.get("debug", False)) + return bool(_get_config_value(plugin_config, "debug", False)) + + +@dataclass +class SearchRuntimeBundle: + """Resolved runtime components and initialized retriever/filter.""" + + vector_store: Optional[Any] = None + graph_store: Optional[Any] = None + metadata_store: Optional[Any] = None + embedding_manager: Optional[Any] = None + sparse_index: Optional[Any] = None + retriever: Optional[DualPathRetriever] = None + threshold_filter: Optional[DynamicThresholdFilter] = None + error: str = "" + + @property + def ready(self) -> bool: + return ( + self.retriever is not None + and self.vector_store is not None + and self.graph_store is not None + and self.metadata_store is not None + and self.embedding_manager is not None + ) + + +def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle: + bundle = SearchRuntimeBundle( + vector_store=_get_config_value(plugin_config, "vector_store"), + graph_store=_get_config_value(plugin_config, "graph_store"), + metadata_store=_get_config_value(plugin_config, "metadata_store"), + embedding_manager=_get_config_value(plugin_config, "embedding_manager"), + sparse_index=_get_config_value(plugin_config, "sparse_index"), + ) + + missing_required = any( + getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS + ) + if not missing_required: + return bundle + + try: + from ...plugin import AMemorixPlugin + + instances = AMemorixPlugin.get_storage_instances() + except Exception: + instances = {} + + if not isinstance(instances, dict) or not instances: + return bundle + + if bundle.vector_store is None: + bundle.vector_store = instances.get("vector_store") + if bundle.graph_store is None: + bundle.graph_store = instances.get("graph_store") + if bundle.metadata_store is None: + bundle.metadata_store = instances.get("metadata_store") + if bundle.embedding_manager is None: + bundle.embedding_manager = instances.get("embedding_manager") + if bundle.sparse_index is None: + bundle.sparse_index = instances.get("sparse_index") + return bundle + + +def build_search_runtime( + plugin_config: Optional[dict], + logger_obj: Optional[Any], + owner_tag: str, + *, + log_prefix: str = "", +) -> SearchRuntimeBundle: + """Build retriever + threshold filter with unified fallback/config parsing.""" + + log = logger_obj or _logger + owner = str(owner_tag or "runtime").strip().lower() or "runtime" + prefix = str(log_prefix or "").strip() + prefix_text = f"{prefix} " if prefix else "" + + runtime = _resolve_runtime_components(plugin_config) + if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS): + runtime.error = "存储组件未完全初始化" + log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能") + return runtime + + sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {}) + fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {}) + relation_intent_cfg_raw = _safe_dict( + _get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {} + ) + graph_recall_cfg_raw = _safe_dict( + _get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {} + ) + + try: + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}") + sparse_cfg = SparseBM25Config() + + try: + fusion_cfg = FusionConfig(**fusion_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}") + fusion_cfg = FusionConfig() + + try: + relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}") + relation_intent_cfg = RelationIntentConfig() + + try: + graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) + except Exception as e: + log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}") + graph_recall_cfg = GraphRelationRecallConfig() + + try: + config = DualPathRetrieverConfig( + top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20), + top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10), + top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10), + alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5), + enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True), + ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85), + ppr_timeout_seconds=_get_config_value( + plugin_config, "retrieval.ppr_timeout_seconds", 1.5 + ), + ppr_concurrency_limit=_get_config_value( + plugin_config, "retrieval.ppr_concurrency_limit", 4 + ), + enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True), + retrieval_strategy=RetrievalStrategy.DUAL_PATH, + debug=_resolve_debug_enabled(plugin_config), + sparse=sparse_cfg, + fusion=fusion_cfg, + relation_intent=relation_intent_cfg, + graph_recall=graph_recall_cfg, + ) + + runtime.retriever = DualPathRetriever( + vector_store=runtime.vector_store, + graph_store=runtime.graph_store, + metadata_store=runtime.metadata_store, + embedding_manager=runtime.embedding_manager, + sparse_index=runtime.sparse_index, + config=config, + ) + + threshold_config = ThresholdConfig( + method=ThresholdMethod.ADAPTIVE, + min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3), + max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95), + percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0), + std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5), + min_results=_get_config_value(plugin_config, "threshold.min_results", 3), + enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True), + ) + runtime.threshold_filter = DynamicThresholdFilter(threshold_config) + runtime.error = "" + log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成") + except Exception as e: + runtime.retriever = None + runtime.threshold_filter = None + runtime.error = str(e) + log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}") + + return runtime + + +class SearchRuntimeInitializer: + """Compatibility wrapper around the function style initializer.""" + + @staticmethod + def build_search_runtime( + plugin_config: Optional[dict], + logger_obj: Optional[Any], + owner_tag: str, + *, + log_prefix: str = "", + ) -> SearchRuntimeBundle: + return build_search_runtime( + plugin_config=plugin_config, + logger_obj=logger_obj, + owner_tag=owner_tag, + log_prefix=log_prefix, + ) diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py index 8a075864..0a5fd95d 100644 --- a/plugins/A_memorix/core/storage/graph_store.py +++ b/plugins/A_memorix/core/storage/graph_store.py @@ -24,6 +24,20 @@ try: from scipy.sparse.linalg import norm HAS_SCIPY = True except ImportError: + class _SparseMatrixPlaceholder: + pass + + def _scipy_missing(*args, **kwargs): + raise ImportError("SciPy 未安装,请安装: pip install scipy") + + csr_matrix = _SparseMatrixPlaceholder + csc_matrix = _SparseMatrixPlaceholder + lil_matrix = _SparseMatrixPlaceholder + triu = _scipy_missing + save_npz = _scipy_missing + load_npz = _scipy_missing + bmat = _scipy_missing + norm = _scipy_missing HAS_SCIPY = False import contextlib diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py index e94610f0..39f2701c 100644 --- a/plugins/A_memorix/core/storage/metadata_store.py +++ b/plugins/A_memorix/core/storage/metadata_store.py @@ -7,6 +7,8 @@ import sqlite3 import pickle import json +import uuid +import re from datetime import datetime from pathlib import Path from typing import Optional, Union, List, Dict, Any, Tuple @@ -24,7 +26,7 @@ from .knowledge_types import ( logger = get_logger("A_Memorix.MetadataStore") -SCHEMA_VERSION = 7 +SCHEMA_VERSION = 8 class MetadataStore: @@ -500,6 +502,63 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph ON external_memory_refs(paragraph_hash) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_v5_operations ( + operation_id TEXT PRIMARY KEY, + action TEXT NOT NULL, + target TEXT, + reason TEXT, + updated_by TEXT, + created_at REAL NOT NULL, + resolved_hashes_json TEXT, + result_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created + ON memory_v5_operations(created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operations ( + operation_id TEXT PRIMARY KEY, + mode TEXT NOT NULL, + selector TEXT, + reason TEXT, + requested_by TEXT, + status TEXT NOT NULL, + created_at REAL NOT NULL, + restored_at REAL, + summary_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_created + ON delete_operations(created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_mode + ON delete_operations(mode, created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operation_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_id TEXT NOT NULL, + item_type TEXT NOT NULL, + item_hash TEXT, + item_key TEXT, + payload_json TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation + ON delete_operation_items(operation_id, id ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash + ON delete_operation_items(item_hash) + """) # 新版 schema 包含完整字段,直接写入版本信息 cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) self._conn.commit() @@ -618,6 +677,63 @@ class MetadataStore: CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph ON external_memory_refs(paragraph_hash) """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS memory_v5_operations ( + operation_id TEXT PRIMARY KEY, + action TEXT NOT NULL, + target TEXT, + reason TEXT, + updated_by TEXT, + created_at REAL NOT NULL, + resolved_hashes_json TEXT, + result_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created + ON memory_v5_operations(created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operations ( + operation_id TEXT PRIMARY KEY, + mode TEXT NOT NULL, + selector TEXT, + reason TEXT, + requested_by TEXT, + status TEXT NOT NULL, + created_at REAL NOT NULL, + restored_at REAL, + summary_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_created + ON delete_operations(created_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operations_mode + ON delete_operations(mode, created_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS delete_operation_items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + operation_id TEXT NOT NULL, + item_type TEXT NOT NULL, + item_hash TEXT, + item_key TEXT, + payload_json TEXT, + created_at REAL NOT NULL, + FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation + ON delete_operation_items(operation_id, id ASC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash + ON delete_operation_items(item_hash) + """) # 检查paragraphs表是否有knowledge_type列 cursor.execute("PRAGMA table_info(paragraphs)") @@ -2595,6 +2711,328 @@ class MetadataStore: "metadata": metadata or {}, } + @staticmethod + def _json_dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + @staticmethod + def _json_loads(value: Any, default: Any) -> Any: + if value in {None, ""}: + return default + try: + return json.loads(value) + except Exception: + return default + + def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: + hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] + if not hashes: + return [] + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT external_id, paragraph_hash, source_type, created_at, metadata_json + FROM external_memory_refs + WHERE paragraph_hash IN ({placeholders}) + ORDER BY created_at ASC, external_id ASC + """, + tuple(hashes), + ) + items: List[Dict[str, Any]] = [] + for row in cursor.fetchall(): + payload = dict(row) + payload["metadata"] = self._json_loads(payload.get("metadata_json"), {}) + items.append(payload) + return items + + def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: + items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes) + hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] + if not hashes: + return items + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute( + f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})", + tuple(hashes), + ) + self._conn.commit() + return items + + def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int: + count = 0 + for item in refs or []: + external_id = str(item.get("external_id", "") or "").strip() + paragraph_hash = str(item.get("paragraph_hash", "") or "").strip() + if not external_id or not paragraph_hash: + continue + created_at = float(item.get("created_at") or datetime.now().timestamp()) + metadata_json = self._json_dumps(item.get("metadata") or {}) + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO external_memory_refs ( + external_id, paragraph_hash, source_type, created_at, metadata_json + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(external_id) DO UPDATE SET + paragraph_hash = excluded.paragraph_hash, + source_type = excluded.source_type, + created_at = excluded.created_at, + metadata_json = excluded.metadata_json + """, + ( + external_id, + paragraph_hash, + str(item.get("source_type", "") or "").strip() or None, + created_at, + metadata_json, + ), + ) + count += max(0, int(cursor.rowcount or 0)) + self._conn.commit() + return count + + def record_v5_operation( + self, + *, + action: str, + target: str, + resolved_hashes: List[str], + reason: str = "", + updated_by: str = "", + result: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + operation_id = f"v5_{uuid.uuid4().hex}" + created_at = datetime.now().timestamp() + payload = { + "operation_id": operation_id, + "action": str(action or "").strip(), + "target": str(target or "").strip(), + "reason": str(reason or "").strip(), + "updated_by": str(updated_by or "").strip(), + "created_at": created_at, + "resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()], + "result": result or {}, + } + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO memory_v5_operations ( + operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + operation_id, + payload["action"], + payload["target"] or None, + payload["reason"] or None, + payload["updated_by"] or None, + created_at, + self._json_dumps(payload["resolved_hashes"]), + self._json_dumps(payload["result"]), + ), + ) + self._conn.commit() + return payload + + def create_delete_operation( + self, + *, + mode: str, + selector: Any, + items: List[Dict[str, Any]], + reason: str = "", + requested_by: str = "", + status: str = "executed", + summary: Optional[Dict[str, Any]] = None, + operation_id: Optional[str] = None, + ) -> Dict[str, Any]: + op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip() + created_at = datetime.now().timestamp() + normalized_items: List[Dict[str, Any]] = [] + for item in items or []: + if not isinstance(item, dict): + continue + item_type = str(item.get("item_type", "") or "").strip() + if not item_type: + continue + normalized_items.append( + { + "item_type": item_type, + "item_hash": str(item.get("item_hash", "") or "").strip() or None, + "item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None, + "payload": item.get("payload") if isinstance(item.get("payload"), dict) else {}, + } + ) + + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO delete_operations ( + operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?) + """, + ( + op_id, + str(mode or "").strip(), + self._json_dumps(selector if selector is not None else {}), + str(reason or "").strip() or None, + str(requested_by or "").strip() or None, + str(status or "executed").strip(), + created_at, + self._json_dumps(summary or {}), + ), + ) + if normalized_items: + cursor.executemany( + """ + INSERT INTO delete_operation_items ( + operation_id, item_type, item_hash, item_key, payload_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, + [ + ( + op_id, + item["item_type"], + item["item_hash"], + item["item_key"], + self._json_dumps(item["payload"]), + created_at, + ) + for item in normalized_items + ], + ) + self._conn.commit() + return self.get_delete_operation(op_id) or { + "operation_id": op_id, + "mode": str(mode or "").strip(), + "selector": selector, + "reason": str(reason or "").strip(), + "requested_by": str(requested_by or "").strip(), + "status": str(status or "executed").strip(), + "created_at": created_at, + "summary": summary or {}, + "items": normalized_items, + } + + def mark_delete_operation_restored( + self, + operation_id: str, + *, + summary: Optional[Dict[str, Any]] = None, + ) -> bool: + token = str(operation_id or "").strip() + if not token: + return False + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE delete_operations + SET status = ?, restored_at = ?, summary_json = ? + WHERE operation_id = ? + """, + ( + "restored", + datetime.now().timestamp(), + self._json_dumps(summary or {}), + token, + ), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]: + cursor = self._conn.cursor() + params: List[Any] = [] + where = "" + mode_token = str(mode or "").strip().lower() + if mode_token: + where = "WHERE LOWER(mode) = ?" + params.append(mode_token) + params.append(max(1, int(limit or 50))) + cursor.execute( + f""" + SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + FROM delete_operations + {where} + ORDER BY created_at DESC + LIMIT ? + """, + tuple(params), + ) + items: List[Dict[str, Any]] = [] + for row in cursor.fetchall(): + payload = dict(row) + payload["selector"] = self._json_loads(payload.get("selector"), {}) + payload["summary"] = self._json_loads(payload.get("summary_json"), {}) + items.append(payload) + return items + + def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]: + token = str(operation_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json + FROM delete_operations + WHERE operation_id = ? + LIMIT 1 + """, + (token,), + ) + row = cursor.fetchone() + if row is None: + return None + + payload = dict(row) + payload["selector"] = self._json_loads(payload.get("selector"), {}) + payload["summary"] = self._json_loads(payload.get("summary_json"), {}) + + cursor.execute( + """ + SELECT item_type, item_hash, item_key, payload_json, created_at + FROM delete_operation_items + WHERE operation_id = ? + ORDER BY id ASC + """, + (token,), + ) + payload["items"] = [ + { + "item_type": str(item["item_type"] or ""), + "item_hash": str(item["item_hash"] or ""), + "item_key": str(item["item_key"] or ""), + "payload": self._json_loads(item["payload_json"], {}), + "created_at": item["created_at"], + } + for item in cursor.fetchall() + ] + return payload + + def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT hash + FROM deleted_relations + WHERE deleted_at IS NOT NULL AND deleted_at < ? + ORDER BY deleted_at ASC + LIMIT ? + """, + (float(cutoff_time), max(1, int(limit or 1000))), + ) + hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()] + if not hashes: + return [] + placeholders = ",".join(["?"] * len(hashes)) + cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes)) + self._conn.commit() + return hashes + def get_statistics(self) -> Dict[str, int]: """ 获取统计信息 @@ -2956,6 +3394,18 @@ class MetadataStore: self._conn.commit() return changed + def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool: + """恢复软删除段落。""" + cursor = self._conn.cursor() + cursor.execute( + "UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?", + (str(paragraph_hash),), + ) + changed = cursor.rowcount > 0 + if changed: + self._conn.commit() + return changed + def backfill_temporal_metadata_from_created_at( self, *, @@ -4698,6 +5148,29 @@ class MetadataStore: ) self._conn.commit() + def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]: + """统计某个 source 当前 pending 队列中的状态分布。""" + token = self._normalize_episode_source(source) + if not token: + return {"pending": 0, "running": 0, "failed": 0, "done": 0} + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) AS count + FROM episode_pending_paragraphs + WHERE TRIM(COALESCE(source, '')) = ? + GROUP BY status + """, + (token,), + ) + counts = {"pending": 0, "running": 0, "failed": 0, "done": 0} + for row in cursor.fetchall(): + status = str(row["status"] or "").strip().lower() + if status in counts: + counts[status] = int(row["count"] or 0) + return counts + def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: data = dict(row) @@ -4904,7 +5377,7 @@ class MetadataStore: SELECT 1 FROM episode_rebuild_sources ers WHERE ers.source = TRIM(COALESCE(e.source, '')) - AND ers.status IN ('pending', 'running', 'failed') + AND ers.status IN ('pending', 'running') ) """ ) @@ -4948,6 +5421,26 @@ class MetadataStore: return source_expr, effective_start, effective_end, conditions, params + @staticmethod + def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]: + """将 episode 查询归一化为短语和 token。""" + normalized = normalize_text(str(query or "")).strip().lower() + if not normalized: + return "", [] + + token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") + tokens: List[str] = [] + seen = set() + for token in token_pattern.findall(normalized): + if token in seen: + continue + seen.add(token) + tokens.append(token) + + if not tokens and len(normalized) >= 2: + tokens = [normalized] + return normalized, tokens + def get_episode_rows_by_paragraph_hashes( self, paragraph_hashes: List[str], @@ -5097,28 +5590,58 @@ class MetadataStore: source=source, ) - q = str(query or "").strip().lower() + q, tokens = self._tokenize_episode_query(query) select_score_sql = "0.0 AS lexical_score" order_sql = f"{effective_end} DESC, e.updated_at DESC" select_params: List[Any] = [] query_params: List[Any] = [] if q: - like = f"%{q}%" - title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?" - summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?" - keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?" - participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?" - conditions.append( - f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})" + field_exprs = { + "title": "LOWER(COALESCE(e.title, ''))", + "summary": "LOWER(COALESCE(e.summary, ''))", + "keywords": "LOWER(COALESCE(e.keywords_json, ''))", + "participants": "LOWER(COALESCE(e.participants_json, ''))", + } + + score_parts: List[str] = [] + phrase_like = f"%{q}%" + score_parts.extend( + [ + f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END", + f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END", + f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END", + f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END", + ] ) - select_score_sql = ( - f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + " - f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + " - f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + " - f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score" - ) - select_params.extend([like, like, like, like]) - query_params.extend([like, like, like, like]) + select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like]) + + token_predicates: List[str] = [] + for token in tokens: + like = f"%{token}%" + token_any = ( + f"({field_exprs['title']} LIKE ? OR " + f"{field_exprs['summary']} LIKE ? OR " + f"{field_exprs['keywords']} LIKE ? OR " + f"{field_exprs['participants']} LIKE ?)" + ) + token_predicates.append(token_any) + query_params.extend([like, like, like, like]) + + score_parts.append( + "(" + f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + " + f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + " + f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END" + ")" + ) + select_params.extend([like, like, like, like, like, like, like, like]) + + if token_predicates: + conditions.append("(" + " OR ".join(token_predicates) + ")") + + select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score" order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py index a87a4913..dcf64c34 100644 --- a/plugins/A_memorix/core/utils/aggregate_query_service.py +++ b/plugins/A_memorix/core/utils/aggregate_query_service.py @@ -302,7 +302,7 @@ class AggregateQueryService: ) for (branch_name, _), payload in zip(scheduled, done): if isinstance(payload, Exception): - logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload) + logger.error(f"aggregate branch failed: branch={branch_name} error={payload}") normalized = self._normalize_branch_payload( branch_name, { diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py index 5a4cd24d..44b22854 100644 --- a/plugins/A_memorix/core/utils/episode_retrieval_service.py +++ b/plugins/A_memorix/core/utils/episode_retrieval_service.py @@ -70,7 +70,7 @@ class EpisodeRetrievalService: temporal=temporal, ) except Exception as exc: - logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc) + logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}") else: paragraph_rank_map: Dict[str, int] = {} relation_rank_map: Dict[str, int] = {} diff --git a/plugins/A_memorix/core/utils/episode_segmentation_service.py b/plugins/A_memorix/core/utils/episode_segmentation_service.py new file mode 100644 index 00000000..f42b1456 --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_segmentation_service.py @@ -0,0 +1,304 @@ +""" +Episode 语义切分服务(LLM 主路径)。 + +职责: +1. 组装语义切分提示词 +2. 调用 LLM 生成结构化 episode JSON +3. 严格校验输出结构,返回标准化结果 +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.config.model_configs import TaskConfig +from src.config.config import model_config as host_model_config +from src.services import llm_service as llm_api + +logger = get_logger("A_Memorix.EpisodeSegmentationService") + + +class EpisodeSegmentationService: + """基于 LLM 的 episode 语义切分服务。""" + + SEGMENTATION_VERSION = "episode_mvp_v1" + + def __init__(self, plugin_config: Optional[dict] = None): + self.plugin_config = plugin_config or {} + + def _cfg(self, key: str, default: Any = None) -> Any: + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _is_task_config(obj: Any) -> bool: + return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", [])) + + def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig: + return TaskConfig( + model_list=[model_name], + max_tokens=template.max_tokens, + temperature=template.temperature, + slow_threshold=template.slow_threshold, + selection_strategy=template.selection_strategy, + ) + + def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]: + preferred = ("utils", "replyer", "planner", "tool_use") + for task_name in preferred: + cfg = available_tasks.get(task_name) + if self._is_task_config(cfg): + return cfg + for task_name, cfg in available_tasks.items(): + if task_name != "embedding" and self._is_task_config(cfg): + return cfg + for cfg in available_tasks.values(): + if self._is_task_config(cfg): + return cfg + return None + + def _resolve_model_config(self) -> Tuple[Optional[Any], str]: + available_tasks = llm_api.get_available_models() or {} + if not available_tasks: + return None, "unavailable" + + selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip() + model_dict = getattr(host_model_config, "models_dict", {}) or {} + + if selector and selector.lower() != "auto": + direct_task = available_tasks.get(selector) + if self._is_task_config(direct_task): + return direct_task, selector + + if selector in model_dict: + template = self._pick_template_task(available_tasks) + if template is not None: + return self._build_single_model_task(selector, template), selector + + logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto") + + for task_name in ("utils", "replyer", "planner", "tool_use"): + cfg = available_tasks.get(task_name) + if self._is_task_config(cfg): + return cfg, task_name + + fallback = self._pick_template_task(available_tasks) + if fallback is not None: + return fallback, "auto" + return None, "unavailable" + + @staticmethod + def _clamp_score(value: Any, default: float = 0.0) -> float: + try: + num = float(value) + except Exception: + num = default + if num < 0.0: + return 0.0 + if num > 1.0: + return 1.0 + return num + + @staticmethod + def _safe_json_loads(text: str) -> Dict[str, Any]: + raw = str(text or "").strip() + if not raw: + raise ValueError("empty_response") + + if "```" in raw: + raw = raw.replace("```json", "```").replace("```JSON", "```") + parts = raw.split("```") + for part in parts: + part = part.strip() + if part.startswith("{") and part.endswith("}"): + raw = part + break + + try: + data = json.loads(raw) + if isinstance(data, dict): + return data + except Exception: + pass + + start = raw.find("{") + end = raw.rfind("}") + if start >= 0 and end > start: + candidate = raw[start : end + 1] + data = json.loads(candidate) + if isinstance(data, dict): + return data + + raise ValueError("invalid_json_response") + + def _build_prompt( + self, + *, + source: str, + window_start: Optional[float], + window_end: Optional[float], + paragraphs: List[Dict[str, Any]], + ) -> str: + rows: List[str] = [] + for idx, item in enumerate(paragraphs, 1): + p_hash = str(item.get("hash", "") or "").strip() + content = str(item.get("content", "") or "").strip().replace("\r\n", "\n") + content = content[:800] + event_start = item.get("event_time_start") + event_end = item.get("event_time_end") + event_time = item.get("event_time") + rows.append( + ( + f"[{idx}] hash={p_hash}\n" + f"event_time={event_time}\n" + f"event_time_start={event_start}\n" + f"event_time_end={event_end}\n" + f"content={content}" + ) + ) + + source_text = str(source or "").strip() or "unknown" + return ( + "You are an episode segmentation engine.\n" + "Group the given paragraphs into one or more coherent episodes.\n" + "Return JSON ONLY. No markdown, no explanation.\n" + "\n" + "Hard JSON schema:\n" + "{\n" + ' "episodes": [\n' + " {\n" + ' "title": "string",\n' + ' "summary": "string",\n' + ' "paragraph_hashes": ["hash1", "hash2"],\n' + ' "participants": ["person1", "person2"],\n' + ' "keywords": ["kw1", "kw2"],\n' + ' "time_confidence": 0.0,\n' + ' "llm_confidence": 0.0\n' + " }\n" + " ]\n" + "}\n" + "\n" + "Rules:\n" + "1) paragraph_hashes must come from input only.\n" + "2) title and summary must be non-empty.\n" + "3) keep participants/keywords concise and deduplicated.\n" + "4) if uncertain, still provide best effort confidence values.\n" + "\n" + f"source={source_text}\n" + f"window_start={window_start}\n" + f"window_end={window_end}\n" + "paragraphs:\n" + + "\n\n".join(rows) + ) + + def _normalize_episodes( + self, + *, + payload: Dict[str, Any], + input_hashes: List[str], + ) -> List[Dict[str, Any]]: + raw_episodes = payload.get("episodes") + if not isinstance(raw_episodes, list): + raise ValueError("episodes_missing_or_not_list") + + valid_hashes = set(input_hashes) + normalized: List[Dict[str, Any]] = [] + for item in raw_episodes: + if not isinstance(item, dict): + continue + + title = str(item.get("title", "") or "").strip() + summary = str(item.get("summary", "") or "").strip() + if not title or not summary: + continue + + raw_hashes = item.get("paragraph_hashes") + if not isinstance(raw_hashes, list): + continue + + dedup_hashes: List[str] = [] + seen_hashes = set() + for h in raw_hashes: + token = str(h or "").strip() + if not token or token in seen_hashes or token not in valid_hashes: + continue + seen_hashes.add(token) + dedup_hashes.append(token) + + if not dedup_hashes: + continue + + participants = [] + for p in item.get("participants", []) or []: + token = str(p or "").strip() + if token: + participants.append(token) + + keywords = [] + for kw in item.get("keywords", []) or []: + token = str(kw or "").strip() + if token: + keywords.append(token) + + normalized.append( + { + "title": title, + "summary": summary, + "paragraph_hashes": dedup_hashes, + "participants": participants[:16], + "keywords": keywords[:20], + "time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0), + "llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5), + } + ) + + if not normalized: + raise ValueError("episodes_all_invalid") + return normalized + + async def segment( + self, + *, + source: str, + window_start: Optional[float], + window_end: Optional[float], + paragraphs: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if not paragraphs: + raise ValueError("paragraphs_empty") + + model_config, model_label = self._resolve_model_config() + if model_config is None: + raise RuntimeError("episode segmentation model unavailable") + + prompt = self._build_prompt( + source=source, + window_start=window_start, + window_end=window_end, + paragraphs=paragraphs, + ) + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="A_Memorix.EpisodeSegmentation", + ) + if not success or not response: + raise RuntimeError("llm_generate_failed") + + payload = self._safe_json_loads(str(response)) + input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs] + episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes) + + return { + "episodes": episodes, + "segmentation_model": model_label, + "segmentation_version": self.SEGMENTATION_VERSION, + } + diff --git a/plugins/A_memorix/core/utils/episode_service.py b/plugins/A_memorix/core/utils/episode_service.py new file mode 100644 index 00000000..ca94dd96 --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_service.py @@ -0,0 +1,558 @@ +""" +Episode 聚合与落库服务。 + +流程: +1. 从 pending 队列读取段落并组批 +2. 按 source + 时间窗口切组 +3. 调用 LLM 语义切分 +4. 写入 episodes + episode_paragraphs +5. LLM 失败时使用确定性 fallback +""" + +from __future__ import annotations + +import json +import re +from collections import Counter +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from .episode_segmentation_service import EpisodeSegmentationService +from .hash import compute_hash + +logger = get_logger("A_Memorix.EpisodeService") + + +class EpisodeService: + """Episode MVP 后台处理服务。""" + + def __init__( + self, + *, + metadata_store: Any, + plugin_config: Optional[Any] = None, + segmentation_service: Optional[EpisodeSegmentationService] = None, + ): + self.metadata_store = metadata_store + self.plugin_config = plugin_config or {} + self.segmentation_service = segmentation_service or EpisodeSegmentationService( + plugin_config=self._config_dict(), + ) + + def _config_dict(self) -> Dict[str, Any]: + if isinstance(self.plugin_config, dict): + return self.plugin_config + return {} + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin_config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _to_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except Exception: + return None + + @staticmethod + def _clamp_score(value: Any, default: float = 1.0) -> float: + try: + num = float(value) + except Exception: + num = default + if num < 0.0: + return 0.0 + if num > 1.0: + return 1.0 + return num + + @staticmethod + def _paragraph_anchor(paragraph: Dict[str, Any]) -> float: + for key in ("event_time_end", "event_time_start", "event_time", "created_at"): + value = paragraph.get(key) + try: + if value is not None: + return float(value) + except Exception: + continue + return 0.0 + + @staticmethod + def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]: + return ( + EpisodeService._paragraph_anchor(paragraph), + str(paragraph.get("hash", "") or ""), + ) + + def load_pending_paragraphs( + self, + pending_rows: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], List[str]]: + """ + 将 pending 行展开为段落上下文。 + + Returns: + (loaded_paragraphs, missing_hashes) + """ + loaded: List[Dict[str, Any]] = [] + missing: List[str] = [] + for row in pending_rows or []: + p_hash = str(row.get("paragraph_hash", "") or "").strip() + if not p_hash: + continue + + paragraph = self.metadata_store.get_paragraph(p_hash) + if not paragraph: + missing.append(p_hash) + continue + + loaded.append( + { + "hash": p_hash, + "source": str(row.get("source") or paragraph.get("source") or "").strip(), + "content": str(paragraph.get("content", "") or ""), + "created_at": self._to_optional_float(paragraph.get("created_at")) + or self._to_optional_float(row.get("created_at")) + or 0.0, + "event_time": self._to_optional_float(paragraph.get("event_time")), + "event_time_start": self._to_optional_float(paragraph.get("event_time_start")), + "event_time_end": self._to_optional_float(paragraph.get("event_time_end")), + "time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None, + "time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0), + } + ) + return loaded, missing + + def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + 按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。 + """ + if not paragraphs: + return [] + + max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20))) + max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000))) + window_seconds = max( + 60.0, + float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0, + ) + + by_source: Dict[str, List[Dict[str, Any]]] = {} + for paragraph in paragraphs: + source = str(paragraph.get("source", "") or "").strip() + by_source.setdefault(source, []).append(paragraph) + + groups: List[Dict[str, Any]] = [] + for source, items in by_source.items(): + ordered = sorted(items, key=self._paragraph_sort_key) + + current: List[Dict[str, Any]] = [] + current_chars = 0 + last_anchor: Optional[float] = None + + def flush() -> None: + nonlocal current, current_chars, last_anchor + if not current: + return + sorted_current = sorted(current, key=self._paragraph_sort_key) + groups.append( + { + "source": source, + "paragraphs": sorted_current, + } + ) + current = [] + current_chars = 0 + last_anchor = None + + for paragraph in ordered: + anchor = self._paragraph_anchor(paragraph) + content_len = len(str(paragraph.get("content", "") or "")) + + need_flush = False + if current: + if len(current) >= max_paragraphs: + need_flush = True + elif current_chars + content_len > max_chars: + need_flush = True + elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds: + need_flush = True + + if need_flush: + flush() + + current.append(paragraph) + current_chars += content_len + last_anchor = anchor + + flush() + + groups.sort( + key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0 + ) + return groups + + def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]: + starts: List[float] = [] + ends: List[float] = [] + granularity_priority = { + "minute": 4, + "hour": 3, + "day": 2, + "month": 1, + "year": 0, + } + granularity = None + granularity_rank = -1 + conf_values: List[float] = [] + + for p in paragraphs: + s = self._to_optional_float(p.get("event_time_start")) + e = self._to_optional_float(p.get("event_time_end")) + t = self._to_optional_float(p.get("event_time")) + c = self._to_optional_float(p.get("created_at")) + + start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c)) + end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c)) + + if start_candidate is not None: + starts.append(start_candidate) + if end_candidate is not None: + ends.append(end_candidate) + + g = str(p.get("time_granularity", "") or "").strip().lower() + if g in granularity_priority and granularity_priority[g] > granularity_rank: + granularity_rank = granularity_priority[g] + granularity = g + + conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0)) + + time_start = min(starts) if starts else None + time_end = max(ends) if ends else None + time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0 + return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0) + + def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]: + seen = set() + participants: List[str] = [] + for p_hash in paragraph_hashes: + try: + entities = self.metadata_store.get_paragraph_entities(p_hash) + except Exception: + entities = [] + for item in entities: + name = str(item.get("name", "") or "").strip() + if not name: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + participants.append(name) + if len(participants) >= limit: + return participants + return participants + + @staticmethod + def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]: + token_counter: Counter[str] = Counter() + token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") + stop_words = { + "the", + "and", + "that", + "this", + "with", + "from", + "for", + "have", + "will", + "your", + "you", + "我们", + "你们", + "他们", + "以及", + "一个", + "这个", + "那个", + "然后", + "因为", + "所以", + } + for p in paragraphs: + text = str(p.get("content", "") or "").lower() + for token in token_pattern.findall(text): + if token in stop_words: + continue + token_counter[token] += 1 + + return [token for token, _ in token_counter.most_common(limit)] + + def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]: + paragraphs = group.get("paragraphs", []) or [] + source = str(group.get("source", "") or "").strip() + hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] + snippets = [] + for p in paragraphs[:3]: + text = str(p.get("content", "") or "").strip().replace("\n", " ") + if text: + snippets.append(text[:140]) + summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。" + + time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs) + participants = self._collect_participants(hashes, limit=12) + keywords = self._derive_keywords(paragraphs, limit=10) + + if time_start is not None: + day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d") + title = f"{source or 'unknown'} {day_text} 情景片段" + else: + title = f"{source or 'unknown'} 情景片段" + + return { + "title": title[:80], + "summary": summary, + "paragraph_hashes": hashes, + "participants": participants, + "keywords": keywords, + "time_confidence": time_conf, + "llm_confidence": 0.0, + "event_time_start": time_start, + "event_time_end": time_end, + "time_granularity": granularity, + "segmentation_model": "fallback_rule", + "segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION, + } + + @staticmethod + def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]: + in_group = set(group_hashes_ordered) + dedup: List[str] = [] + seen = set() + for h in episode_hashes or []: + token = str(h or "").strip() + if not token or token not in in_group or token in seen: + continue + seen.add(token) + dedup.append(token) + return dedup + + async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]: + paragraphs = group.get("paragraphs", []) or [] + if not paragraphs: + return { + "payloads": [], + "done_hashes": [], + "episode_count": 0, + "fallback_count": 0, + } + + source = str(group.get("source", "") or "").strip() + group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] + group_start, group_end, _, _ = self._compute_time_meta(paragraphs) + + fallback_used = False + segmentation_model = "fallback_rule" + segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION + + try: + llm_result = await self.segmentation_service.segment( + source=source, + window_start=group_start, + window_end=group_end, + paragraphs=paragraphs, + ) + episodes = list(llm_result.get("episodes") or []) + segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto" + segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION + if not episodes: + raise ValueError("llm_empty_episodes") + except Exception as e: + logger.warning( + "Episode segmentation fallback: " + f"source={source} " + f"size={len(group_hashes)} " + f"err={e}" + ) + episodes = [self._build_fallback_episode(group)] + fallback_used = True + + stored_payloads: List[Dict[str, Any]] = [] + for episode in episodes: + ordered_hashes = self._normalize_episode_hashes( + episode_hashes=episode.get("paragraph_hashes", []), + group_hashes_ordered=group_hashes, + ) + if not ordered_hashes: + continue + + sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)] + event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs) + + participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()] + keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()] + if not participants: + participants = self._collect_participants(ordered_hashes, limit=16) + if not keywords: + keywords = self._derive_keywords(sub_paragraphs, limit=12) + + title = str(episode.get("title", "") or "").strip()[:120] + summary = str(episode.get("summary", "") or "").strip()[:2000] + if not title or not summary: + continue + + seed = json.dumps( + { + "source": source, + "hashes": ordered_hashes, + "version": segmentation_version, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + payload = { + "episode_id": episode_id, + "source": source or None, + "title": title, + "summary": summary, + "event_time_start": episode.get("event_time_start", event_start), + "event_time_end": episode.get("event_time_end", event_end), + "time_granularity": episode.get("time_granularity", granularity), + "time_confidence": self._clamp_score( + episode.get("time_confidence"), + default=time_conf_default, + ), + "participants": participants[:16], + "keywords": keywords[:20], + "evidence_ids": ordered_hashes, + "paragraph_count": len(ordered_hashes), + "llm_confidence": self._clamp_score( + episode.get("llm_confidence"), + default=0.0 if fallback_used else 0.6, + ), + "segmentation_model": ( + str(episode.get("segmentation_model", "") or "").strip() + or ("fallback_rule" if fallback_used else segmentation_model) + ), + "segmentation_version": ( + str(episode.get("segmentation_version", "") or "").strip() + or segmentation_version + ), + } + stored_payloads.append(payload) + + return { + "payloads": stored_payloads, + "done_hashes": group_hashes, + "episode_count": len(stored_payloads), + "fallback_count": 1 if fallback_used else 0, + } + + async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]: + result = await self._build_episode_payloads_for_group(group) + stored_count = 0 + for payload in result.get("payloads") or []: + stored = self.metadata_store.upsert_episode(payload) + final_id = str(stored.get("episode_id") or payload.get("episode_id") or "") + if final_id: + self.metadata_store.bind_episode_paragraphs( + final_id, + list(payload.get("evidence_ids") or []), + ) + stored_count += 1 + + result["episode_count"] = stored_count + return { + "done_hashes": list(result.get("done_hashes") or []), + "episode_count": stored_count, + "fallback_count": int(result.get("fallback_count") or 0), + } + + async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]: + loaded, missing_hashes = self.load_pending_paragraphs(pending_rows) + groups = self.group_paragraphs(loaded) + + done_hashes: List[str] = list(missing_hashes) + failed_hashes: Dict[str, str] = {} + episode_count = 0 + fallback_count = 0 + + for group in groups: + group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])] + try: + result = await self.process_group(group) + done_hashes.extend(result.get("done_hashes") or []) + episode_count += int(result.get("episode_count") or 0) + fallback_count += int(result.get("fallback_count") or 0) + except Exception as e: + err = str(e)[:500] + for h in group_hashes: + if h: + failed_hashes[h] = err + + dedup_done = list(dict.fromkeys([h for h in done_hashes if h])) + return { + "done_hashes": dedup_done, + "failed_hashes": failed_hashes, + "episode_count": episode_count, + "fallback_count": fallback_count, + "missing_count": len(missing_hashes), + "group_count": len(groups), + } + + async def rebuild_source(self, source: str) -> Dict[str, Any]: + token = str(source or "").strip() + if not token: + return { + "source": "", + "episode_count": 0, + "fallback_count": 0, + "group_count": 0, + "paragraph_count": 0, + } + + paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) + if not paragraphs: + replace_result = self.metadata_store.replace_episodes_for_source(token, []) + return { + "source": token, + "episode_count": int(replace_result.get("episode_count") or 0), + "fallback_count": 0, + "group_count": 0, + "paragraph_count": 0, + } + + groups = self.group_paragraphs(paragraphs) + payloads: List[Dict[str, Any]] = [] + fallback_count = 0 + + for group in groups: + result = await self._build_episode_payloads_for_group(group) + payloads.extend(list(result.get("payloads") or [])) + fallback_count += int(result.get("fallback_count") or 0) + + replace_result = self.metadata_store.replace_episodes_for_source(token, payloads) + return { + "source": token, + "episode_count": int(replace_result.get("episode_count") or 0), + "fallback_count": fallback_count, + "group_count": len(groups), + "paragraph_count": len(paragraphs), + } diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py index ccbbaf90..6460c013 100644 --- a/plugins/A_memorix/core/utils/person_profile_service.py +++ b/plugins/A_memorix/core/utils/person_profile_service.py @@ -9,7 +9,11 @@ import json import time from typing import Any, Dict, List, Optional, Tuple +from sqlalchemy import or_ +from sqlmodel import select + from src.common.logger import get_logger +from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from ..embedding import EmbeddingAPIAdapter @@ -120,31 +124,40 @@ class PersonProfileService: if not key: return "" + try: + with get_db_session(auto_commit=False) as session: + record = session.exec( + select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1) + ).first() + if record: + return str(record) + + record = session.exec( + select(PersonInfo.person_id) + .where( + or_( + PersonInfo.person_name == key, + PersonInfo.user_nickname == key, + ) + ) + .limit(1) + ).first() + if record: + return str(record) + + record = session.exec( + select(PersonInfo.person_id) + .where(PersonInfo.group_cardname.contains(key)) + .limit(1) + ).first() + if record: + return str(record) + except Exception as e: + logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}") + if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): return key.lower() - try: - record = ( - PersonInfo.select(PersonInfo.person_id) - .where((PersonInfo.person_name == key) | (PersonInfo.nickname == key)) - .first() - ) - if record and record.person_id: - return str(record.person_id) - except Exception: - pass - - try: - record = ( - PersonInfo.select(PersonInfo.person_id) - .where(PersonInfo.group_nick_name.contains(key)) - .first() - ) - if record and record.person_id: - return str(record.person_id) - except Exception: - pass - return "" def _parse_group_nicks(self, raw_value: Any) -> List[str]: @@ -160,7 +173,7 @@ class PersonProfileService: names: List[str] = [] for item in items: if isinstance(item, dict): - value = str(item.get("group_nick_name", "")).strip() + value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip() if value: names.append(value) elif isinstance(item, str): @@ -193,6 +206,42 @@ class PersonProfileService: traits.append(text) return traits[:10] + def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]: + """当人物主档案缺失时,从已有记忆证据里回捞可用别名。""" + if not person_id: + return [], "" + + aliases: List[str] = [] + primary_name = "" + seen = set() + + try: + paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id) + except Exception as e: + logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}") + return [], "" + + for paragraph in paragraphs[:20]: + paragraph_hash = str(paragraph.get("hash", "") or "").strip() + if not paragraph_hash: + continue + try: + paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) + except Exception: + paragraph_entities = [] + for entity in paragraph_entities: + name = str(entity.get("name", "") or "").strip() + if not name or name == person_id: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + aliases.append(name) + if not primary_name: + primary_name = name + return aliases, primary_name + def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: """获取人物别名集合、主展示名、记忆特征。""" aliases: List[str] = [] @@ -200,18 +249,28 @@ class PersonProfileService: memory_traits: List[str] = [] if not person_id: return aliases, primary_name, memory_traits + recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id) try: - record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - if not record: - return aliases, primary_name, memory_traits + with get_db_session(auto_commit=False) as session: + record = session.exec( + select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1) + ).first() + if not record: + return recovered_aliases, recovered_primary_name or person_id, memory_traits person_name = str(getattr(record, "person_name", "") or "").strip() - nickname = str(getattr(record, "nickname", "") or "").strip() - group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None)) + nickname = str(getattr(record, "user_nickname", "") or "").strip() + group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None)) memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) - primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id + primary_name = ( + person_name + or nickname + or recovered_primary_name + or str(getattr(record, "user_id", "") or "").strip() + or person_id + ) - candidates = [person_name, nickname] + group_nicks + candidates = [person_name, nickname] + group_nicks + recovered_aliases seen = set() for item in candidates: norm = str(item or "").strip() diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py index b73e1260..6fa2e621 100644 --- a/plugins/A_memorix/core/utils/relation_write_service.py +++ b/plugins/A_memorix/core/utils/relation_write_service.py @@ -82,8 +82,9 @@ class RelationWriteService: ) self.metadata_store.set_relation_vector_state(hash_value, "ready") logger.info( - "metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s", - hash_value[:16], + "metric.relation_vector_write_success=1 " + "metric.relation_vector_write_success_count=1 " + f"hash={hash_value[:16]}" ) return RelationWriteResult( hash_value=hash_value, @@ -109,9 +110,10 @@ class RelationWriteService: bump_retry=True, ) logger.warning( - "metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s", - hash_value[:16], - err, + "metric.relation_vector_write_fail=1 " + "metric.relation_vector_write_fail_count=1 " + f"hash={hash_value[:16]} " + f"err={err}" ) return RelationWriteResult( hash_value=hash_value, diff --git a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py new file mode 100644 index 00000000..e0e8ecd6 --- /dev/null +++ b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py @@ -0,0 +1,1857 @@ +""" +Retrieval tuning manager for WebUI. +""" + +from __future__ import annotations + +import asyncio +import copy +import json +import random +import re +import time +import uuid +from collections import Counter, deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from ..runtime.search_runtime_initializer import build_search_runtime +from .search_execution_service import SearchExecutionRequest, SearchExecutionService + +try: + from src.services import llm_service as llm_api +except Exception: # pragma: no cover + llm_api = None + +logger = get_logger("A_Memorix.RetrievalTuningManager") + + +OBJECTIVES = {"precision_priority", "balanced", "recall_priority"} +INTENSITIES = {"quick": 8, "standard": 20, "deep": 32} +CATEGORIES = {"query_nl", "query_kw", "spo_relation", "spo_search"} +_RUNTIME_CONFIG_INSTANCE_KEYS = { + "vector_store", + "graph_store", + "metadata_store", + "embedding_manager", + "sparse_index", + "relation_write_service", + "plugin_instance", +} + + +def _now() -> float: + return time.time() + + +def _clamp_int(value: Any, default: int, min_value: int, max_value: int) -> int: + try: + parsed = int(value) + except Exception: + parsed = int(default) + return max(min_value, min(max_value, parsed)) + + +def _clamp_float(value: Any, default: float, min_value: float, max_value: float) -> float: + try: + parsed = float(value) + except Exception: + parsed = float(default) + return max(min_value, min(max_value, parsed)) + + +def _coerce_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y", "on"}: + return True + if text in {"0", "false", "no", "n", "off"}: + return False + return default + + +def _nested_get(data: Dict[str, Any], key: str, default: Any = None) -> Any: + cur: Any = data + for part in key.split("."): + if isinstance(cur, dict) and part in cur: + cur = cur[part] + else: + return default + return cur + + +def _nested_set(data: Dict[str, Any], key: str, value: Any) -> None: + parts = key.split(".") + cur = data + for part in parts[:-1]: + if part not in cur or not isinstance(cur[part], dict): + cur[part] = {} + cur = cur[part] + cur[parts[-1]] = value + + +def _deep_merge(base: Dict[str, Any], patch: Dict[str, Any]) -> Dict[str, Any]: + out = copy.deepcopy(base) + for key, value in (patch or {}).items(): + if isinstance(value, dict) and isinstance(out.get(key), dict): + out[key] = _deep_merge(out[key], value) + else: + out[key] = copy.deepcopy(value) + return out + + +def _safe_json_loads(text: str) -> Optional[Any]: + raw = str(text or "").strip() + if not raw: + return None + if "```" in raw: + raw = raw.replace("```json", "```") + for seg in raw.split("```"): + seg = seg.strip() + if seg.startswith("{") or seg.startswith("["): + raw = seg + break + try: + return json.loads(raw) + except Exception: + pass + s = raw.find("{") + e = raw.rfind("}") + if s >= 0 and e > s: + try: + return json.loads(raw[s : e + 1]) + except Exception: + return None + return None + + +@dataclass +class RetrievalQueryCase: + case_id: str + category: str + query: str + expected_hashes: List[str] = field(default_factory=list) + expected_spo: Dict[str, str] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "case_id": self.case_id, + "category": self.category, + "query": self.query, + "expected_hashes": list(self.expected_hashes), + "expected_spo": dict(self.expected_spo), + "metadata": dict(self.metadata), + } + + +@dataclass +class RetrievalTuningRoundRecord: + round_index: int + candidate_profile: Dict[str, Any] + metrics: Dict[str, Any] + score: float + latency_ms: float + failure_summary: Dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=_now) + + def to_dict(self) -> Dict[str, Any]: + return { + "round_index": self.round_index, + "candidate_profile": copy.deepcopy(self.candidate_profile), + "metrics": copy.deepcopy(self.metrics), + "score": float(self.score), + "latency_ms": float(self.latency_ms), + "failure_summary": copy.deepcopy(self.failure_summary), + "created_at": float(self.created_at), + } + + +@dataclass +class RetrievalTuningTaskRecord: + task_id: str + status: str + progress: float + objective: str + intensity: str + rounds_total: int + rounds_done: int = 0 + best_profile: Dict[str, Any] = field(default_factory=dict) + best_metrics: Dict[str, Any] = field(default_factory=dict) + best_score: float = -1.0 + baseline_profile: Dict[str, Any] = field(default_factory=dict) + baseline_metrics: Dict[str, Any] = field(default_factory=dict) + error: str = "" + params: Dict[str, Any] = field(default_factory=dict) + query_set_stats: Dict[str, Any] = field(default_factory=dict) + artifact_paths: Dict[str, str] = field(default_factory=dict) + rounds: List[RetrievalTuningRoundRecord] = field(default_factory=list) + cancel_requested: bool = False + created_at: float = field(default_factory=_now) + started_at: Optional[float] = None + finished_at: Optional[float] = None + updated_at: float = field(default_factory=_now) + apply_log: List[Dict[str, Any]] = field(default_factory=list) + + def to_summary(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "status": self.status, + "progress": self.progress, + "objective": self.objective, + "intensity": self.intensity, + "rounds_total": self.rounds_total, + "rounds_done": self.rounds_done, + "best_score": self.best_score, + "error": self.error, + "query_set_stats": dict(self.query_set_stats), + "artifact_paths": dict(self.artifact_paths), + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + "updated_at": self.updated_at, + } + + def to_detail(self, include_rounds: bool = False) -> Dict[str, Any]: + payload = self.to_summary() + payload.update( + { + "params": copy.deepcopy(self.params), + "best_profile": copy.deepcopy(self.best_profile), + "best_metrics": copy.deepcopy(self.best_metrics), + "baseline_profile": copy.deepcopy(self.baseline_profile), + "baseline_metrics": copy.deepcopy(self.baseline_metrics), + "apply_log": copy.deepcopy(self.apply_log), + } + ) + if include_rounds: + payload["rounds"] = [x.to_dict() for x in self.rounds] + return payload + + +class RetrievalTuningManager: + def __init__( + self, + plugin: Any, + *, + import_write_blocked_provider: Optional[Callable[[], bool]] = None, + ): + self.plugin = plugin + self._import_write_blocked_provider = import_write_blocked_provider + + self._lock = asyncio.Lock() + self._tasks: Dict[str, RetrievalTuningTaskRecord] = {} + self._task_order: deque[str] = deque() + self._queue: deque[str] = deque() + self._active_task_id: Optional[str] = None + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + + self._rollback_snapshot: Optional[Dict[str, Any]] = None + + self._artifacts_root = Path(__file__).resolve().parents[2] / "artifacts" / "retrieval_tuning" + self._artifacts_root.mkdir(parents=True, exist_ok=True) + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin, "get_config", None) + if callable(getter): + return getter(key, default) + return default + + def _is_enabled(self) -> bool: + return bool(self._cfg("web.tuning.enabled", True)) + + def _queue_limit(self) -> int: + return _clamp_int(self._cfg("web.tuning.max_queue_size", 8), 8, 1, 100) + + def _poll_interval_s(self) -> float: + ms = _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000) + return max(0.2, ms / 1000.0) + + def _llm_retry_cfg(self) -> Dict[str, Any]: + return { + "max_attempts": _clamp_int(self._cfg("web.tuning.llm_retry.max_attempts", 3), 3, 1, 10), + "min_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.min_wait_seconds", 2), 2.0, 0.1, 60.0), + "max_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.max_wait_seconds", 20), 20.0, 0.2, 120.0), + "backoff_multiplier": _clamp_float(self._cfg("web.tuning.llm_retry.backoff_multiplier", 2), 2.0, 1.0, 10.0), + } + + def _eval_query_timeout_s(self) -> float: + return _clamp_float( + self._cfg("web.tuning.eval_query_timeout_seconds", 10.0), + 10.0, + 0.01, + 120.0, + ) + + def get_runtime_settings(self) -> Dict[str, Any]: + intensity = str(self._cfg("web.tuning.default_intensity", "standard") or "standard") + if intensity not in INTENSITIES: + intensity = "standard" + objective = str(self._cfg("web.tuning.default_objective", "precision_priority") or "precision_priority") + if objective not in OBJECTIVES: + objective = "precision_priority" + return { + "enabled": self._is_enabled(), + "poll_interval_ms": _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000), + "max_queue_size": self._queue_limit(), + "default_objective": objective, + "default_intensity": intensity, + "default_rounds": INTENSITIES[intensity], + "default_top_k_eval": _clamp_int(self._cfg("web.tuning.default_top_k_eval", 20), 20, 5, 100), + "default_sample_size": _clamp_int(self._cfg("web.tuning.default_sample_size", 24), 24, 4, 200), + "eval_query_timeout_seconds": self._eval_query_timeout_s(), + "llm_retry": self._llm_retry_cfg(), + } + + def _ensure_ready(self) -> None: + required = ("metadata_store", "vector_store", "graph_store", "embedding_manager") + missing = [x for x in required if getattr(self.plugin, x, None) is None] + if missing: + raise ValueError(f"调优依赖未初始化: {', '.join(missing)}") + checker = getattr(self.plugin, "is_runtime_ready", None) + if callable(checker) and not checker(): + raise ValueError("插件运行时未就绪") + provider = self._import_write_blocked_provider + if provider is not None and bool(provider()): + raise ValueError("导入任务运行中,当前禁止启动检索调优") + + def get_profile_snapshot(self) -> Dict[str, Any]: + cfg = getattr(self.plugin, "config", {}) or {} + profile = { + "retrieval": { + "top_k_paragraphs": _nested_get(cfg, "retrieval.top_k_paragraphs", 20), + "top_k_relations": _nested_get(cfg, "retrieval.top_k_relations", 10), + "top_k_final": _nested_get(cfg, "retrieval.top_k_final", 10), + "alpha": _nested_get(cfg, "retrieval.alpha", 0.5), + "enable_ppr": _nested_get(cfg, "retrieval.enable_ppr", True), + "search": {"smart_fallback": {"enabled": _nested_get(cfg, "retrieval.search.smart_fallback.enabled", True)}}, + "sparse": { + "enabled": _nested_get(cfg, "retrieval.sparse.enabled", True), + "mode": _nested_get(cfg, "retrieval.sparse.mode", "auto"), + "candidate_k": _nested_get(cfg, "retrieval.sparse.candidate_k", 80), + "relation_candidate_k": _nested_get(cfg, "retrieval.sparse.relation_candidate_k", 60), + }, + "fusion": { + "method": _nested_get(cfg, "retrieval.fusion.method", "weighted_rrf"), + "rrf_k": _nested_get(cfg, "retrieval.fusion.rrf_k", 60), + "vector_weight": _nested_get(cfg, "retrieval.fusion.vector_weight", 0.7), + "bm25_weight": _nested_get(cfg, "retrieval.fusion.bm25_weight", 0.3), + }, + }, + "threshold": { + "percentile": _nested_get(cfg, "threshold.percentile", 75.0), + "min_results": _nested_get(cfg, "threshold.min_results", 3), + }, + } + return self._normalize_profile(profile, fallback=profile) + + def _normalize_profile(self, profile: Optional[Dict[str, Any]], *, fallback: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + raw = copy.deepcopy(profile or {}) + base = copy.deepcopy(fallback or self.get_profile_snapshot()) + + def pick(path: str, default: Any) -> Any: + if _nested_get(raw, path, None) is not None: + return _nested_get(raw, path, default) + if path in raw: + return raw.get(path, default) + return _nested_get(base, path, default) + + fusion_method = str(pick("retrieval.fusion.method", "weighted_rrf") or "weighted_rrf").strip().lower() + if fusion_method not in {"weighted_rrf", "alpha_legacy"}: + fusion_method = "weighted_rrf" + + sparse_mode = str(pick("retrieval.sparse.mode", "auto") or "auto").strip().lower() + if sparse_mode not in {"auto", "hybrid", "fallback_only"}: + sparse_mode = "auto" + + vec_w = _clamp_float(pick("retrieval.fusion.vector_weight", 0.7), 0.7, 0.0, 1.0) + bm_w = _clamp_float(pick("retrieval.fusion.bm25_weight", 0.3), 0.3, 0.0, 1.0) + s = vec_w + bm_w + if s <= 1e-9: + vec_w, bm_w = 0.7, 0.3 + else: + vec_w, bm_w = vec_w / s, bm_w / s + + return { + "retrieval": { + "top_k_paragraphs": _clamp_int(pick("retrieval.top_k_paragraphs", 20), 20, 10, 1200), + "top_k_relations": _clamp_int(pick("retrieval.top_k_relations", 10), 10, 4, 512), + "top_k_final": _clamp_int(pick("retrieval.top_k_final", 10), 10, 4, 512), + "alpha": _clamp_float(pick("retrieval.alpha", 0.5), 0.5, 0.0, 1.0), + "enable_ppr": _coerce_bool(pick("retrieval.enable_ppr", True), True), + "search": {"smart_fallback": {"enabled": _coerce_bool(pick("retrieval.search.smart_fallback.enabled", True), True)}}, + "sparse": { + "enabled": _coerce_bool(pick("retrieval.sparse.enabled", True), True), + "mode": sparse_mode, + "candidate_k": _clamp_int(pick("retrieval.sparse.candidate_k", 80), 80, 20, 2000), + "relation_candidate_k": _clamp_int(pick("retrieval.sparse.relation_candidate_k", 60), 60, 20, 2000), + }, + "fusion": { + "method": fusion_method, + "rrf_k": _clamp_int(pick("retrieval.fusion.rrf_k", 60), 60, 1, 500), + "vector_weight": float(vec_w), + "bm25_weight": float(bm_w), + }, + }, + "threshold": { + "percentile": _clamp_float(pick("threshold.percentile", 75.0), 75.0, 1.0, 99.0), + "min_results": _clamp_int(pick("threshold.min_results", 3), 3, 1, 100), + }, + } + + def _apply_profile_to_runtime(self, normalized: Dict[str, Any]) -> None: + if not isinstance(getattr(self.plugin, "config", None), dict): + raise RuntimeError("插件 config 不可写") + for key, value in normalized.items(): + _nested_set(self.plugin.config, key, value) + plugin_cfg = getattr(self.plugin, "_plugin_config", None) + if isinstance(plugin_cfg, dict): + for key, value in normalized.items(): + _nested_set(plugin_cfg, key, value) + + async def apply_profile(self, profile: Dict[str, Any], *, reason: str = "manual") -> Dict[str, Any]: + normalized = self._normalize_profile(profile) + current = self.get_profile_snapshot() + self._rollback_snapshot = current + self._apply_profile_to_runtime(normalized) + return { + "applied": normalized, + "rollback_snapshot": current, + "reason": reason, + "applied_at": _now(), + } + + async def rollback_profile(self) -> Dict[str, Any]: + if not self._rollback_snapshot: + raise ValueError("暂无可回滚的参数快照") + target = self._normalize_profile(self._rollback_snapshot, fallback=self._rollback_snapshot) + self._apply_profile_to_runtime(target) + return {"rolled_back_to": target, "rolled_back_at": _now()} + + def export_toml_snippet(self, profile: Optional[Dict[str, Any]] = None) -> str: + p = self._normalize_profile(profile or self.get_profile_snapshot()) + r = p["retrieval"] + t = p["threshold"] + lines = [ + "[retrieval]", + f"top_k_paragraphs = {int(r['top_k_paragraphs'])}", + f"top_k_relations = {int(r['top_k_relations'])}", + f"top_k_final = {int(r['top_k_final'])}", + f"alpha = {float(r['alpha']):.4f}", + f"enable_ppr = {str(bool(r['enable_ppr'])).lower()}", + "", + "[retrieval.search.smart_fallback]", + f"enabled = {str(bool(r['search']['smart_fallback']['enabled'])).lower()}", + "", + "[retrieval.sparse]", + f"enabled = {str(bool(r['sparse']['enabled'])).lower()}", + f"mode = \"{r['sparse']['mode']}\"", + f"candidate_k = {int(r['sparse']['candidate_k'])}", + f"relation_candidate_k = {int(r['sparse']['relation_candidate_k'])}", + "", + "[retrieval.fusion]", + f"method = \"{r['fusion']['method']}\"", + f"rrf_k = {int(r['fusion']['rrf_k'])}", + f"vector_weight = {float(r['fusion']['vector_weight']):.4f}", + f"bm25_weight = {float(r['fusion']['bm25_weight']):.4f}", + "", + "[threshold]", + f"percentile = {float(t['percentile']):.4f}", + f"min_results = {int(t['min_results'])}", + ] + return "\n".join(lines).strip() + "\n" + + def _pending_task_count(self) -> int: + return sum(1 for t in self._tasks.values() if t.status in {"queued", "running", "cancel_requested"}) + + def _sample_triples_for_query_set( + self, + *, + triples: List[Tuple[Any, Any, Any, Any]], + sample_size: int, + seed: int, + ) -> Tuple[List[Tuple[str, str, str, str]], Dict[str, Any]]: + normalized: List[Tuple[str, str, str, str]] = [] + for row in triples: + try: + subject, predicate, obj, rel_hash = row + except Exception: + continue + relation_hash = str(rel_hash or "").strip() + if not relation_hash: + continue + normalized.append((str(subject or ""), str(predicate or ""), str(obj or ""), relation_hash)) + + if not normalized: + return [], {"error": "no_relations"} + + target = min(max(4, int(sample_size)), len(normalized)) + predicate_counter = Counter([str(x[1] or "").strip() or "__empty__" for x in normalized]) + entity_counter = Counter() + for subj, _, obj, _ in normalized: + entity_counter.update([str(subj or "").strip().lower() or "__empty__"]) + entity_counter.update([str(obj or "").strip().lower() or "__empty__"]) + + if target >= len(normalized): + return list(normalized), { + "strategy": "all", + "sample_size": int(target), + "total_triples": int(len(normalized)), + "predicate_total": int(len(predicate_counter)), + "predicate_sampled": int(len(predicate_counter)), + } + + rng = random.Random(f"{seed}:triple_sample") + by_predicate: Dict[str, List[int]] = {} + for idx, (_, predicate, _, _) in enumerate(normalized): + key = str(predicate or "").strip() or "__empty__" + by_predicate.setdefault(key, []).append(idx) + for pool in by_predicate.values(): + rng.shuffle(pool) + + predicate_order = sorted(by_predicate.keys()) + rng.shuffle(predicate_order) + + selected: List[int] = [] + selected_set = set() + + # First pass: predicate round-robin to avoid head predicate dominating query set. + while len(selected) < target: + progressed = False + for key in predicate_order: + pool = by_predicate.get(key, []) + if not pool: + continue + idx = int(pool.pop()) + if idx in selected_set: + continue + selected.append(idx) + selected_set.add(idx) + progressed = True + if len(selected) >= target: + break + if not progressed: + break + + if len(selected) < target: + remain = [idx for idx in range(len(normalized)) if idx not in selected_set] + rng.shuffle(remain) + + # Second pass: prefer lower-frequency entities and predicates for better diversity. + def _remain_score(idx: int) -> Tuple[int, int]: + subj, predicate, obj, _ = normalized[idx] + subject_freq = int(entity_counter.get(str(subj or "").strip().lower() or "__empty__", 0)) + object_freq = int(entity_counter.get(str(obj or "").strip().lower() or "__empty__", 0)) + pred_freq = int(predicate_counter.get(str(predicate or "").strip() or "__empty__", 0)) + return (subject_freq + object_freq, pred_freq) + + remain = sorted(remain, key=_remain_score) + need = target - len(selected) + for idx in remain[:need]: + selected.append(idx) + selected_set.add(idx) + + selected = selected[:target] + sampled = [normalized[idx] for idx in selected] + sampled_predicates = {str(x[1] or "").strip() or "__empty__" for x in sampled} + + return sampled, { + "strategy": "predicate_round_robin_entity_diversity", + "sample_size": int(target), + "total_triples": int(len(normalized)), + "predicate_total": int(len(predicate_counter)), + "predicate_sampled": int(len(sampled_predicates)), + } + + def _select_round_eval_cases( + self, + *, + cases: List[RetrievalQueryCase], + intensity: str, + round_index: int, + seed: int, + ) -> List[RetrievalQueryCase]: + if not cases: + return [] + mode = str(intensity or "standard").strip().lower() + if mode not in INTENSITIES: + mode = "standard" + if mode == "deep": + return list(cases) + + if mode == "quick": + ratio = 0.45 + min_total = 16 + else: + ratio = 0.70 + min_total = 24 + + total = len(cases) + target = max(min_total, int(total * ratio)) + if target >= total: + return list(cases) + + rng = random.Random(f"{seed}:{round_index}:subset") + by_cat: Dict[str, List[RetrievalQueryCase]] = {} + for item in cases: + by_cat.setdefault(str(item.category), []).append(item) + + selected: List[RetrievalQueryCase] = [] + selected_ids = set() + cat_names = sorted([x for x in by_cat.keys() if x in CATEGORIES]) + if not cat_names: + cat_names = sorted(by_cat.keys()) + per_cat = max(1, target // max(1, len(cat_names))) + + for cat in cat_names: + pool = by_cat.get(cat, []) + if not pool: + continue + picked = list(pool) if len(pool) <= per_cat else rng.sample(pool, per_cat) + for item in picked: + if item.case_id in selected_ids: + continue + selected.append(item) + selected_ids.add(item.case_id) + + if len(selected) < target: + remain = [x for x in cases if x.case_id not in selected_ids] + if len(remain) > (target - len(selected)): + remain = rng.sample(remain, target - len(selected)) + for item in remain: + selected.append(item) + selected_ids.add(item.case_id) + + return selected[:target] + + async def _ensure_worker(self) -> None: + async with self._lock: + if self._worker_task and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def shutdown(self) -> None: + self._stopping = True + worker = self._worker_task + if worker is None or worker.done(): + return + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning(f"Retrieval tuning worker shutdown failed: {e}") + + async def create_task(self, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("检索调优中心已禁用") + self._ensure_ready() + + data = payload or {} + objective = str(data.get("objective") or self._cfg("web.tuning.default_objective", "precision_priority")) + if objective not in OBJECTIVES: + raise ValueError(f"objective 非法: {objective}") + + intensity = str(data.get("intensity") or self._cfg("web.tuning.default_intensity", "standard")) + if intensity not in INTENSITIES: + raise ValueError(f"intensity 非法: {intensity}") + + rounds_total = _clamp_int(data.get("rounds", INTENSITIES[intensity]), INTENSITIES[intensity], 1, 200) + sample_size = _clamp_int(data.get("sample_size", self._cfg("web.tuning.default_sample_size", 24)), 24, 4, 500) + top_k_eval = _clamp_int(data.get("top_k_eval", self._cfg("web.tuning.default_top_k_eval", 20)), 20, 5, 100) + eval_query_timeout_seconds = _clamp_float( + data.get("eval_query_timeout_seconds", self._eval_query_timeout_s()), + self._eval_query_timeout_s(), + 0.01, + 120.0, + ) + llm_enabled = _coerce_bool(data.get("llm_enabled", True), True) + seed = data.get("seed") + try: + seed = int(seed) + except Exception: + seed = int(time.time()) % 1000003 + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("调优任务队列已满,请稍后重试") + task = RetrievalTuningTaskRecord( + task_id=uuid.uuid4().hex, + status="queued", + progress=0.0, + objective=objective, + intensity=intensity, + rounds_total=rounds_total, + params={ + "sample_size": sample_size, + "top_k_eval": top_k_eval, + "eval_query_timeout_seconds": float(eval_query_timeout_seconds), + "llm_enabled": llm_enabled, + "seed": seed, + }, + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + task.updated_at = _now() + + await self._ensure_worker() + return task.to_summary() + + async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: + limit = _clamp_int(limit, 50, 1, 500) + async with self._lock: + items: List[Dict[str, Any]] = [] + for task_id in list(self._task_order)[:limit]: + task = self._tasks.get(task_id) + if task: + items.append(task.to_summary()) + return items + + async def get_task(self, task_id: str, include_rounds: bool = False) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + return task.to_detail(include_rounds=include_rounds) + + async def get_rounds(self, task_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: + offset = max(0, int(offset)) + limit = _clamp_int(limit, 50, 1, 500) + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + total = len(task.rounds) + sliced = task.rounds[offset : offset + limit] + return { + "total": total, + "offset": offset, + "limit": limit, + "items": [item.to_dict() for item in sliced], + } + + async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + if task.status in {"completed", "failed", "cancelled"}: + return task.to_summary() + if task.status == "queued": + task.status = "cancelled" + task.cancel_requested = True + task.finished_at = _now() + task.updated_at = task.finished_at + self._queue = deque([x for x in self._queue if x != task_id]) + return task.to_summary() + task.status = "cancel_requested" + task.cancel_requested = True + task.updated_at = _now() + return task.to_summary() + + async def apply_best(self, task_id: str) -> Dict[str, Any]: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + raise ValueError("任务不存在") + if task.status != "completed": + raise ValueError("任务未完成,无法应用最优参数") + if not task.best_profile: + raise ValueError("任务没有可应用的最优参数") + best = copy.deepcopy(task.best_profile) + applied = await self.apply_profile(best, reason=f"task:{task_id}:apply_best") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.apply_log.append({"applied_at": _now(), "reason": "apply_best", "profile": best}) + task.updated_at = _now() + return applied + + async def get_report(self, task_id: str, fmt: str = "md") -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return None + artifacts = dict(task.artifact_paths) + fmt = str(fmt or "md").strip().lower() + if fmt not in {"md", "json"}: + fmt = "md" + path_key = "report_md" if fmt == "md" else "report_json" + path = artifacts.get(path_key) + if not path: + return {"format": fmt, "content": "", "path": ""} + p = Path(path) + if not p.exists(): + return {"format": fmt, "content": "", "path": str(p)} + try: + content = p.read_text(encoding="utf-8") + except Exception: + content = "" + return {"format": fmt, "content": content, "path": str(p)} + + async def _worker_loop(self) -> None: + while not self._stopping: + task_id: Optional[str] = None + async with self._lock: + while self._queue: + candidate = self._queue.popleft() + task = self._tasks.get(candidate) + if task is None: + continue + if task.status != "queued": + continue + task_id = candidate + self._active_task_id = candidate + break + + if not task_id: + await asyncio.sleep(self._poll_interval_s()) + continue + + try: + await self._run_task(task_id) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Retrieval tuning task crashed: task_id={task_id}, err={e}") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.status = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = task.finished_at + finally: + async with self._lock: + if self._active_task_id == task_id: + self._active_task_id = None + + async def _run_task(self, task_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.status = "running" + task.started_at = _now() + task.updated_at = task.started_at + + artifacts_dir = self._artifacts_root / task_id + artifacts_dir.mkdir(parents=True, exist_ok=True) + query_set_path = artifacts_dir / "query_set.json" + rounds_path = artifacts_dir / "round_metrics.jsonl" + best_profile_path = artifacts_dir / "best_profile.json" + report_json_path = artifacts_dir / "report.json" + report_md_path = artifacts_dir / "report.md" + + try: + params = dict(task.params) + cases, stats = await self._build_query_set( + sample_size=int(params["sample_size"]), + seed=int(params["seed"]), + llm_enabled=bool(params.get("llm_enabled", True)), + ) + if not cases: + raise ValueError("当前知识库样本不足,无法构建调优测试集") + + query_set_path.write_text( + json.dumps( + { + "task_id": task_id, + "created_at": _now(), + "stats": stats, + "items": [c.to_dict() for c in cases], + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + + baseline_profile = self.get_profile_snapshot() + top_k_eval = int(params["top_k_eval"]) + baseline_eval = await self._evaluate_profile( + profile=baseline_profile, + cases=cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + baseline_round = RetrievalTuningRoundRecord( + round_index=0, + candidate_profile=baseline_profile, + metrics=baseline_eval["metrics"], + score=float(baseline_eval["score"]), + latency_ms=float(baseline_eval["avg_elapsed_ms"]), + failure_summary=baseline_eval["failure_summary"], + ) + rounds_path.write_text(json.dumps(baseline_round.to_dict(), ensure_ascii=False) + "\n", encoding="utf-8") + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.query_set_stats = stats + task.baseline_profile = copy.deepcopy(baseline_profile) + task.baseline_metrics = copy.deepcopy(baseline_eval["metrics"]) + task.rounds.append(baseline_round) + task.best_profile = copy.deepcopy(baseline_profile) + task.best_metrics = copy.deepcopy(baseline_eval["metrics"]) + task.best_score = float(baseline_eval["score"]) + task.progress = 0.0 + task.updated_at = _now() + + best_profile = copy.deepcopy(baseline_profile) + best_metrics = copy.deepcopy(baseline_eval["metrics"]) + best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) + best_score = float(baseline_eval["score"]) + llm_suggestions: List[Dict[str, Any]] = [] + task_cancelled = False + + for round_idx in range(1, int(task.rounds_total) + 1): + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + if task.cancel_requested or task.status == "cancel_requested": + task.status = "cancelled" + task.finished_at = _now() + task.updated_at = task.finished_at + task_cancelled = True + break + + if round_idx == 1 or (round_idx % 5 == 0 and not llm_suggestions): + llm_suggestions = await self._suggest_profiles_with_llm( + base_profile=best_profile, + failure_summary=best_failure_summary, + objective=task.objective, + max_count=3, + enabled=bool(params.get("llm_enabled", True)), + ) + + candidate_profile = self._generate_candidate_profile( + task_id=task_id, + round_index=round_idx, + objective=task.objective, + baseline_profile=baseline_profile, + best_profile=best_profile, + llm_suggestions=llm_suggestions, + ) + eval_cases = self._select_round_eval_cases( + cases=cases, + intensity=task.intensity, + round_index=round_idx, + seed=int(params.get("seed", 0)), + ) + eval_result = await self._evaluate_profile( + profile=candidate_profile, + cases=eval_cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + round_record = RetrievalTuningRoundRecord( + round_index=round_idx, + candidate_profile=candidate_profile, + metrics=eval_result["metrics"], + score=float(eval_result["score"]), + latency_ms=float(eval_result["avg_elapsed_ms"]), + failure_summary=eval_result["failure_summary"], + ) + with rounds_path.open("a", encoding="utf-8") as fp: + fp.write(json.dumps(round_record.to_dict(), ensure_ascii=False) + "\n") + + if float(eval_result["score"]) > float(best_score): + best_score = float(eval_result["score"]) + best_profile = copy.deepcopy(candidate_profile) + best_metrics = copy.deepcopy(eval_result["metrics"]) + best_failure_summary = copy.deepcopy(eval_result["failure_summary"]) + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + task.rounds_done = round_idx + task.rounds.append(round_record) + task.best_profile = copy.deepcopy(best_profile) + task.best_metrics = copy.deepcopy(best_metrics) + task.best_score = float(best_score) + task.progress = min(1.0, float(round_idx) / float(task.rounds_total)) + task.updated_at = _now() + + if best_profile and (not task_cancelled): + # 候选轮可能基于子样本评估,收官时用全量样本复核,确保最终指标可解释。 + best_full = await self._evaluate_profile( + profile=best_profile, + cases=cases, + objective=task.objective, + top_k_eval=top_k_eval, + query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), + ) + best_profile = copy.deepcopy(best_profile) + best_metrics = copy.deepcopy(best_full["metrics"]) + best_failure_summary = copy.deepcopy(best_full["failure_summary"]) + best_score = float(best_full["score"]) + if best_score < float(baseline_eval["score"]): + best_profile = copy.deepcopy(baseline_profile) + best_metrics = copy.deepcopy(baseline_eval["metrics"]) + best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) + best_score = float(baseline_eval["score"]) + + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.best_profile = copy.deepcopy(best_profile) + task.best_metrics = copy.deepcopy(best_metrics) + task.best_score = float(best_score) + task.updated_at = _now() + + async with self._lock: + task = self._tasks.get(task_id) + if task is None: + return + if task.status not in {"cancelled", "failed"}: + task.status = "completed" + task.progress = 1.0 + task.finished_at = _now() + task.updated_at = task.finished_at + final_task = copy.deepcopy(task) + + if final_task.status == "completed": + best_profile_path.write_text(json.dumps(final_task.best_profile, ensure_ascii=False, indent=2), encoding="utf-8") + report_payload = self._build_report_payload(final_task) + report_json_path.write_text(json.dumps(report_payload, ensure_ascii=False, indent=2), encoding="utf-8") + report_md_path.write_text(self._build_report_markdown(final_task, report_payload), encoding="utf-8") + + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.artifact_paths = { + "query_set": str(query_set_path), + "round_metrics_jsonl": str(rounds_path), + "best_profile": str(best_profile_path), + "report_json": str(report_json_path), + "report_md": str(report_md_path), + } + task.updated_at = _now() + except Exception as e: + logger.error(f"Retrieval tuning task failed: task_id={task_id}, err={e}") + async with self._lock: + task = self._tasks.get(task_id) + if task is not None: + task.status = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = task.finished_at + + async def _build_query_set(self, *, sample_size: int, seed: int, llm_enabled: bool) -> Tuple[List[RetrievalQueryCase], Dict[str, Any]]: + store = getattr(self.plugin, "metadata_store", None) + if store is None: + return [], {"error": "metadata_store_unavailable"} + + triples = list(store.get_all_triples() or []) + if not triples: + return [], {"error": "no_relations"} + + sampled, sample_info = self._sample_triples_for_query_set( + triples=triples, + sample_size=sample_size, + seed=seed, + ) + if not sampled: + return [], {"error": "no_relations"} + + anchors: List[Dict[str, Any]] = [] + for idx, row in enumerate(sampled): + subject, predicate, obj, relation_hash = row + paragraphs = store.get_paragraphs_by_relation(relation_hash) + para_hash = "" + para_content = "" + if paragraphs: + para_hash = str(paragraphs[0].get("hash") or "").strip() + para_content = str(paragraphs[0].get("content") or "") + anchors.append( + { + "anchor_id": f"a{idx+1:04d}", + "subject": str(subject or ""), + "predicate": str(predicate or ""), + "object": str(obj or ""), + "relation_hash": relation_hash, + "paragraph_hash": para_hash, + "paragraph_excerpt": para_content[:300], + } + ) + + if not anchors: + return [], {"error": "no_anchors"} + + predicate_groups: Dict[str, List[Dict[str, Any]]] = {} + for anchor in anchors: + predicate_groups.setdefault(str(anchor.get("predicate") or ""), []).append(anchor) + + nl_queries = await self._generate_nl_queries_with_llm(anchors, enabled=llm_enabled) + cases: List[RetrievalQueryCase] = [] + + seq = 0 + for anchor in anchors: + seq += 1 + subject = anchor["subject"] + predicate = anchor["predicate"] + obj = anchor["object"] + rel_hash = anchor["relation_hash"] + para_hash = anchor["paragraph_hash"] + expected = [rel_hash] + if para_hash: + expected.append(para_hash) + aid = anchor["anchor_id"] + + common_meta = { + "anchor_id": aid, + "relation_hash": rel_hash, + "paragraph_hash": para_hash, + "subject": subject, + "predicate": predicate, + "object": obj, + } + cases.append( + RetrievalQueryCase( + case_id=f"spo_relation_{seq:04d}", + category="spo_relation", + query=f"{subject}|{predicate}|{obj}", + expected_hashes=[rel_hash], + expected_spo={"subject": subject, "predicate": predicate, "object": obj}, + metadata=dict(common_meta), + ) + ) + cases.append( + RetrievalQueryCase( + case_id=f"spo_search_{seq:04d}", + category="spo_search", + query=self._build_spo_search_query( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ), + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + cases.append( + RetrievalQueryCase( + case_id=f"query_kw_{seq:04d}", + category="query_kw", + query=self._build_keyword_query( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ), + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + nl_query = nl_queries.get(aid) or self._build_nl_template( + anchor=anchor, + seq=seq, + predicate_groups=predicate_groups, + ) + cases.append( + RetrievalQueryCase( + case_id=f"query_nl_{seq:04d}", + category="query_nl", + query=nl_query, + expected_hashes=list(expected), + metadata=dict(common_meta), + ) + ) + + counts = Counter([c.category for c in cases]) + stats = { + "anchors": len(anchors), + "case_total": len(cases), + "category_counts": {k: int(v) for k, v in counts.items()}, + "seed": int(seed), + "sample_size": int(sample_info.get("sample_size", len(anchors))), + "sampling": dict(sample_info), + "llm_nl_enabled": bool(llm_enabled), + "llm_nl_generated": int(len(nl_queries)), + } + return cases, stats + + def _pick_contrast_anchor( + self, + *, + anchor: Dict[str, Any], + predicate_groups: Dict[str, List[Dict[str, Any]]], + seq: int, + ) -> Optional[Dict[str, Any]]: + predicate = str(anchor.get("predicate") or "") + pool = predicate_groups.get(predicate, []) + if not pool: + return None + candidates = [x for x in pool if x is not anchor and str(x.get("object") or "") != str(anchor.get("object") or "")] + if not candidates: + return None + return candidates[seq % len(candidates)] + + def _build_spo_search_query( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + + variants = [ + f"{subject} {predicate} {obj}", + f"{subject} {obj} relation {predicate}", + f"{predicate} {subject} {obj} evidence", + f"{subject} {predicate} {obj} not {contrast_obj}".strip(), + ] + return variants[seq % len(variants)].strip() + + def _build_keyword_query( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + excerpt = str(anchor.get("paragraph_excerpt") or "") + tokens = re.findall(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}", excerpt) + extras: List[str] = [] + seen = set() + for token in tokens: + key = token.lower() + if key in seen: + continue + if key in {subject.lower(), predicate.lower(), obj.lower()}: + continue + seen.add(key) + extras.append(token) + if len(extras) >= 2: + break + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + + variants = [ + [subject, obj] + extras[:2], + [predicate, obj] + extras[:2], + [subject, predicate] + extras[:2], + [subject, obj, predicate, contrast_obj] + extras[:1], + ] + parts = variants[seq % len(variants)] + return " ".join([x for x in parts if x]).strip() + + def _build_nl_template( + self, + *, + anchor: Dict[str, Any], + seq: int, + predicate_groups: Dict[str, List[Dict[str, Any]]], + ) -> str: + subject = str(anchor.get("subject") or "") + predicate = str(anchor.get("predicate") or "") + obj = str(anchor.get("object") or "") + contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) + contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" + templates = [ + f"请问 {subject} 与 {obj} 的关系是什么,是否是“{predicate}”?", + f"在当前知识库中,哪条信息说明 {subject} 对应的是 {obj},关系词接近“{predicate}”?", + f"我想确认:{subject} 和 {obj} 之间是不是“{predicate}”这层关系,而不是 {contrast_obj}?", + f"帮我查一下关于 {subject} 与 {obj} 的证据,重点看 {predicate} 相关描述。", + ] + return templates[seq % len(templates)] + + async def _select_llm_model(self) -> Optional[Any]: + if llm_api is None: + return None + try: + models = llm_api.get_available_models() or {} + except Exception: + return None + if not models: + return None + + cfg_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() + if cfg_model.lower() != "auto" and cfg_model in models: + return models[cfg_model] + for task_name in ["utils", "planner", "tool_use", "replyer", "embedding"]: + if task_name in models: + return models[task_name] + return models[next(iter(models))] + + async def _llm_call_text(self, prompt: str, *, request_type: str) -> str: + if llm_api is None: + raise RuntimeError("llm_api unavailable") + model_cfg = await self._select_llm_model() + if model_cfg is None: + raise RuntimeError("no_llm_model") + + retry = self._llm_retry_cfg() + max_attempts = int(retry["max_attempts"]) + min_wait = float(retry["min_wait_seconds"]) + max_wait = float(retry["max_wait_seconds"]) + backoff = float(retry["backoff_multiplier"]) + + last_error: Optional[Exception] = None + for idx in range(max_attempts): + try: + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_cfg, + request_type=request_type, + ) + if not success: + raise RuntimeError("llm_generation_failed") + text = str(response or "").strip() + if text: + return text + raise RuntimeError("empty_llm_response") + except Exception as e: + last_error = e + if idx >= max_attempts - 1: + break + delay = min(max_wait, min_wait * (backoff ** idx)) + await asyncio.sleep(max(0.05, delay)) + raise RuntimeError(f"LLM call failed: {last_error}") + + async def _generate_nl_queries_with_llm(self, anchors: List[Dict[str, Any]], *, enabled: bool) -> Dict[str, str]: + if not enabled or llm_api is None or not anchors: + return {} + payload = [ + { + "anchor_id": x["anchor_id"], + "subject": x["subject"], + "predicate": x["predicate"], + "object": x["object"], + "paragraph_excerpt": x["paragraph_excerpt"][:180], + } + for x in anchors[:60] + ] + prompt = ( + "你是检索评估问题生成器。" + "请基于给定 SPO 与简短上下文,为每条样本生成 1 条自然语言检索问题,返回 JSON:" + "{\"items\":[{\"anchor_id\":\"...\",\"query\":\"...\"}]}。\n" + f"样本:\n{json.dumps(payload, ensure_ascii=False)}" + ) + try: + raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.NLCaseGen") + obj = _safe_json_loads(raw) + if not isinstance(obj, dict): + return {} + items = obj.get("items") + if not isinstance(items, list): + return {} + out: Dict[str, str] = {} + for row in items: + if not isinstance(row, dict): + continue + anchor_id = str(row.get("anchor_id") or "").strip() + query = str(row.get("query") or "").strip() + if anchor_id and query: + out[anchor_id] = query + return out + except Exception: + return {} + + async def _suggest_profiles_with_llm( + self, + *, + base_profile: Dict[str, Any], + failure_summary: Dict[str, Any], + objective: str, + max_count: int, + enabled: bool, + ) -> List[Dict[str, Any]]: + if not enabled or llm_api is None or max_count <= 0: + return [] + prompt = ( + "你是检索调参专家。" + "请基于基础参数与失败摘要,给出最多 " + f"{int(max_count)} 组候选参数,返回 JSON: {{\"profiles\": [ ... ]}}。\n" + "字段仅可包含:retrieval.top_k_paragraphs, retrieval.top_k_relations, retrieval.top_k_final, " + "retrieval.alpha, retrieval.enable_ppr, retrieval.search.smart_fallback.enabled, " + "retrieval.sparse.enabled, retrieval.sparse.mode, retrieval.sparse.candidate_k, retrieval.sparse.relation_candidate_k, " + "retrieval.fusion.method, retrieval.fusion.rrf_k, retrieval.fusion.vector_weight, retrieval.fusion.bm25_weight, " + "threshold.percentile, threshold.min_results。\n" + f"objective={objective}\n" + f"base={json.dumps(base_profile, ensure_ascii=False)}\n" + f"failure_summary={json.dumps(failure_summary, ensure_ascii=False)}" + ) + try: + raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.ProfileSuggest") + obj = _safe_json_loads(raw) + if not isinstance(obj, dict): + return [] + profiles = obj.get("profiles") + if not isinstance(profiles, list): + return [] + out = [] + for item in profiles[:max_count]: + if isinstance(item, dict): + out.append(self._normalize_profile(item, fallback=base_profile)) + return out + except Exception: + return [] + + def _generate_candidate_profile( + self, + *, + task_id: str, + round_index: int, + objective: str, + baseline_profile: Dict[str, Any], + best_profile: Dict[str, Any], + llm_suggestions: List[Dict[str, Any]], + ) -> Dict[str, Any]: + if llm_suggestions: + return self._normalize_profile(llm_suggestions.pop(0), fallback=best_profile) + + rng = random.Random(f"{task_id}:{round_index}") + base = baseline_profile if round_index % 4 == 1 else best_profile + candidate = copy.deepcopy(base) + + if objective == "precision_priority": + para_choices = [40, 80, 120, 180, 240, 320] + rel_choices = [4, 8, 12, 16, 24] + final_choices = [4, 8, 12, 16, 20, 32, 48, 64] + alpha_choices = [0.0, 0.35, 0.50, 0.62, 0.72, 0.82, 0.90] + pct_choices = [55, 60, 65, 72, 80] + min_results_choices = [1, 2] + elif objective == "recall_priority": + para_choices = [120, 220, 300, 420, 560, 720] + rel_choices = [8, 12, 16, 24, 32] + final_choices = [8, 16, 32, 48, 64, 96, 128] + alpha_choices = [0.20, 0.35, 0.45, 0.55, 0.65, 0.75] + pct_choices = [40, 48, 55, 62] + min_results_choices = [1, 2, 3] + else: + para_choices = [80, 160, 240, 320, 420, 520] + rel_choices = [6, 10, 14, 18, 24, 30] + final_choices = [6, 12, 20, 32, 48, 64, 80] + alpha_choices = [0.25, 0.45, 0.55, 0.65, 0.75, 0.85] + pct_choices = [48, 55, 62, 70] + min_results_choices = [1, 2, 3] + + _nested_set(candidate, "retrieval.top_k_paragraphs", rng.choice(para_choices)) + _nested_set(candidate, "retrieval.top_k_relations", rng.choice(rel_choices)) + _nested_set(candidate, "retrieval.top_k_final", rng.choice(final_choices)) + _nested_set(candidate, "retrieval.alpha", rng.choice(alpha_choices)) + # PPR 在 TestClient/异步评估场景下存在偶发长时阻塞风险,调优评估链路固定关闭。 + _nested_set(candidate, "retrieval.enable_ppr", False) + _nested_set(candidate, "retrieval.search.smart_fallback.enabled", bool(rng.choice([True, True, False]))) + _nested_set(candidate, "retrieval.sparse.enabled", bool(rng.choice([True, True, False]))) + _nested_set(candidate, "retrieval.sparse.mode", rng.choice(["auto", "hybrid", "fallback_only"])) + _nested_set(candidate, "retrieval.sparse.candidate_k", rng.choice([60, 80, 120, 160, 220, 320])) + _nested_set(candidate, "retrieval.sparse.relation_candidate_k", rng.choice([40, 60, 90, 120, 180, 260])) + _nested_set(candidate, "retrieval.fusion.method", rng.choice(["weighted_rrf", "weighted_rrf", "alpha_legacy"])) + _nested_set(candidate, "retrieval.fusion.rrf_k", rng.choice([30, 45, 60, 75, 90])) + vec_w = float(rng.choice([0.55, 0.65, 0.72, 0.80, 0.88])) + _nested_set(candidate, "retrieval.fusion.vector_weight", vec_w) + _nested_set(candidate, "retrieval.fusion.bm25_weight", 1.0 - vec_w) + _nested_set(candidate, "threshold.percentile", rng.choice(pct_choices)) + _nested_set(candidate, "threshold.min_results", rng.choice(min_results_choices)) + + return self._normalize_profile(candidate, fallback=base) + + def _build_runtime_config(self, normalized_profile: Dict[str, Any]) -> Dict[str, Any]: + raw_base = getattr(self.plugin, "config", {}) or {} + if isinstance(raw_base, dict): + base = { + key: value + for key, value in raw_base.items() + if key not in _RUNTIME_CONFIG_INSTANCE_KEYS + } + else: + base = {} + merged = _deep_merge(base, normalized_profile) + # 调优评估场景优先稳定性,避免并发访问共享 SQLite/Faiss 导致长时阻塞。 + _nested_set(merged, "retrieval.enable_parallel", False) + # 调优评估阶段关闭 PPR,规避 PageRank 线程计算偶发阻塞导致整轮卡死。 + _nested_set(merged, "retrieval.enable_ppr", False) + merged["vector_store"] = getattr(self.plugin, "vector_store", None) + merged["graph_store"] = getattr(self.plugin, "graph_store", None) + merged["metadata_store"] = getattr(self.plugin, "metadata_store", None) + merged["embedding_manager"] = getattr(self.plugin, "embedding_manager", None) + merged["sparse_index"] = getattr(self.plugin, "sparse_index", None) + merged["plugin_instance"] = self.plugin + return merged + + async def _evaluate_profile( + self, + *, + profile: Dict[str, Any], + cases: List[RetrievalQueryCase], + objective: str, + top_k_eval: int, + query_timeout_s: float, + ) -> Dict[str, Any]: + normalized = self._normalize_profile(profile) + eval_top_k = _clamp_int(top_k_eval, 20, 1, 1000) + # 评估时让 top_k_final 参与有效召回深度,避免该参数对评分无影响。 + request_top_k = min( + int(eval_top_k), + _clamp_int(_nested_get(normalized, "retrieval.top_k_final", eval_top_k), eval_top_k, 1, 512), + ) + eval_timeout_s = _clamp_float( + query_timeout_s, + self._eval_query_timeout_s(), + 0.01, + 120.0, + ) + runtime_cfg = self._build_runtime_config(normalized) + runtime = build_search_runtime( + plugin_config=runtime_cfg, + logger_obj=logger, + owner_tag="retrieval_tuning", + log_prefix="[RetrievalTuning]", + ) + if not runtime.ready: + metrics = { + "total_text_cases": 0, + "precision_at_1": 0.0, + "precision_at_3": 0.0, + "mrr": 0.0, + "recall_at_k": 0.0, + "spo_relation_hit_rate": 0.0, + "empty_rate": 1.0, + "avg_elapsed_ms": 0.0, + "category": {}, + "error": runtime.error or "runtime_not_ready", + } + return {"metrics": metrics, "score": -1.0, "avg_elapsed_ms": 0.0, "failure_summary": {"reason": metrics["error"]}} + + text_total = 0 + hit1 = 0 + hit3 = 0 + hitk = 0 + mrr_sum = 0.0 + empty_count = 0 + timeout_count = 0 + elapsed_total = 0.0 + text_failed: List[str] = [] + + spo_total = 0 + spo_hit = 0 + spo_failed: List[str] = [] + + category_stats: Dict[str, Dict[str, Any]] = {} + failed_predicates = Counter() + + for case in cases: + cat = str(case.category) + if cat not in CATEGORIES: + continue + if cat not in category_stats: + category_stats[cat] = { + "total": 0, + "hit": 0, + "hit_at_1": 0, + "hit_at_3": 0, + "empty": 0, + } + category_stats[cat]["total"] += 1 + + if cat == "spo_relation": + spo_total += 1 + spo = case.expected_spo or {} + rows = runtime.metadata_store.get_relations( + subject=str(spo.get("subject") or ""), + predicate=str(spo.get("predicate") or ""), + object=str(spo.get("object") or ""), + ) + expected_hash = str(case.expected_hashes[0]) if case.expected_hashes else "" + ok = False + for row in rows: + if not isinstance(row, dict): + continue + if expected_hash and str(row.get("hash") or "") == expected_hash: + ok = True + break + if not expected_hash: + ok = True + break + if ok: + spo_hit += 1 + category_stats[cat]["hit"] += 1 + category_stats[cat]["hit_at_1"] += 1 + category_stats[cat]["hit_at_3"] += 1 + else: + spo_failed.append(case.case_id) + failed_predicates.update([str(spo.get("predicate") or "").strip() or "__empty__"]) + continue + + text_total += 1 + req = SearchExecutionRequest( + caller="retrieval_tuning", + query_type="search", + query=str(case.query or "").strip(), + top_k=int(request_top_k), + use_threshold=True, + # 调优评估固定关闭 PPR,避免该链路阻塞拖挂整轮任务。 + enable_ppr=False, + ) + try: + execution = await asyncio.wait_for( + SearchExecutionService.execute( + retriever=runtime.retriever, + threshold_filter=runtime.threshold_filter, + plugin_config=runtime_cfg, + request=req, + enforce_chat_filter=False, + reinforce_access=False, + ), + timeout=float(eval_timeout_s), + ) + except asyncio.TimeoutError: + timeout_count += 1 + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + if execution is None: + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + elapsed_total += float(getattr(execution, "elapsed_ms", 0.0) or 0.0) + + if not bool(getattr(execution, "success", False)): + empty_count += 1 + category_stats[cat]["empty"] += 1 + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + continue + + hashes = [str(getattr(x, "hash_value", "") or "") for x in (getattr(execution, "results", None) or [])] + if not hashes: + empty_count += 1 + category_stats[cat]["empty"] += 1 + + expected_set = set(case.expected_hashes or []) + rank = 0 + for idx, hv in enumerate(hashes, start=1): + if hv and hv in expected_set: + rank = idx + break + + if rank > 0: + category_stats[cat]["hit"] += 1 + hitk += 1 + if rank <= 1: + hit1 += 1 + category_stats[cat]["hit_at_1"] += 1 + if rank <= 3: + hit3 += 1 + category_stats[cat]["hit_at_3"] += 1 + mrr_sum += 1.0 / float(rank) + else: + text_failed.append(case.case_id) + failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) + + p1 = (hit1 / text_total) if text_total else 0.0 + p3 = (hit3 / text_total) if text_total else 0.0 + recall = (hitk / text_total) if text_total else 0.0 + mrr = (mrr_sum / text_total) if text_total else 0.0 + spo_rate = (spo_hit / spo_total) if spo_total else 0.0 + empty_rate = (empty_count / text_total) if text_total else 1.0 + avg_elapsed = (elapsed_total / text_total) if text_total else 0.0 + + metrics = { + "total_text_cases": int(text_total), + "precision_at_1": float(round(p1, 6)), + "precision_at_3": float(round(p3, 6)), + "mrr": float(round(mrr, 6)), + "recall_at_k": float(round(recall, 6)), + "spo_relation_hit_rate": float(round(spo_rate, 6)), + "empty_rate": float(round(empty_rate, 6)), + "timeout_count": int(timeout_count), + "avg_elapsed_ms": float(round(avg_elapsed, 3)), + "category": category_stats, + } + metrics["category_floor_penalty"] = float(round(self._category_floor_penalty(metrics, objective=objective), 6)) + + score = self._score_metrics(metrics, objective=objective) + failure_summary = { + "text_failed_count": len(text_failed), + "spo_failed_count": len(spo_failed), + "failed_case_ids": text_failed[:50] + spo_failed[:50], + "failed_by_category": {k: int(v["total"] - v["hit"]) for k, v in category_stats.items()}, + "top_failed_predicates": [ + {"predicate": key, "count": int(cnt)} + for key, cnt in failed_predicates.most_common(5) + if key + ], + "query_timeout_seconds": float(eval_timeout_s), + "timeout_count": int(timeout_count), + "effective_top_k": int(request_top_k), + "ppr_forced_disabled": True, + } + return { + "metrics": metrics, + "score": float(round(score, 6)), + "avg_elapsed_ms": float(avg_elapsed), + "failure_summary": failure_summary, + } + + def _score_metrics(self, metrics: Dict[str, Any], *, objective: str) -> float: + p1 = float(metrics.get("precision_at_1", 0.0) or 0.0) + p3 = float(metrics.get("precision_at_3", 0.0) or 0.0) + mrr = float(metrics.get("mrr", 0.0) or 0.0) + recall = float(metrics.get("recall_at_k", 0.0) or 0.0) + spo = float(metrics.get("spo_relation_hit_rate", 0.0) or 0.0) + empty_rate = float(metrics.get("empty_rate", 1.0) or 1.0) + category_penalty = metrics.get("category_floor_penalty", None) + if category_penalty is None: + category_penalty = self._category_floor_penalty(metrics, objective=objective) + category_penalty = float(max(0.0, category_penalty)) + + if objective == "recall_priority": + raw = 0.15 * p1 + 0.15 * p3 + 0.15 * mrr + 0.40 * recall + 0.15 * spo + penalty = 0.05 * empty_rate + elif objective == "balanced": + raw = 0.25 * p1 + 0.20 * p3 + 0.15 * mrr + 0.25 * recall + 0.15 * spo + penalty = 0.10 * empty_rate + else: + raw = 0.40 * p1 + 0.20 * p3 + 0.15 * mrr + 0.15 * recall + 0.10 * spo + penalty = 0.15 * empty_rate + return float(raw - penalty - category_penalty) + + def _category_floor_penalty(self, metrics: Dict[str, Any], *, objective: str) -> float: + category = metrics.get("category") + if not isinstance(category, dict) or not category: + return 0.0 + + if objective == "recall_priority": + floors = {"query_nl": 0.60, "query_kw": 0.48, "spo_search": 0.52, "spo_relation": 0.88} + scale = 0.12 + elif objective == "balanced": + floors = {"query_nl": 0.65, "query_kw": 0.52, "spo_search": 0.55, "spo_relation": 0.90} + scale = 0.18 + else: + floors = {"query_nl": 0.70, "query_kw": 0.55, "spo_search": 0.58, "spo_relation": 0.92} + scale = 0.25 + + weights = {"query_nl": 1.0, "query_kw": 1.1, "spo_search": 1.0, "spo_relation": 1.2} + weighted_shortfall = 0.0 + weight_total = 0.0 + + for cat, floor in floors.items(): + row = category.get(cat) + if not isinstance(row, dict): + continue + total = int(row.get("total", 0) or 0) + if total <= 0: + continue + hit = float(row.get("hit", 0.0) or 0.0) + hit_rate = max(0.0, min(1.0, hit / float(max(1, total)))) + shortfall = max(0.0, float(floor) - hit_rate) + w = float(weights.get(cat, 1.0)) + weighted_shortfall += w * shortfall + weight_total += w + + if weight_total <= 1e-9: + return 0.0 + return float(scale * (weighted_shortfall / weight_total)) + + def _build_report_payload(self, task: RetrievalTuningTaskRecord) -> Dict[str, Any]: + baseline = task.baseline_metrics or {} + best = task.best_metrics or {} + + def delta(name: str) -> float: + return float(best.get(name, 0.0) or 0.0) - float(baseline.get(name, 0.0) or 0.0) + + return { + "task_id": task.task_id, + "objective": task.objective, + "intensity": task.intensity, + "status": task.status, + "created_at": task.created_at, + "started_at": task.started_at, + "finished_at": task.finished_at, + "rounds_total": task.rounds_total, + "rounds_done": task.rounds_done, + "best_score": task.best_score, + "baseline_score": self._score_metrics(baseline, objective=task.objective), + "query_set_stats": task.query_set_stats, + "baseline_metrics": baseline, + "best_metrics": best, + "deltas": { + "precision_at_1": delta("precision_at_1"), + "precision_at_3": delta("precision_at_3"), + "mrr": delta("mrr"), + "recall_at_k": delta("recall_at_k"), + "spo_relation_hit_rate": delta("spo_relation_hit_rate"), + "empty_rate": delta("empty_rate"), + "timeout_count": delta("timeout_count"), + "avg_elapsed_ms": delta("avg_elapsed_ms"), + }, + "best_profile": task.best_profile, + "baseline_profile": task.baseline_profile, + "apply_log": task.apply_log, + } + + def _build_report_markdown(self, task: RetrievalTuningTaskRecord, payload: Dict[str, Any]) -> str: + baseline = payload.get("baseline_metrics", {}) or {} + best = payload.get("best_metrics", {}) or {} + d = payload.get("deltas", {}) or {} + lines = [ + f"# 检索调优报告({task.task_id})", + "", + "## 1. 任务信息", + f"- 状态: {task.status}", + f"- 目标函数: {task.objective}", + f"- 强度: {task.intensity}", + f"- 轮次: baseline + {task.rounds_total}", + f"- 创建时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.created_at))}", + f"- 开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.started_at)) if task.started_at else '-'}", + f"- 完成时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.finished_at)) if task.finished_at else '-'}", + "", + "## 2. 基线 vs 最优", + f"- baseline score: {payload.get('baseline_score', 0.0):.6f}", + f"- best score: {task.best_score:.6f}", + f"- P@1: {baseline.get('precision_at_1', 0.0):.4f} -> {best.get('precision_at_1', 0.0):.4f} (Δ {d.get('precision_at_1', 0.0):+.4f})", + f"- P@3: {baseline.get('precision_at_3', 0.0):.4f} -> {best.get('precision_at_3', 0.0):.4f} (Δ {d.get('precision_at_3', 0.0):+.4f})", + f"- MRR: {baseline.get('mrr', 0.0):.4f} -> {best.get('mrr', 0.0):.4f} (Δ {d.get('mrr', 0.0):+.4f})", + f"- Recall@K: {baseline.get('recall_at_k', 0.0):.4f} -> {best.get('recall_at_k', 0.0):.4f} (Δ {d.get('recall_at_k', 0.0):+.4f})", + f"- SPO relation hit: {baseline.get('spo_relation_hit_rate', 0.0):.4f} -> {best.get('spo_relation_hit_rate', 0.0):.4f} (Δ {d.get('spo_relation_hit_rate', 0.0):+.4f})", + f"- 空结果率: {baseline.get('empty_rate', 0.0):.4f} -> {best.get('empty_rate', 0.0):.4f} (Δ {d.get('empty_rate', 0.0):+.4f})", + f"- 超时数: {int(baseline.get('timeout_count', 0) or 0)} -> {int(best.get('timeout_count', 0) or 0)} (Δ {int(d.get('timeout_count', 0) or 0):+d})", + f"- 平均耗时(ms): {baseline.get('avg_elapsed_ms', 0.0):.2f} -> {best.get('avg_elapsed_ms', 0.0):.2f} (Δ {d.get('avg_elapsed_ms', 0.0):+.2f})", + "", + "## 3. 最优参数", + "```json", + json.dumps(task.best_profile, ensure_ascii=False, indent=2), + "```", + "", + "## 4. 测试集规模", + f"- {json.dumps(task.query_set_stats, ensure_ascii=False)}", + "", + "## 5. 说明", + "- 本报告仅对当前已存储图谱与向量状态有效。", + "- 参数应用策略:运行时生效,不自动写入 config.toml。", + ] + return "\n".join(lines).strip() + "\n" diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py index 36a2cf7e..131ab32a 100644 --- a/plugins/A_memorix/core/utils/runtime_self_check.py +++ b/plugins/A_memorix/core/utils/runtime_self_check.py @@ -61,6 +61,29 @@ def _build_report( } +def _normalize_encoded_vector(encoded: Any) -> np.ndarray: + if encoded is None: + raise ValueError("embedding encode returned None") + + if isinstance(encoded, np.ndarray): + array = encoded + else: + array = np.asarray(encoded, dtype=np.float32) + + if array.ndim == 2: + if array.shape[0] != 1: + raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}") + array = array[0] + + if array.ndim != 1: + raise ValueError(f"embedding encode returned invalid ndim={array.ndim}") + if array.size <= 0: + raise ValueError("embedding encode returned empty vector") + if not np.all(np.isfinite(array)): + raise ValueError("embedding encode returned non-finite values") + return array.astype(np.float32, copy=False) + + async def run_embedding_runtime_self_check( *, config: Any, @@ -91,13 +114,11 @@ async def run_embedding_runtime_self_check( try: detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) encoded = await embedding_manager.encode(sample_text) - if isinstance(encoded, np.ndarray): - encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1]) - else: - encoded_dimension = len(encoded) if encoded is not None else 0 + encoded_array = _normalize_encoded_vector(encoded) + encoded_dimension = int(encoded_array.shape[0]) except Exception as exc: elapsed_ms = (time.perf_counter() - start) * 1000.0 - logger.warning("embedding runtime self-check failed: %s", exc) + logger.warning(f"embedding runtime self-check failed: {exc}") return _build_report( ok=False, code="embedding_probe_failed", diff --git a/plugins/A_memorix/core/utils/search_execution_service.py b/plugins/A_memorix/core/utils/search_execution_service.py new file mode 100644 index 00000000..efb2093f --- /dev/null +++ b/plugins/A_memorix/core/utils/search_execution_service.py @@ -0,0 +1,442 @@ +""" +统一检索执行服务。 + +用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。 +""" + +from __future__ import annotations + +import hashlib +import json +import time +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +from ..retrieval import TemporalQueryOptions +from .search_postprocess import ( + apply_safe_content_dedup, + maybe_apply_smart_path_fallback, +) +from .time_parser import parse_query_time_range + +logger = get_logger("A_Memorix.SearchExecutionService") + + +def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: + if not isinstance(config, dict): + return default + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _sanitize_text(value: Any) -> str: + if value is None: + return "" + return str(value).strip() + + +@dataclass +class SearchExecutionRequest: + caller: str + stream_id: Optional[str] = None + group_id: Optional[str] = None + user_id: Optional[str] = None + query_type: str = "search" # search|semantic|time|hybrid + query: str = "" + top_k: Optional[int] = None + time_from: Optional[str] = None + time_to: Optional[str] = None + person: Optional[str] = None + source: Optional[str] = None + use_threshold: bool = True + enable_ppr: bool = True + + +@dataclass +class SearchExecutionResult: + success: bool + error: str = "" + query_type: str = "search" + query: str = "" + top_k: int = 10 + time_from: Optional[str] = None + time_to: Optional[str] = None + person: Optional[str] = None + source: Optional[str] = None + temporal: Optional[TemporalQueryOptions] = None + results: List[Any] = field(default_factory=list) + elapsed_ms: float = 0.0 + chat_filtered: bool = False + dedup_hit: bool = False + + @property + def count(self) -> int: + return len(self.results) + + +class SearchExecutionService: + """统一检索执行服务。""" + + @staticmethod + def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]: + if isinstance(plugin_config, dict): + plugin_instance = plugin_config.get("plugin_instance") + if plugin_instance is not None: + return plugin_instance + + try: + from ...plugin import AMemorixPlugin + + return getattr(AMemorixPlugin, "get_global_instance", lambda: None)() + except Exception: + return None + + @staticmethod + def _normalize_query_type(raw_query_type: str) -> str: + query_type = _sanitize_text(raw_query_type).lower() or "search" + if query_type == "semantic": + return "search" + return query_type + + @staticmethod + def _resolve_runtime_component( + plugin_config: Optional[dict], + plugin_instance: Optional[Any], + key: str, + ) -> Optional[Any]: + if isinstance(plugin_config, dict): + value = plugin_config.get(key) + if value is not None: + return value + if plugin_instance is not None: + value = getattr(plugin_instance, key, None) + if value is not None: + return value + return None + + @staticmethod + def _resolve_top_k( + plugin_config: Optional[dict], + query_type: str, + top_k_raw: Optional[Any], + ) -> Tuple[bool, int, str]: + temporal_default_top_k = int( + _get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10) + ) + default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10 + if top_k_raw is None: + return True, max(1, min(50, default_top_k)), "" + try: + top_k = int(top_k_raw) + except (TypeError, ValueError): + return False, 0, "top_k 参数必须为整数" + return True, max(1, min(50, top_k)), "" + + @staticmethod + def _build_temporal( + plugin_config: Optional[dict], + query_type: str, + time_from_raw: Optional[str], + time_to_raw: Optional[str], + person: Optional[str], + source: Optional[str], + ) -> Tuple[bool, Optional[TemporalQueryOptions], str]: + if query_type not in {"time", "hybrid"}: + return True, None, "" + + temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True)) + if not temporal_enabled: + return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)" + + if not time_from_raw and not time_to_raw: + return False, None, "time/hybrid 模式至少需要 time_from 或 time_to" + + try: + ts_from, ts_to = parse_query_time_range( + str(time_from_raw) if time_from_raw is not None else None, + str(time_to_raw) if time_to_raw is not None else None, + ) + except ValueError as e: + return False, None, f"时间参数错误: {e}" + + temporal = TemporalQueryOptions( + time_from=ts_from, + time_to=ts_to, + person=_sanitize_text(person) or None, + source=_sanitize_text(source) or None, + allow_created_fallback=bool( + _get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True) + ), + candidate_multiplier=int( + _get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8) + ), + max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)), + ) + return True, temporal, "" + + @staticmethod + def _build_request_key( + request: SearchExecutionRequest, + query_type: str, + top_k: int, + temporal: Optional[TemporalQueryOptions], + ) -> str: + payload = { + "stream_id": _sanitize_text(request.stream_id), + "query_type": query_type, + "query": _sanitize_text(request.query), + "time_from": _sanitize_text(request.time_from), + "time_to": _sanitize_text(request.time_to), + "time_from_ts": temporal.time_from if temporal else None, + "time_to_ts": temporal.time_to if temporal else None, + "person": _sanitize_text(request.person), + "source": _sanitize_text(request.source), + "top_k": int(top_k), + "use_threshold": bool(request.use_threshold), + "enable_ppr": bool(request.enable_ppr), + } + payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True) + return hashlib.sha1(payload_json.encode("utf-8")).hexdigest() + + @staticmethod + async def execute( + *, + retriever: Any, + threshold_filter: Optional[Any], + plugin_config: Optional[dict], + request: SearchExecutionRequest, + enforce_chat_filter: bool = True, + reinforce_access: bool = True, + ) -> SearchExecutionResult: + if retriever is None: + return SearchExecutionResult(success=False, error="知识检索器未初始化") + + query_type = SearchExecutionService._normalize_query_type(request.query_type) + query = _sanitize_text(request.query) + if query_type not in {"search", "time", "hybrid"}: + return SearchExecutionResult( + success=False, + error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)", + ) + + if query_type in {"search", "hybrid"} and not query: + return SearchExecutionResult( + success=False, + error="search/hybrid 模式必须提供 query", + ) + + top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k( + plugin_config, query_type, request.top_k + ) + if not top_k_ok: + return SearchExecutionResult(success=False, error=top_k_error) + + temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal( + plugin_config=plugin_config, + query_type=query_type, + time_from_raw=request.time_from, + time_to_raw=request.time_to, + person=request.person, + source=request.source, + ) + if not temporal_ok: + return SearchExecutionResult(success=False, error=temporal_error) + + plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config) + if ( + enforce_chat_filter + and plugin_instance is not None + and hasattr(plugin_instance, "is_chat_enabled") + ): + if not plugin_instance.is_chat_enabled( + stream_id=request.stream_id, + group_id=request.group_id, + user_id=request.user_id, + ): + logger.info( + "检索请求被聊天过滤拦截: " + f"caller={request.caller}, " + f"stream_id={request.stream_id}" + ) + return SearchExecutionResult( + success=True, + query_type=query_type, + query=query, + top_k=top_k, + time_from=request.time_from, + time_to=request.time_to, + person=request.person, + source=request.source, + temporal=temporal, + results=[], + elapsed_ms=0.0, + chat_filtered=True, + dedup_hit=False, + ) + + request_key = SearchExecutionService._build_request_key( + request=request, + query_type=query_type, + top_k=top_k, + temporal=temporal, + ) + + async def _executor() -> Dict[str, Any]: + original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) + setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) + started_at = time.time() + try: + retrieved = await retriever.retrieve( + query=query, + top_k=top_k, + temporal=temporal, + ) + + should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None + if ( + query_type == "time" + and not query + and bool( + _get_config_value( + plugin_config, + "retrieval.time.skip_threshold_when_query_empty", + True, + ) + ) + ): + should_apply_threshold = False + + if should_apply_threshold: + retrieved = threshold_filter.filter(retrieved) + + if ( + reinforce_access + and plugin_instance is not None + and hasattr(plugin_instance, "reinforce_access") + ): + relation_hashes = [ + item.hash_value + for item in retrieved + if getattr(item, "result_type", "") == "relation" + ] + if relation_hashes: + await plugin_instance.reinforce_access(relation_hashes) + + if query_type == "search": + graph_store = SearchExecutionService._resolve_runtime_component( + plugin_config, plugin_instance, "graph_store" + ) + metadata_store = SearchExecutionService._resolve_runtime_component( + plugin_config, plugin_instance, "metadata_store" + ) + fallback_enabled = bool( + _get_config_value( + plugin_config, + "retrieval.search.smart_fallback.enabled", + True, + ) + ) + fallback_threshold = float( + _get_config_value( + plugin_config, + "retrieval.search.smart_fallback.threshold", + 0.6, + ) + ) + retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback( + query=query, + results=list(retrieved), + graph_store=graph_store, + metadata_store=metadata_store, + enabled=fallback_enabled, + threshold=fallback_threshold, + ) + if fallback_triggered: + logger.info( + "metric.smart_fallback_triggered_count=1 " + f"caller={request.caller} " + f"added={fallback_added}" + ) + + dedup_enabled = bool( + _get_config_value( + plugin_config, + "retrieval.search.safe_content_dedup.enabled", + True, + ) + ) + if dedup_enabled: + retrieved, removed_count = apply_safe_content_dedup(list(retrieved)) + if removed_count > 0: + logger.info( + f"metric.safe_dedup_removed_count={removed_count} " + f"caller={request.caller}" + ) + + elapsed_ms = (time.time() - started_at) * 1000.0 + return {"results": retrieved, "elapsed_ms": elapsed_ms} + finally: + setattr(retriever.config, "enable_ppr", original_ppr) + + dedup_hit = False + try: + # 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。 + bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning" + if ( + not bypass_request_dedup + and + plugin_instance is not None + and hasattr(plugin_instance, "execute_request_with_dedup") + ): + dedup_hit, payload = await plugin_instance.execute_request_with_dedup( + request_key, + _executor, + ) + else: + payload = await _executor() + except Exception as e: + return SearchExecutionResult(success=False, error=f"知识检索失败: {e}") + + if dedup_hit: + logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}") + + return SearchExecutionResult( + success=True, + query_type=query_type, + query=query, + top_k=top_k, + time_from=request.time_from, + time_to=request.time_to, + person=request.person, + source=request.source, + temporal=temporal, + results=payload.get("results", []), + elapsed_ms=float(payload.get("elapsed_ms", 0.0)), + chat_filtered=False, + dedup_hit=bool(dedup_hit), + ) + + @staticmethod + def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: + serialized: List[Dict[str, Any]] = [] + for item in results: + metadata = dict(getattr(item, "metadata", {}) or {}) + if "time_meta" not in metadata: + metadata["time_meta"] = {} + serialized.append( + { + "hash": getattr(item, "hash_value", ""), + "type": getattr(item, "result_type", ""), + "score": float(getattr(item, "score", 0.0)), + "content": getattr(item, "content", ""), + "metadata": metadata, + } + ) + return serialized diff --git a/plugins/A_memorix/core/utils/summary_importer.py b/plugins/A_memorix/core/utils/summary_importer.py new file mode 100644 index 00000000..b6271db4 --- /dev/null +++ b/plugins/A_memorix/core/utils/summary_importer.py @@ -0,0 +1,425 @@ +""" +聊天总结与知识导入工具 + +该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系 +导入到 A_memorix 的存储组件中。 +""" + +import time +import json +import re +import traceback +from typing import List, Dict, Any, Tuple, Optional +from pathlib import Path + +from src.common.logger import get_logger +from src.services import llm_service as llm_api +from src.services import message_service as message_api +from src.config.config import global_config, model_config as host_model_config +from src.config.model_configs import TaskConfig + +from ..storage import ( + KnowledgeType, + VectorStore, + GraphStore, + MetadataStore, + resolve_stored_knowledge_type, +) +from ..embedding import EmbeddingAPIAdapter +from .relation_write_service import RelationWriteService +from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check + +logger = get_logger("A_Memorix.SummaryImporter") + +# 默认总结提示词模版 +SUMMARY_PROMPT_TEMPLATE = """ +你是 {bot_name}。{personality_context} +现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。 + +聊天记录内容: +{chat_history} + +请完成以下任务: +1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。 +2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。 + +请严格以 JSON 格式输出,格式如下: +{{ + "summary": "总结文本内容", + "entities": ["张三", "李四"], + "relations": [ + {{"subject": "张三", "predicate": "认识", "object": "李四"}} + ] +}} + +注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。 +""" + +class SummaryImporter: + """总结并导入知识的工具类""" + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + metadata_store: MetadataStore, + embedding_manager: EmbeddingAPIAdapter, + plugin_config: dict + ): + self.vector_store = vector_store + self.graph_store = graph_store + self.metadata_store = metadata_store + self.embedding_manager = embedding_manager + self.plugin_config = plugin_config + self.relation_write_service: Optional[RelationWriteService] = ( + plugin_config.get("relation_write_service") + if isinstance(plugin_config, dict) + else None + ) + + def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]: + """标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。""" + if raw_value is None: + return ["auto"] + if isinstance(raw_value, list): + selectors = [str(x).strip() for x in raw_value if str(x).strip()] + return selectors or ["auto"] + raise ValueError( + "summarization.model_name 在 vNext 必须为 List[str]。" + " 请执行 scripts/release_vnext_migrate.py migrate。" + ) + + def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]: + """ + 选择总结默认任务,避免错误落到 embedding 任务。 + 优先级:replyer > utils > planner > tool_use > 其他非 embedding。 + """ + preferred = ("replyer", "utils", "planner", "tool_use") + for name in preferred: + cfg = available_tasks.get(name) + if cfg and cfg.model_list: + return name, cfg + + for name, cfg in available_tasks.items(): + if name != "embedding" and cfg.model_list: + return name, cfg + + for name, cfg in available_tasks.items(): + if cfg.model_list: + return name, cfg + + return None, None + + def _resolve_summary_model_config(self) -> Optional[TaskConfig]: + """ + 解析 summarization.model_name 为 TaskConfig。 + 支持: + - "auto" + - "replyer"(任务名) + - "some-model-name"(具体模型名) + - ["utils:model1", "utils:model2", "replyer"](数组混合语法) + """ + available_tasks = llm_api.get_available_models() + if not available_tasks: + return None + + raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto") + selectors = self._normalize_summary_model_selectors(raw_cfg) + default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks) + + selected_models: List[str] = [] + base_cfg: Optional[TaskConfig] = None + model_dict = getattr(host_model_config, "models_dict", {}) + + def _append_models(models: List[str]): + for model_name in models: + if model_name and model_name not in selected_models: + selected_models.append(model_name) + + for raw_selector in selectors: + selector = raw_selector.strip() + if not selector: + continue + + if selector.lower() == "auto": + if default_task_cfg: + _append_models(default_task_cfg.model_list) + if base_cfg is None: + base_cfg = default_task_cfg + continue + + if ":" in selector: + task_name, model_name = selector.split(":", 1) + task_name = task_name.strip() + model_name = model_name.strip() + task_cfg = available_tasks.get(task_name) + if not task_cfg: + logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过") + continue + + if base_cfg is None: + base_cfg = task_cfg + + if not model_name or model_name.lower() == "auto": + _append_models(task_cfg.model_list) + continue + + if model_name in model_dict or model_name in task_cfg.model_list: + _append_models([model_name]) + else: + logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过") + continue + + task_cfg = available_tasks.get(selector) + if task_cfg: + _append_models(task_cfg.model_list) + if base_cfg is None: + base_cfg = task_cfg + continue + + if selector in model_dict: + _append_models([selector]) + continue + + logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过") + + if not selected_models: + if default_task_cfg: + _append_models(default_task_cfg.model_list) + if base_cfg is None: + base_cfg = default_task_cfg + else: + first_cfg = next(iter(available_tasks.values())) + _append_models(first_cfg.model_list) + if base_cfg is None: + base_cfg = first_cfg + + if not selected_models: + return None + + template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values())) + return TaskConfig( + model_list=selected_models, + max_tokens=template_cfg.max_tokens, + temperature=template_cfg.temperature, + slow_threshold=template_cfg.slow_threshold, + selection_strategy=template_cfg.selection_strategy, + ) + + async def import_from_stream( + self, + stream_id: str, + context_length: Optional[int] = None, + include_personality: Optional[bool] = None + ) -> Tuple[bool, str]: + """ + 从指定的聊天流中提取记录并执行总结导入 + + Args: + stream_id: 聊天流 ID + context_length: 总结的历史消息条数 + include_personality: 是否包含人设 + + Returns: + Tuple[bool, str]: (是否成功, 结果消息) + """ + try: + self_check_ok, self_check_msg = await self._ensure_runtime_self_check() + if not self_check_ok: + return False, f"导入前自检失败: {self_check_msg}" + + # 1. 获取配置 + if context_length is None: + context_length = self.plugin_config.get("summarization", {}).get("context_length", 50) + + if include_personality is None: + include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) + + # 2. 获取历史消息 + # 获取当前时间之前的消息 + now = time.time() + messages = message_api.get_messages_before_time_in_chat( + chat_id=stream_id, + timestamp=now, + limit=context_length + ) + + if not messages: + return False, "未找到有效的聊天记录进行总结" + + # 转换为可读文本 + chat_history_text = message_api.build_readable_messages(messages) + + # 3. 准备提示词内容 + bot_name = global_config.bot.nickname or "机器人" + personality_context = "" + if include_personality: + personality = getattr(global_config.bot, "personality", "") + if personality: + personality_context = f"你的性格设定是:{personality}" + + # 4. 调用 LLM + prompt = SUMMARY_PROMPT_TEMPLATE.format( + bot_name=bot_name, + personality_context=personality_context, + chat_history=chat_history_text + ) + + model_config_to_use = self._resolve_summary_model_config() + if model_config_to_use is None: + return False, "未找到可用的总结模型配置" + + logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}") + logger.info(f"总结模型候选列表: {model_config_to_use.model_list}") + + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config_to_use, + request_type="A_Memorix.ChatSummarization" + ) + + if not success or not response: + return False, "LLM 生成总结失败" + + # 5. 解析结果 + data = self._parse_llm_response(response) + if not data or "summary" not in data: + return False, "解析 LLM 响应失败或总结为空" + + summary_text = data["summary"] + entities = data.get("entities", []) + relations = data.get("relations", []) + msg_times = [ + float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)()) + for msg in messages + if getattr(msg, "time", None) is not None + ] + time_meta = {} + if msg_times: + time_meta = { + "event_time_start": min(msg_times), + "event_time_end": max(msg_times), + "time_granularity": "minute", + "time_confidence": 0.95, + } + + # 6. 执行导入 + await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) + + # 7. 持久化 + self.vector_store.save() + self.graph_store.save() + + result_msg = ( + f"✅ 总结导入成功\n" + f"📝 总结长度: {len(summary_text)}\n" + f"📌 提取实体: {len(entities)}\n" + f"🔗 提取关系: {len(relations)}" + ) + return True, result_msg + + except Exception as e: + logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}") + return False, f"错误: {str(e)}" + + async def _ensure_runtime_self_check(self) -> Tuple[bool, str]: + plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None + if plugin_instance is not None: + report = await ensure_runtime_self_check(plugin_instance) + else: + report = await run_embedding_runtime_self_check( + config=self.plugin_config, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + if bool(report.get("ok", False)): + return True, "" + return ( + False, + f"{report.get('message', 'unknown')} " + f"(configured={report.get('configured_dimension', 0)}, " + f"store={report.get('vector_store_dimension', 0)}, " + f"encoded={report.get('encoded_dimension', 0)})", + ) + + def _parse_llm_response(self, response: str) -> Dict[str, Any]: + """解析 LLM 返回的 JSON""" + try: + # 尝试查找 JSON + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + return json.loads(json_match.group()) + return {} + except Exception as e: + logger.warning(f"解析总结 JSON 失败: {e}") + return {} + + async def _execute_import( + self, + summary: str, + entities: List[str], + relations: List[Dict[str, str]], + stream_id: str, + time_meta: Optional[Dict[str, Any]] = None, + ): + """将数据写入存储""" + # 获取默认知识类型 + type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative") + try: + knowledge_type = resolve_stored_knowledge_type(type_str, content=summary) + except ValueError: + logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative") + knowledge_type = KnowledgeType.NARRATIVE + + # 导入总结文本 + hash_value = self.metadata_store.add_paragraph( + content=summary, + source=f"chat_summary:{stream_id}", + knowledge_type=knowledge_type.value, + time_meta=time_meta, + ) + + embedding = await self.embedding_manager.encode(summary) + self.vector_store.add( + vectors=embedding.reshape(1, -1), + ids=[hash_value] + ) + + # 导入实体 + if entities: + self.graph_store.add_nodes(entities) + + # 导入关系 + rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + rv_cfg = {} + write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + for rel in relations: + s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") + if all([s, p, o]): + if self.relation_write_service is not None: + await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=1.0, + source_paragraph=summary, + write_vector=write_vector, + ) + else: + # 写入元数据 + rel_hash = self.metadata_store.add_relation( + subject=s, + predicate=p, + obj=o, + confidence=1.0, + source_paragraph=summary + ) + # 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪) + self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + + logger.info(f"总结导入完成: hash={hash_value[:8]}") diff --git a/plugins/A_memorix/core/utils/web_import_manager.py b/plugins/A_memorix/core/utils/web_import_manager.py new file mode 100644 index 00000000..b088be1f --- /dev/null +++ b/plugins/A_memorix/core/utils/web_import_manager.py @@ -0,0 +1,3522 @@ +""" +Web Import Task Manager + +为 A_Memorix WebUI 提供导入任务队列、状态管理、并发调度与取消/重试能力。 +""" + +from __future__ import annotations + +import asyncio +import hashlib +import json +import os +import shutil +import sys +import time +import traceback +import uuid +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.services import llm_service as llm_api + +from ..storage import ( + parse_import_strategy, + resolve_stored_knowledge_type, + select_import_strategy, + KnowledgeType, + MetadataStore, +) +from ..storage.type_detection import looks_like_quote_text +from ..utils.import_payloads import normalize_paragraph_import_item +from ..utils.runtime_self_check import ensure_runtime_self_check +from ..utils.time_parser import normalize_time_meta +from ..storage.knowledge_types import ImportStrategy +from ..strategies.base import ProcessedChunk, KnowledgeType as StrategyKnowledgeType +from ..strategies.narrative import NarrativeStrategy +from ..strategies.factual import FactualStrategy +from ..strategies.quote import QuoteStrategy + +logger = get_logger("A_Memorix.WebImportManager") + + +TASK_STATUS = { + "queued", + "preparing", + "running", + "cancel_requested", + "cancelled", + "completed", + "completed_with_errors", + "failed", +} + +FILE_STATUS = { + "queued", + "preparing", + "splitting", + "extracting", + "writing", + "saving", + "completed", + "failed", + "cancelled", +} + +CHUNK_STATUS = { + "queued", + "extracting", + "writing", + "completed", + "failed", + "cancelled", +} + + +def _now() -> float: + return time.time() + + +def _coerce_int(value: Any, default: int) -> int: + try: + return int(value) + except Exception: + return default + + +def _coerce_bool(value: Any, default: bool) -> bool: + if value is None: + return default + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(value) + text = str(value).strip().lower() + if text in {"1", "true", "yes", "y", "on"}: + return True + if text in {"0", "false", "no", "n", "off", ""}: + return False + return default + + +def _clamp(value: int, min_value: int, max_value: int) -> int: + return max(min_value, min(max_value, value)) + + +def _coerce_list(value: Any) -> List[str]: + if value is None: + return [] + if isinstance(value, list): + raw_items = value + else: + text = str(value or "").replace("\r", "\n") + raw_items = [] + for seg in text.split("\n"): + raw_items.extend(seg.split(",")) + + out: List[str] = [] + seen = set() + for item in raw_items: + v = str(item or "").strip() + if not v: + continue + key = v.lower() + if key in seen: + continue + seen.add(key) + out.append(v) + return out + + +def _parse_optional_positive_int(value: Any, field_name: str) -> Optional[int]: + if value is None: + return None + text = str(value).strip() + if text == "": + return None + try: + parsed = int(text) + except Exception: + raise ValueError(f"{field_name} 必须为整数") + if parsed <= 0: + raise ValueError(f"{field_name} 必须 > 0") + return parsed + + +def _safe_filename(name: str) -> str: + base = os.path.basename(str(name or "").strip()) + if not base: + return f"unnamed_{uuid.uuid4().hex[:8]}.txt" + return base + + +def _storage_type_from_strategy(strategy_type: StrategyKnowledgeType) -> str: + if strategy_type == StrategyKnowledgeType.NARRATIVE: + return KnowledgeType.NARRATIVE.value + if strategy_type == StrategyKnowledgeType.FACTUAL: + return KnowledgeType.FACTUAL.value + if strategy_type == StrategyKnowledgeType.QUOTE: + return KnowledgeType.QUOTE.value + return KnowledgeType.MIXED.value + + +@dataclass +class ImportChunkRecord: + chunk_id: str + index: int + chunk_type: str + status: str = "queued" + step: str = "queued" + failed_at: str = "" + retryable: bool = False + error: str = "" + progress: float = 0.0 + content_preview: str = "" + updated_at: float = field(default_factory=_now) + + def to_dict(self) -> Dict[str, Any]: + return { + "chunk_id": self.chunk_id, + "index": self.index, + "chunk_type": self.chunk_type, + "status": self.status, + "step": self.step, + "failed_at": self.failed_at, + "retryable": self.retryable, + "error": self.error, + "progress": self.progress, + "content_preview": self.content_preview, + "updated_at": self.updated_at, + } + + +@dataclass +class ImportFileRecord: + file_id: str + name: str + source_kind: str + input_mode: str + status: str = "queued" + current_step: str = "queued" + detected_strategy_type: str = "unknown" + total_chunks: int = 0 + done_chunks: int = 0 + failed_chunks: int = 0 + cancelled_chunks: int = 0 + progress: float = 0.0 + error: str = "" + chunks: List[ImportChunkRecord] = field(default_factory=list) + created_at: float = field(default_factory=_now) + updated_at: float = field(default_factory=_now) + temp_path: Optional[str] = None + source_path: Optional[str] = None + inline_content: Optional[str] = None + content_hash: str = "" + retry_chunk_indexes: List[int] = field(default_factory=list) + retry_mode: str = "" + + def to_dict(self, include_chunks: bool = False) -> Dict[str, Any]: + payload = { + "file_id": self.file_id, + "name": self.name, + "source_kind": self.source_kind, + "input_mode": self.input_mode, + "status": self.status, + "current_step": self.current_step, + "detected_strategy_type": self.detected_strategy_type, + "total_chunks": self.total_chunks, + "done_chunks": self.done_chunks, + "failed_chunks": self.failed_chunks, + "cancelled_chunks": self.cancelled_chunks, + "progress": self.progress, + "error": self.error, + "created_at": self.created_at, + "updated_at": self.updated_at, + "source_path": self.source_path or "", + "content_hash": self.content_hash or "", + "retry_chunk_indexes": list(self.retry_chunk_indexes or []), + "retry_mode": self.retry_mode or "", + } + if include_chunks: + payload["chunks"] = [chunk.to_dict() for chunk in self.chunks] + return payload + + +@dataclass +class ImportTaskRecord: + task_id: str + source: str + params: Dict[str, Any] + status: str = "queued" + current_step: str = "queued" + total_chunks: int = 0 + done_chunks: int = 0 + failed_chunks: int = 0 + cancelled_chunks: int = 0 + progress: float = 0.0 + error: str = "" + files: List[ImportFileRecord] = field(default_factory=list) + created_at: float = field(default_factory=_now) + started_at: Optional[float] = None + finished_at: Optional[float] = None + updated_at: float = field(default_factory=_now) + schema_detected: str = "" + artifact_paths: Dict[str, str] = field(default_factory=dict) + rollback_info: Dict[str, Any] = field(default_factory=dict) + retry_parent_task_id: str = "" + retry_summary: Dict[str, Any] = field(default_factory=dict) + + def to_summary(self) -> Dict[str, Any]: + return { + "task_id": self.task_id, + "source": self.source, + "status": self.status, + "current_step": self.current_step, + "total_chunks": self.total_chunks, + "done_chunks": self.done_chunks, + "failed_chunks": self.failed_chunks, + "cancelled_chunks": self.cancelled_chunks, + "progress": self.progress, + "error": self.error, + "file_count": len(self.files), + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + "updated_at": self.updated_at, + "task_kind": str(self.params.get("task_kind") or self.source), + "schema_detected": self.schema_detected, + "artifact_paths": dict(self.artifact_paths), + "rollback_info": dict(self.rollback_info), + "retry_parent_task_id": self.retry_parent_task_id or "", + "retry_summary": dict(self.retry_summary), + } + + def to_detail(self, include_chunks: bool = False) -> Dict[str, Any]: + payload = self.to_summary() + payload["params"] = self.params + payload["files"] = [f.to_dict(include_chunks=include_chunks) for f in self.files] + return payload + + +class ImportTaskManager: + def __init__(self, plugin: Any): + self.plugin = plugin + self._lock = asyncio.Lock() + self._storage_lock = asyncio.Lock() + + self._tasks: Dict[str, ImportTaskRecord] = {} + self._task_order: deque[str] = deque() + self._queue: deque[str] = deque() + self._active_task_id: Optional[str] = None + + self._worker_task: Optional[asyncio.Task] = None + self._stopping = False + + self._temp_root = self._resolve_temp_root() + self._temp_root.mkdir(parents=True, exist_ok=True) + self._reports_root = self._resolve_reports_root() + self._reports_root.mkdir(parents=True, exist_ok=True) + self._manifest_path = self._resolve_manifest_path() + self._manifest_cache: Optional[Dict[str, Any]] = None + self._write_changed_callback: Optional[Callable[[Dict[str, Any]], Any]] = None + + def set_write_changed_callback(self, callback: Optional[Callable[[Dict[str, Any]], Any]]) -> None: + self._write_changed_callback = callback + + async def _notify_write_changed(self, payload: Dict[str, Any]) -> None: + callback = self._write_changed_callback + if callback is None: + return + try: + maybe_awaitable = callback(payload) + if asyncio.iscoroutine(maybe_awaitable): + await maybe_awaitable + except Exception as e: + logger.warning(f"写入变更回调执行失败: {e}") + + def _resolve_temp_root(self) -> Path: + data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) + if str(data_dir).startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir).resolve() + return data_dir / "web_import_tmp" + + def _resolve_reports_root(self) -> Path: + return self._resolve_data_dir() / "web_import_reports" + + def _resolve_manifest_path(self) -> Path: + return self._resolve_data_dir() / "import_manifest.json" + + def _resolve_staging_root(self) -> Path: + return self._resolve_data_dir() / "import_staging" + + def _resolve_backup_root(self) -> Path: + return self._resolve_data_dir() / "import_backup" + + def _resolve_repo_root(self) -> Path: + return Path(__file__).resolve().parents[3] + + def _resolve_data_dir(self) -> Path: + data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) + if str(data_dir).startswith("."): + plugin_dir = Path(__file__).resolve().parents[2] + data_dir = (plugin_dir / data_dir).resolve() + return data_dir.resolve() + + def _resolve_migration_script(self) -> Path: + return Path(__file__).resolve().parents[2] / "scripts" / "migrate_maibot_memory.py" + + def _default_maibot_source_db(self) -> Path: + # A_memorix/core/utils -> workspace root + return self._resolve_repo_root() / "MaiBot" / "data" / "MaiBot.db" + + def _cfg(self, key: str, default: Any) -> Any: + return self.plugin.get_config(key, default) + + def _cfg_int(self, key: str, default: int) -> int: + return _coerce_int(self._cfg(key, default), default) + + def _is_enabled(self) -> bool: + return bool(self._cfg("web.import.enabled", True)) + + def _queue_limit(self) -> int: + return max(1, self._cfg_int("web.import.max_queue_size", 20)) + + def _max_files_per_task(self) -> int: + return max(1, self._cfg_int("web.import.max_files_per_task", 200)) + + def _max_file_size_bytes(self) -> int: + mb = max(1, self._cfg_int("web.import.max_file_size_mb", 20)) + return mb * 1024 * 1024 + + def _max_paste_chars(self) -> int: + return max(1000, self._cfg_int("web.import.max_paste_chars", 200000)) + + def _default_file_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.default_file_concurrency", 2)) + + def _default_chunk_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.default_chunk_concurrency", 4)) + + def _max_file_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.max_file_concurrency", 6)) + + def _max_chunk_concurrency(self) -> int: + return max(1, self._cfg_int("web.import.max_chunk_concurrency", 12)) + + def _llm_retry_config(self) -> Dict[str, float]: + retries = max(0, self._cfg_int("web.import.llm_retry.max_attempts", 4)) + min_wait = max(0.1, float(self._cfg("web.import.llm_retry.min_wait_seconds", 3) or 3)) + max_wait = max(min_wait, float(self._cfg("web.import.llm_retry.max_wait_seconds", 40) or 40)) + mult = max(1.0, float(self._cfg("web.import.llm_retry.backoff_multiplier", 3) or 3)) + return { + "retries": retries, + "min_wait": min_wait, + "max_wait": max_wait, + "multiplier": mult, + } + + def _default_path_aliases(self) -> Dict[str, str]: + plugin_dir = Path(__file__).resolve().parents[2] + repo_root = self._resolve_repo_root() + return { + "raw": str((plugin_dir / "data" / "raw").resolve()), + "lpmm": str((repo_root / "data" / "lpmm_storage").resolve()), + "plugin_data": str((plugin_dir / "data").resolve()), + } + + def get_path_aliases(self) -> Dict[str, str]: + configured = self._cfg("web.import.path_aliases", self._default_path_aliases()) + if not isinstance(configured, dict): + configured = self._default_path_aliases() + + repo_root = self._resolve_repo_root() + result: Dict[str, str] = {} + for alias, raw_path in configured.items(): + key = str(alias or "").strip() + if not key: + continue + text = str(raw_path or "").strip() + if not text: + continue + if text.startswith("\\\\"): + continue + p = Path(text) + if not p.is_absolute(): + p = (repo_root / p).resolve() + else: + p = p.resolve() + result[key] = str(p) + + defaults = self._default_path_aliases() + for key, path in defaults.items(): + result.setdefault(key, path) + return result + + def resolve_path_alias( + self, + alias: str, + relative_path: str = "", + *, + must_exist: bool = False, + ) -> Path: + alias_key = str(alias or "").strip() + aliases = self.get_path_aliases() + if alias_key not in aliases: + raise ValueError(f"未知路径别名: {alias_key}") + + root = Path(aliases[alias_key]).resolve() + rel = str(relative_path or "").strip().replace("\\", "/") + if rel.startswith("/") or rel.startswith("\\") or rel.startswith("//"): + raise ValueError("relative_path 不能为绝对路径") + if ":" in rel: + raise ValueError("relative_path 不允许包含盘符") + + candidate = (root / rel).resolve() if rel else root + try: + candidate.relative_to(root) + except ValueError: + raise ValueError("路径越界:relative_path 超出白名单目录") + if must_exist and not candidate.exists(): + raise ValueError(f"路径不存在: {candidate}") + return candidate + + async def resolve_path_request(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "").strip() + relative_path = str(payload.get("relative_path") or "").strip() + must_exist = _coerce_bool(payload.get("must_exist"), True) + resolved = self.resolve_path_alias(alias, relative_path, must_exist=must_exist) + return { + "alias": alias, + "relative_path": relative_path, + "resolved_path": str(resolved), + "exists": resolved.exists(), + "is_file": resolved.is_file(), + "is_dir": resolved.is_dir(), + } + + def _load_manifest(self) -> Dict[str, Any]: + if self._manifest_cache is not None: + return self._manifest_cache + path = self._manifest_path + if not path.exists(): + self._manifest_cache = {} + return self._manifest_cache + try: + payload = json.loads(path.read_text(encoding="utf-8")) + if isinstance(payload, dict): + self._manifest_cache = payload + else: + self._manifest_cache = {} + except Exception: + self._manifest_cache = {} + return self._manifest_cache + + def _save_manifest(self, payload: Dict[str, Any]) -> None: + path = self._manifest_path + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + self._manifest_cache = payload + + def _clear_manifest(self) -> None: + self._save_manifest({}) + + def _normalize_manifest_path(self, raw_path: str) -> str: + text = str(raw_path or "").strip() + if not text: + return "" + return text.replace("\\", "/").strip().lower() + + def _match_manifest_item_for_source(self, source: str, item: Dict[str, Any]) -> bool: + source_text = str(source or "").strip() + if not source_text or ":" not in source_text: + return False + prefix, tail = source_text.split(":", 1) + source_kind = prefix.strip().lower() + source_value = tail.strip() + if not source_value: + return False + + item_kind = str(item.get("source_kind") or "").strip().lower() + item_name = str(item.get("name") or "").strip() + item_path_norm = self._normalize_manifest_path(item.get("source_path") or "") + + if source_kind in {"raw_scan", "lpmm_openie"}: + source_path_norm = self._normalize_manifest_path(source_value) + if source_path_norm and item_path_norm and source_path_norm == item_path_norm and item_kind == source_kind: + return True + + if source_kind == "web_import": + return item_kind in {"upload", "paste"} and item_name == source_value + + if source_kind == "lpmm_openie": + source_name = Path(source_value).name + return item_kind == "lpmm_openie" and item_name == source_name + + return False + + async def invalidate_manifest_for_sources(self, sources: List[str]) -> Dict[str, Any]: + requested_sources: List[str] = [] + seen_sources = set() + for raw in sources or []: + source = str(raw or "").strip() + if not source: + continue + key = source.lower() + if key in seen_sources: + continue + seen_sources.add(key) + requested_sources.append(source) + + result: Dict[str, Any] = { + "requested_sources": requested_sources, + "removed_count": 0, + "removed_keys": [], + "remaining_count": 0, + "unmatched_sources": [], + "warnings": [], + } + + async with self._lock: + manifest = self._load_manifest() + if not isinstance(manifest, dict): + manifest = {} + + valid_items: List[Tuple[str, Dict[str, Any]]] = [] + malformed_keys: List[str] = [] + for key, item in manifest.items(): + if isinstance(item, dict): + valid_items.append((str(key), item)) + else: + malformed_keys.append(str(key)) + + keys_to_remove = set() + for source in requested_sources: + matched = False + for key, item in valid_items: + if self._match_manifest_item_for_source(source, item): + keys_to_remove.add(key) + matched = True + if not matched: + result["unmatched_sources"].append(source) + + if keys_to_remove: + for key in keys_to_remove: + manifest.pop(key, None) + self._save_manifest(manifest) + + result["removed_keys"] = sorted(keys_to_remove) + result["removed_count"] = len(keys_to_remove) + result["remaining_count"] = len(manifest) + + if malformed_keys: + preview = ", ".join(malformed_keys[:5]) + extra = "" if len(malformed_keys) <= 5 else f" ... (+{len(malformed_keys) - 5})" + result["warnings"].append( + f"manifest 条目结构异常,已跳过 {len(malformed_keys)} 项: {preview}{extra}" + ) + + return result + + def _manifest_key_for_file(self, file_record: ImportFileRecord, content_hash: str, dedupe_policy: str) -> str: + if dedupe_policy == "content_hash": + return f"hash:{content_hash}" + if file_record.source_path: + return f"path:{Path(file_record.source_path).as_posix().lower()}" + return f"hash:{content_hash}" + + def _is_manifest_hit( + self, + file_record: ImportFileRecord, + content_hash: str, + dedupe_policy: str, + ) -> bool: + key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) + manifest = self._load_manifest() + item = manifest.get(key) + if not isinstance(item, dict): + return False + return str(item.get("hash") or "") == content_hash and bool(item.get("imported")) + + def _record_manifest_import( + self, + file_record: ImportFileRecord, + content_hash: str, + dedupe_policy: str, + task_id: str, + ) -> None: + key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) + manifest = self._load_manifest() + manifest[key] = { + "hash": content_hash, + "imported": True, + "timestamp": _now(), + "task_id": task_id, + "name": file_record.name, + "source_path": file_record.source_path or "", + "source_kind": file_record.source_kind, + } + self._save_manifest(manifest) + + def _normalize_common_import_params(self, payload: Dict[str, Any], *, default_dedupe: str) -> Dict[str, Any]: + input_mode = str(payload.get("input_mode", "text") or "text").strip().lower() + if input_mode not in {"text", "json"}: + raise ValueError("input_mode 必须为 text 或 json") + + file_concurrency = _coerce_int( + payload.get("file_concurrency", self._default_file_concurrency()), + self._default_file_concurrency(), + ) + chunk_concurrency = _coerce_int( + payload.get("chunk_concurrency", self._default_chunk_concurrency()), + self._default_chunk_concurrency(), + ) + file_concurrency = _clamp(file_concurrency, 1, self._max_file_concurrency()) + chunk_concurrency = _clamp(chunk_concurrency, 1, self._max_chunk_concurrency()) + + llm_enabled = _coerce_bool(payload.get("llm_enabled", True), True) + strategy_override = parse_import_strategy( + payload.get("strategy_override", "auto"), + default=ImportStrategy.AUTO, + ).value + + dedupe_policy = str(payload.get("dedupe_policy", default_dedupe) or default_dedupe).strip().lower() + if dedupe_policy not in {"content_hash", "manifest", "none"}: + raise ValueError("dedupe_policy 必须为 content_hash/manifest/none") + + chat_log = _coerce_bool(payload.get("chat_log"), False) + chat_reference_time = str(payload.get("chat_reference_time") or "").strip() or None + force = _coerce_bool(payload.get("force"), False) + clear_manifest = _coerce_bool(payload.get("clear_manifest"), False) + + return { + "input_mode": input_mode, + "file_concurrency": file_concurrency, + "chunk_concurrency": chunk_concurrency, + "llm_enabled": llm_enabled, + "strategy_override": strategy_override, + "chat_log": chat_log, + "chat_reference_time": chat_reference_time, + "force": force, + "clear_manifest": clear_manifest, + "dedupe_policy": dedupe_policy, + } + + def _normalize_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="content_hash") + params["task_kind"] = "upload" + return params + + def _normalize_raw_scan_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="manifest") + alias = str(payload.get("alias") or "raw").strip() + relative_path = str(payload.get("relative_path") or "").strip() + glob_pattern = str(payload.get("glob") or "*").strip() or "*" + recursive = _coerce_bool(payload.get("recursive"), True) + if ".." in relative_path.replace("\\", "/").split("/"): + raise ValueError("relative_path 不允许包含 ..") + params.update( + { + "task_kind": "raw_scan", + "alias": alias, + "relative_path": relative_path, + "glob": glob_pattern, + "recursive": recursive, + } + ) + return params + + def _normalize_lpmm_openie_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + params = self._normalize_common_import_params(payload, default_dedupe="manifest") + alias = str(payload.get("alias") or "lpmm").strip() + relative_path = str(payload.get("relative_path") or "").strip() + include_all_json = _coerce_bool(payload.get("include_all_json"), False) + params.update( + { + "task_kind": "lpmm_openie", + "alias": alias, + "relative_path": relative_path, + "include_all_json": include_all_json, + "input_mode": "json", + } + ) + return params + + def _normalize_temporal_backfill_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "plugin_data").strip() + relative_path = str(payload.get("relative_path") or "").strip() + dry_run = _coerce_bool(payload.get("dry_run"), False) + no_created_fallback = _coerce_bool(payload.get("no_created_fallback"), False) + limit = _parse_optional_positive_int(payload.get("limit"), "limit") or 100000 + return { + "task_kind": "temporal_backfill", + "alias": alias, + "relative_path": relative_path, + "dry_run": dry_run, + "no_created_fallback": no_created_fallback, + "limit": limit, + } + + def _normalize_lpmm_convert_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + alias = str(payload.get("alias") or "lpmm").strip() + relative_path = str(payload.get("relative_path") or "").strip() + target_alias = str(payload.get("target_alias") or "plugin_data").strip() + target_relative_path = str(payload.get("target_relative_path") or "").strip() + dimension = _parse_optional_positive_int(payload.get("dimension"), "dimension") or _coerce_int( + self._cfg("embedding.dimension", 384), + 384, + ) + batch_size = _parse_optional_positive_int(payload.get("batch_size"), "batch_size") or 1024 + return { + "task_kind": "lpmm_convert", + "alias": alias, + "relative_path": relative_path, + "target_alias": target_alias, + "target_relative_path": target_relative_path, + "dimension": dimension, + "batch_size": batch_size, + } + + def _normalize_by_task_kind(self, task_kind: str, payload: Dict[str, Any]) -> Dict[str, Any]: + kind = str(task_kind or "").strip().lower() + if kind in {"upload", "paste"}: + params = self._normalize_params(payload) + params["task_kind"] = kind + return params + if kind == "maibot_migration": + return self._normalize_migration_params(payload) + if kind == "raw_scan": + return self._normalize_raw_scan_params(payload) + if kind == "lpmm_openie": + return self._normalize_lpmm_openie_params(payload) + if kind == "temporal_backfill": + return self._normalize_temporal_backfill_params(payload) + if kind == "lpmm_convert": + return self._normalize_lpmm_convert_params(payload) + # upload/paste 默认走通用文本导入参数 + return self._normalize_params(payload) + + def _normalize_migration_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: + source_db = str(payload.get("source_db") or "").strip() + if not source_db: + source_db = str(self._default_maibot_source_db()) + + time_from = str(payload.get("time_from") or "").strip() or None + time_to = str(payload.get("time_to") or "").strip() or None + + stream_ids = _coerce_list(payload.get("stream_ids")) + group_ids = _coerce_list(payload.get("group_ids")) + user_ids = _coerce_list(payload.get("user_ids")) + + start_id = _parse_optional_positive_int(payload.get("start_id"), "start_id") + end_id = _parse_optional_positive_int(payload.get("end_id"), "end_id") + if start_id is not None and end_id is not None and start_id > end_id: + raise ValueError("start_id 不能大于 end_id") + + read_batch_size = _parse_optional_positive_int(payload.get("read_batch_size"), "read_batch_size") or 2000 + commit_window_rows = _parse_optional_positive_int(payload.get("commit_window_rows"), "commit_window_rows") or 20000 + embed_batch_size = _parse_optional_positive_int(payload.get("embed_batch_size"), "embed_batch_size") or 256 + entity_embed_batch_size = ( + _parse_optional_positive_int(payload.get("entity_embed_batch_size"), "entity_embed_batch_size") or 512 + ) + embed_workers = _parse_optional_positive_int(payload.get("embed_workers"), "embed_workers") + max_errors = _parse_optional_positive_int(payload.get("max_errors"), "max_errors") or 500 + log_every = _parse_optional_positive_int(payload.get("log_every"), "log_every") or 5000 + preview_limit = _parse_optional_positive_int(payload.get("preview_limit"), "preview_limit") or 20 + + no_resume = _coerce_bool(payload.get("no_resume"), False) + reset_state = _coerce_bool(payload.get("reset_state"), False) + dry_run = _coerce_bool(payload.get("dry_run"), False) + verify_only = _coerce_bool(payload.get("verify_only"), False) + + return { + "task_kind": "maibot_migration", + "source_db": source_db, + "target_data_dir": str(self._resolve_data_dir()), + "time_from": time_from, + "time_to": time_to, + "stream_ids": stream_ids, + "group_ids": group_ids, + "user_ids": user_ids, + "start_id": start_id, + "end_id": end_id, + "read_batch_size": read_batch_size, + "commit_window_rows": commit_window_rows, + "embed_batch_size": embed_batch_size, + "entity_embed_batch_size": entity_embed_batch_size, + "embed_workers": embed_workers, + "max_errors": max_errors, + "log_every": log_every, + "preview_limit": preview_limit, + "no_resume": no_resume, + "reset_state": reset_state, + "dry_run": dry_run, + "verify_only": verify_only, + } + + def _pending_task_count(self) -> int: + pending = 0 + for task in self._tasks.values(): + if task.status in {"queued", "preparing", "running", "cancel_requested"}: + pending += 1 + return pending + + async def _ensure_worker(self) -> None: + async with self._lock: + if self._worker_task and not self._worker_task.done(): + return + self._stopping = False + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def get_runtime_settings(self) -> Dict[str, Any]: + llm_retry = self._llm_retry_config() + return { + "max_queue_size": self._queue_limit(), + "max_files_per_task": self._max_files_per_task(), + "max_file_size_mb": self._cfg_int("web.import.max_file_size_mb", 20), + "max_paste_chars": self._max_paste_chars(), + "default_file_concurrency": self._default_file_concurrency(), + "default_chunk_concurrency": self._default_chunk_concurrency(), + "max_file_concurrency": self._max_file_concurrency(), + "max_chunk_concurrency": self._max_chunk_concurrency(), + "poll_interval_ms": max(200, self._cfg_int("web.import.poll_interval_ms", 1000)), + "maibot_source_db_default": str(self._default_maibot_source_db()), + "maibot_target_data_dir": str(self._resolve_data_dir()), + "path_aliases": self.get_path_aliases(), + "llm_retry": llm_retry, + "convert_enable_staging_switch": _coerce_bool( + self._cfg("web.import.convert.enable_staging_switch", True), True + ), + "convert_keep_backup_count": max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)), + } + + def is_write_blocked(self) -> bool: + task_id = self._active_task_id + if not task_id: + return False + task = self._tasks.get(task_id) + if not task: + return False + return task.status in {"preparing", "running", "cancel_requested"} + + def _ensure_ready(self) -> None: + required_attrs = ("metadata_store", "vector_store", "graph_store", "embedding_manager") + + def _collect_missing() -> List[str]: + missing_local: List[str] = [] + for attr in required_attrs: + if getattr(self.plugin, attr, None) is None: + missing_local.append(attr) + return missing_local + + missing = _collect_missing() + if missing: + raise ValueError(f"导入依赖未初始化: {', '.join(missing)}") + ready_checker = getattr(self.plugin, "is_runtime_ready", None) + if callable(ready_checker) and not ready_checker(): + raise ValueError("插件运行时未就绪,请先完成 on_enable 初始化") + + def _scan_files( + self, + base_path: Path, + *, + recursive: bool, + glob_pattern: str, + allowed_exts: Optional[set[str]] = None, + ) -> List[Path]: + if base_path.is_file(): + candidates = [base_path] + else: + if recursive: + candidates = list(base_path.rglob(glob_pattern)) + else: + candidates = list(base_path.glob(glob_pattern)) + out: List[Path] = [] + for p in candidates: + if not p.is_file(): + continue + ext = p.suffix.lower() + if allowed_exts and ext not in allowed_exts: + continue + out.append(p.resolve()) + out.sort(key=lambda x: x.as_posix().lower()) + return out + + async def create_upload_task(self, files: List[Any], payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + if not files: + raise ValueError("至少需要上传一个文件") + + params = self._normalize_params(payload) + max_files = self._max_files_per_task() + if len(files) > max_files: + raise ValueError(f"单任务文件数超过上限: {max_files}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="upload", + params=params, + status="queued", + current_step="queued", + ) + task_dir = self._temp_root / task.task_id + task_dir.mkdir(parents=True, exist_ok=True) + + max_size = self._max_file_size_bytes() + for idx, uploaded in enumerate(files): + file_id = uuid.uuid4().hex + if isinstance(uploaded, dict): + staged_path_raw = uploaded.get("staged_path") or uploaded.get("path") or "" + staged_path = Path(str(staged_path_raw or "")).expanduser().resolve() + if not staged_path.is_file(): + raise ValueError(f"上传暂存文件不存在: {staged_path}") + name = _safe_filename(uploaded.get("filename") or uploaded.get("name") or staged_path.name) + ext = Path(name).suffix.lower() + if ext not in {".txt", ".md", ".json"}: + raise ValueError(f"不支持的文件类型: {name}") + if staged_path.stat().st_size > max_size: + raise ValueError(f"文件超过大小限制: {name}") + temp_path = task_dir / f"{file_id}_{name}" + shutil.copy2(staged_path, temp_path) + else: + name = _safe_filename(getattr(uploaded, "filename", f"file_{idx}.txt")) + ext = Path(name).suffix.lower() + if ext not in {".txt", ".md", ".json"}: + raise ValueError(f"不支持的文件类型: {name}") + content = await uploaded.read() + if len(content) > max_size: + raise ValueError(f"文件超过大小限制: {name}") + temp_path = task_dir / f"{file_id}_{name}" + temp_path.write_bytes(content) + file_mode = "json" if ext == ".json" else params["input_mode"] + task.files.append( + ImportFileRecord( + file_id=file_id, + name=name, + source_kind="upload", + input_mode=file_mode, + temp_path=str(temp_path), + ) + ) + + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_paste_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + + params = self._normalize_params(payload) + params["task_kind"] = "paste" + content = str(payload.get("content", "") or "") + if not content.strip(): + raise ValueError("content 不能为空") + if len(content) > self._max_paste_chars(): + raise ValueError(f"粘贴内容超过限制: {self._max_paste_chars()} 字符") + + name = _safe_filename(payload.get("name") or f"paste_{int(_now())}.txt") + if params["input_mode"] == "json" and Path(name).suffix.lower() != ".json": + name = f"{Path(name).stem}.json" + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="paste", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=name, + source_kind="paste", + input_mode=params["input_mode"], + inline_content=content, + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_raw_scan_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + params = self._normalize_raw_scan_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + files = self._scan_files( + source_path, + recursive=bool(params["recursive"]), + glob_pattern=str(params["glob"] or "*"), + allowed_exts={".txt", ".md", ".json"}, + ) + if not files: + raise ValueError("未找到可导入文件") + if len(files) > self._max_files_per_task(): + raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="raw_scan", + params=params, + status="queued", + current_step="queued", + ) + for path in files: + mode = "json" if path.suffix.lower() == ".json" else params["input_mode"] + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=path.name, + source_kind="raw_scan", + input_mode=mode, + source_path=str(path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_lpmm_openie_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + params = self._normalize_lpmm_openie_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + files: List[Path] = [] + if source_path.is_file(): + files = [source_path] + else: + files = self._scan_files( + source_path, + recursive=True, + glob_pattern="*-openie.json", + allowed_exts={".json"}, + ) + if not files and params.get("include_all_json"): + files = self._scan_files( + source_path, + recursive=True, + glob_pattern="*.json", + allowed_exts={".json"}, + ) + if not files: + raise ValueError("未找到 LPMM OpenIE JSON 文件") + if len(files) > self._max_files_per_task(): + raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="lpmm_openie", + params=params, + status="queued", + current_step="queued", + schema_detected="lpmm_openie", + ) + for path in files: + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=path.name, + source_kind="lpmm_openie", + input_mode="json", + source_path=str(path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_temporal_backfill_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + params = self._normalize_temporal_backfill_params(payload) + target_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + if not target_path.is_dir(): + raise ValueError("temporal_backfill 目标路径必须为目录") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="temporal_backfill", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"temporal_backfill_{int(_now())}", + source_kind="temporal_backfill", + input_mode="json", + source_path=str(target_path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_lpmm_convert_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + params = self._normalize_lpmm_convert_params(payload) + source_path = self.resolve_path_alias( + params["alias"], + params["relative_path"], + must_exist=True, + ) + if not source_path.is_dir(): + raise ValueError("lpmm_convert 输入路径必须为目录") + target_path = self.resolve_path_alias( + params["target_alias"], + params["target_relative_path"], + must_exist=False, + ) + target_path.mkdir(parents=True, exist_ok=True) + if not target_path.is_dir(): + raise ValueError("lpmm_convert 目标路径必须为目录") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="lpmm_convert", + params={**params, "source_path": str(source_path), "target_path": str(target_path)}, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"lpmm_convert_{int(_now())}", + source_kind="lpmm_convert", + input_mode="json", + source_path=str(source_path), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def create_maibot_migration_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: + if not self._is_enabled(): + raise ValueError("导入功能已禁用") + self._ensure_ready() + + params = self._normalize_migration_params(payload) + script_path = self._resolve_migration_script() + if not script_path.exists(): + raise ValueError(f"迁移脚本不存在: {script_path}") + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + + task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source="maibot_migration", + params=params, + status="queued", + current_step="queued", + ) + task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=f"maibot_migration_{int(_now())}", + source_kind="maibot_migration", + input_mode="text", + inline_content=json.dumps(params, ensure_ascii=False), + ) + ) + self._tasks[task.task_id] = task + self._task_order.appendleft(task.task_id) + self._queue.append(task.task_id) + + await self._ensure_worker() + return task.to_summary() + + async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: + async with self._lock: + task_ids = list(self._task_order)[: max(1, int(limit))] + return [self._tasks[task_id].to_summary() for task_id in task_ids if task_id in self._tasks] + + async def get_task(self, task_id: str, include_chunks: bool = False) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + return task.to_detail(include_chunks=include_chunks) + + async def get_chunks(self, task_id: str, file_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + file_obj = self._find_file(task, file_id) + if not file_obj: + return None + start = max(0, int(offset)) + size = max(1, min(500, int(limit))) + items = file_obj.chunks[start : start + size] + return { + "task_id": task_id, + "file_id": file_id, + "offset": start, + "limit": size, + "total": len(file_obj.chunks), + "items": [x.to_dict() for x in items], + } + + async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + if task.status == "queued": + self._mark_task_cancelled_locked(task, "任务已取消") + self._queue = deque([x for x in self._queue if x != task_id]) + elif task.status in {"preparing", "running"}: + task.status = "cancel_requested" + task.current_step = "cancel_requested" + task.updated_at = _now() + return task.to_summary() + + def _build_retry_plan(self, task: ImportTaskRecord) -> Dict[str, Any]: + chunk_retry_candidates: List[Tuple[ImportFileRecord, List[int]]] = [] + file_fallback_candidates: List[ImportFileRecord] = [] + skipped: List[Dict[str, str]] = [] + + for file_obj in task.files: + if file_obj.status == "cancelled": + continue + + failed_chunks = [c for c in file_obj.chunks if c.status == "failed"] + has_file_level_failure = file_obj.status == "failed" and not failed_chunks + if has_file_level_failure: + file_fallback_candidates.append(file_obj) + continue + + if not failed_chunks: + continue + + retry_indexes: List[int] = [] + has_non_retryable = False + for chunk in failed_chunks: + failed_at = str(chunk.failed_at or "").strip().lower() + retryable = bool(chunk.retryable) or ( + file_obj.input_mode == "text" and failed_at == "extracting" + ) + if retryable: + try: + retry_indexes.append(int(chunk.index)) + except Exception: + has_non_retryable = True + else: + has_non_retryable = True + + if has_non_retryable: + file_fallback_candidates.append(file_obj) + continue + + retry_indexes = sorted(set(retry_indexes)) + if retry_indexes: + chunk_retry_candidates.append((file_obj, retry_indexes)) + else: + skipped.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": "no_retryable_failed_chunks", + } + ) + + unique_fallback: List[ImportFileRecord] = [] + fallback_seen = set() + for file_obj in file_fallback_candidates: + if file_obj.file_id in fallback_seen: + continue + fallback_seen.add(file_obj.file_id) + unique_fallback.append(file_obj) + + return { + "chunk_retry_candidates": chunk_retry_candidates, + "file_fallback_candidates": unique_fallback, + "skipped": skipped, + } + + def _clone_failed_file_for_retry( + self, + retry_task: ImportTaskRecord, + failed_file: ImportFileRecord, + task_dir: Path, + *, + retry_mode: str, + retry_chunk_indexes: Optional[List[int]] = None, + ) -> Tuple[bool, str]: + source_kind = str(failed_file.source_kind or "").strip().lower() + retry_chunk_indexes = list(retry_chunk_indexes or []) + + if source_kind == "upload": + candidate_paths: List[Path] = [] + if failed_file.temp_path: + candidate_paths.append(Path(failed_file.temp_path)) + if failed_file.source_path: + candidate_paths.append(Path(failed_file.source_path)) + src_path = next((p for p in candidate_paths if p.exists() and p.is_file()), None) + if src_path is None: + return False, "upload_source_missing" + data = src_path.read_bytes() + file_id = uuid.uuid4().hex + name = _safe_filename(failed_file.name) + dst = task_dir / f"{file_id}_{name}" + dst.write_bytes(data) + retry_task.files.append( + ImportFileRecord( + file_id=file_id, + name=name, + source_kind="upload", + input_mode=failed_file.input_mode, + temp_path=str(dst), + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + if source_kind == "paste": + if failed_file.inline_content is None: + return False, "paste_content_missing" + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind="paste", + input_mode=failed_file.input_mode, + inline_content=failed_file.inline_content, + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + if source_kind == "maibot_migration": + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind="maibot_migration", + input_mode="text", + inline_content=failed_file.inline_content, + retry_mode="file_fallback", + retry_chunk_indexes=[], + ) + ) + return True, "" + + if source_kind in {"raw_scan", "lpmm_openie", "lpmm_convert", "temporal_backfill"}: + retry_task.files.append( + ImportFileRecord( + file_id=uuid.uuid4().hex, + name=_safe_filename(failed_file.name), + source_kind=source_kind, + input_mode=failed_file.input_mode, + source_path=failed_file.source_path, + inline_content=failed_file.inline_content, + retry_mode=retry_mode, + retry_chunk_indexes=retry_chunk_indexes, + ) + ) + return True, "" + + return False, f"unsupported_source_kind:{source_kind or 'unknown'}" + + async def retry_failed(self, task_id: str, overrides: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return None + retry_plan = self._build_retry_plan(task) + chunk_retry_candidates = list(retry_plan["chunk_retry_candidates"]) + file_fallback_candidates = list(retry_plan["file_fallback_candidates"]) + skipped_candidates = list(retry_plan["skipped"]) + if not chunk_retry_candidates and not file_fallback_candidates: + raise ValueError("当前任务没有可重试失败项") + base_params = dict(task.params) + task_kind = str(task.params.get("task_kind") or "").strip().lower() + + if overrides: + base_params.update(overrides) + params = self._normalize_by_task_kind(task_kind, base_params) + params["retry_parent_task_id"] = task_id + params["retry_strategy"] = "chunk_first_auto_file_fallback" + + async with self._lock: + if self._pending_task_count() >= self._queue_limit(): + raise ValueError("任务队列已满,请稍后重试") + retry_task = ImportTaskRecord( + task_id=uuid.uuid4().hex, + source=task.source, + params=params, + status="queued", + current_step="queued", + schema_detected=task.schema_detected, + retry_parent_task_id=task_id, + ) + + task_dir = self._temp_root / retry_task.task_id + task_dir.mkdir(parents=True, exist_ok=True) + + retry_summary = { + "chunk_retry_files": 0, + "chunk_retry_chunks": 0, + "file_fallback_files": 0, + "skipped_files": 0, + "parent_task_id": task_id, + } + skipped_details = list(skipped_candidates) + + for file_obj, chunk_indexes in chunk_retry_candidates: + ok, reason = self._clone_failed_file_for_retry( + retry_task, + file_obj, + task_dir, + retry_mode="chunk", + retry_chunk_indexes=chunk_indexes, + ) + if ok: + retry_summary["chunk_retry_files"] += 1 + retry_summary["chunk_retry_chunks"] += len(chunk_indexes) + else: + skipped_details.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": reason, + } + ) + + for file_obj in file_fallback_candidates: + ok, reason = self._clone_failed_file_for_retry( + retry_task, + file_obj, + task_dir, + retry_mode="file_fallback", + retry_chunk_indexes=[], + ) + if ok: + retry_summary["file_fallback_files"] += 1 + else: + skipped_details.append( + { + "file_name": file_obj.name, + "source_kind": file_obj.source_kind, + "reason": reason, + } + ) + + retry_summary["skipped_files"] = len(skipped_details) + if skipped_details: + retry_summary["skipped_details"] = skipped_details + retry_task.retry_summary = retry_summary + + if not retry_task.files: + raise ValueError("无可执行的重试输入:失败项均无法构建重试任务") + + self._tasks[retry_task.task_id] = retry_task + self._task_order.appendleft(retry_task.task_id) + self._queue.append(retry_task.task_id) + logger.info( + "重试任务已创建 " + f"parent={task_id} retry={retry_task.task_id} " + f"chunk_files={retry_summary['chunk_retry_files']} " + f"chunk_chunks={retry_summary['chunk_retry_chunks']} " + f"file_fallback={retry_summary['file_fallback_files']} " + f"skipped={retry_summary['skipped_files']}" + ) + + await self._ensure_worker() + return retry_task.to_summary() + + async def shutdown(self) -> None: + async with self._lock: + self._stopping = True + for task in self._tasks.values(): + if task.status in {"queued", "preparing", "running", "cancel_requested"}: + self._mark_task_cancelled_locked(task, "服务关闭") + self._queue.clear() + worker = self._worker_task + self._worker_task = None + + if worker: + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + except Exception: + pass + + self._cleanup_temp_root() + + def _cleanup_temp_root(self) -> None: + try: + if not self._temp_root.exists(): + return + for child in self._temp_root.rglob("*"): + if child.is_file(): + child.unlink(missing_ok=True) + for child in sorted(self._temp_root.rglob("*"), reverse=True): + if child.is_dir(): + child.rmdir() + self._temp_root.rmdir() + except Exception as e: + logger.warning(f"清理临时导入目录失败: {e}") + + async def _worker_loop(self) -> None: + logger.info("Web 导入任务 worker 已启动") + while True: + if self._stopping: + break + + task_id: Optional[str] = None + async with self._lock: + while self._queue: + candidate = self._queue.popleft() + t = self._tasks.get(candidate) + if not t: + continue + if t.status == "cancelled": + continue + task_id = candidate + self._active_task_id = candidate + break + + if not task_id: + await asyncio.sleep(0.2) + continue + + try: + await self._run_task(task_id) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"导入任务执行失败 task={task_id}: {e}\n{traceback.format_exc()}") + async with self._lock: + task = self._tasks.get(task_id) + if task and task.status not in {"cancelled", "completed", "completed_with_errors"}: + task.status = "failed" + task.current_step = "failed" + task.error = str(e) + task.finished_at = _now() + task.updated_at = _now() + finally: + should_cleanup = await self._should_cleanup_task_temp(task_id) + async with self._lock: + if self._active_task_id == task_id: + self._active_task_id = None + if should_cleanup: + await self._cleanup_task_temp_files(task_id) + + logger.info("Web 导入任务 worker 已停止") + + async def _cleanup_task_temp_files(self, task_id: str) -> None: + task_dir = self._temp_root / task_id + if not task_dir.exists(): + return + try: + for child in task_dir.rglob("*"): + if child.is_file(): + child.unlink(missing_ok=True) + for child in sorted(task_dir.rglob("*"), reverse=True): + if child.is_dir(): + child.rmdir() + task_dir.rmdir() + except Exception as e: + logger.warning(f"清理任务临时文件失败 task={task_id}: {e}") + + def _task_report_path(self, task_id: str) -> Path: + self._reports_root.mkdir(parents=True, exist_ok=True) + return self._reports_root / f"{task_id}_summary.json" + + def _write_task_report(self, task: ImportTaskRecord) -> None: + path = self._task_report_path(task.task_id) + payload = task.to_detail(include_chunks=False) + payload["generated_at"] = _now() + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + task.artifact_paths["summary"] = str(path) + + async def _run_task(self, task_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + task.status = "preparing" + task.current_step = "preparing" + task.started_at = _now() + task.updated_at = _now() + if task.params.get("clear_manifest"): + self._clear_manifest() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + if task.status == "cancel_requested": + task.status = "cancelled" + task.current_step = "cancelled" + task.finished_at = _now() + task.updated_at = _now() + return + task.status = "running" + task.current_step = "running" + task.updated_at = _now() + + task_kind = str(task.params.get("task_kind") or task.source).strip().lower() + if task_kind == "maibot_migration": + if not task.files: + raise RuntimeError("迁移任务缺少文件记录") + await self._process_maibot_migration(task_id, task.files[0]) + elif task_kind == "temporal_backfill": + if not task.files: + raise RuntimeError("回填任务缺少文件记录") + await self._process_temporal_backfill(task_id, task.files[0]) + elif task_kind == "lpmm_convert": + if not task.files: + raise RuntimeError("转换任务缺少文件记录") + await self._process_lpmm_convert(task_id, task.files[0]) + else: + file_semaphore = asyncio.Semaphore(task.params["file_concurrency"]) + chunk_semaphore = asyncio.Semaphore(task.params["chunk_concurrency"]) + jobs = [ + asyncio.create_task(self._process_file(task_id, f, file_semaphore, chunk_semaphore)) + for f in task.files + ] + await asyncio.gather(*jobs, return_exceptions=True) + + write_changed_payload: Optional[Dict[str, Any]] = None + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + self._recompute_task_progress(task) + has_failed = any( + (f.status == "failed") + or (f.failed_chunks > 0) + or bool(str(f.error or "").strip()) + for f in task.files + ) + has_cancelled = any(f.status == "cancelled" for f in task.files) + has_completed = any(f.status == "completed" for f in task.files) + + # 统一按文件真实终态收敛任务状态,避免出现“任务已取消但文件已完成”的矛盾结果。 + if has_failed and not has_cancelled: + task.status = "completed_with_errors" + task.current_step = "completed_with_errors" + elif has_cancelled and not has_completed: + task.status = "cancelled" + task.current_step = "cancelled" + elif has_cancelled and has_completed: + task.status = "cancelled" + task.current_step = "cancelled" + else: + task.status = "completed" + task.current_step = "completed" + task.finished_at = _now() + task.updated_at = _now() + try: + self._write_task_report(task) + except Exception as report_err: + logger.warning(f"写入任务报告失败 task={task_id}: {report_err}") + task_kind = str(task.params.get("task_kind") or task.source).strip().lower() + write_task_kinds = {"upload", "paste", "raw_scan", "lpmm_openie", "maibot_migration", "lpmm_convert"} + has_written_chunks = (task.done_chunks > 0) or any(f.done_chunks > 0 for f in task.files) + if task_kind in write_task_kinds and has_written_chunks: + write_changed_payload = { + "task_id": task.task_id, + "task_kind": task_kind, + "status": task.status, + "done_chunks": task.done_chunks, + "finished_at": task.finished_at, + } + + if write_changed_payload: + await self._notify_write_changed(write_changed_payload) + + def _build_maibot_migration_command(self, params: Dict[str, Any]) -> List[str]: + script_path = self._resolve_migration_script() + if not script_path.exists(): + raise RuntimeError(f"迁移脚本不存在: {script_path}") + + cmd = [ + sys.executable, + str(script_path), + "--source-db", + str(params["source_db"]), + "--target-data-dir", + str(params["target_data_dir"]), + "--read-batch-size", + str(params["read_batch_size"]), + "--commit-window-rows", + str(params["commit_window_rows"]), + "--embed-batch-size", + str(params["embed_batch_size"]), + "--entity-embed-batch-size", + str(params["entity_embed_batch_size"]), + "--max-errors", + str(params["max_errors"]), + "--log-every", + str(params["log_every"]), + "--preview-limit", + str(params["preview_limit"]), + "--yes", + ] + + if params.get("embed_workers") is not None: + cmd.extend(["--embed-workers", str(params["embed_workers"])]) + if params.get("start_id") is not None: + cmd.extend(["--start-id", str(params["start_id"])]) + if params.get("end_id") is not None: + cmd.extend(["--end-id", str(params["end_id"])]) + if params.get("time_from"): + cmd.extend(["--time-from", str(params["time_from"])]) + if params.get("time_to"): + cmd.extend(["--time-to", str(params["time_to"])]) + + for sid in params.get("stream_ids") or []: + cmd.extend(["--stream-id", str(sid)]) + for gid in params.get("group_ids") or []: + cmd.extend(["--group-id", str(gid)]) + for uid in params.get("user_ids") or []: + cmd.extend(["--user-id", str(uid)]) + + if params.get("reset_state"): + cmd.append("--reset-state") + if params.get("no_resume"): + cmd.append("--no-resume") + if params.get("dry_run"): + cmd.append("--dry-run") + if params.get("verify_only"): + cmd.append("--verify-only") + + return cmd + + async def _ensure_maibot_migration_chunk( + self, + task_id: str, + file_id: str, + *, + chunk_type: str = "maibot_migration", + preview: str = "MaiBot chat_history 迁移任务", + ) -> str: + chunk_id = f"{file_id}_{chunk_type}" + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return chunk_id + f = self._find_file(task, file_id) + if not f: + return chunk_id + if not f.chunks: + f.chunks = [ + ImportChunkRecord( + chunk_id=chunk_id, + index=0, + chunk_type=chunk_type, + status="queued", + step="queued", + progress=0.0, + content_preview=preview, + ) + ] + f.total_chunks = 1 + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 + f.updated_at = _now() + self._recompute_task_progress(task) + else: + chunk_id = f.chunks[0].chunk_id + return chunk_id + + async def _refresh_maibot_progress_from_state( + self, + task_id: str, + file_id: str, + chunk_id: str, + state_path: Path, + ) -> None: + if not state_path.exists(): + return + try: + payload = json.loads(state_path.read_text(encoding="utf-8")) + except Exception: + return + + stats = payload.get("stats", {}) if isinstance(payload, dict) else {} + if not isinstance(stats, dict): + stats = {} + + total = max(0, _coerce_int(stats.get("source_matched_total", 0), 0)) + scanned = max(0, _coerce_int(stats.get("scanned_rows", 0), 0)) + bad = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) + done = max(0, scanned - bad) + migrated = max(0, _coerce_int(stats.get("migrated_rows", 0), 0)) + last_id = max(0, _coerce_int(stats.get("last_committed_id", 0), 0)) + + if total <= 0: + total = max(1, scanned) + + progress = max(0.0, min(1.0, float(scanned) / float(total))) if total > 0 else 0.0 + preview = f"scanned={scanned}/{total}, migrated={migrated}, bad={bad}, last_id={last_id}" + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if c: + if c.status not in {"completed", "failed", "cancelled"}: + c.status = "writing" + c.step = "migrating" + c.progress = progress + c.content_preview = preview + c.updated_at = _now() + f.total_chunks = total + f.done_chunks = done + f.failed_chunks = bad + f.cancelled_chunks = 0 + f.progress = progress + if f.status not in {"failed", "cancelled"}: + f.status = "writing" + f.current_step = "migrating" + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _terminate_process(self, process: asyncio.subprocess.Process) -> None: + if process.returncode is not None: + return + try: + process.terminate() + await asyncio.wait_for(process.wait(), timeout=5.0) + except Exception: + try: + process.kill() + await asyncio.wait_for(process.wait(), timeout=3.0) + except Exception: + pass + + async def _reload_stores_after_external_migration(self) -> None: + async with self._storage_lock: + try: + if self.plugin.vector_store and self.plugin.vector_store.has_data(): + self.plugin.vector_store.load() + except Exception as e: + logger.warning(f"迁移后重载 VectorStore 失败: {e}") + try: + if self.plugin.graph_store and self.plugin.graph_store.has_data(): + self.plugin.graph_store.load() + except Exception as e: + logger.warning(f"迁移后重载 GraphStore 失败: {e}") + + async def _process_maibot_migration(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "maibot_migration") + await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="maibot_migration", + preview="MaiBot chat_history 迁移任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "migrating", 0.0) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + + command = self._build_maibot_migration_command(params) + project_root = self._resolve_repo_root() + state_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_resume.json" + report_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_report.json" + + logger.info(f"开始执行 MaiBot 迁移任务: {' '.join(command)}") + process = await asyncio.create_subprocess_exec( + *command, + cwd=str(project_root), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout_lines: List[str] = [] + stderr_lines: List[str] = [] + + async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: + if stream is None: + return + while True: + line = await stream.readline() + if not line: + break + text = line.decode("utf-8", errors="replace").strip() + if not text: + continue + target.append(text) + if len(target) > 120: + del target[:-120] + + drain_tasks = [ + asyncio.create_task(_drain(process.stdout, stdout_lines)), + asyncio.create_task(_drain(process.stderr, stderr_lines)), + ] + + cancelled = False + return_code: Optional[int] = None + try: + while True: + if await self._is_cancel_requested(task_id): + cancelled = True + await self._terminate_process(process) + break + + await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) + try: + return_code = await asyncio.wait_for(process.wait(), timeout=1.0) + break + except asyncio.TimeoutError: + continue + finally: + await asyncio.gather(*drain_tasks, return_exceptions=True) + + if cancelled: + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) + + report: Dict[str, Any] = {} + if report_path.exists(): + try: + report = json.loads(report_path.read_text(encoding="utf-8")) + except Exception: + report = {} + + stats = report.get("stats", {}) if isinstance(report, dict) else {} + if not isinstance(stats, dict): + stats = {} + bad_rows = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) + + if return_code in {0, 2}: + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + await self._reload_stores_after_external_migration() + + async with self._lock: + task2 = self._tasks.get(task_id) + if not task2: + return + f = self._find_file(task2, file_record.file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if c and c.status not in {"cancelled", "failed"}: + c.status = "completed" + c.step = "completed" + c.progress = 1.0 + c.updated_at = _now() + if f.total_chunks <= 0: + f.total_chunks = 1 + if f.done_chunks + f.failed_chunks <= 0: + f.done_chunks = f.total_chunks - bad_rows + f.failed_chunks = bad_rows + f.done_chunks = max(0, min(f.done_chunks, f.total_chunks)) + f.failed_chunks = max(0, min(f.failed_chunks, f.total_chunks)) + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + if bad_rows > 0 and not f.error: + f.error = f"迁移完成,但存在坏行: {bad_rows}" + f.updated_at = _now() + self._recompute_task_progress(task2) + return + + fail_reason = "" + if isinstance(report, dict): + fail_reason = str(report.get("fail_reason") or "").strip() + tail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") + detail = fail_reason or tail or f"迁移进程退出码: {return_code}" + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail) + await self._set_file_failed(task_id, file_record.file_id, detail) + + def _resolve_convert_script(self) -> Path: + return Path(__file__).resolve().parents[2] / "scripts" / "convert_lpmm.py" + + def _cleanup_old_backups(self) -> None: + keep = max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)) + backup_root = self._resolve_backup_root() + if not backup_root.exists() or keep <= 0: + return + dirs = [p for p in backup_root.iterdir() if p.is_dir() and p.name.startswith("lpmm_convert_")] + dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) + for old in dirs[keep:]: + try: + shutil.rmtree(old, ignore_errors=True) + except Exception: + pass + + def _verify_convert_output(self, output_dir: Path) -> Dict[str, Any]: + vectors = output_dir / "vectors" + graph = output_dir / "graph" + metadata = output_dir / "metadata" + checks = { + "vectors_exists": vectors.exists(), + "graph_exists": graph.exists(), + "metadata_exists": metadata.exists(), + "vectors_nonempty": vectors.exists() and any(vectors.iterdir()), + "graph_nonempty": graph.exists() and any(graph.iterdir()), + "metadata_nonempty": metadata.exists() and any(metadata.iterdir()), + } + checks["ok"] = checks["vectors_exists"] and checks["graph_exists"] and checks["metadata_exists"] + return checks + + async def _preflight_convert_runtime(self) -> Tuple[bool, str]: + """使用当前服务解释器做 convert 依赖预检,避免子进程报错信息不透明。""" + probe_code = ( + "import importlib\n" + "mods=['networkx','scipy','pyarrow']\n" + "failed=[]\n" + "for m in mods:\n" + " try:\n" + " importlib.import_module(m)\n" + " except Exception as e:\n" + " failed.append(f'{m}:{e.__class__.__name__}:{e}')\n" + "print('OK' if not failed else ';'.join(failed))\n" + ) + try: + probe = await asyncio.create_subprocess_exec( + sys.executable, + "-c", + probe_code, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(probe.communicate(), timeout=20.0) + except Exception as e: + return False, f"依赖预检执行失败: {e}" + + out = (stdout or b"").decode("utf-8", errors="replace").strip() + err = (stderr or b"").decode("utf-8", errors="replace").strip() + if probe.returncode != 0: + detail = err or out or f"return_code={probe.returncode}" + return False, f"依赖预检失败 (python={sys.executable}): {detail}" + if out != "OK": + return False, f"依赖预检失败 (python={sys.executable}): {out}" + return True, "" + + async def _process_lpmm_convert(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "lpmm_convert") + await self._set_file_state(task_id, file_record.file_id, "preparing", "preflight") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="lpmm_convert", + preview="LPMM 二进制转换任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "converting", 0.05) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + source_dir = Path(params.get("source_path") or "") + target_dir = Path(params.get("target_path") or "") + if not source_dir.exists() or not source_dir.is_dir(): + await self._set_file_failed(task_id, file_record.file_id, f"输入目录无效: {source_dir}") + return + if not target_dir.exists() or not target_dir.is_dir(): + await self._set_file_failed(task_id, file_record.file_id, f"目标目录无效: {target_dir}") + return + + script_path = self._resolve_convert_script() + if not script_path.exists(): + await self._set_file_failed(task_id, file_record.file_id, f"转换脚本不存在: {script_path}") + return + + runtime_ok, runtime_detail = await self._preflight_convert_runtime() + if not runtime_ok: + await self._set_file_failed(task_id, file_record.file_id, runtime_detail) + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, runtime_detail) + return + + required_inputs = ["paragraph.parquet", "entity.parquet"] + if not any((source_dir / name).exists() for name in required_inputs): + await self._set_file_failed( + task_id, + file_record.file_id, + f"输入目录缺少必要文件,至少需要其一: {', '.join(required_inputs)}", + ) + return + + staging_root = self._resolve_staging_root() + staging_root.mkdir(parents=True, exist_ok=True) + staging_dir = staging_root / f"lpmm_convert_{task_id}" + if staging_dir.exists(): + shutil.rmtree(staging_dir, ignore_errors=True) + staging_dir.mkdir(parents=True, exist_ok=True) + + # 简单空间预检:至少保留 512MB + usage = shutil.disk_usage(str(target_dir)) + if usage.free < 512 * 1024 * 1024: + await self._set_file_failed(task_id, file_record.file_id, "磁盘剩余空间不足(<512MB)") + return + + cmd = [ + sys.executable, + str(script_path), + "--input", + str(source_dir), + "--output", + str(staging_dir), + "--dim", + str(params.get("dimension", 384)), + "--batch-size", + str(params.get("batch_size", 1024)), + ] + process = await asyncio.create_subprocess_exec( + *cmd, + cwd=str(self._resolve_repo_root()), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout_lines: List[str] = [] + stderr_lines: List[str] = [] + + async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: + if stream is None: + return + while True: + line = await stream.readline() + if not line: + break + text = line.decode("utf-8", errors="replace").strip() + if text: + target.append(text) + if len(target) > 120: + del target[:-120] + + drain_tasks = [ + asyncio.create_task(_drain(process.stdout, stdout_lines)), + asyncio.create_task(_drain(process.stderr, stderr_lines)), + ] + + cancelled = False + return_code: Optional[int] = None + try: + while True: + if await self._is_cancel_requested(task_id): + cancelled = True + await self._terminate_process(process) + break + try: + return_code = await asyncio.wait_for(process.wait(), timeout=1.0) + break + except asyncio.TimeoutError: + continue + finally: + await asyncio.gather(*drain_tasks, return_exceptions=True) + + if cancelled: + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + if return_code != 0: + detail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") + await self._set_file_failed(task_id, file_record.file_id, detail or f"转换失败,退出码: {return_code}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail or f"退出码: {return_code}") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "verifying", 0.65) + verify = self._verify_convert_output(staging_dir) + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.artifact_paths["staging_dir"] = str(staging_dir) + t.artifact_paths["verify"] = json.dumps(verify, ensure_ascii=False) + if not verify.get("ok"): + await self._set_file_failed(task_id, file_record.file_id, f"校验失败: {verify}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"校验失败: {verify}") + return + + enable_switch = _coerce_bool(self._cfg("web.import.convert.enable_staging_switch", True), True) + if not enable_switch: + await self._set_file_failed(task_id, file_record.file_id, "未启用 staging 切换") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, "未启用 staging 切换") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "switching", 0.85) + backup_root = self._resolve_backup_root() + backup_root.mkdir(parents=True, exist_ok=True) + backup_dir = backup_root / f"lpmm_convert_{task_id}_{int(_now())}" + backup_dir.mkdir(parents=True, exist_ok=True) + + switched = False + rollback_info: Dict[str, Any] = {"attempted": True, "restored": False, "error": ""} + moved_items: List[Tuple[Path, Path]] = [] + try: + for name in ("vectors", "graph", "metadata"): + src_current = target_dir / name + src_new = staging_dir / name + if not src_new.exists(): + raise RuntimeError(f"staging 缺少目录: {src_new}") + if src_current.exists(): + dst_backup = backup_dir / name + shutil.move(str(src_current), str(dst_backup)) + moved_items.append((dst_backup, src_current)) + shutil.move(str(src_new), str(src_current)) + switched = True + except Exception as switch_err: + rollback_info["error"] = str(switch_err) + # 尝试回滚 + for src_backup, dst_original in moved_items: + if src_backup.exists() and not dst_original.exists(): + try: + shutil.move(str(src_backup), str(dst_original)) + except Exception: + pass + rollback_info["restored"] = True + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.rollback_info = rollback_info + await self._set_file_failed(task_id, file_record.file_id, f"切换失败并回滚: {switch_err}") + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"switch failed: {switch_err}") + return + + if switched: + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.rollback_info = rollback_info + t.artifact_paths["backup_dir"] = str(backup_dir) + self._cleanup_old_backups() + try: + await self._reload_stores_after_external_migration() + except Exception as reload_err: + logger.warning(f"转换后重载存储失败: {reload_err}") + + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + async with self._lock: + t = self._tasks.get(task_id) + if not t: + return + f = self._find_file(t, file_record.file_id) + if not f: + return + f.total_chunks = 1 + f.done_chunks = 1 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + f.updated_at = _now() + self._recompute_task_progress(t) + + async def _process_temporal_backfill(self, task_id: str, file_record: ImportFileRecord) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "temporal_backfill") + await self._set_file_state(task_id, file_record.file_id, "preparing", "backfilling") + chunk_id = await self._ensure_maibot_migration_chunk( + task_id, + file_record.file_id, + chunk_type="temporal_backfill", + preview="时序字段回填任务", + ) + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "backfilling", 0.2) + + task = self._tasks.get(task_id) + if not task: + await self._set_file_failed(task_id, file_record.file_id, "任务不存在") + return + params = dict(task.params) + target_dir = Path(file_record.source_path or "") + metadata_dir = target_dir / "metadata" + if not metadata_dir.exists(): + await self._set_file_failed(task_id, file_record.file_id, f"metadata 目录不存在: {metadata_dir}") + return + + dry_run = bool(params.get("dry_run")) + no_created_fallback = bool(params.get("no_created_fallback")) + limit = max(1, _coerce_int(params.get("limit"), 100000)) + + store = MetadataStore(data_dir=metadata_dir) + updated = 0 + candidates = 0 + try: + store.connect() + summary = store.backfill_temporal_metadata_from_created_at( + limit=limit, + dry_run=dry_run, + no_created_fallback=no_created_fallback, + ) + candidates = int(summary.get("candidates", 0)) + updated = int(summary.get("updated", 0)) + finally: + try: + store.close() + except Exception: + pass + + async with self._lock: + t = self._tasks.get(task_id) + if t: + t.artifact_paths["temporal_backfill"] = json.dumps( + { + "target_dir": str(target_dir), + "dry_run": dry_run, + "no_created_fallback": no_created_fallback, + "limit": limit, + "candidates": candidates, + "updated": updated, + }, + ensure_ascii=False, + ) + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + async with self._lock: + t = self._tasks.get(task_id) + if not t: + return + f = self._find_file(t, file_record.file_id) + if not f: + return + f.total_chunks = 1 + f.done_chunks = 1 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 1.0 + f.status = "completed" + f.current_step = "completed" + f.updated_at = _now() + self._recompute_task_progress(t) + + async def _process_file( + self, + task_id: str, + file_record: ImportFileRecord, + file_semaphore: asyncio.Semaphore, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + async with file_semaphore: + await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + try: + content = await self._read_file_content(file_record) + content_hash = hashlib.md5(content.encode("utf-8", errors="ignore")).hexdigest() + file_record.content_hash = content_hash + task = self._tasks.get(task_id) + if task: + dedupe_policy = str(task.params.get("dedupe_policy") or "none") + force = bool(task.params.get("force")) + if dedupe_policy != "none" and not force: + async with self._lock: + if self._is_manifest_hit(file_record, content_hash, dedupe_policy): + task2 = self._tasks.get(task_id) + if task2: + f = self._find_file(task2, file_record.file_id) + if f: + f.status = "completed" + f.current_step = "skipped" + f.progress = 1.0 + f.total_chunks = 0 + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.detected_strategy_type = "skipped" + f.error = "" + f.updated_at = _now() + self._recompute_task_progress(task2) + return + if file_record.input_mode == "json": + await self._process_json_file(task_id, file_record, content, chunk_semaphore) + else: + await self._process_text_file(task_id, file_record, content, chunk_semaphore) + task3 = self._tasks.get(task_id) + if task3: + dedupe_policy = str(task3.params.get("dedupe_policy") or "none") + f3 = self._find_file(task3, file_record.file_id) + if dedupe_policy != "none" and f3 and f3.status == "completed": + async with self._lock: + self._record_manifest_import(file_record, content_hash, dedupe_policy, task_id) + except Exception as e: + await self._set_file_failed(task_id, file_record.file_id, str(e)) + + async def _read_file_content(self, file_record: ImportFileRecord) -> str: + if file_record.inline_content is not None: + return file_record.inline_content + if file_record.source_path and Path(file_record.source_path).exists(): + data = Path(file_record.source_path).read_bytes() + try: + return data.decode("utf-8") + except UnicodeDecodeError: + return data.decode("utf-8", errors="replace") + if file_record.temp_path and Path(file_record.temp_path).exists(): + data = Path(file_record.temp_path).read_bytes() + try: + return data.decode("utf-8") + except UnicodeDecodeError: + return data.decode("utf-8", errors="replace") + raise RuntimeError("读取文件失败:输入内容缺失") + + async def _process_text_file( + self, + task_id: str, + file_record: ImportFileRecord, + content: str, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + task = self._tasks[task_id] + async with self._lock: + t = self._tasks.get(task_id) + if t and not t.schema_detected: + t.schema_detected = "plain_text" + strategy = self._determine_strategy( + file_record.name, + content, + task.params["strategy_override"], + chat_log=bool(task.params.get("chat_log")), + ) + await self._set_file_strategy(task_id, file_record.file_id, strategy) + await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") + await self._ensure_embedding_runtime_ready() + + chunks = strategy.split(content) + selected_chunks = list(chunks) + if file_record.retry_mode == "chunk": + retry_index_set = set() + for idx in file_record.retry_chunk_indexes: + try: + retry_index_set.add(int(idx)) + except Exception: + continue + selected_chunks = [chunk for chunk in chunks if int(chunk.chunk.index) in retry_index_set] + if not selected_chunks: + raise RuntimeError("失败分块重试索引无效,未匹配到可执行分块") + logger.info( + "重试任务按失败分块执行: " + f"file={file_record.name} " + f"selected={len(selected_chunks)} " + f"total={len(chunks)}" + ) + + await self._register_chunks(task_id, file_record.file_id, selected_chunks) + + await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") + model_cfg = None + if task.params["llm_enabled"]: + model_cfg = await self._select_model() + + jobs = [] + for chunk in selected_chunks: + jobs.append( + asyncio.create_task( + self._process_text_chunk( + task_id=task_id, + file_record=file_record, + chunk=chunk, + strategy=strategy, + llm_enabled=task.params["llm_enabled"], + model_cfg=model_cfg, + chunk_semaphore=chunk_semaphore, + chat_log=bool(task.params.get("chat_log")), + chat_reference_time=str(task.params.get("chat_reference_time") or "").strip() or None, + ) + ) + ) + await asyncio.gather(*jobs, return_exceptions=True) + + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + async with self._storage_lock: + self.plugin.vector_store.save() + self.plugin.graph_store.save() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_record.file_id) + if not f: + return + if f.failed_chunks > 0: + f.status = "failed" + f.current_step = "failed" + if not f.error: + f.error = f"存在失败分块: {f.failed_chunks}" + elif task.status == "cancel_requested": + f.status = "cancelled" + f.current_step = "cancelled" + else: + f.status = "completed" + f.current_step = "completed" + f.progress = 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + async def _process_text_chunk( + self, + task_id: str, + file_record: ImportFileRecord, + chunk: ProcessedChunk, + strategy: Any, + llm_enabled: bool, + model_cfg: Any, + chunk_semaphore: asyncio.Semaphore, + chat_log: bool = False, + chat_reference_time: Optional[str] = None, + ) -> None: + async with chunk_semaphore: + chunk_id = chunk.chunk.chunk_id + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "extracting", "extracting", 0.25) + + processed = chunk + rescue_strategy = self._chunk_rescue(chunk, file_record.name) + current_strategy = strategy + if rescue_strategy: + chunk.type = StrategyKnowledgeType.QUOTE + chunk.flags.verbatim = True + chunk.flags.requires_llm = False + current_strategy = rescue_strategy + try: + if llm_enabled and chunk.flags.requires_llm: + processed = await current_strategy.extract( + chunk, + lambda prompt: self._llm_call(prompt, model_cfg), + ) + elif chunk.type == StrategyKnowledgeType.QUOTE: + processed = await current_strategy.extract(chunk) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"抽取失败: {e}") + return + + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) + try: + time_meta = None + if chat_log and llm_enabled and model_cfg is not None: + time_meta = await self._extract_chat_time_meta_with_llm( + processed.chunk.text, + model_cfg, + reference_time=chat_reference_time, + ) + async with self._storage_lock: + await self._persist_processed_chunk(file_record, processed, time_meta=time_meta) + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") + + async def _process_json_file( + self, + task_id: str, + file_record: ImportFileRecord, + content: str, + chunk_semaphore: asyncio.Semaphore, + ) -> None: + await self._set_file_strategy(task_id, file_record.file_id, "json") + await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") + await self._ensure_embedding_runtime_ready() + + try: + data = json.loads(content) + except Exception as e: + raise RuntimeError(f"JSON 解析失败: {e}") + + schema = self._detect_json_schema(data) + async with self._lock: + task = self._tasks.get(task_id) + if task: + task.schema_detected = schema + task.updated_at = _now() + units = self._build_json_units(data, file_record.file_id, file_record.name, schema) + await self._register_json_units(task_id, file_record.file_id, units) + + await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") + jobs = [ + asyncio.create_task(self._process_json_unit(task_id, file_record, unit, chunk_semaphore)) + for unit in units + ] + await asyncio.gather(*jobs, return_exceptions=True) + + if await self._is_cancel_requested(task_id): + await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") + return + + await self._set_file_state(task_id, file_record.file_id, "saving", "saving") + async with self._storage_lock: + self.plugin.vector_store.save() + self.plugin.graph_store.save() + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_record.file_id) + if not f: + return + if f.failed_chunks > 0: + f.status = "failed" + f.current_step = "failed" + if not f.error: + f.error = f"存在失败分块: {f.failed_chunks}" + elif task.status == "cancel_requested": + f.status = "cancelled" + f.current_step = "cancelled" + else: + f.status = "completed" + f.current_step = "completed" + f.progress = 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + def _detect_json_schema(self, data: Any) -> str: + if isinstance(data, dict) and isinstance(data.get("docs"), list): + return "lpmm_openie" + if isinstance(data, dict) and isinstance(data.get("paragraphs"), list): + paragraphs = data.get("paragraphs", []) + for p in paragraphs: + if isinstance(p, dict) and any( + key in p for key in ("entities", "relations", "time_meta", "source", "type", "knowledge_type") + ): + return "script_json" + return "web_json" + raise RuntimeError("不支持的 JSON 格式:需要 paragraphs 或 docs") + + def _build_json_units(self, data: Any, file_id: str, filename: str, schema: str) -> List[Dict[str, Any]]: + units: List[Dict[str, Any]] = [] + paragraphs: List[Any] = [] + entities: List[Any] = [] + relations: List[Any] = [] + + if schema in {"web_json", "script_json"}: + paragraphs = data.get("paragraphs", []) + entities = data.get("entities", []) + relations = data.get("relations", []) + elif schema == "lpmm_openie": + docs = data.get("docs", []) + for d in docs: + if not isinstance(d, dict): + continue + content = str(d.get("passage", "") or "").strip() + if not content: + continue + triples = d.get("extracted_triples", []) or [] + rels = [] + for t in triples: + if isinstance(t, list) and len(t) == 3: + rels.append( + { + "subject": str(t[0]), + "predicate": str(t[1]), + "object": str(t[2]), + } + ) + para_item = { + "content": content, + "source": f"lpmm_openie:{filename}", + "entities": d.get("extracted_entities", []) or [], + "relations": rels, + "knowledge_type": "factual", + } + paragraphs.append(para_item) + + for p in paragraphs: + paragraph = normalize_paragraph_import_item( + p, + default_source=f"web_import:{filename}", + ) + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "paragraph", + "content": paragraph["content"], + "time_meta": paragraph["time_meta"], + "knowledge_type": paragraph["knowledge_type"], + "chunk_type": paragraph["knowledge_type"], + "source": paragraph["source"], + "entities": paragraph["entities"], + "relations": paragraph["relations"], + "preview": paragraph["content"][:120], + } + ) + + for e in entities: + name = str(e or "").strip() + if name: + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "entity", + "name": name, + "chunk_type": "entity", + "preview": name[:120], + } + ) + + for r in relations: + if not isinstance(r, dict): + continue + s = str(r.get("subject", "")).strip() + p = str(r.get("predicate", "")).strip() + o = str(r.get("object", "")).strip() + if s and p and o: + units.append( + { + "chunk_id": f"{file_id}_json_{len(units)}", + "kind": "relation", + "subject": s, + "predicate": p, + "object": o, + "chunk_type": "relation", + "preview": f"{s} {p} {o}"[:120], + } + ) + return units + + async def _register_json_units(self, task_id: str, file_id: str, units: List[Dict[str, Any]]) -> None: + records = [ + ImportChunkRecord( + chunk_id=u["chunk_id"], + index=i, + chunk_type=u.get("chunk_type", "json"), + status="queued", + step="queued", + progress=0.0, + content_preview=str(u.get("preview", "")), + ) + for i, u in enumerate(units) + ] + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.chunks = records + f.total_chunks = len(records) + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 if records else 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _process_json_unit( + self, + task_id: str, + file_record: ImportFileRecord, + unit: Dict[str, Any], + chunk_semaphore: asyncio.Semaphore, + ) -> None: + chunk_id = unit["chunk_id"] + async with chunk_semaphore: + if await self._is_cancel_requested(task_id): + await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") + return + + await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) + try: + async with self._storage_lock: + kind = unit["kind"] + if kind == "paragraph": + content = str(unit.get("content", "")) + k_type = resolve_stored_knowledge_type( + unit.get("knowledge_type"), + content=content, + ).value + source = str(unit.get("source") or f"web_import:{file_record.name}") + para_hash = self.plugin.metadata_store.add_paragraph( + content=content, + source=source, + knowledge_type=k_type, + time_meta=unit.get("time_meta"), + ) + emb = await self.plugin.embedding_manager.encode(content) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) + except ValueError: + pass + for name in unit.get("entities", []) or []: + n = str(name or "").strip() + if n: + await self._add_entity_with_vector(n, source_paragraph=para_hash) + for rel in unit.get("relations", []) or []: + if not isinstance(rel, dict): + continue + s = str(rel.get("subject", "")).strip() + p = str(rel.get("predicate", "")).strip() + o = str(rel.get("object", "")).strip() + if s and p and o: + await self._add_relation(s, p, o, source_paragraph=para_hash) + elif kind == "entity": + await self._add_entity_with_vector(unit["name"]) + elif kind == "relation": + await self._add_relation(unit["subject"], unit["predicate"], unit["object"]) + else: + raise RuntimeError(f"未知 JSON 导入单元类型: {kind}") + await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) + except Exception as e: + await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") + + def _source_label(self, file_record: ImportFileRecord) -> str: + if file_record.source_path: + return f"{file_record.source_kind}:{file_record.source_path}" + return f"web_import:{file_record.name}" + + async def _ensure_embedding_runtime_ready(self) -> None: + report = await ensure_runtime_self_check(self.plugin) + if bool(report.get("ok", False)): + return + raise RuntimeError( + "embedding runtime self-check failed: " + f"{report.get('message', 'unknown')} " + f"(configured={report.get('configured_dimension', 0)}, " + f"store={report.get('vector_store_dimension', 0)}, " + f"encoded={report.get('encoded_dimension', 0)})" + ) + + async def _persist_processed_chunk( + self, + file_record: ImportFileRecord, + processed: ProcessedChunk, + *, + time_meta: Optional[Dict[str, Any]] = None, + ) -> None: + content = processed.chunk.text + para_hash = self.plugin.metadata_store.add_paragraph( + content=content, + source=self._source_label(file_record), + knowledge_type=_storage_type_from_strategy(processed.type), + time_meta=time_meta, + ) + + emb = await self.plugin.embedding_manager.encode(content) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) + except ValueError: + pass + + data = processed.data or {} + entities: List[str] = [] + relations: List[Tuple[str, str, str]] = [] + + for triple in data.get("triples", []): + s = str(triple.get("subject", "")).strip() + p = str(triple.get("predicate", "")).strip() + o = str(triple.get("object", "")).strip() + if s and p and o: + relations.append((s, p, o)) + entities.extend([s, o]) + + for rel in data.get("relations", []): + s = str(rel.get("subject", "")).strip() + p = str(rel.get("predicate", "")).strip() + o = str(rel.get("object", "")).strip() + if s and p and o: + relations.append((s, p, o)) + entities.extend([s, o]) + + for k in ("entities", "events", "verbatim_entities"): + for e in data.get(k, []): + name = str(e or "").strip() + if name: + entities.append(name) + + uniq_entities = list({x.strip().lower(): x.strip() for x in entities if str(x).strip()}.values()) + for name in uniq_entities: + await self._add_entity_with_vector(name, source_paragraph=para_hash) + + for s, p, o in relations: + await self._add_relation(s, p, o, source_paragraph=para_hash) + + async def _add_entity_with_vector(self, name: str, source_paragraph: str = "") -> str: + hash_value = self.plugin.metadata_store.add_entity(name=name, source_paragraph=source_paragraph) + self.plugin.graph_store.add_nodes([name]) + if hash_value not in self.plugin.vector_store: + emb = await self.plugin.embedding_manager.encode(name) + try: + self.plugin.vector_store.add(emb.reshape(1, -1), [hash_value]) + except ValueError: + pass + return hash_value + + async def _add_relation(self, subject: str, predicate: str, obj: str, source_paragraph: str = "") -> str: + await self._add_entity_with_vector(subject, source_paragraph=source_paragraph) + await self._add_entity_with_vector(obj, source_paragraph=source_paragraph) + rv_cfg = self.plugin.get_config("retrieval.relation_vectorization", {}) or {} + if not isinstance(rv_cfg, dict): + rv_cfg = {} + write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + relation_service = getattr(self.plugin, "relation_write_service", None) + if relation_service is not None: + result = await relation_service.upsert_relation_with_vector( + subject=subject, + predicate=predicate, + obj=obj, + confidence=1.0, + source_paragraph=source_paragraph, + write_vector=write_vector, + ) + return result.hash_value + + rel_hash = self.plugin.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + source_paragraph=source_paragraph, + confidence=1.0, + ) + self.plugin.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + try: + self.plugin.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + return rel_hash + async def _select_model(self) -> Any: + models = llm_api.get_available_models() + if not models: + raise RuntimeError("没有可用 LLM 模型") + + config_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() + if config_model.lower() != "auto" and config_model in models: + return models[config_model] + + for task_name in [ + "lpmm_entity_extract", + "lpmm_rdf_build", + "embedding", + "replyer", + "utils", + "planner", + "tool_use", + ]: + if task_name in models: + return models[task_name] + + return models[next(iter(models))] + + async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]: + cfg = self._llm_retry_config() + retries = int(cfg["retries"]) + last_error: Optional[Exception] = None + for attempt in range(retries + 1): + try: + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="A_Memorix.WebImport", + ) + if not success or not response: + raise RuntimeError("LLM 生成失败") + + txt = str(response or "").strip() + if "```" in txt: + txt = txt.split("```json")[-1].split("```")[0].strip() + if txt.startswith("json"): + txt = txt[4:].strip() + + try: + return json.loads(txt) + except Exception: + s = txt.find("{") + e = txt.rfind("}") + if s >= 0 and e > s: + return json.loads(txt[s : e + 1]) + raise + except Exception as err: + last_error = err + if attempt >= retries: + break + delay = min(cfg["max_wait"], cfg["min_wait"] * (cfg["multiplier"] ** attempt)) + await asyncio.sleep(max(0.0, float(delay))) + raise RuntimeError(f"LLM 抽取失败: {last_error}") + + def _parse_reference_time(self, value: Optional[str]) -> datetime: + if not value: + return datetime.now() + text = str(value).strip() + formats = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d", + "%Y-%m-%d", + ] + for fmt in formats: + try: + return datetime.strptime(text, fmt) + except ValueError: + continue + return datetime.now() + + async def _extract_chat_time_meta_with_llm( + self, + text: str, + model_config: Any, + *, + reference_time: Optional[str] = None, + ) -> Optional[Dict[str, Any]]: + if not str(text or "").strip(): + return None + ref_dt = self._parse_reference_time(reference_time) + reference_now = ref_dt.strftime("%Y/%m/%d %H:%M") + prompt = f"""You are a time extraction engine for chat logs. +Extract temporal information from the following chat paragraph. + +Rules: +1. Use semantic understanding, not regex matching. +2. Convert relative expressions to absolute local datetime using reference_now. +3. If a time span exists, return event_time_start/event_time_end. +4. If only one point in time exists, return event_time. +5. If no reliable time info exists, keep all event_time fields null. +6. Return JSON only. + +reference_now: {reference_now} +text: +{text} + +JSON schema: +{{ + "event_time": null, + "event_time_start": null, + "event_time_end": null, + "time_range": null, + "time_granularity": null, + "time_confidence": 0.0 +}} +""" + try: + result = await self._llm_call(prompt, model_config) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取失败: {e}") + return None + + raw_time_meta = { + "event_time": result.get("event_time"), + "event_time_start": result.get("event_time_start"), + "event_time_end": result.get("event_time_end"), + "time_range": result.get("time_range"), + "time_granularity": result.get("time_granularity"), + "time_confidence": result.get("time_confidence"), + } + try: + normalized = normalize_time_meta(raw_time_meta) + except Exception: + return None + has_effective = any(k in normalized for k in ("event_time", "event_time_start", "event_time_end")) + if not has_effective: + return None + return normalized + + def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[Any]: + if chunk.type == StrategyKnowledgeType.QUOTE: + return None + if looks_like_quote_text(chunk.chunk.text): + return QuoteStrategy(filename) + return None + + def _instantiate_strategy(self, filename: str, strategy: ImportStrategy) -> Any: + if strategy == ImportStrategy.FACTUAL: + return FactualStrategy(filename) + if strategy == ImportStrategy.QUOTE: + return QuoteStrategy(filename) + return NarrativeStrategy(filename) + + def _determine_strategy(self, filename: str, content: str, override: str, *, chat_log: bool = False) -> Any: + strategy = select_import_strategy( + content, + override=override, + chat_log=chat_log, + ) + return self._instantiate_strategy(filename, strategy) + + async def _set_file_strategy(self, task_id: str, file_id: str, strategy: Any) -> None: + if isinstance(strategy, str): + strategy_type = strategy + elif isinstance(strategy, NarrativeStrategy): + strategy_type = "narrative" + elif isinstance(strategy, FactualStrategy): + strategy_type = "factual" + elif isinstance(strategy, QuoteStrategy): + strategy_type = "quote" + else: + strategy_type = "unknown" + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.detected_strategy_type = strategy_type + f.updated_at = _now() + task.updated_at = _now() + + async def _register_chunks(self, task_id: str, file_id: str, chunks: List[ProcessedChunk]) -> None: + records = [ + ImportChunkRecord( + chunk_id=chunk.chunk.chunk_id, + index=index, + chunk_type=chunk.type.value, + status="queued", + step="queued", + progress=0.0, + content_preview=str(chunk.chunk.text or "")[:120], + ) + for index, chunk in enumerate(chunks) + ] + + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.chunks = records + f.total_chunks = len(records) + f.done_chunks = 0 + f.failed_chunks = 0 + f.cancelled_chunks = 0 + f.progress = 0.0 if records else 1.0 + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_file_state(self, task_id: str, file_id: str, status: str, step: str) -> None: + if status not in FILE_STATUS: + return + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = status + f.current_step = step + f.updated_at = _now() + task.updated_at = _now() + if step in {"preparing", "splitting", "extracting", "writing", "saving"} and task.status in {"queued", "preparing"}: + task.status = "running" + task.current_step = "running" + + async def _set_file_failed(self, task_id: str, file_id: str, error: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = "failed" + f.current_step = "failed" + f.error = str(error) + f.updated_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_file_cancelled(self, task_id: str, file_id: str, reason: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + f.status = "cancelled" + f.current_step = "cancelled" + f.error = reason + additional_cancelled = 0 + for chunk in f.chunks: + if chunk.status in {"completed", "failed", "cancelled"}: + continue + chunk.status = "cancelled" + chunk.step = "cancelled" + chunk.retryable = False + chunk.error = reason + chunk.progress = 1.0 + chunk.updated_at = _now() + additional_cancelled += 1 + if additional_cancelled > 0: + f.cancelled_chunks += additional_cancelled + f.progress = self._compute_ratio( + f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks + ) + f.updated_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_state( + self, + task_id: str, + file_id: str, + chunk_id: str, + status: str, + step: str, + progress: float, + ) -> None: + if status not in CHUNK_STATUS: + return + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c: + return + c.status = status + c.step = step + if status in {"queued", "extracting", "writing"}: + c.error = "" + c.failed_at = "" + c.retryable = False + c.progress = max(0.0, min(1.0, float(progress))) + c.updated_at = _now() + if f.status not in {"failed", "cancelled"}: + f.status = "extracting" if status == "extracting" else "writing" + f.current_step = step + f.updated_at = _now() + task.updated_at = _now() + + async def _set_chunk_completed(self, task_id: str, file_id: str, chunk_id: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "completed": + return + c.status = "completed" + c.step = "completed" + c.failed_at = "" + c.retryable = False + c.progress = 1.0 + c.updated_at = _now() + f.done_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_failed(self, task_id: str, file_id: str, chunk_id: str, error: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "failed": + return + failed_stage = str(c.step or "").strip().lower() + if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: + failed_stage = str(f.current_step or "").strip().lower() + if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: + failed_stage = "unknown" + c.status = "failed" + c.step = "failed" + c.failed_at = failed_stage + c.retryable = bool(f.input_mode == "text" and failed_stage == "extracting") + c.error = str(error) + c.progress = 1.0 + c.updated_at = _now() + f.failed_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + if not f.error: + f.error = str(error) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _set_chunk_cancelled(self, task_id: str, file_id: str, chunk_id: str, reason: str) -> None: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return + f = self._find_file(task, file_id) + if not f: + return + c = self._find_chunk(f, chunk_id) + if not c or c.status == "cancelled": + return + c.status = "cancelled" + c.step = "cancelled" + c.retryable = False + c.error = reason + c.progress = 1.0 + c.updated_at = _now() + f.cancelled_chunks += 1 + f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) + f.updated_at = _now() + self._recompute_task_progress(task) + + async def _is_cancel_requested(self, task_id: str) -> bool: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return True + return task.status == "cancel_requested" + + def _find_file(self, task: ImportTaskRecord, file_id: str) -> Optional[ImportFileRecord]: + for f in task.files: + if f.file_id == file_id: + return f + return None + + def _find_chunk(self, file_record: ImportFileRecord, chunk_id: str) -> Optional[ImportChunkRecord]: + for c in file_record.chunks: + if c.chunk_id == chunk_id: + return c + return None + + def _compute_ratio(self, done: int, total: int) -> float: + if total <= 0: + return 1.0 + return max(0.0, min(1.0, float(done) / float(total))) + + def _recompute_task_progress(self, task: ImportTaskRecord) -> None: + total = 0 + done = 0 + failed = 0 + cancelled = 0 + for f in task.files: + total += f.total_chunks + done += f.done_chunks + failed += f.failed_chunks + cancelled += f.cancelled_chunks + task.total_chunks = total + task.done_chunks = done + task.failed_chunks = failed + task.cancelled_chunks = cancelled + task.progress = self._compute_ratio(done + failed + cancelled, total) + task.updated_at = _now() + + async def _should_cleanup_task_temp(self, task_id: str) -> bool: + async with self._lock: + task = self._tasks.get(task_id) + if not task: + return True + for f in task.files: + if f.status == "failed": + return False + return True + + def _mark_task_cancelled_locked(self, task: ImportTaskRecord, reason: str) -> None: + for f in task.files: + if f.status in {"completed", "failed", "cancelled"}: + continue + f.status = "cancelled" + f.current_step = "cancelled" + f.error = reason + additional_cancelled = 0 + for c in f.chunks: + if c.status in {"completed", "failed", "cancelled"}: + continue + c.status = "cancelled" + c.step = "cancelled" + c.retryable = False + c.error = reason + c.progress = 1.0 + c.updated_at = _now() + additional_cancelled += 1 + if additional_cancelled > 0: + f.cancelled_chunks += additional_cancelled + f.progress = self._compute_ratio( + f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks + ) + f.updated_at = _now() + task.status = "cancelled" + task.current_step = "cancelled" + task.finished_at = _now() + task.updated_at = _now() + self._recompute_task_progress(task) diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py index 56df45b9..390515f5 100644 --- a/plugins/A_memorix/plugin.py +++ b/plugins/A_memorix/plugin.py @@ -15,6 +15,12 @@ def _tool_param(name: str, param_type: ToolParamType, description: str, required return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) +_ADMIN_TOOL_PARAMS = [ + _tool_param("action", ToolParamType.STRING, "管理动作", True), + _tool_param("target", ToolParamType.STRING, "可选目标标识", False), +] + + class AMemorixPlugin(MaiBotPlugin): def __init__(self) -> None: super().__init__() @@ -33,7 +39,11 @@ class AMemorixPlugin(MaiBotPlugin): async def on_unload(self): if self._kernel is not None: - self._kernel.close() + shutdown = getattr(self._kernel, "shutdown", None) + if callable(shutdown): + await shutdown() + else: + self._kernel.close() self._kernel = None async def _get_kernel(self) -> SDKMemoryKernel: @@ -42,6 +52,11 @@ class AMemorixPlugin(MaiBotPlugin): await self._kernel.initialize() return self._kernel + async def _dispatch_admin_tool(self, method_name: str, action: str, **kwargs): + kernel = await self._get_kernel() + handler = getattr(kernel, method_name) + return await handler(action=action, **kwargs) + @Tool( "search_memory", description="搜索长期记忆", @@ -53,6 +68,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_search_memory( @@ -62,11 +78,11 @@ class AMemorixPlugin(MaiBotPlugin): mode: str = "hybrid", chat_id: str = "", person_id: str = "", - time_start: float | None = None, - time_end: float | None = None, + time_start: str | float | None = None, + time_end: str | float | None = None, + respect_filter: bool = True, **kwargs, ): - _ = kwargs kernel = await self._get_kernel() return await kernel.search_memory( KernelSearchRequest( @@ -77,6 +93,9 @@ class AMemorixPlugin(MaiBotPlugin): person_id=person_id, time_start=time_start, time_end=time_end, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) ) @@ -89,6 +108,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("text", ToolParamType.STRING, "摘要文本", True), _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_ingest_summary( @@ -101,9 +121,9 @@ class AMemorixPlugin(MaiBotPlugin): time_end: float | None = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, **kwargs, ): - _ = kwargs kernel = await self._get_kernel() return await kernel.ingest_summary( external_id=external_id, @@ -114,6 +134,9 @@ class AMemorixPlugin(MaiBotPlugin): time_end=time_end, tags=tags, metadata=metadata, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) @Tool( @@ -125,6 +148,7 @@ class AMemorixPlugin(MaiBotPlugin): _tool_param("text", ToolParamType.STRING, "原始文本", True), _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), + _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), ], ) async def handle_ingest_text( @@ -140,6 +164,7 @@ class AMemorixPlugin(MaiBotPlugin): time_end: float | None = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, + respect_filter: bool = True, **kwargs, ): relations = kwargs.get("relations") @@ -159,6 +184,9 @@ class AMemorixPlugin(MaiBotPlugin): metadata=metadata, entities=entities, relations=relations, + respect_filter=respect_filter, + user_id=str(kwargs.get("user_id", "") or "").strip(), + group_id=str(kwargs.get("group_id", "") or "").strip(), ) @Tool( @@ -179,22 +207,24 @@ class AMemorixPlugin(MaiBotPlugin): "maintain_memory", description="维护长期记忆关系状态", parameters=[ - _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True), - _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True), + _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore/freeze/recycle_bin", True), + _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", False), _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), + _tool_param("limit", ToolParamType.INTEGER, "查询条数(用于 recycle_bin)", False), ], ) async def handle_maintain_memory( self, action: str, - target: str, + target: str = "", hours: float | None = None, reason: str = "", + limit: int = 50, **kwargs, ): _ = kwargs kernel = await self._get_kernel() - return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason) + return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason, limit=limit) @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) async def handle_memory_stats(self, **kwargs): @@ -202,6 +232,42 @@ class AMemorixPlugin(MaiBotPlugin): kernel = await self._get_kernel() return kernel.memory_stats() + @Tool("memory_graph_admin", description="长期记忆图谱管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_graph_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_graph_admin", action=action, **kwargs) + + @Tool("memory_source_admin", description="长期记忆来源管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_source_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_source_admin", action=action, **kwargs) + + @Tool("memory_episode_admin", description="Episode 管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_episode_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_episode_admin", action=action, **kwargs) + + @Tool("memory_profile_admin", description="人物画像管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_profile_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_profile_admin", action=action, **kwargs) + + @Tool("memory_runtime_admin", description="长期记忆运行时管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_runtime_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_runtime_admin", action=action, **kwargs) + + @Tool("memory_import_admin", description="长期记忆导入管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_import_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_import_admin", action=action, **kwargs) + + @Tool("memory_tuning_admin", description="长期记忆调优管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_tuning_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_tuning_admin", action=action, **kwargs) + + @Tool("memory_v5_admin", description="长期记忆 V5 管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_v5_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_v5_admin", action=action, **kwargs) + + @Tool("memory_delete_admin", description="长期记忆删除管理接口", parameters=_ADMIN_TOOL_PARAMS) + async def handle_memory_delete_admin(self, action: str, **kwargs): + return await self._dispatch_admin_tool("memory_delete_admin", action=action, **kwargs) + def create_plugin(): return AMemorixPlugin() diff --git a/plugins/A_memorix/requirements.txt b/plugins/A_memorix/requirements.txt new file mode 100644 index 00000000..f737fdf4 --- /dev/null +++ b/plugins/A_memorix/requirements.txt @@ -0,0 +1,52 @@ +# A_Memorix 插件依赖 +# +# 核心依赖 (必需) +# ================== + +# 数值计算 - 用于向量操作、矩阵计算 +numpy>=1.20.0 + +# 稀疏矩阵 - 用于图存储的邻接矩阵 +scipy>=1.7.0 + +# 图结构处理(LPMM 转换) +networkx>=3.0.0 + +# Parquet 读取(LPMM 转换) +pyarrow>=10.0.0 + +# DataFrame 处理(LPMM 转换) +pandas>=1.5.0 + +# 异步事件循环嵌套 - 用于插件初始化时的异步操作 +nest-asyncio>=1.5.0 + +# 向量索引 - 用于向量存储和检索 +faiss-cpu>=1.7.0 + +# Web 服务器依赖 (可视化功能需要) +# ================== + +# ASGI 服务器 +uvicorn>=0.20.0 + +# Web 框架 +fastapi>=0.100.0 + +# 数据验证 +pydantic>=2.0.0 +python-multipart>=0.0.9 + +# 注意事项 +# ================== +# +# 1. sqlite3 是 Python 标准库,无需安装 +# 2. json, re, time, pathlib 等都是标准库 +# 3. sentence-transformers 不需要(使用主程序 Embedding API) + +# UI 交互 +rich>=14.0.0 +tenacity>=8.0.0 + +# 稀疏检索中文分词(可选,未安装时自动回退 char n-gram) +jieba>=0.42.1 diff --git a/plugins/A_memorix/scripts/audit_vector_consistency.py b/plugins/A_memorix/scripts/audit_vector_consistency.py new file mode 100644 index 00000000..c97806dc --- /dev/null +++ b/plugins/A_memorix/scripts/audit_vector_consistency.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +A_Memorix 一致性审计脚本。 + +输出内容: +1. paragraph/entity/relation 向量覆盖率 +2. relation vector_state 分布 +3. 孤儿向量数量(向量存在但 metadata 不存在) +4. 状态与向量文件不一致统计 +""" + +from __future__ import annotations + +import argparse +import json +import pickle +import sys +from pathlib import Path +from typing import Any, Dict, Set + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="审计 A_Memorix 向量一致性") + parser.add_argument( + "--data-dir", + default=str(PLUGIN_ROOT / "data"), + help="A_Memorix 数据目录(默认: plugins/A_memorix/data)", + ) + parser.add_argument("--json-out", default="", help="可选:输出 JSON 文件路径") + parser.add_argument( + "--strict", + action="store_true", + help="若发现一致性异常则返回非 0 退出码", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + +try: + from core.storage.vector_store import VectorStore + from core.storage.metadata_store import MetadataStore + from core.storage import QuantizationType +except Exception as e: # pragma: no cover + print(f"❌ 导入核心模块失败: {e}") + sys.exit(1) + + +def _safe_ratio(numerator: int, denominator: int) -> float: + if denominator <= 0: + return 0.0 + return float(numerator) / float(denominator) + + +def _load_vector_store(data_dir: Path) -> VectorStore: + meta_path = data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + raise FileNotFoundError(f"未找到向量元数据文件: {meta_path}") + + with open(meta_path, "rb") as f: + meta = pickle.load(f) + dimension = int(meta.get("dimension", 1024)) + + store = VectorStore( + dimension=max(1, dimension), + quantization_type=QuantizationType.INT8, + data_dir=data_dir / "vectors", + ) + if store.has_data(): + store.load() + return store + + +def _load_metadata_store(data_dir: Path) -> MetadataStore: + store = MetadataStore(data_dir=data_dir / "metadata") + store.connect() + return store + + +def _hash_set(metadata_store: MetadataStore, table: str) -> Set[str]: + return {str(h) for h in metadata_store.list_hashes(table)} + + +def _relation_state_stats(metadata_store: MetadataStore) -> Dict[str, int]: + return metadata_store.count_relations_by_vector_state() + + +def run_audit(data_dir: Path) -> Dict[str, Any]: + vector_store = _load_vector_store(data_dir) + metadata_store = _load_metadata_store(data_dir) + try: + paragraph_hashes = _hash_set(metadata_store, "paragraphs") + entity_hashes = _hash_set(metadata_store, "entities") + relation_hashes = _hash_set(metadata_store, "relations") + + known_hashes = set(getattr(vector_store, "_known_hashes", set())) + live_vector_hashes = {h for h in known_hashes if h in vector_store} + + para_vector_hits = len(paragraph_hashes & live_vector_hashes) + ent_vector_hits = len(entity_hashes & live_vector_hashes) + rel_vector_hits = len(relation_hashes & live_vector_hashes) + + orphan_vector_hashes = sorted( + live_vector_hashes - paragraph_hashes - entity_hashes - relation_hashes + ) + + relation_rows = metadata_store.get_relations() + ready_but_missing = 0 + not_ready_but_present = 0 + for row in relation_rows: + h = str(row.get("hash") or "") + state = str(row.get("vector_state") or "none").lower() + in_vector = h in live_vector_hashes + if state == "ready" and not in_vector: + ready_but_missing += 1 + if state != "ready" and in_vector: + not_ready_but_present += 1 + + relation_states = _relation_state_stats(metadata_store) + rel_total = max(0, int(relation_states.get("total", len(relation_hashes)))) + ready_count = max(0, int(relation_states.get("ready", 0))) + + result = { + "counts": { + "paragraphs": len(paragraph_hashes), + "entities": len(entity_hashes), + "relations": len(relation_hashes), + "vectors_live": len(live_vector_hashes), + }, + "coverage": { + "paragraph_vector_coverage": _safe_ratio(para_vector_hits, len(paragraph_hashes)), + "entity_vector_coverage": _safe_ratio(ent_vector_hits, len(entity_hashes)), + "relation_vector_coverage": _safe_ratio(rel_vector_hits, len(relation_hashes)), + "relation_ready_coverage": _safe_ratio(ready_count, rel_total), + }, + "relation_states": relation_states, + "orphans": { + "vector_only_count": len(orphan_vector_hashes), + "vector_only_sample": orphan_vector_hashes[:30], + }, + "consistency_checks": { + "ready_but_missing_vector": ready_but_missing, + "not_ready_but_vector_present": not_ready_but_present, + }, + } + return result + finally: + metadata_store.close() + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + + data_dir = Path(args.data_dir).resolve() + if not data_dir.exists(): + print(f"❌ 数据目录不存在: {data_dir}") + return 2 + + try: + result = run_audit(data_dir) + except Exception as e: + print(f"❌ 审计失败: {e}") + return 2 + + print("=== A_Memorix Vector Consistency Audit ===") + print(f"data_dir: {data_dir}") + print(f"paragraphs: {result['counts']['paragraphs']}") + print(f"entities: {result['counts']['entities']}") + print(f"relations: {result['counts']['relations']}") + print(f"vectors_live: {result['counts']['vectors_live']}") + print( + "coverage: " + f"paragraph={result['coverage']['paragraph_vector_coverage']:.3f}, " + f"entity={result['coverage']['entity_vector_coverage']:.3f}, " + f"relation={result['coverage']['relation_vector_coverage']:.3f}, " + f"relation_ready={result['coverage']['relation_ready_coverage']:.3f}" + ) + print(f"relation_states: {result['relation_states']}") + print( + "consistency_checks: " + f"ready_but_missing_vector={result['consistency_checks']['ready_but_missing_vector']}, " + f"not_ready_but_vector_present={result['consistency_checks']['not_ready_but_vector_present']}" + ) + print(f"orphan_vectors: {result['orphans']['vector_only_count']}") + + if args.json_out: + out_path = Path(args.json_out).resolve() + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + json.dump(result, f, ensure_ascii=False, indent=2) + print(f"json_out: {out_path}") + + has_anomaly = ( + result["orphans"]["vector_only_count"] > 0 + or result["consistency_checks"]["ready_but_missing_vector"] > 0 + ) + if args.strict and has_anomaly: + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/backfill_relation_vectors.py b/plugins/A_memorix/scripts/backfill_relation_vectors.py new file mode 100644 index 00000000..7ba0ade0 --- /dev/null +++ b/plugins/A_memorix/scripts/backfill_relation_vectors.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +""" +关系向量一次性回填脚本(灰度/离线执行)。 + +用途: +1. 对 relations 中 vector_state in (none, failed, pending) 的记录补齐向量。 +2. 支持并发控制,降低总耗时。 +3. 可作为灰度阶段验证工具,与 audit_vector_consistency.py 配合使用。 +4. 可选自动纳入“ready 但向量缺失”的漂移记录进行修复。 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import time +from pathlib import Path +from typing import Any, Dict, List + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="关系向量一次性回填") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="配置文件路径(默认 plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default=str(PLUGIN_ROOT / "data"), + help="数据目录(默认 plugins/A_memorix/data)", + ) + parser.add_argument( + "--states", + default="none,failed,pending", + help="待处理状态列表,逗号分隔", + ) + parser.add_argument("--limit", type=int, default=50000, help="最大处理数量") + parser.add_argument("--concurrency", type=int, default=8, help="并发数") + parser.add_argument("--max-retry", type=int, default=None, help="最大重试次数过滤") + parser.add_argument( + "--include-ready-missing", + action="store_true", + help="额外纳入 vector_state=ready 但向量缺失的关系", + ) + parser.add_argument("--dry-run", action="store_true", help="仅统计候选,不写入") + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +from core.storage import ( + VectorStore, + GraphStore, + MetadataStore, + QuantizationType, + SparseMatrixFormat, +) +from core.embedding import create_embedding_api_adapter +from core.utils.relation_write_service import RelationWriteService + + +def _load_config(config_path: Path) -> Dict[str, Any]: + with open(config_path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + return dict(raw) if isinstance(raw, dict) else {} + + +def _build_vector_store(data_dir: Path, emb_cfg: Dict[str, Any]) -> VectorStore: + q_type = str(emb_cfg.get("quantization_type", "int8")).lower() + if q_type != "int8": + raise ValueError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + dim = int(emb_cfg.get("dimension", 1024)) + store = VectorStore( + dimension=max(1, dim), + quantization_type=QuantizationType.INT8, + data_dir=data_dir / "vectors", + ) + if store.has_data(): + store.load() + return store + + +def _build_graph_store(data_dir: Path, graph_cfg: Dict[str, Any]) -> GraphStore: + fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() + fmt_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + store = GraphStore( + matrix_format=fmt_map.get(fmt, SparseMatrixFormat.CSR), + data_dir=data_dir / "graph", + ) + if store.has_data(): + store.load() + return store + + +def _build_metadata_store(data_dir: Path) -> MetadataStore: + store = MetadataStore(data_dir=data_dir / "metadata") + store.connect() + return store + + +def _build_embedding_manager(emb_cfg: Dict[str, Any]): + retry_cfg = emb_cfg.get("retry", {}) + if not isinstance(retry_cfg, dict): + retry_cfg = {} + return create_embedding_api_adapter( + batch_size=int(emb_cfg.get("batch_size", 32)), + max_concurrent=int(emb_cfg.get("max_concurrent", 5)), + default_dimension=int(emb_cfg.get("dimension", 1024)), + model_name=str(emb_cfg.get("model_name", "auto")), + retry_config=retry_cfg, + ) + + +async def _process_rows( + service: RelationWriteService, + rows: List[Dict[str, Any]], + concurrency: int, +) -> Dict[str, int]: + semaphore = asyncio.Semaphore(max(1, int(concurrency))) + stat = {"success": 0, "failed": 0, "skipped": 0} + + async def _worker(row: Dict[str, Any]) -> None: + async with semaphore: + result = await service.ensure_relation_vector( + hash_value=str(row["hash"]), + subject=str(row.get("subject", "")), + predicate=str(row.get("predicate", "")), + obj=str(row.get("object", "")), + ) + if result.vector_state == "ready": + if result.vector_written: + stat["success"] += 1 + else: + stat["skipped"] += 1 + else: + stat["failed"] += 1 + + await asyncio.gather(*[_worker(row) for row in rows]) + return stat + + +async def main_async(args: argparse.Namespace) -> int: + config_path = Path(args.config).resolve() + if not config_path.exists(): + print(f"❌ 配置文件不存在: {config_path}") + return 2 + + cfg = _load_config(config_path) + emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} + graph_cfg = cfg.get("graph", {}) if isinstance(cfg, dict) else {} + retrieval_cfg = cfg.get("retrieval", {}) if isinstance(cfg, dict) else {} + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) if isinstance(retrieval_cfg, dict) else {} + if not isinstance(emb_cfg, dict): + emb_cfg = {} + if not isinstance(graph_cfg, dict): + graph_cfg = {} + if not isinstance(rv_cfg, dict): + rv_cfg = {} + + data_dir = Path(args.data_dir).resolve() + if not data_dir.exists(): + print(f"❌ 数据目录不存在: {data_dir}") + return 2 + + print(f"data_dir: {data_dir}") + print(f"config: {config_path}") + + vector_store = _build_vector_store(data_dir, emb_cfg) + graph_store = _build_graph_store(data_dir, graph_cfg) + metadata_store = _build_metadata_store(data_dir) + embedding_manager = _build_embedding_manager(emb_cfg) + service = RelationWriteService( + metadata_store=metadata_store, + graph_store=graph_store, + vector_store=vector_store, + embedding_manager=embedding_manager, + ) + + try: + states = [s.strip() for s in str(args.states).split(",") if s.strip()] + if not states: + states = ["none", "failed", "pending"] + max_retry = int(args.max_retry) if args.max_retry is not None else int(rv_cfg.get("max_retry", 3)) + limit = int(args.limit) + + rows = metadata_store.list_relations_by_vector_state( + states=states, + limit=max(1, limit), + max_retry=max(1, max_retry), + ) + added_ready_missing = 0 + if args.include_ready_missing: + ready_rows = metadata_store.list_relations_by_vector_state( + states=["ready"], + limit=max(1, limit), + max_retry=max(1, max_retry), + ) + ready_missing_rows = [ + row for row in ready_rows if str(row.get("hash", "")) not in vector_store + ] + added_ready_missing = len(ready_missing_rows) + if ready_missing_rows: + dedup: Dict[str, Dict[str, Any]] = {} + for row in rows: + dedup[str(row.get("hash", ""))] = row + for row in ready_missing_rows: + dedup.setdefault(str(row.get("hash", "")), row) + rows = list(dedup.values())[: max(1, limit)] + print(f"candidates: {len(rows)} (states={states}, max_retry={max_retry})") + if args.include_ready_missing: + print(f"ready_missing_candidates_added: {added_ready_missing}") + if not rows: + return 0 + + if args.dry_run: + print("dry_run=true,未执行写入。") + return 0 + + started = time.time() + stat = await _process_rows( + service=service, + rows=rows, + concurrency=int(args.concurrency), + ) + elapsed = (time.time() - started) * 1000.0 + + vector_store.save() + graph_store.save() + state_stats = metadata_store.count_relations_by_vector_state() + output = { + "processed": len(rows), + "success": int(stat["success"]), + "failed": int(stat["failed"]), + "skipped": int(stat["skipped"]), + "elapsed_ms": elapsed, + "state_stats": state_stats, + } + print(json.dumps(output, ensure_ascii=False, indent=2)) + return 0 if stat["failed"] == 0 else 1 + finally: + metadata_store.close() + + +def parse_args() -> argparse.Namespace: + return _build_arg_parser().parse_args() + + +if __name__ == "__main__": + arguments = parse_args() + raise SystemExit(asyncio.run(main_async(arguments))) diff --git a/plugins/A_memorix/scripts/backfill_temporal_metadata.py b/plugins/A_memorix/scripts/backfill_temporal_metadata.py new file mode 100644 index 00000000..b68820cd --- /dev/null +++ b/plugins/A_memorix/scripts/backfill_temporal_metadata.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +回填段落时序字段。 + +默认策略: +1. 若段落缺失 event_time/event_time_start/event_time_end +2. 且存在 created_at +3. 写入 event_time=created_at, time_granularity=day, time_confidence=0.2 +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from plugins.A_memorix.core.storage import MetadataStore # noqa: E402 + + +def backfill( + data_dir: Path, + dry_run: bool, + limit: int, + no_created_fallback: bool, +) -> int: + store = MetadataStore(data_dir=data_dir) + store.connect() + summary = store.backfill_temporal_metadata_from_created_at( + limit=limit, + dry_run=dry_run, + no_created_fallback=no_created_fallback, + ) + store.close() + if dry_run: + print(f"[dry-run] candidates={summary['candidates']}") + return int(summary["candidates"]) + if no_created_fallback: + print(f"skip update (no-created-fallback), candidates={summary['candidates']}") + return 0 + print(f"updated={summary['updated']}") + return int(summary["updated"]) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Backfill temporal metadata for A_Memorix paragraphs") + parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="数据目录") + parser.add_argument("--dry-run", action="store_true", help="仅统计,不写入") + parser.add_argument("--limit", type=int, default=100000, help="最大处理条数") + parser.add_argument( + "--no-created-fallback", + action="store_true", + help="不使用 created_at 回填,仅输出候选数量", + ) + args = parser.parse_args() + + backfill( + data_dir=Path(args.data_dir), + dry_run=args.dry_run, + limit=max(1, int(args.limit)), + no_created_fallback=args.no_created_fallback, + ) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) + diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py index 5ff284fb..2ef0b396 100644 --- a/plugins/A_memorix/scripts/convert_lpmm.py +++ b/plugins/A_memorix/scripts/convert_lpmm.py @@ -46,9 +46,14 @@ if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): _build_arg_parser().print_help() sys.exit(0) -# 设置日志 -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") -logger = logging.getLogger("LPMM_Converter") +# 设置日志:优先复用 MaiBot 统一日志体系,失败时回退到标准 logging。 +try: + from src.common.logger import get_logger + + logger = get_logger("A_Memorix.LPMMConverter") +except Exception: + logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") + logger = logging.getLogger("A_Memorix.LPMMConverter") try: import networkx as nx @@ -225,11 +230,11 @@ class LPMMConverter: failed += 1 logger.info( - "关系向量重建完成: total=%s success=%s skipped=%s failed=%s", - len(rows), - success, - skipped, - failed, + "关系向量重建完成: " + f"total={len(rows)} " + f"success={success} " + f"skipped={skipped} " + f"failed={failed}" ) @staticmethod @@ -317,8 +322,8 @@ class LPMMConverter: if p_type == "relation": relation_count = self._import_relation_metadata_from_parquet(p_path) logger.warning( - "跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s", - relation_count, + "跳过 relation.parquet 向量导入(保持一致性);" + f"已导入关系元数据: {relation_count}" ) continue diff --git a/plugins/A_memorix/scripts/import_lpmm_json.py b/plugins/A_memorix/scripts/import_lpmm_json.py new file mode 100644 index 00000000..2e458e16 --- /dev/null +++ b/plugins/A_memorix/scripts/import_lpmm_json.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +""" +LPMM OpenIE JSON 导入工具。 + +功能: +1. 读取符合 LPMM 规范的 OpenIE JSON 文件 +2. 转换为 A_Memorix 的统一导入格式 +3. 复用 `process_knowledge.py` 中的 `AutoImporter` 直接入库 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import traceback +from pathlib import Path +from typing import Any, Dict, List + +from rich.console import Console +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn + +console = Console() + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +for path in (CURRENT_DIR, WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 LPMM OpenIE JSON 导入 A_Memorix") + parser.add_argument("path", help="LPMM JSON 文件路径或目录") + parser.add_argument("--force", action="store_true", help="强制重新导入") + parser.add_argument("--concurrency", "-c", type=int, default=5, help="并发数") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + + +try: + from process_knowledge import AutoImporter + from A_memorix.core.utils.hash import compute_paragraph_hash + from src.common.logger import get_logger +except ImportError as exc: # pragma: no cover - script bootstrap + print(f"导入模块失败,请确认 PYTHONPATH 与工作区结构: {exc}") + raise SystemExit(1) + + +logger = get_logger("A_Memorix.LPMMImport") + + +class LPMMConverter: + def convert_lpmm_to_memorix(self, lpmm_data: Dict[str, Any], filename: str) -> Dict[str, Any]: + memorix_data = {"paragraphs": [], "entities": []} + docs = lpmm_data.get("docs", []) or [] + if not docs: + logger.warning(f"文件中未找到 docs 字段: {filename}") + return memorix_data + + all_entities = set() + for doc in docs: + content = str(doc.get("passage", "") or "").strip() + if not content: + continue + + relations: List[Dict[str, str]] = [] + for triple in doc.get("extracted_triples", []) or []: + if isinstance(triple, list) and len(triple) == 3: + relations.append( + { + "subject": str(triple[0] or "").strip(), + "predicate": str(triple[1] or "").strip(), + "object": str(triple[2] or "").strip(), + } + ) + + entities = [str(item or "").strip() for item in doc.get("extracted_entities", []) or [] if str(item or "").strip()] + all_entities.update(entities) + for relation in relations: + if relation["subject"]: + all_entities.add(relation["subject"]) + if relation["object"]: + all_entities.add(relation["object"]) + + memorix_data["paragraphs"].append( + { + "hash": compute_paragraph_hash(content), + "content": content, + "source": filename, + "entities": entities, + "relations": relations, + } + ) + + memorix_data["entities"] = sorted(all_entities) + return memorix_data + + +async def main() -> None: + parser = _build_arg_parser() + args = parser.parse_args() + + target_path = Path(args.path) + if not target_path.exists(): + logger.error(f"路径不存在: {target_path}") + return + + if target_path.is_dir(): + files_to_process = list(target_path.glob("*-openie.json")) or list(target_path.glob("*.json")) + else: + files_to_process = [target_path] + + if not files_to_process: + logger.error("未找到可处理的 JSON 文件") + return + + importer = AutoImporter(force=bool(args.force), concurrency=int(args.concurrency)) + if not await importer.initialize(): + logger.error("初始化存储失败") + return + + converter = LPMMConverter() + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console, + transient=False, + ) as progress: + for json_file in files_to_process: + logger.info(f"正在转换并导入: {json_file.name}") + try: + with open(json_file, "r", encoding="utf-8") as handle: + lpmm_data = json.load(handle) + memorix_data = converter.convert_lpmm_to_memorix(lpmm_data, json_file.name) + total_items = len(memorix_data.get("paragraphs", [])) + if total_items <= 0: + logger.warning(f"转换结果为空: {json_file.name}") + continue + + task_id = progress.add_task(f"Importing {json_file.name}", total=total_items) + + def update_progress(step: int = 1) -> None: + progress.advance(task_id, advance=step) + + await importer.import_json_data( + memorix_data, + filename=f"lpmm_{json_file.name}", + progress_callback=update_progress, + ) + except Exception as exc: + logger.error(f"处理文件 {json_file.name} 失败: {exc}\n{traceback.format_exc()}") + + await importer.close() + logger.info("全部处理完成") + + +if __name__ == "__main__": + if sys.platform == "win32": # pragma: no cover + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/migrate_maibot_memory.py b/plugins/A_memorix/scripts/migrate_maibot_memory.py new file mode 100644 index 00000000..0b26a9cd --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_maibot_memory.py @@ -0,0 +1,1714 @@ +#!/usr/bin/env python3 +""" +MaiBot 记忆迁移脚本(chat_history -> A_memorix) + +特性: +1. 高性能:分页读取 + 批量 embedding + 批量写入 +2. 断点续传:基于 last_committed_id 的窗口提交 +3. 精确一次语义:稳定哈希 + 幂等写入 + 向量存在性检查 +4. 可确认筛选:支持时间区间、聊天流(stream/group/user)筛选,并先预览后确认 +""" + +from __future__ import annotations + +import argparse +import asyncio +import hashlib +import importlib +import json +import logging +import os +import pickle +import sqlite3 +import sys +import time +import traceback +import types +from collections import defaultdict +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple + +import numpy as np +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +RUNTIME_CORE_PACKAGE = "_a_memorix_runtime_core" + +VectorStore = None +GraphStore = None +MetadataStore = None +create_embedding_api_adapter = None +KnowledgeType = None +QuantizationType = None +SparseMatrixFormat = None +compute_hash = None +normalize_text = None +atomic_write = None +model_config = None +RelationWriteService = None + + +def _create_bootstrap_logger(): + fallback = logging.getLogger("A_Memorix.MaiBotMigration") + if not fallback.handlers: + fallback.addHandler(logging.NullHandler()) + try: + for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + from src.common.logger import get_logger + + return get_logger("A_Memorix.MaiBotMigration") + except Exception: + return fallback + + +logger = _create_bootstrap_logger() + + +def _ensure_import_paths() -> None: + for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _ensure_runtime_core_package() -> str: + existing = sys.modules.get(RUNTIME_CORE_PACKAGE) + if existing is not None and hasattr(existing, "__path__"): + return RUNTIME_CORE_PACKAGE + + pkg = types.ModuleType(RUNTIME_CORE_PACKAGE) + pkg.__path__ = [str(PLUGIN_ROOT / "core")] + pkg.__package__ = RUNTIME_CORE_PACKAGE + sys.modules[RUNTIME_CORE_PACKAGE] = pkg + return RUNTIME_CORE_PACKAGE + + +def _disable_unavailable_gemini_provider() -> None: + global model_config + try: + from google import genai # type: ignore # noqa: F401 + return + except Exception: + pass + + from src.config.config import model_config as loaded_model_config + + providers = list(getattr(loaded_model_config, "api_providers", [])) + if not providers: + model_config = loaded_model_config + return + + kept_providers = [p for p in providers if str(getattr(p, "client_type", "")).lower() != "gemini"] + if len(kept_providers) == len(providers): + model_config = loaded_model_config + return + + loaded_model_config.api_providers = kept_providers + loaded_model_config.api_providers_dict = {p.name: p for p in kept_providers} + + models = list(getattr(loaded_model_config, "models", [])) + kept_models = [m for m in models if m.api_provider in loaded_model_config.api_providers_dict] + loaded_model_config.models = kept_models + loaded_model_config.models_dict = {m.name: m for m in kept_models} + + task_cfg = loaded_model_config.model_task_config + for field_name in task_cfg.__dataclass_fields__.keys(): + task = getattr(task_cfg, field_name, None) + if task is None or not hasattr(task, "model_list"): + continue + task.model_list = [m for m in list(task.model_list) if m in loaded_model_config.models_dict] + + model_config = loaded_model_config + logger.warning("检测到缺少 google.genai,已临时禁用 gemini provider 以保证脚本可运行。") + + +def _bootstrap_runtime_symbols() -> None: + global VectorStore + global GraphStore + global MetadataStore + global KnowledgeType + global QuantizationType + global SparseMatrixFormat + global compute_hash + global normalize_text + global atomic_write + global RelationWriteService + global logger + + if VectorStore is not None and compute_hash is not None and atomic_write is not None: + return + + _ensure_import_paths() + + import src # noqa: F401 + from src.common.logger import get_logger + + logger = get_logger("A_Memorix.MaiBotMigration") + + pkg = _ensure_runtime_core_package() + + vector_store_module = importlib.import_module(f"{pkg}.storage.vector_store") + graph_store_module = importlib.import_module(f"{pkg}.storage.graph_store") + metadata_store_module = importlib.import_module(f"{pkg}.storage.metadata_store") + knowledge_types_module = importlib.import_module(f"{pkg}.storage.knowledge_types") + hash_module = importlib.import_module(f"{pkg}.utils.hash") + io_module = importlib.import_module(f"{pkg}.utils.io") + relation_write_service_module = importlib.import_module(f"{pkg}.utils.relation_write_service") + + VectorStore = vector_store_module.VectorStore + GraphStore = graph_store_module.GraphStore + MetadataStore = metadata_store_module.MetadataStore + KnowledgeType = knowledge_types_module.KnowledgeType + QuantizationType = vector_store_module.QuantizationType + SparseMatrixFormat = graph_store_module.SparseMatrixFormat + compute_hash = hash_module.compute_hash + normalize_text = hash_module.normalize_text + atomic_write = io_module.atomic_write + RelationWriteService = relation_write_service_module.RelationWriteService + + +def _load_embedding_adapter_factory() -> None: + global create_embedding_api_adapter + global model_config + + if create_embedding_api_adapter is not None: + return + + _ensure_import_paths() + + from src.config.config import model_config as loaded_model_config + + model_config = loaded_model_config + _disable_unavailable_gemini_provider() + + pkg = _ensure_runtime_core_package() + api_adapter_module = importlib.import_module(f"{pkg}.embedding.api_adapter") + create_embedding_api_adapter = api_adapter_module.create_embedding_api_adapter + + +DEFAULT_SOURCE_DB = MAIBOT_ROOT / "data" / "MaiBot.db" +DEFAULT_TARGET_DATA_DIR = PLUGIN_ROOT / "data" +DEFAULT_CONFIG_PATH = PLUGIN_ROOT / "config.toml" + +MIGRATION_STATE_DIRNAME = "migration_state" +STATE_FILENAME = "chat_history_resume.json" +BAD_ROWS_FILENAME = "chat_history_bad_rows.jsonl" +REPORT_FILENAME = "chat_history_report.json" + + +class MigrationError(Exception): + """迁移流程错误。""" + + +@dataclass +class SelectionFilter: + time_from_ts: Optional[float] + time_to_ts: Optional[float] + stream_ids: List[str] + stream_filter_requested: bool + start_id: Optional[int] + end_id: Optional[int] + time_from_raw: Optional[str] + time_to_raw: Optional[str] + + def fingerprint_payload(self) -> Dict[str, Any]: + return { + "time_from_ts": self.time_from_ts, + "time_to_ts": self.time_to_ts, + "time_from_raw": self.time_from_raw, + "time_to_raw": self.time_to_raw, + "stream_ids": sorted(self.stream_ids), + "stream_filter_requested": self.stream_filter_requested, + "start_id": self.start_id, + "end_id": self.end_id, + } + + +@dataclass +class PreviewResult: + total: int + distribution: List[Tuple[str, int]] + samples: List[Dict[str, Any]] + + +@dataclass +class MappedRow: + row_id: int + chat_id: str + paragraph_hash: str + content: str + source: str + time_meta: Dict[str, Any] + entities: List[str] + relations: List[Tuple[str, str, str]] + existing_paragraph_vector: bool + + +def _safe_int(value: Any, default: int) -> int: + try: + return int(value) + except Exception: + return default + + +def _safe_float(value: Any, default: float) -> float: + try: + return float(value) + except Exception: + return default + + +def _normalize_name(value: Any) -> str: + return str(value or "").strip() + + +def _canonical_name(value: Any) -> str: + return _normalize_name(value).lower() + + +def _dedup_keep_order(items: Iterable[str]) -> List[str]: + out: List[str] = [] + seen: set[str] = set() + for raw in items: + v = _normalize_name(raw) + if not v: + continue + k = v.lower() + if k in seen: + continue + seen.add(k) + out.append(v) + return out + + +def _format_ts(ts: Optional[float]) -> str: + if ts is None: + return "-" + try: + return datetime.fromtimestamp(float(ts)).strftime("%Y-%m-%d %H:%M:%S") + except Exception: + return str(ts) + + +def _parse_cli_datetime(text: str, is_end: bool = False) -> float: + value = str(text or "").strip() + if not value: + raise ValueError("时间不能为空") + + formats = [ + ("%Y-%m-%d %H:%M:%S", False), + ("%Y/%m/%d %H:%M:%S", False), + ("%Y-%m-%d %H:%M", False), + ("%Y/%m/%d %H:%M", False), + ("%Y-%m-%d", True), + ("%Y/%m/%d", True), + ] + + for fmt, is_date_only in formats: + try: + dt = datetime.strptime(value, fmt) + if is_date_only and is_end: + dt = dt.replace(hour=23, minute=59, second=59, microsecond=0) + return dt.timestamp() + except ValueError: + continue + + raise ValueError( + f"时间格式错误: {value},仅支持 YYYY-MM-DD、YYYY/MM/DD、YYYY-MM-DD HH:mm[:ss]、YYYY/MM/DD HH:mm[:ss]" + ) + + +def _json_hash(payload: Dict[str, Any]) -> str: + data = json.dumps(payload, ensure_ascii=False, sort_keys=True) + return hashlib.sha1(data.encode("utf-8")).hexdigest() + + +def _deep_merge_dict(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: + out = dict(base) + for key, value in override.items(): + if isinstance(value, dict) and isinstance(out.get(key), dict): + out[key] = _deep_merge_dict(out[key], value) + else: + out[key] = value + return out + + +def _extract_schema_defaults(schema_obj: Dict[str, Any]) -> Dict[str, Any]: + defaults: Dict[str, Any] = {} + if not isinstance(schema_obj, dict): + return defaults + + for key, spec in schema_obj.items(): + if not isinstance(spec, dict): + continue + if "default" in spec: + defaults[key] = spec.get("default") + continue + props = spec.get("properties") + if isinstance(props, dict): + defaults[key] = _extract_schema_defaults(props) + return defaults + + +def _load_manifest_defaults() -> Dict[str, Any]: + manifest_path = PLUGIN_ROOT / "_manifest.json" + if not manifest_path.exists(): + return {} + try: + with open(manifest_path, "r", encoding="utf-8") as f: + payload = json.load(f) + schema = payload.get("config_schema") + if isinstance(schema, dict): + return _extract_schema_defaults(schema) + except Exception as e: + logger.warning(f"读取 manifest 默认配置失败,已回退空配置: {e}") + return {} + + +def _build_source_db_fingerprint(db_path: Path) -> Dict[str, Any]: + stat = db_path.stat() + payload = { + "path": str(db_path.resolve()), + "size": stat.st_size, + "mtime": stat.st_mtime, + } + payload["sha1"] = _json_hash(payload) + return payload + + +def _state_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / STATE_FILENAME + + +def _bad_rows_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / BAD_ROWS_FILENAME + + +def _report_path(target_data_dir: Path) -> Path: + return target_data_dir / MIGRATION_STATE_DIRNAME / REPORT_FILENAME + + +def _dump_json_atomic(path: Path, payload: Dict[str, Any]) -> None: + if atomic_write is None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp = path.with_suffix(path.suffix + ".tmp") + with open(tmp, "w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + f.write("\n") + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + return + + with atomic_write(path, mode="w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2) + f.write("\n") + + +class SourceDB: + def __init__(self, db_path: Path): + self.db_path = db_path + self.conn: Optional[sqlite3.Connection] = None + + def connect(self) -> None: + if not self.db_path.exists(): + raise MigrationError(f"源数据库不存在: {self.db_path}") + + uri = f"file:{self.db_path.resolve().as_posix()}?mode=ro" + try: + self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False) + except sqlite3.OperationalError: + self.conn = sqlite3.connect(str(self.db_path.resolve()), check_same_thread=False) + + self.conn.row_factory = sqlite3.Row + pragmas = [ + "PRAGMA query_only = ON", + "PRAGMA cache_size = -128000", + "PRAGMA temp_store = MEMORY", + "PRAGMA synchronous = OFF", + "PRAGMA journal_mode = WAL", + ] + for sql in pragmas: + try: + self.conn.execute(sql) + except sqlite3.OperationalError: + # 部分 PRAGMA 在 mode=ro 下会失败,不影响只读扫描能力 + continue + + def close(self) -> None: + if self.conn is not None: + self.conn.close() + self.conn = None + + def _require_conn(self) -> sqlite3.Connection: + if self.conn is None: + raise MigrationError("源数据库尚未连接") + return self.conn + + def resolve_stream_ids( + self, + stream_ids: Sequence[str], + group_ids: Sequence[str], + user_ids: Sequence[str], + ) -> List[str]: + conn = self._require_conn() + resolved: set[str] = set(_normalize_name(x) for x in stream_ids if _normalize_name(x)) + has_group_or_user = any(_normalize_name(x) for x in group_ids) or any(_normalize_name(x) for x in user_ids) + if not has_group_or_user: + return sorted(resolved) + + table_exists = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_streams' LIMIT 1" + ).fetchone() + if table_exists is None: + raise MigrationError("源库缺少 chat_streams 表,无法根据 --group-id/--user-id 映射 stream_id") + + def _select_by_field(field: str, values: Sequence[str]) -> None: + values_norm = [_normalize_name(v) for v in values if _normalize_name(v)] + if not values_norm: + return + placeholders = ",".join("?" for _ in values_norm) + sql = f"SELECT DISTINCT stream_id FROM chat_streams WHERE {field} IN ({placeholders})" + cur = conn.execute(sql, tuple(values_norm)) + for row in cur.fetchall(): + sid = _normalize_name(row["stream_id"]) + if sid: + resolved.add(sid) + + _select_by_field("group_id", group_ids) + _select_by_field("user_id", user_ids) + return sorted(resolved) + + @staticmethod + def _build_where( + selection: SelectionFilter, + start_after_id: Optional[int] = None, + ) -> Tuple[str, List[Any]]: + conditions: List[str] = [] + params: List[Any] = [] + + if selection.start_id is not None: + conditions.append("id >= ?") + params.append(selection.start_id) + if selection.end_id is not None: + conditions.append("id <= ?") + params.append(selection.end_id) + if start_after_id is not None: + conditions.append("id > ?") + params.append(start_after_id) + + if selection.stream_ids: + placeholders = ",".join("?" for _ in selection.stream_ids) + conditions.append(f"chat_id IN ({placeholders})") + params.extend(selection.stream_ids) + elif selection.stream_filter_requested: + conditions.append("1=0") + + if selection.time_from_ts is not None and selection.time_to_ts is not None: + conditions.append("(end_time >= ? AND start_time <= ?)") + params.extend([selection.time_from_ts, selection.time_to_ts]) + elif selection.time_from_ts is not None: + conditions.append("(end_time >= ?)") + params.append(selection.time_from_ts) + elif selection.time_to_ts is not None: + conditions.append("(start_time <= ?)") + params.append(selection.time_to_ts) + + where_sql = "WHERE " + " AND ".join(conditions) if conditions else "" + return where_sql, params + + def count_candidates(self, selection: SelectionFilter) -> int: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" + cur = conn.execute(sql, tuple(params)) + return int(cur.fetchone()["c"]) + + def preview(self, selection: SelectionFilter, preview_limit: int) -> PreviewResult: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + + total_sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" + total = int(conn.execute(total_sql, tuple(params)).fetchone()["c"]) + + dist_sql = ( + f"SELECT chat_id, COUNT(*) AS c FROM chat_history {where_sql} " + "GROUP BY chat_id ORDER BY c DESC LIMIT 30" + ) + distribution = [ + (_normalize_name(row["chat_id"]), int(row["c"])) + for row in conn.execute(dist_sql, tuple(params)).fetchall() + ] + + sample_sql = ( + "SELECT id, chat_id, start_time, end_time, theme, summary " + f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" + ) + sample_params = list(params) + sample_params.append(max(1, int(preview_limit))) + samples = [dict(row) for row in conn.execute(sample_sql, tuple(sample_params)).fetchall()] + + return PreviewResult(total=total, distribution=distribution, samples=samples) + + def iter_rows( + self, + selection: SelectionFilter, + batch_size: int, + start_after_id: int, + ) -> Generator[List[sqlite3.Row], None, None]: + conn = self._require_conn() + cursor = int(start_after_id) + while True: + where_sql, params = self._build_where(selection, start_after_id=cursor) + sql = ( + "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " + f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" + ) + bind = list(params) + bind.append(max(1, int(batch_size))) + rows = conn.execute(sql, tuple(bind)).fetchall() + if not rows: + break + yield rows + cursor = int(rows[-1]["id"]) + + def sample_rows_for_verify( + self, + selection: SelectionFilter, + sample_size: int, + ) -> List[sqlite3.Row]: + conn = self._require_conn() + where_sql, params = self._build_where(selection, start_after_id=None) + sql = ( + "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " + f"FROM chat_history {where_sql} ORDER BY RANDOM() LIMIT ?" + ) + bind = list(params) + bind.append(max(1, int(sample_size))) + return conn.execute(sql, tuple(bind)).fetchall() + + +class MigrationRunner: + def __init__(self, args: argparse.Namespace): + self.args = args + self.source_db_path = Path(args.source_db).resolve() + self.target_data_dir = Path(args.target_data_dir).resolve() + self.state_file = _state_path(self.target_data_dir) + self.bad_rows_file = _bad_rows_path(self.target_data_dir) + self.report_file = _report_path(self.target_data_dir) + + self.source_db = SourceDB(self.source_db_path) + + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.relation_write_service = None + self.plugin_config: Dict[str, Any] = {} + self.embed_workers: int = 5 + + self.selection: Optional[SelectionFilter] = None + self.filter_fingerprint: str = "" + self.source_db_fingerprint: Dict[str, Any] = {} + self.source_db_fingerprint_hash: str = "" + self.state: Dict[str, Any] = {} + + self.started_at = time.time() + self.exit_code = 0 + self.failed = False + self.fail_reason: Optional[str] = None + + self.stats: Dict[str, Any] = { + "source_matched_total": 0, + "scanned_rows": 0, + "valid_rows": 0, + "migrated_rows": 0, + "skipped_existing_rows": 0, + "bad_rows": 0, + "paragraph_vectors_added": 0, + "entity_vectors_added": 0, + "relations_written": 0, + "relation_vectors_written": 0, + "relation_vectors_failed": 0, + "relation_vectors_skipped": 0, + "graph_edges_written": 0, + "windows_committed": 0, + "last_committed_id": 0, + "verify_sample_size": 0, + "verify_paragraph_missing": 0, + "verify_vector_missing": 0, + "verify_relation_missing": 0, + "verify_edge_missing": 0, + "verify_passed": False, + } + + async def run(self) -> int: + try: + _bootstrap_runtime_symbols() + self._prepare_paths() + + self.source_db.connect() + self.selection = self._build_selection_filter() + self.filter_fingerprint = _json_hash(self.selection.fingerprint_payload()) + + self.source_db_fingerprint = _build_source_db_fingerprint(self.source_db_path) + self.source_db_fingerprint_hash = str(self.source_db_fingerprint.get("sha1", "")) + + preview = self.source_db.preview(self.selection, preview_limit=self.args.preview_limit) + self.stats["source_matched_total"] = int(preview.total) + self._print_preview(preview) + + if preview.total <= 0: + logger.info("筛选后无数据,退出。") + self.stats["verify_passed"] = True + if self.args.verify_only: + self._load_plugin_config() + await self._init_target_stores(require_embedding=False) + await self._verify(strict=True) + return self._finalize() + + if self.args.verify_only: + self._load_plugin_config() + await self._init_target_stores(require_embedding=False) + await self._verify(strict=True) + return self._finalize() + + if self.args.dry_run: + logger.info("dry-run 模式:仅预览,不写入。") + return self._finalize() + + if not self.args.yes: + if not self._confirm(): + logger.info("用户取消执行。") + return self._finalize() + + self._load_plugin_config() + await self._init_target_stores(require_embedding=True) + self._load_or_init_state() + + start_after_id = self._resolve_start_after_id() + await self._migrate(start_after_id=start_after_id) + await self._verify(strict=True) + return self._finalize() + except Exception as e: + self.failed = True + self.fail_reason = str(e) + logger.error(f"迁移失败: {e}\n{traceback.format_exc()}") + return self._finalize() + finally: + self._close() + + def _prepare_paths(self) -> None: + (self.target_data_dir / MIGRATION_STATE_DIRNAME).mkdir(parents=True, exist_ok=True) + if self.args.reset_state and self.state_file.exists(): + self.state_file.unlink() + if self.args.reset_state and self.bad_rows_file.exists(): + self.bad_rows_file.unlink() + + def _load_plugin_config(self) -> None: + merged = _load_manifest_defaults() + + config_path = DEFAULT_CONFIG_PATH + if config_path.exists(): + try: + with open(config_path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + if isinstance(raw, dict): + merged = _deep_merge_dict(merged, dict(raw)) + except Exception as e: + logger.warning(f"读取插件配置失败,继续使用默认配置: {e}") + + self.plugin_config = merged + + def _read_existing_vector_dimension(self, fallback_dimension: int) -> int: + meta_path = self.target_data_dir / "vectors" / "vectors_metadata.pkl" + if not meta_path.exists(): + return fallback_dimension + try: + with open(meta_path, "rb") as f: + payload = pickle.load(f) + value = _safe_int(payload.get("dimension"), fallback_dimension) + return max(1, value) + except Exception: + return fallback_dimension + + async def _init_target_stores(self, require_embedding: bool) -> None: + if VectorStore is None or GraphStore is None or MetadataStore is None: + raise MigrationError("运行时初始化失败:存储组件不可用") + + emb_cfg = self.plugin_config.get("embedding", {}) if isinstance(self.plugin_config, dict) else {} + graph_cfg = self.plugin_config.get("graph", {}) if isinstance(self.plugin_config, dict) else {} + + self.embed_workers = max(1, _safe_int(self.args.embed_workers, _safe_int(emb_cfg.get("max_concurrent"), 5))) + emb_batch_size = max(1, _safe_int(emb_cfg.get("batch_size"), 32)) + emb_default_dim = max(1, _safe_int(emb_cfg.get("dimension"), 1024)) + emb_model_name = str(emb_cfg.get("model_name", "auto")) + emb_retry = emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {} + + if require_embedding: + _load_embedding_adapter_factory() + if create_embedding_api_adapter is None: + raise MigrationError("运行时初始化失败:embedding 适配器不可用") + + if model_config is not None: + embedding_task = getattr(getattr(model_config, "model_task_config", None), "embedding", None) + if embedding_task is not None and hasattr(embedding_task, "model_list"): + if not list(embedding_task.model_list): + raise MigrationError( + "当前配置没有可用 embedding 模型。若你使用 gemini provider,请先安装 `google-genai` " + "或切换到可用的 embedding provider。" + ) + + self.embedding_manager = create_embedding_api_adapter( + batch_size=emb_batch_size, + max_concurrent=self.embed_workers, + default_dimension=emb_default_dim, + model_name=emb_model_name, + retry_config=emb_retry, + ) + + try: + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + has_existing_vectors = (self.target_data_dir / "vectors" / "vectors_metadata.pkl").exists() + if not has_existing_vectors: + detected_dim = await self.embedding_manager._detect_dimension() + except Exception as e: + logger.warning(f"嵌入维度探测失败,回退配置维度: {e}") + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + else: + detected_dim = self._read_existing_vector_dimension(emb_default_dim) + self.embedding_manager = None + + q_type = str(emb_cfg.get("quantization_type", "int8")).lower() + if q_type != "int8": + raise MigrationError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + quantization = QuantizationType.INT8 + + matrix_fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() + fmt_map = { + "csr": SparseMatrixFormat.CSR, + "csc": SparseMatrixFormat.CSC, + } + sparse_fmt = fmt_map.get(matrix_fmt, SparseMatrixFormat.CSR) + + self.vector_store = VectorStore( + dimension=detected_dim, + quantization_type=quantization, + data_dir=self.target_data_dir / "vectors", + ) + self.graph_store = GraphStore( + matrix_format=sparse_fmt, + data_dir=self.target_data_dir / "graph", + ) + self.metadata_store = MetadataStore(data_dir=self.target_data_dir / "metadata") + self.metadata_store.connect() + + if self.vector_store.has_data(): + self.vector_store.load() + if self.graph_store.has_data(): + self.graph_store.load() + + self.relation_write_service = None + if require_embedding and RelationWriteService is not None and self.embedding_manager is not None: + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + + logger.info( + f"目标存储初始化完成: dim={self.vector_store.dimension}, quant={q_type}, graph_fmt={matrix_fmt}, " + f"embed_workers={self.embed_workers}" + ) + + def _should_write_relation_vectors(self) -> bool: + retrieval_cfg = self.plugin_config.get("retrieval", {}) if isinstance(self.plugin_config, dict) else {} + if not isinstance(retrieval_cfg, dict): + return False + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + return False + return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + async def _ensure_relation_vectors_for_records( + self, + relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]], + ) -> None: + if not relation_records: + return + if self.relation_write_service is None: + return + + success = 0 + failed = 0 + skipped = 0 + for relation_hash, rel in relation_records.items(): + result = await self.relation_write_service.ensure_relation_vector( + hash_value=relation_hash, + subject=str(rel[0]), + predicate=str(rel[1]), + obj=str(rel[2]), + ) + if result.vector_state == "ready": + if result.vector_written: + success += 1 + else: + skipped += 1 + else: + failed += 1 + + self.stats["relation_vectors_written"] += success + self.stats["relation_vectors_failed"] += failed + self.stats["relation_vectors_skipped"] += skipped + + def _build_selection_filter(self) -> SelectionFilter: + if self.args.start_id is not None and self.args.start_id <= 0: + raise MigrationError("--start-id 必须 > 0") + if self.args.end_id is not None and self.args.end_id <= 0: + raise MigrationError("--end-id 必须 > 0") + if self.args.start_id is not None and self.args.end_id is not None and self.args.start_id > self.args.end_id: + raise MigrationError("--start-id 不能大于 --end-id") + + time_from_ts = _parse_cli_datetime(self.args.time_from, is_end=False) if self.args.time_from else None + time_to_ts = _parse_cli_datetime(self.args.time_to, is_end=True) if self.args.time_to else None + if time_from_ts is not None and time_to_ts is not None and time_from_ts > time_to_ts: + raise MigrationError("--time-from 不能晚于 --time-to") + + stream_filter_requested = bool( + (self.args.stream_id or []) or (self.args.group_id or []) or (self.args.user_id or []) + ) + stream_ids = self.source_db.resolve_stream_ids( + stream_ids=self.args.stream_id or [], + group_ids=self.args.group_id or [], + user_ids=self.args.user_id or [], + ) + if stream_filter_requested and not stream_ids: + logger.warning("已指定 stream/group/user 筛选,但未解析到任何 stream_id,结果将为空。") + + logger.info( + f"筛选条件: time_from={self.args.time_from or '-'}, time_to={self.args.time_to or '-'}, " + f"stream_ids={len(stream_ids)}, stream_filter_requested={stream_filter_requested}" + ) + + return SelectionFilter( + time_from_ts=time_from_ts, + time_to_ts=time_to_ts, + stream_ids=stream_ids, + stream_filter_requested=stream_filter_requested, + start_id=self.args.start_id, + end_id=self.args.end_id, + time_from_raw=self.args.time_from, + time_to_raw=self.args.time_to, + ) + + def _load_or_init_state(self) -> None: + if self.args.start_id is not None: + logger.info("检测到 --start-id,已按用户指定起点覆盖断点状态。") + self.state = self._new_state(last_committed_id=int(self.args.start_id) - 1) + return + + if self.args.no_resume: + self.state = self._new_state(last_committed_id=0) + return + + if not self.state_file.exists(): + self.state = self._new_state(last_committed_id=0) + return + + with open(self.state_file, "r", encoding="utf-8") as f: + loaded = json.load(f) + + loaded_filter_fp = str(loaded.get("filter_fingerprint", "")) + loaded_source_fp = str(loaded.get("source_db_fingerprint", "")) + + if loaded_filter_fp != self.filter_fingerprint or loaded_source_fp != self.source_db_fingerprint_hash: + if self.args.dry_run or self.args.verify_only: + logger.info("检测到断点与当前筛选不一致;当前为只读模式,将忽略旧断点。") + self.state = self._new_state(last_committed_id=0) + return + raise MigrationError( + "检测到筛选条件或源库指纹变化,已拒绝继续续传。请使用 --reset-state 或调整参数后重试。" + ) + + self.state = loaded + stored_stats = loaded.get("stats", {}) + if isinstance(stored_stats, dict): + for k, v in stored_stats.items(): + if k in self.stats and isinstance(v, (int, float, bool)): + self.stats[k] = v + + def _new_state(self, last_committed_id: int) -> Dict[str, Any]: + return { + "version": 1, + "updated_at": time.time(), + "last_committed_id": int(last_committed_id), + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint_hash, + "source_db_meta": self.source_db_fingerprint, + "stats": dict(self.stats), + } + + def _flush_state(self, last_committed_id: int) -> None: + self.stats["last_committed_id"] = int(last_committed_id) + self.state = { + "version": 1, + "updated_at": time.time(), + "last_committed_id": int(last_committed_id), + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint_hash, + "source_db_meta": self.source_db_fingerprint, + "stats": dict(self.stats), + } + _dump_json_atomic(self.state_file, self.state) + + def _resolve_start_after_id(self) -> int: + if self.selection is None: + raise MigrationError("selection 未初始化") + + if self.args.start_id is not None: + return int(self.args.start_id) - 1 + + if self.args.no_resume: + return 0 + + state_last = _safe_int(self.state.get("last_committed_id"), 0) if self.state else 0 + return max(0, state_last) + + def _print_preview(self, preview: PreviewResult) -> None: + print("\n=== Migration Preview ===") + print(f"source_db: {self.source_db_path}") + print(f"target_data_dir: {self.target_data_dir}") + if self.selection: + print( + f"time_window: [{self.selection.time_from_raw or '-'} ~ {self.selection.time_to_raw or '-'}] " + f"(ts: {_format_ts(self.selection.time_from_ts)} ~ {_format_ts(self.selection.time_to_ts)})" + ) + print( + f"id_window: [{self.selection.start_id or '-'} ~ {self.selection.end_id or '-'}], " + f"selected_streams={len(self.selection.stream_ids)}" + ) + print(f"matched_rows: {preview.total}") + + if preview.distribution: + print("top_chat_distribution:") + for cid, cnt in preview.distribution[:10]: + print(f" - {cid}: {cnt}") + else: + print("top_chat_distribution: (none)") + + if preview.samples: + print(f"samples (first {len(preview.samples)}):") + for row in preview.samples: + summary_preview = _normalize_name(row.get("summary", ""))[:60] + theme_preview = _normalize_name(row.get("theme", ""))[:30] + print( + f" - id={row.get('id')} chat_id={row.get('chat_id')} " + f"[{_format_ts(row.get('start_time'))} ~ {_format_ts(row.get('end_time'))}] " + f"theme={theme_preview!r} summary={summary_preview!r}" + ) + print("=========================\n") + + def _confirm(self) -> bool: + answer = input("确认按以上筛选执行迁移?输入 y 继续 [y/N]: ").strip().lower() + return answer in {"y", "yes"} + + def _parse_json_list_field(self, raw: Any, field_name: str, row_id: int) -> List[str]: + if raw is None: + return [] + if isinstance(raw, list): + data = raw + elif isinstance(raw, str): + try: + parsed = json.loads(raw) + except Exception as e: + raise ValueError(f"{field_name} JSON 解析失败: {e}") from e + if not isinstance(parsed, list): + raise ValueError(f"{field_name} JSON 必须是 list,当前为 {type(parsed).__name__}") + data = parsed + else: + raise ValueError(f"{field_name} 字段类型不支持: {type(raw).__name__}") + return _dedup_keep_order(str(x) for x in data if _normalize_name(x)) + + def _map_row(self, row: sqlite3.Row) -> MappedRow: + row_id = int(row["id"]) + chat_id = _normalize_name(row["chat_id"]) + theme = _normalize_name(row["theme"]) + summary = _normalize_name(row["summary"]) + + participants = self._parse_json_list_field(row["participants"], "participants", row_id) + keywords = self._parse_json_list_field(row["keywords"], "keywords", row_id) + keywords_top = keywords[:8] + + participants_text = "、".join(participants) if participants else "" + keywords_text = "、".join(keywords_top) if keywords_top else "" + + content = ( + f"话题:{theme}\n" + f"概括:{summary}\n" + f"参与者:{participants_text}\n" + f"关键词:{keywords_text}" + ).strip() + + paragraph_hash = compute_hash(normalize_text(content)) + source = f"maibot.chat_history:{chat_id}" + + start_time = _safe_float(row["start_time"], 0.0) + end_time = _safe_float(row["end_time"], start_time) + time_meta = { + "event_time_start": start_time, + "event_time_end": end_time, + "time_granularity": "minute", + "time_confidence": 0.95, + } + + entities = _dedup_keep_order([*participants, theme, *keywords_top]) + relations: List[Tuple[str, str, str]] = [] + if theme: + for participant in participants: + relations.append((participant, "参与话题", theme)) + for keyword in keywords_top: + relations.append((theme, "关键词", keyword)) + + existing_vector = paragraph_hash in self.vector_store + return MappedRow( + row_id=row_id, + chat_id=chat_id, + paragraph_hash=paragraph_hash, + content=content, + source=source, + time_meta=time_meta, + entities=entities, + relations=relations, + existing_paragraph_vector=existing_vector, + ) + + def _append_bad_row(self, row: sqlite3.Row, reason: str) -> None: + payload = { + "id": int(row["id"]), + "chat_id": _normalize_name(row["chat_id"]), + "start_time": row["start_time"], + "end_time": row["end_time"], + "participants": row["participants"], + "theme": _normalize_name(row["theme"]), + "keywords": row["keywords"], + "summary": row["summary"], + "error": reason, + "timestamp": time.time(), + } + self.bad_rows_file.parent.mkdir(parents=True, exist_ok=True) + with open(self.bad_rows_file, "a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False)) + f.write("\n") + + async def _migrate(self, start_after_id: int) -> None: + if self.selection is None: + raise MigrationError("selection 未初始化") + + read_batch_size = max(1, int(self.args.read_batch_size)) + commit_window_rows = max(1, int(self.args.commit_window_rows)) + log_every = max(1, int(self.args.log_every)) + + window_rows: List[MappedRow] = [] + window_scanned = 0 + last_seen_id = start_after_id + + logger.info( + f"开始迁移: start_after_id={start_after_id}, read_batch_size={read_batch_size}, " + f"commit_window_rows={commit_window_rows}" + ) + + for batch in self.source_db.iter_rows(self.selection, read_batch_size, start_after_id): + for row in batch: + row_id = int(row["id"]) + last_seen_id = row_id + self.stats["scanned_rows"] += 1 + window_scanned += 1 + + try: + mapped = self._map_row(row) + except Exception as e: + self.stats["bad_rows"] += 1 + self._append_bad_row(row, str(e)) + if self.stats["bad_rows"] > int(self.args.max_errors): + raise MigrationError( + f"坏行数量超过上限 max_errors={self.args.max_errors},已中止。" + ) + continue + + self.stats["valid_rows"] += 1 + if mapped.existing_paragraph_vector: + self.stats["skipped_existing_rows"] += 1 + else: + self.stats["migrated_rows"] += 1 + window_rows.append(mapped) + + if window_scanned >= commit_window_rows: + await self._commit_window(window_rows, last_seen_id) + window_rows = [] + window_scanned = 0 + + if self.stats["scanned_rows"] % log_every == 0: + logger.info( + f"迁移进度: scanned={self.stats['scanned_rows']}/{self.stats['source_matched_total']}, " + f"valid={self.stats['valid_rows']}, bad={self.stats['bad_rows']}, " + f"last_id={last_seen_id}" + ) + + if window_scanned > 0 or window_rows: + await self._commit_window(window_rows, last_seen_id) + + logger.info( + f"迁移主流程完成: scanned={self.stats['scanned_rows']}, valid={self.stats['valid_rows']}, " + f"bad={self.stats['bad_rows']}, last_committed_id={self.stats['last_committed_id']}" + ) + + async def _commit_window(self, rows: List[MappedRow], last_seen_id: int) -> None: + if not rows: + self._flush_state(last_seen_id) + self.stats["windows_committed"] += 1 + return + + now_ts = time.time() + empty_meta_blob = pickle.dumps({}) + + conn = self.metadata_store.get_connection() + + cursor = conn.cursor() + + # 批量查询本窗口内已存在的段落,保证重跑时 entity/mention 不重复累计 + existing_paragraph_hashes: set[str] = set() + all_hashes = [item.paragraph_hash for item in rows] + for i in range(0, len(all_hashes), 800): + batch_hashes = all_hashes[i : i + 800] + if not batch_hashes: + continue + placeholders = ",".join("?" for _ in batch_hashes) + existing_rows = cursor.execute( + f"SELECT hash FROM paragraphs WHERE hash IN ({placeholders})", + tuple(batch_hashes), + ).fetchall() + for row in existing_rows: + existing_paragraph_hashes.add(str(row["hash"])) + + paragraph_records: List[Tuple[Any, ...]] = [] + paragraph_embed_map: Dict[str, str] = {} + + entity_display: Dict[str, str] = {} + entity_counts: Dict[str, int] = defaultdict(int) + paragraph_entity_mentions: Dict[Tuple[str, str], int] = defaultdict(int) + entity_embed_map: Dict[str, str] = {} + + relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]] = {} + paragraph_relation_links: set[Tuple[str, str]] = set() + + for item in rows: + is_new_paragraph = item.paragraph_hash not in existing_paragraph_hashes + + start_ts = _safe_float(item.time_meta.get("event_time_start"), 0.0) + end_ts = _safe_float(item.time_meta.get("event_time_end"), start_ts) + confidence = _safe_float(item.time_meta.get("time_confidence"), 0.95) + granularity = _normalize_name(item.time_meta.get("time_granularity")) or "minute" + + if is_new_paragraph: + paragraph_records.append( + ( + item.paragraph_hash, + item.content, + None, + now_ts, + now_ts, + empty_meta_blob, + item.source, + len(normalize_text(item.content).split()), + None, + start_ts, + end_ts, + granularity, + confidence, + KnowledgeType.NARRATIVE.value, + ) + ) + + if item.paragraph_hash not in self.vector_store: + paragraph_embed_map[item.paragraph_hash] = item.content + + for entity in item.entities: + name = _normalize_name(entity) + if not name: + continue + canon = _canonical_name(name) + if not canon: + continue + entity_hash = compute_hash(canon) + entity_display.setdefault(entity_hash, name) + if is_new_paragraph: + entity_counts[entity_hash] += 1 + paragraph_entity_mentions[(item.paragraph_hash, entity_hash)] += 1 + if entity_hash not in self.vector_store: + entity_embed_map.setdefault(entity_hash, name) + + for subject, predicate, obj in item.relations: + s = _normalize_name(subject) + p = _normalize_name(predicate) + o = _normalize_name(obj) + if not (s and p and o): + continue + + s_canon = _canonical_name(s) + p_canon = _canonical_name(p) + o_canon = _canonical_name(o) + relation_hash = compute_hash(f"{s_canon}|{p_canon}|{o_canon}") + + if is_new_paragraph: + relation_records.setdefault( + relation_hash, + (s, p, o, 1.0, item.paragraph_hash, empty_meta_blob), + ) + paragraph_relation_links.add((item.paragraph_hash, relation_hash)) + + for relation_entity in (s, o): + e_canon = _canonical_name(relation_entity) + if not e_canon: + continue + e_hash = compute_hash(e_canon) + entity_display.setdefault(e_hash, relation_entity) + if is_new_paragraph: + entity_counts[e_hash] += 1 + paragraph_entity_mentions[(item.paragraph_hash, e_hash)] += 1 + if e_hash not in self.vector_store: + entity_embed_map.setdefault(e_hash, relation_entity) + + try: + cursor.execute("BEGIN") + + if paragraph_records: + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraphs + ( + hash, content, vector_index, created_at, updated_at, metadata, source, word_count, + event_time, event_time_start, event_time_end, time_granularity, time_confidence, knowledge_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + paragraph_records, + ) + + if entity_counts: + entity_rows = [ + ( + entity_hash, + entity_display[entity_hash], + None, + int(count), + now_ts, + empty_meta_blob, + ) + for entity_hash, count in entity_counts.items() + ] + try: + cursor.executemany( + """ + INSERT INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(hash) DO UPDATE SET + appearance_count = entities.appearance_count + excluded.appearance_count + """, + entity_rows, + ) + except sqlite3.OperationalError: + cursor.executemany( + """ + INSERT OR IGNORE INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, ?, ?, ?) + """, + entity_rows, + ) + cursor.executemany( + "UPDATE entities SET appearance_count = appearance_count + ? WHERE hash = ?", + [(int(count), entity_hash) for entity_hash, count in entity_counts.items()], + ) + + if paragraph_entity_mentions: + pe_rows = [ + (paragraph_hash, entity_hash, int(mentions)) + for (paragraph_hash, entity_hash), mentions in paragraph_entity_mentions.items() + ] + try: + cursor.executemany( + """ + INSERT INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + ON CONFLICT(paragraph_hash, entity_hash) DO UPDATE SET + mention_count = paragraph_entities.mention_count + excluded.mention_count + """, + pe_rows, + ) + except sqlite3.OperationalError: + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, + pe_rows, + ) + cursor.executemany( + """ + UPDATE paragraph_entities + SET mention_count = mention_count + ? + WHERE paragraph_hash = ? AND entity_hash = ? + """, + [(m, p, e) for (p, e, m) in pe_rows], + ) + + if relation_records: + relation_rows = [ + ( + relation_hash, + rel[0], + rel[1], + rel[2], + None, + rel[3], + now_ts, + rel[4], + rel[5], + ) + for relation_hash, rel in relation_records.items() + ] + cursor.executemany( + """ + INSERT OR IGNORE INTO relations + (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + relation_rows, + ) + + if paragraph_relation_links: + pr_rows = [(p_hash, r_hash) for p_hash, r_hash in paragraph_relation_links] + cursor.executemany( + """ + INSERT OR IGNORE INTO paragraph_relations + (paragraph_hash, relation_hash) + VALUES (?, ?) + """, + pr_rows, + ) + + conn.commit() + except Exception: + conn.rollback() + raise + + self.stats["relations_written"] += len(relation_records) + + if relation_records: + edge_pairs = [] + relation_hashes = [] + for relation_hash, rel in relation_records.items(): + edge_pairs.append((rel[0], rel[2])) + relation_hashes.append(relation_hash) + + with self.graph_store.batch_update(): + self.graph_store.add_edges(edge_pairs, relation_hashes=relation_hashes) + self.stats["graph_edges_written"] += len(edge_pairs) + + if self._should_write_relation_vectors(): + await self._ensure_relation_vectors_for_records(relation_records) + + para_added = await self._embed_and_add_vectors( + id_to_text=paragraph_embed_map, + batch_size=max(1, int(self.args.embed_batch_size)), + workers=self.embed_workers, + ) + ent_added = await self._embed_and_add_vectors( + id_to_text=entity_embed_map, + batch_size=max(1, int(self.args.entity_embed_batch_size)), + workers=self.embed_workers, + ) + self.stats["paragraph_vectors_added"] += para_added + self.stats["entity_vectors_added"] += ent_added + + self.vector_store.save() + self.graph_store.save() + + self.stats["windows_committed"] += 1 + self._flush_state(last_seen_id) + + async def _embed_and_add_vectors( + self, + id_to_text: Dict[str, str], + batch_size: int, + workers: int, + ) -> int: + if not id_to_text: + return 0 + if self.embedding_manager is None: + raise MigrationError("embedding_manager 未初始化,无法写入向量") + + ids = [] + texts = [] + for hash_id, text in id_to_text.items(): + if hash_id in self.vector_store: + continue + ids.append(hash_id) + texts.append(text) + + if not ids: + return 0 + + total_added = 0 + chunk_size = max(1, int(batch_size)) + for i in range(0, len(ids), chunk_size): + chunk_ids = ids[i : i + chunk_size] + chunk_texts = texts[i : i + chunk_size] + + embeddings = await self.embedding_manager.encode_batch( + chunk_texts, + batch_size=chunk_size, + num_workers=max(1, int(workers)), + ) + + emb_arr = np.asarray(embeddings, dtype=np.float32) + if emb_arr.ndim == 1: + emb_arr = emb_arr.reshape(1, -1) + if emb_arr.shape[0] != len(chunk_ids): + logger.warning( + f"embedding 返回数量异常: expected={len(chunk_ids)}, got={emb_arr.shape[0]},跳过该批次" + ) + continue + + valid_vectors = [] + valid_ids = [] + for idx, vec in enumerate(emb_arr): + if vec.ndim != 1: + continue + if vec.shape[0] != self.vector_store.dimension: + logger.warning( + f"向量维度不匹配,跳过: id={chunk_ids[idx]}, got={vec.shape[0]}, expected={self.vector_store.dimension}" + ) + continue + if not np.all(np.isfinite(vec)): + logger.warning(f"向量含 NaN/Inf,跳过: id={chunk_ids[idx]}") + continue + if chunk_ids[idx] in self.vector_store: + continue + valid_vectors.append(vec) + valid_ids.append(chunk_ids[idx]) + + if valid_vectors: + batch_vectors = np.stack(valid_vectors).astype(np.float32, copy=False) + added = self.vector_store.add(batch_vectors, valid_ids) + total_added += int(added) + + return total_added + + async def _verify(self, strict: bool) -> None: + if self.selection is None: + raise MigrationError("selection 未初始化") + + sample_size = min(2000, max(0, int(self.stats.get("source_matched_total", 0)))) + self.stats["verify_sample_size"] = sample_size + + if sample_size <= 0: + self.stats["verify_passed"] = True + return + + sample_rows = self.source_db.sample_rows_for_verify(self.selection, sample_size) + para_missing = 0 + vec_missing = 0 + rel_missing = 0 + edge_missing = 0 + + for row in sample_rows: + try: + mapped = self._map_row(row) + except Exception: + continue + + paragraph = self.metadata_store.get_paragraph(mapped.paragraph_hash) + if paragraph is None: + para_missing += 1 + if mapped.paragraph_hash not in self.vector_store: + vec_missing += 1 + + for s, p, o in mapped.relations: + relation_hash = compute_hash(f"{_canonical_name(s)}|{_canonical_name(p)}|{_canonical_name(o)}") + relation = self.metadata_store.get_relation(relation_hash) + if relation is None: + rel_missing += 1 + if self.graph_store.get_edge_weight(s, o) <= 0.0: + edge_missing += 1 + + self.stats["verify_paragraph_missing"] = para_missing + self.stats["verify_vector_missing"] = vec_missing + self.stats["verify_relation_missing"] = rel_missing + self.stats["verify_edge_missing"] = edge_missing + + verify_passed = all(x == 0 for x in [para_missing, vec_missing, rel_missing, edge_missing]) + if strict and not verify_passed: + self.failed = True + self.fail_reason = ( + "严格校验失败: " + f"paragraph_missing={para_missing}, vector_missing={vec_missing}, " + f"relation_missing={rel_missing}, edge_missing={edge_missing}" + ) + + self.stats["verify_passed"] = verify_passed + + def _finalize(self) -> int: + elapsed = time.time() - self.started_at + self.stats["elapsed_seconds"] = elapsed + + report = { + "success": not self.failed, + "fail_reason": self.fail_reason, + "args": vars(self.args), + "source_db": str(self.source_db_path), + "target_data_dir": str(self.target_data_dir), + "selection": self.selection.fingerprint_payload() if self.selection else {}, + "filter_fingerprint": self.filter_fingerprint, + "source_db_fingerprint": self.source_db_fingerprint, + "state_file": str(self.state_file), + "bad_rows_file": str(self.bad_rows_file), + "stats": dict(self.stats), + "timestamp": time.time(), + } + + _dump_json_atomic(self.report_file, report) + + if self.failed: + self.exit_code = 1 + elif self.stats.get("bad_rows", 0) > 0: + self.exit_code = 2 + else: + self.exit_code = 0 + + print("\n=== Migration Report ===") + print(f"success: {not self.failed}") + if self.fail_reason: + print(f"fail_reason: {self.fail_reason}") + print(f"elapsed: {elapsed:.2f}s") + print(f"source_matched_total: {self.stats['source_matched_total']}") + print(f"scanned_rows: {self.stats['scanned_rows']}") + print(f"valid_rows: {self.stats['valid_rows']}") + print(f"migrated_rows: {self.stats['migrated_rows']}") + print(f"skipped_existing_rows: {self.stats['skipped_existing_rows']}") + print(f"bad_rows: {self.stats['bad_rows']}") + print(f"paragraph_vectors_added: {self.stats['paragraph_vectors_added']}") + print(f"entity_vectors_added: {self.stats['entity_vectors_added']}") + print(f"relations_written: {self.stats['relations_written']}") + print( + "relation_vectors: " + f"written={self.stats['relation_vectors_written']}, " + f"failed={self.stats['relation_vectors_failed']}, " + f"skipped={self.stats['relation_vectors_skipped']}" + ) + print(f"graph_edges_written: {self.stats['graph_edges_written']}") + print(f"windows_committed: {self.stats['windows_committed']}") + print(f"last_committed_id: {self.stats['last_committed_id']}") + print( + "verify: " + f"sample={self.stats['verify_sample_size']}, " + f"paragraph_missing={self.stats['verify_paragraph_missing']}, " + f"vector_missing={self.stats['verify_vector_missing']}, " + f"relation_missing={self.stats['verify_relation_missing']}, " + f"edge_missing={self.stats['verify_edge_missing']}, " + f"passed={self.stats['verify_passed']}" + ) + print(f"report_file: {self.report_file}") + print("========================\n") + + return self.exit_code + + def _close(self) -> None: + try: + if self.metadata_store is not None: + self.metadata_store.close() + except Exception: + pass + self.source_db.close() + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="迁移 MaiBot chat_history 到 A_memorix(高性能 + 可断点续传 + 可确认筛选)" + ) + + parser.add_argument("--source-db", default=str(DEFAULT_SOURCE_DB), help="源数据库路径(默认 data/MaiBot.db)") + parser.add_argument( + "--target-data-dir", + default=str(DEFAULT_TARGET_DATA_DIR), + help="A_memorix 数据目录(默认 plugins/A_memorix/data)", + ) + + resume_group = parser.add_mutually_exclusive_group() + resume_group.add_argument("--resume", dest="no_resume", action="store_false", help="启用断点续传(默认)") + resume_group.add_argument("--no-resume", dest="no_resume", action="store_true", help="禁用断点续传") + parser.set_defaults(no_resume=False) + + parser.add_argument("--reset-state", action="store_true", help="清空迁移状态文件后执行") + parser.add_argument("--start-id", type=int, default=None, help="从指定 chat_history.id 开始迁移(覆盖断点)") + parser.add_argument("--end-id", type=int, default=None, help="迁移到指定 chat_history.id") + + parser.add_argument("--read-batch-size", type=int, default=2000, help="源库分页读取大小(默认 2000)") + parser.add_argument("--commit-window-rows", type=int, default=20000, help="每窗口提交行数(默认 20000)") + parser.add_argument("--embed-batch-size", type=int, default=256, help="段落 embedding 批次大小(默认 256)") + parser.add_argument( + "--entity-embed-batch-size", + type=int, + default=512, + help="实体 embedding 批次大小(默认 512)", + ) + parser.add_argument("--embed-workers", type=int, default=None, help="embedding 并发数(默认读取配置)") + parser.add_argument("--max-errors", type=int, default=500, help="坏行上限(默认 500)") + parser.add_argument("--log-every", type=int, default=5000, help="日志输出步长(默认 5000)") + + parser.add_argument("--dry-run", action="store_true", help="仅预览不写入") + parser.add_argument("--verify-only", action="store_true", help="仅执行严格校验") + + parser.add_argument("--time-from", default=None, help="开始时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") + parser.add_argument("--time-to", default=None, help="结束时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") + parser.add_argument("--stream-id", action="append", default=[], help="聊天流 stream_id(可重复)") + parser.add_argument("--group-id", action="append", default=[], help="群号(可重复,自动映射 stream_id)") + parser.add_argument("--user-id", action="append", default=[], help="用户号(可重复,自动映射 stream_id)") + parser.add_argument("--yes", action="store_true", help="跳过交互确认") + parser.add_argument("--preview-limit", type=int, default=20, help="预览样本条数(默认 20)") + + return parser + + +async def async_main() -> int: + parser = build_parser() + args = parser.parse_args() + + runner = MigrationRunner(args) + return await runner.run() + + +def main() -> int: + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + return asyncio.run(async_main()) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/process_knowledge.py b/plugins/A_memorix/scripts/process_knowledge.py new file mode 100644 index 00000000..d9e6fe32 --- /dev/null +++ b/plugins/A_memorix/scripts/process_knowledge.py @@ -0,0 +1,728 @@ +#!/usr/bin/env python3 +""" +知识库自动导入脚本 (Strategy-Aware Version) + +功能: +1. 扫描 plugins/A_memorix/data/raw 下的 .txt 文件 +2. 检查 data/import_manifest.json 确认是否已导入 +3. 使用 Strategy 模式处理文件 (Narrative/Factual/Quote) +4. 将生成的数据直接存入 VectorStore/GraphStore/MetadataStore +5. 更新 manifest +""" + +import sys +import os +import json +import asyncio +import time +import random +import hashlib +import tomlkit +import argparse +from pathlib import Path +from datetime import datetime +from typing import List, Dict, Any, Optional +from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn +from rich.console import Console +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type + +console = Console() + +class LLMGenerationError(Exception): + pass + +# 路径设置 +current_dir = Path(__file__).resolve().parent +plugin_root = current_dir.parent +workspace_root = plugin_root.parent +maibot_root = workspace_root / "MaiBot" +for path in (workspace_root, maibot_root, plugin_root): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + +# 数据目录 +DATA_DIR = plugin_root / "data" +RAW_DIR = DATA_DIR / "raw" +PROCESSED_DIR = DATA_DIR / "processed" +MANIFEST_PATH = DATA_DIR / "import_manifest.json" + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix Knowledge Importer (Strategy-Aware)") + parser.add_argument("--force", action="store_true", help="Force re-import") + parser.add_argument("--clear-manifest", action="store_true", help="Clear manifest") + parser.add_argument( + "--type", + "-t", + default="auto", + help="Target import strategy override (auto/narrative/factual/quote)", + ) + parser.add_argument("--concurrency", "-c", type=int, default=5) + parser.add_argument( + "--chat-log", + action="store_true", + help="聊天记录导入模式:强制 narrative 策略,并使用 LLM 语义抽取 event_time/event_time_range", + ) + parser.add_argument( + "--chat-reference-time", + default=None, + help="chat_log 模式的相对时间参考点(如 2026/02/12 10:30);不传则使用当前本地时间", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + + +try: + import A_memorix.core as core_module + import A_memorix.core.storage as storage_module + from src.common.logger import get_logger + from src.services import llm_service as llm_api + from src.config.config import global_config, model_config + + VectorStore = core_module.VectorStore + GraphStore = core_module.GraphStore + MetadataStore = core_module.MetadataStore + ImportStrategy = core_module.ImportStrategy + create_embedding_api_adapter = core_module.create_embedding_api_adapter + RelationWriteService = getattr(core_module, "RelationWriteService", None) + + looks_like_quote_text = storage_module.looks_like_quote_text + parse_import_strategy = storage_module.parse_import_strategy + resolve_stored_knowledge_type = storage_module.resolve_stored_knowledge_type + select_import_strategy = storage_module.select_import_strategy + + from A_memorix.core.utils.time_parser import normalize_time_meta + from A_memorix.core.utils.import_payloads import normalize_paragraph_import_item + from A_memorix.core.strategies.base import BaseStrategy, ProcessedChunk, KnowledgeType as StratKnowledgeType + from A_memorix.core.strategies.narrative import NarrativeStrategy + from A_memorix.core.strategies.factual import FactualStrategy + from A_memorix.core.strategies.quote import QuoteStrategy + +except ImportError as e: + print(f"❌ 无法导入模块: {e}") + import traceback + traceback.print_exc() + sys.exit(1) + +logger = get_logger("A_Memorix.AutoImport") + + +def _log_before_retry(retry_state) -> None: + """使用项目统一日志风格记录重试信息。""" + exc = None + if getattr(retry_state, "outcome", None) is not None and retry_state.outcome.failed: + exc = retry_state.outcome.exception() + next_sleep = getattr(getattr(retry_state, "next_action", None), "sleep", None) + logger.warning( + "LLM 调用即将重试: " + f"attempt={getattr(retry_state, 'attempt_number', '?')} " + f"next_sleep={next_sleep} " + f"error={exc}" + ) + +class AutoImporter: + def __init__( + self, + force: bool = False, + clear_manifest: bool = False, + target_type: str = "auto", + concurrency: int = 5, + chat_log: bool = False, + chat_reference_time: Optional[str] = None, + ): + self.vector_store: Optional[VectorStore] = None + self.graph_store: Optional[GraphStore] = None + self.metadata_store: Optional[MetadataStore] = None + self.embedding_manager = None + self.relation_write_service = None + self.plugin_config = {} + self.manifest = {} + self.force = force + self.clear_manifest = clear_manifest + self.chat_log = chat_log + parsed_target_type = parse_import_strategy(target_type, default=ImportStrategy.AUTO) + self.target_type = ImportStrategy.NARRATIVE.value if chat_log else parsed_target_type.value + self.chat_reference_dt = self._parse_reference_time(chat_reference_time) + if self.chat_log and parsed_target_type not in {ImportStrategy.AUTO, ImportStrategy.NARRATIVE}: + logger.warning( + f"chat_log 模式已启用,target_type={target_type} 将被覆盖为 narrative" + ) + self.concurrency_limit = concurrency + self.semaphore = None + self.storage_lock = None + + async def initialize(self): + logger.info(f"正在初始化... (并发数: {self.concurrency_limit})") + self.semaphore = asyncio.Semaphore(self.concurrency_limit) + self.storage_lock = asyncio.Lock() + + RAW_DIR.mkdir(parents=True, exist_ok=True) + PROCESSED_DIR.mkdir(parents=True, exist_ok=True) + + if self.clear_manifest: + logger.info("🧹 清理 Mainfest") + self.manifest = {} + self._save_manifest() + elif MANIFEST_PATH.exists(): + try: + with open(MANIFEST_PATH, "r", encoding="utf-8") as f: + self.manifest = json.load(f) + except Exception: + self.manifest = {} + + config_path = plugin_root / "config.toml" + try: + with open(config_path, "r", encoding="utf-8") as f: + self.plugin_config = tomlkit.load(f) + except Exception as e: + logger.error(f"加载插件配置失败: {e}") + return False + + try: + await self._init_stores() + except Exception as e: + logger.error(f"初始化存储失败: {e}") + return False + + return True + + async def _init_stores(self): + # ... (Same as original) + self.embedding_manager = create_embedding_api_adapter( + batch_size=self.plugin_config.get("embedding", {}).get("batch_size", 32), + default_dimension=self.plugin_config.get("embedding", {}).get("dimension", 384), + model_name=self.plugin_config.get("embedding", {}).get("model_name", "auto"), + retry_config=self.plugin_config.get("embedding", {}).get("retry", {}), + ) + try: + dim = await self.embedding_manager._detect_dimension() + except: + dim = self.embedding_manager.default_dimension + + q_type_str = str(self.plugin_config.get("embedding", {}).get("quantization_type", "int8") or "int8").lower() + # Need to access QuantizationType from storage_module if not imported globally + QuantizationType = storage_module.QuantizationType + if q_type_str != "int8": + raise ValueError( + "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + self.vector_store = VectorStore( + dimension=dim, + quantization_type=QuantizationType.INT8, + data_dir=DATA_DIR / "vectors" + ) + + SparseMatrixFormat = storage_module.SparseMatrixFormat + m_fmt_str = self.plugin_config.get("graph", {}).get("sparse_matrix_format", "csr") + m_map = {"csr": SparseMatrixFormat.CSR, "csc": SparseMatrixFormat.CSC} + + self.graph_store = GraphStore( + matrix_format=m_map.get(m_fmt_str, SparseMatrixFormat.CSR), + data_dir=DATA_DIR / "graph" + ) + + self.metadata_store = MetadataStore(data_dir=DATA_DIR / "metadata") + self.metadata_store.connect() + + if RelationWriteService is not None: + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + + if self.vector_store.has_data(): self.vector_store.load() + if self.graph_store.has_data(): self.graph_store.load() + + def _should_write_relation_vectors(self) -> bool: + retrieval_cfg = self.plugin_config.get("retrieval", {}) + if not isinstance(retrieval_cfg, dict): + return False + rv_cfg = retrieval_cfg.get("relation_vectorization", {}) + if not isinstance(rv_cfg, dict): + return False + return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) + + def load_file(self, file_path: Path) -> str: + with open(file_path, "r", encoding="utf-8") as f: + return f.read() + + def get_file_hash(self, content: str) -> str: + return hashlib.md5(content.encode("utf-8")).hexdigest() + + def _parse_reference_time(self, value: Optional[str]) -> datetime: + """解析 chat_log 模式的参考时间(用于相对时间语义解析)。""" + if not value: + return datetime.now() + formats = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y/%m/%d", + "%Y-%m-%d", + ] + text = str(value).strip() + for fmt in formats: + try: + return datetime.strptime(text, fmt) + except ValueError: + continue + logger.warning( + f"无法解析 chat_reference_time={value},将回退为当前本地时间" + ) + return datetime.now() + + async def _extract_chat_time_meta_with_llm( + self, + text: str, + model_config: Any, + ) -> Optional[Dict[str, Any]]: + """ + 使用 LLM 从聊天文本语义中抽取时间信息。 + 支持将相对时间表达转换为绝对时间。 + """ + if not text.strip(): + return None + + reference_now = self.chat_reference_dt.strftime("%Y/%m/%d %H:%M") + prompt = f"""You are a time extraction engine for chat logs. +Extract temporal information from the following chat paragraph. + +Rules: +1. Use semantic understanding, not regex matching. +2. Convert relative expressions (e.g., yesterday evening, last Friday morning) to absolute local datetime using reference_now. +3. If a time span exists, return event_time_start/event_time_end. +4. If only one point in time exists, return event_time. +5. If no reliable time can be inferred, return all time fields as null. +6. Output ONLY valid JSON. No markdown, no explanation. + +reference_now: {reference_now} +timezone: local system timezone + +Allowed output formats for time values: +- "YYYY/MM/DD" +- "YYYY/MM/DD HH:mm" + +JSON schema: +{{ + "event_time": null, + "event_time_start": null, + "event_time_end": null, + "time_range": null, + "time_granularity": "day", + "time_confidence": 0.0 +}} + +Chat paragraph: +\"\"\"{text}\"\"\" +""" + try: + result = await self._llm_call(prompt, model_config) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取失败: {e}") + return None + + if not isinstance(result, dict): + return None + + raw_time_meta = { + "event_time": result.get("event_time"), + "event_time_start": result.get("event_time_start"), + "event_time_end": result.get("event_time_end"), + "time_range": result.get("time_range"), + "time_granularity": result.get("time_granularity"), + "time_confidence": result.get("time_confidence"), + } + try: + normalized = normalize_time_meta(raw_time_meta) + except Exception as e: + logger.warning(f"chat_log 时间语义抽取结果不可用,已忽略: {e}") + return None + + has_effective_time = any( + key in normalized + for key in ("event_time", "event_time_start", "event_time_end") + ) + if not has_effective_time: + return None + + return normalized + + def _determine_strategy(self, filename: str, content: str) -> BaseStrategy: + """Layer 1: Global Strategy Routing""" + strategy = select_import_strategy( + content, + override=self.target_type, + chat_log=self.chat_log, + ) + if self.chat_log: + logger.info(f"chat_log 模式: {filename} 强制使用 NarrativeStrategy") + elif strategy == ImportStrategy.QUOTE: + logger.info(f"Auto-detected Quote/Lyric type for {filename}") + + if strategy == ImportStrategy.FACTUAL: + return FactualStrategy(filename) + if strategy == ImportStrategy.QUOTE: + return QuoteStrategy(filename) + return NarrativeStrategy(filename) + + def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[BaseStrategy]: + """Layer 2: Chunk-level rescue strategies""" + # If we are already in Quote strategy, no need to rescue + if chunk.type == StratKnowledgeType.QUOTE: + return None + + if looks_like_quote_text(chunk.chunk.text): + logger.info(f" > Rescuing chunk {chunk.chunk.index} as Quote") + return QuoteStrategy(filename) + + return None + + async def process_and_import(self): + if not await self.initialize(): return + + files = list(RAW_DIR.glob("*.txt")) + logger.info(f"扫描到 {len(files)} 个文件 in {RAW_DIR}") + + if not files: return + + tasks = [] + for file_path in files: + tasks.append(asyncio.create_task(self._process_single_file(file_path))) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + success_count = sum(1 for r in results if r is True) + logger.info(f"本次主处理完成,共成功处理 {success_count}/{len(files)} 个文件") + + if self.vector_store: self.vector_store.save() + if self.graph_store: self.graph_store.save() + + async def _process_single_file(self, file_path: Path) -> bool: + filename = file_path.name + async with self.semaphore: + try: + content = self.load_file(file_path) + file_hash = self.get_file_hash(content) + + if not self.force and filename in self.manifest: + record = self.manifest[filename] + if record.get("hash") == file_hash and record.get("imported"): + logger.info(f"跳过已导入文件: {filename}") + return False + + logger.info(f">>> 开始处理: {filename}") + + # 1. Strategy Selection + strategy = self._determine_strategy(filename, content) + logger.info(f" 策略: {strategy.__class__.__name__}") + + # 2. Split (Strategy-Aware) + initial_chunks = strategy.split(content) + logger.info(f" 初步分块: {len(initial_chunks)}") + + processed_data = {"paragraphs": [], "entities": [], "relations": []} + + # 3. Extract Loop + model_config = await self._select_model() + + for i, chunk in enumerate(initial_chunks): + current_strategy = strategy + # Layer 2: Chunk Rescue + rescue_strategy = self._chunk_rescue(chunk, filename) + if rescue_strategy: + # Re-split? No, just re-process this text as a single chunk using the rescue strategy + # But rescue strategy might want to split it further? + # Simplification: Treat the whole chunk text as one block for the rescue strategy + # OR create a single chunk object for it. + # Creating a new chunk using rescue strategy logic might be complex if split behavior differs. + # Let's just instantiate a chunk of the new type manually + chunk.type = StratKnowledgeType.QUOTE + chunk.flags.verbatim = True + chunk.flags.requires_llm = False # Quotes don't usually need LLM + current_strategy = rescue_strategy + + # Extraction + if chunk.flags.requires_llm: + result_chunk = await current_strategy.extract(chunk, lambda p: self._llm_call(p, model_config)) + else: + # For quotes, extract might be just pass through or regex + result_chunk = await current_strategy.extract(chunk) + + time_meta = None + if self.chat_log: + time_meta = await self._extract_chat_time_meta_with_llm( + result_chunk.chunk.text, + model_config, + ) + + # Normalize Data + self._normalize_and_aggregate( + result_chunk, + processed_data, + time_meta=time_meta, + ) + + logger.info(f" 已处理块 {i+1}/{len(initial_chunks)}") + + # 4. Save Json + json_path = PROCESSED_DIR / f"{file_path.stem}.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(processed_data, f, ensure_ascii=False, indent=2) + + # 5. Import to DB + async with self.storage_lock: + await self._import_to_db(processed_data) + + self.manifest[filename] = { + "hash": file_hash, + "timestamp": time.time(), + "imported": True + } + self._save_manifest() + self.vector_store.save() + self.graph_store.save() + logger.info(f"✅ 文件 {filename} 处理并导入完成") + return True + + except Exception as e: + logger.error(f"❌ 处理失败 {filename}: {e}") + import traceback + traceback.print_exc() + return False + + def _normalize_and_aggregate( + self, + chunk: ProcessedChunk, + all_data: Dict, + time_meta: Optional[Dict[str, Any]] = None, + ): + """Convert strategy-specific data to unified generic format for storage.""" + # Generic fields + para_item = { + "content": chunk.chunk.text, + "source": chunk.source.file, + "knowledge_type": resolve_stored_knowledge_type( + chunk.type.value, + content=chunk.chunk.text, + ).value, + "entities": [], + "relations": [] + } + + data = chunk.data + + # 1. Triples (Factual) + if "triples" in data: + for t in data["triples"]: + para_item["relations"].append({ + "subject": t.get("subject"), + "predicate": t.get("predicate"), + "object": t.get("object") + }) + # Auto-add entities from triples + para_item["entities"].extend([t.get("subject"), t.get("object")]) + + # 2. Events & Relations (Narrative) + if "events" in data: + # Store events as content/metadata? Or entities? + # For now maybe just keep them in logic, or add as 'Event' entities? + # Creating entities for events is good. + para_item["entities"].extend(data["events"]) + + if "relations" in data: # Narrative also outputs relations list + para_item["relations"].extend(data["relations"]) + for r in data["relations"]: + para_item["entities"].extend([r.get("subject"), r.get("object")]) + + # 3. Verbatim Entities (Quote) + if "verbatim_entities" in data: + para_item["entities"].extend(data["verbatim_entities"]) + + # Dedupe per paragraph + para_item["entities"] = list(set([e for e in para_item["entities"] if e])) + + if time_meta: + para_item["time_meta"] = time_meta + + all_data["paragraphs"].append(para_item) + all_data["entities"].extend(para_item["entities"]) + if "relations" in para_item: + all_data["relations"].extend(para_item["relations"]) + + @retry( + retry=retry_if_exception_type((LLMGenerationError, json.JSONDecodeError)), + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=2, max=10), + before_sleep=_log_before_retry + ) + async def _llm_call(self, prompt: str, model_config: Any) -> Dict: + """Generic LLM Caller""" + success, response, _, _ = await llm_api.generate_with_model( + prompt=prompt, + model_config=model_config, + request_type="Script.ProcessKnowledge" + ) + if success: + txt = response.strip() + if "```" in txt: + txt = txt.split("```json")[-1].split("```")[0].strip() + try: + return json.loads(txt) + except json.JSONDecodeError: + # Fallback: try to find first { and last } + start = txt.find('{') + end = txt.rfind('}') + if start != -1 and end != -1: + return json.loads(txt[start:end+1]) + raise + else: + raise LLMGenerationError("LLM generation failed") + + async def _select_model(self) -> Any: + models = llm_api.get_available_models() + if not models: raise ValueError("No LLM models") + + config_model = self.plugin_config.get("advanced", {}).get("extraction_model", "auto") + if config_model != "auto" and config_model in models: + return models[config_model] + + for task_key in ["lpmm_entity_extract", "lpmm_rdf_build", "embedding"]: + if task_key in models: return models[task_key] + + return models[list(models.keys())[0]] + + # Re-use existing methods + async def _add_entity_with_vector(self, name: str, source_paragraph: Optional[str] = None) -> str: + # Same as before + hash_value = self.metadata_store.add_entity(name, source_paragraph=source_paragraph) + self.graph_store.add_nodes([name]) + try: + emb = await self.embedding_manager.encode(name) + try: + self.vector_store.add(emb.reshape(1, -1), [hash_value]) + except ValueError: pass + except Exception: pass + return hash_value + + async def import_json_data(self, data: Dict, filename: str = "script_import", progress_callback=None): + """Public import entrypoint for pre-processed JSON payloads.""" + if not self.storage_lock: + raise RuntimeError("Importer is not initialized. Call initialize() first.") + + async with self.storage_lock: + await self._import_to_db(data, progress_callback=progress_callback) + self.manifest[filename] = { + "hash": self.get_file_hash(json.dumps(data, ensure_ascii=False, sort_keys=True)), + "timestamp": time.time(), + "imported": True, + } + self._save_manifest() + self.vector_store.save() + self.graph_store.save() + + async def _import_to_db(self, data: Dict, progress_callback=None): + # Same logic, but ensure robust + with self.graph_store.batch_update(): + for item in data.get("paragraphs", []): + paragraph = normalize_paragraph_import_item( + item, + default_source="script", + ) + content = paragraph["content"] + source = paragraph["source"] + k_type_val = paragraph["knowledge_type"] + + h_val = self.metadata_store.add_paragraph( + content=content, + source=source, + knowledge_type=k_type_val, + time_meta=paragraph["time_meta"], + ) + + if h_val not in self.vector_store: + try: + emb = await self.embedding_manager.encode(content) + self.vector_store.add(emb.reshape(1, -1), [h_val]) + except Exception as e: + logger.error(f" Vector fail: {e}") + + para_entities = paragraph["entities"] + for entity in para_entities: + if entity: + await self._add_entity_with_vector(entity, source_paragraph=h_val) + + para_relations = paragraph["relations"] + for rel in para_relations: + s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") + if s and p and o: + await self._add_entity_with_vector(s, source_paragraph=h_val) + await self._add_entity_with_vector(o, source_paragraph=h_val) + confidence = float(rel.get("confidence", 1.0) or 1.0) + rel_meta = rel.get("metadata", {}) + write_vector = self._should_write_relation_vectors() + if self.relation_write_service is not None: + await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=confidence, + source_paragraph=h_val, + metadata=rel_meta if isinstance(rel_meta, dict) else {}, + write_vector=write_vector, + ) + else: + rel_hash = self.metadata_store.add_relation( + s, + p, + o, + confidence=confidence, + source_paragraph=h_val, + metadata=rel_meta if isinstance(rel_meta, dict) else {}, + ) + self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + + if progress_callback: progress_callback(1) + + async def close(self): + if self.metadata_store: self.metadata_store.close() + + def _save_manifest(self): + with open(MANIFEST_PATH, "w", encoding="utf-8") as f: + json.dump(self.manifest, f, ensure_ascii=False, indent=2) + +async def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + if not global_config: return + + importer = AutoImporter( + force=args.force, + clear_manifest=args.clear_manifest, + target_type=args.type, + concurrency=args.concurrency, + chat_log=args.chat_log, + chat_reference_time=args.chat_reference_time, + ) + await importer.process_and_import() + await importer.close() + +if __name__ == "__main__": + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/rebuild_episodes.py b/plugins/A_memorix/scripts/rebuild_episodes.py new file mode 100644 index 00000000..b6adaa21 --- /dev/null +++ b/plugins/A_memorix/scripts/rebuild_episodes.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 +"""Episode source 级重建工具。""" + +from __future__ import annotations + +import argparse +import asyncio +import sys +from pathlib import Path +from typing import Any, Dict, List + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + +try: + import tomlkit # type: ignore +except Exception: # pragma: no cover + tomlkit = None + +from A_memorix.core.storage import MetadataStore +from A_memorix.core.utils.episode_service import EpisodeService + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Rebuild A_Memorix episodes by source") + parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="插件数据目录") + parser.add_argument("--source", type=str, help="指定单个 source 入队/重建") + parser.add_argument("--all", action="store_true", help="对所有 source 入队/重建") + parser.add_argument("--wait", action="store_true", help="在脚本内同步执行重建") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + + +def _load_plugin_config() -> Dict[str, Any]: + config_path = PLUGIN_ROOT / "config.toml" + if tomlkit is None or not config_path.exists(): + return {} + try: + with open(config_path, "r", encoding="utf-8") as handle: + parsed = tomlkit.load(handle) + return dict(parsed) if isinstance(parsed, dict) else {} + except Exception: + return {} + + +def _resolve_sources(store: MetadataStore, *, source: str | None, rebuild_all: bool) -> List[str]: + if rebuild_all: + return list(store.list_episode_sources_for_rebuild()) + token = str(source or "").strip() + if not token: + raise ValueError("必须提供 --source 或 --all") + return [token] + + +async def _run_rebuilds(store: MetadataStore, plugin_config: Dict[str, Any], sources: List[str]) -> int: + service = EpisodeService(metadata_store=store, plugin_config=plugin_config) + failures: List[str] = [] + for source in sources: + started = store.mark_episode_source_running(source) + if not started: + failures.append(f"{source}: unable_to_mark_running") + continue + try: + result = await service.rebuild_source(source) + store.mark_episode_source_done(source) + print( + "rebuilt" + f" source={source}" + f" paragraphs={int(result.get('paragraph_count') or 0)}" + f" groups={int(result.get('group_count') or 0)}" + f" episodes={int(result.get('episode_count') or 0)}" + f" fallback={int(result.get('fallback_count') or 0)}" + ) + except Exception as exc: + err = str(exc)[:500] + store.mark_episode_source_failed(source, err) + failures.append(f"{source}: {err}") + print(f"failed source={source} error={err}") + + if failures: + for item in failures: + print(item) + return 1 + return 0 + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + if bool(args.all) == bool(args.source): + parser.error("必须且只能选择一个:--source 或 --all") + + store = MetadataStore(data_dir=Path(args.data_dir) / "metadata") + store.connect() + try: + sources = _resolve_sources(store, source=args.source, rebuild_all=bool(args.all)) + if not sources: + print("no sources to rebuild") + return 0 + + enqueued = 0 + reason = "script_rebuild_all" if args.all else "script_rebuild_source" + for source in sources: + enqueued += int(store.enqueue_episode_source_rebuild(source, reason=reason)) + print(f"enqueued={enqueued} sources={len(sources)}") + + if not args.wait: + return 0 + + plugin_config = _load_plugin_config() + return asyncio.run(_run_rebuilds(store, plugin_config, sources)) + finally: + store.close() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/release_vnext_migrate.py b/plugins/A_memorix/scripts/release_vnext_migrate.py new file mode 100644 index 00000000..0922fd0b --- /dev/null +++ b/plugins/A_memorix/scripts/release_vnext_migrate.py @@ -0,0 +1,731 @@ +#!/usr/bin/env python3 +""" +vNext release migration entrypoint for A_Memorix. + +Subcommands: +- preflight: detect legacy config/data/schema risks +- migrate: offline migrate config + vectors + metadata schema + graph edge hash map +- verify: strict post-migration consistency checks +""" + +from __future__ import annotations + +import argparse +import json +import pickle +import sqlite3 +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix vNext release migration tool") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="config.toml path (default: plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default="", + help="optional data dir override; default resolved from config.storage.data_dir", + ) + parser.add_argument("--json-out", default="", help="optional JSON report output path") + + sub = parser.add_subparsers(dest="command", required=True) + + p_preflight = sub.add_parser("preflight", help="scan legacy risks") + p_preflight.add_argument("--strict", action="store_true", help="return 1 if any error check exists") + + p_migrate = sub.add_parser("migrate", help="run offline migration") + p_migrate.add_argument("--dry-run", action="store_true", help="only print planned changes") + p_migrate.add_argument( + "--verify-after", + action="store_true", + help="run verify automatically after migrate", + ) + + p_verify = sub.add_parser("verify", help="post-migration verification") + p_verify.add_argument("--strict", action="store_true", help="return 1 if any error check exists") + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +try: + from core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore + from core.storage.metadata_store import SCHEMA_VERSION +except Exception as e: # pragma: no cover + print(f"❌ failed to import storage modules: {e}") + raise SystemExit(2) + + +@dataclass +class CheckItem: + code: str + level: str + message: str + details: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + out = { + "code": self.code, + "level": self.level, + "message": self.message, + } + if self.details: + out["details"] = self.details + return out + + +def _read_toml(path: Path) -> Dict[str, Any]: + text = path.read_text(encoding="utf-8") + return tomlkit.parse(text) + + +def _write_toml(path: Path, data: Dict[str, Any]) -> None: + path.write_text(tomlkit.dumps(data), encoding="utf-8") + + +def _get_nested(obj: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: + cur: Any = obj + for k in keys: + if not isinstance(cur, dict) or k not in cur: + return default + cur = cur[k] + return cur + + +def _ensure_table(obj: Dict[str, Any], key: str) -> Dict[str, Any]: + if key not in obj or not isinstance(obj[key], dict): + obj[key] = tomlkit.table() + return obj[key] + + +def _resolve_data_dir(config_doc: Dict[str, Any], explicit_data_dir: Optional[str]) -> Path: + if explicit_data_dir: + return Path(explicit_data_dir).expanduser().resolve() + raw = str(_get_nested(config_doc, ("storage", "data_dir"), "./data") or "./data").strip() + if raw.startswith("."): + return (PLUGIN_ROOT / raw).resolve() + return Path(raw).expanduser().resolve() + + +def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: + row = conn.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1", + (table,), + ).fetchone() + return row is not None + + +def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: + hashes: List[str] = [] + if _sqlite_table_exists(conn, "relations"): + rows = conn.execute("SELECT hash FROM relations").fetchall() + hashes.extend(str(r[0]) for r in rows if r and r[0]) + if _sqlite_table_exists(conn, "deleted_relations"): + rows = conn.execute("SELECT hash FROM deleted_relations").fetchall() + hashes.extend(str(r[0]) for r in rows if r and r[0]) + + alias_map: Dict[str, str] = {} + conflicts: Dict[str, set[str]] = {} + for h in hashes: + if len(h) != 64: + continue + alias = h[:32] + old = alias_map.get(alias) + if old is None: + alias_map[alias] = h + continue + if old != h: + conflicts.setdefault(alias, set()).update({old, h}) + return {k: sorted(v) for k, v in conflicts.items()} + + +def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: + if not _sqlite_table_exists(conn, "paragraphs"): + return [] + + allowed = {item.value for item in KnowledgeType} + rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() + invalid: List[str] = [] + for row in rows: + raw = row[0] + value = str(raw).strip().lower() if raw is not None else "" + if value not in allowed: + invalid.append(str(raw) if raw is not None else "") + return sorted(set(invalid)) + + +def _guess_vector_dimension(config_doc: Dict[str, Any], vectors_dir: Path) -> int: + meta_path = vectors_dir / "vectors_metadata.pkl" + if meta_path.exists(): + try: + with open(meta_path, "rb") as f: + meta = pickle.load(f) + dim = int(meta.get("dimension", 0)) + if dim > 0: + return dim + except Exception: + pass + try: + dim_cfg = int(_get_nested(config_doc, ("embedding", "dimension"), 1024)) + if dim_cfg > 0: + return dim_cfg + except Exception: + pass + return 1024 + + +def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: + checks: List[CheckItem] = [] + facts: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + } + + if not config_path.exists(): + checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) + return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} + + config_doc = _read_toml(config_path) + tool_mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() + summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) + summary_knowledge_type = str( + _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" + ).strip().lower() + quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() + + facts["routing.tool_search_mode"] = tool_mode + facts["summarization.model_name_type"] = type(summary_model).__name__ + facts["summarization.default_knowledge_type"] = summary_knowledge_type + facts["embedding.quantization_type"] = quantization + + if tool_mode == "legacy": + checks.append( + CheckItem( + "CP-04", + "error", + "routing.tool_search_mode=legacy is no longer accepted at runtime", + ) + ) + elif tool_mode not in {"forward", "disabled"}: + checks.append( + CheckItem( + "CP-04", + "error", + f"routing.tool_search_mode invalid value: {tool_mode}", + ) + ) + + if isinstance(summary_model, str): + checks.append( + CheckItem( + "CP-11", + "error", + "summarization.model_name must be List[str], string legacy format detected", + ) + ) + elif not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): + checks.append( + CheckItem( + "CP-11", + "error", + "summarization.model_name must be List[str]", + ) + ) + + if summary_knowledge_type not in {item.value for item in KnowledgeType}: + checks.append( + CheckItem( + "CP-13", + "error", + f"invalid summarization.default_knowledge_type: {summary_knowledge_type}", + ) + ) + + if quantization != "int8": + checks.append( + CheckItem( + "UG-07", + "error", + "embedding.quantization_type must be int8 in vNext", + ) + ) + + vectors_dir = data_dir / "vectors" + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + facts["vectors.npy_exists"] = npy_path.exists() + facts["vectors.bin_exists"] = bin_path.exists() + facts["vectors_ids.bin_exists"] = ids_bin_path.exists() + + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + checks.append( + CheckItem( + "CP-07", + "error", + "legacy vectors.npy detected; offline migrate required", + {"npy_path": str(npy_path)}, + ) + ) + + metadata_db = data_dir / "metadata" / "metadata.db" + facts["metadata_db_exists"] = metadata_db.exists() + relation_count = 0 + if metadata_db.exists(): + conn = sqlite3.connect(str(metadata_db)) + try: + has_schema_table = _sqlite_table_exists(conn, "schema_migrations") + facts["schema_migrations_exists"] = has_schema_table + if not has_schema_table: + checks.append( + CheckItem( + "CP-08", + "error", + "schema_migrations table missing (legacy metadata schema)", + ) + ) + else: + row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() + version = int(row[0]) if row and row[0] is not None else 0 + facts["schema_version"] = version + if version != SCHEMA_VERSION: + checks.append( + CheckItem( + "CP-08", + "error", + f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", + ) + ) + + if _sqlite_table_exists(conn, "relations"): + row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() + relation_count = int(row[0]) if row and row[0] is not None else 0 + facts["relations_count"] = relation_count + + conflicts = _collect_hash_alias_conflicts(conn) + facts["alias_conflict_count"] = len(conflicts) + if conflicts: + checks.append( + CheckItem( + "CP-05", + "error", + "32-bit relation hash alias conflict detected", + {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, + ) + ) + + invalid_knowledge_types = _collect_invalid_knowledge_types(conn) + facts["invalid_knowledge_type_values"] = invalid_knowledge_types + if invalid_knowledge_types: + checks.append( + CheckItem( + "CP-12", + "error", + "invalid paragraph knowledge_type values detected", + {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, + ) + ) + finally: + conn.close() + else: + checks.append( + CheckItem( + "META-00", + "warning", + "metadata.db not found, schema checks skipped", + ) + ) + + graph_meta_path = data_dir / "graph" / "graph_metadata.pkl" + facts["graph_metadata_exists"] = graph_meta_path.exists() + if relation_count > 0: + if not graph_meta_path.exists(): + checks.append( + CheckItem( + "CP-06", + "error", + "relations exist but graph metadata missing", + ) + ) + else: + try: + with open(graph_meta_path, "rb") as f: + graph_meta = pickle.load(f) + edge_hash_map = graph_meta.get("edge_hash_map", {}) + edge_hash_map_size = len(edge_hash_map) if isinstance(edge_hash_map, dict) else 0 + facts["edge_hash_map_size"] = edge_hash_map_size + if edge_hash_map_size <= 0: + checks.append( + CheckItem( + "CP-06", + "error", + "edge_hash_map missing/empty while relations exist", + ) + ) + except Exception as e: + checks.append( + CheckItem( + "CP-06", + "error", + f"failed to read graph metadata: {e}", + ) + ) + + has_error = any(c.level == "error" for c in checks) + return { + "ok": not has_error, + "checks": [c.to_dict() for c in checks], + "facts": facts, + } + + +def _migrate_config(config_doc: Dict[str, Any]) -> Dict[str, Any]: + changes: Dict[str, Any] = {} + + routing = _ensure_table(config_doc, "routing") + mode_raw = str(routing.get("tool_search_mode", "forward") or "").strip().lower() + mode_new = mode_raw + if mode_raw == "legacy" or mode_raw not in {"forward", "disabled"}: + mode_new = "forward" + if mode_new != mode_raw: + routing["tool_search_mode"] = mode_new + changes["routing.tool_search_mode"] = {"old": mode_raw, "new": mode_new} + + summary = _ensure_table(config_doc, "summarization") + summary_model = summary.get("model_name", ["auto"]) + if isinstance(summary_model, str): + normalized = [summary_model.strip() or "auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": summary_model, "new": normalized} + elif not isinstance(summary_model, list): + normalized = ["auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": str(type(summary_model)), "new": normalized} + elif any(not isinstance(x, str) for x in summary_model): + normalized = [str(x).strip() for x in summary_model if str(x).strip()] + if not normalized: + normalized = ["auto"] + summary["model_name"] = normalized + changes["summarization.model_name"] = {"old": summary_model, "new": normalized} + + default_knowledge_type = str(summary.get("default_knowledge_type", "narrative") or "").strip().lower() + allowed_knowledge_types = {item.value for item in KnowledgeType} + if default_knowledge_type not in allowed_knowledge_types: + summary["default_knowledge_type"] = "narrative" + changes["summarization.default_knowledge_type"] = { + "old": default_knowledge_type, + "new": "narrative", + } + + embedding = _ensure_table(config_doc, "embedding") + quantization = str(embedding.get("quantization_type", "int8") or "").strip().lower() + if quantization != "int8": + embedding["quantization_type"] = "int8" + changes["embedding.quantization_type"] = {"old": quantization, "new": "int8"} + + return changes + + +def _migrate_impl(config_path: Path, data_dir: Path, dry_run: bool) -> Dict[str, Any]: + config_doc = _read_toml(config_path) + result: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + "dry_run": bool(dry_run), + "steps": {}, + } + + config_changes = _migrate_config(config_doc) + result["steps"]["config"] = {"changed": bool(config_changes), "changes": config_changes} + if config_changes and not dry_run: + _write_toml(config_path, config_doc) + + vectors_dir = data_dir / "vectors" + vectors_dir.mkdir(parents=True, exist_ok=True) + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + if dry_run: + result["steps"]["vector"] = {"migrated": False, "reason": "dry_run"} + else: + dim = _guess_vector_dimension(config_doc, vectors_dir) + store = VectorStore( + dimension=max(1, int(dim)), + quantization_type=QuantizationType.INT8, + data_dir=vectors_dir, + ) + result["steps"]["vector"] = store.migrate_legacy_npy(vectors_dir) + else: + result["steps"]["vector"] = {"migrated": False, "reason": "not_required"} + + metadata_dir = data_dir / "metadata" + metadata_dir.mkdir(parents=True, exist_ok=True) + metadata_db = metadata_dir / "metadata.db" + triples: List[Tuple[str, str, str, str]] = [] + relation_count = 0 + + metadata_result: Dict[str, Any] = {"migrated": False, "reason": "not_required"} + if metadata_db.exists(): + store = MetadataStore(data_dir=metadata_dir) + store.connect(enforce_schema=False) + try: + if dry_run: + metadata_result = {"migrated": False, "reason": "dry_run"} + else: + metadata_result = store.run_legacy_migration_for_vnext() + relation_count = int(store.count_relations()) + if relation_count > 0: + triples = [(str(s), str(p), str(o), str(h)) for s, p, o, h in store.get_all_triples()] + finally: + store.close() + result["steps"]["metadata"] = metadata_result + + graph_dir = data_dir / "graph" + graph_dir.mkdir(parents=True, exist_ok=True) + graph_matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") + graph_store = GraphStore(matrix_format=graph_matrix_format, data_dir=graph_dir) + graph_step: Dict[str, Any] = { + "rebuilt": False, + "mapped_hashes": 0, + "relation_count": relation_count, + "topology_rebuilt_from_relations": False, + } + if relation_count > 0: + if dry_run: + graph_step["reason"] = "dry_run" + else: + if graph_store.has_data(): + graph_store.load() + + mapped = graph_store.rebuild_edge_hash_map(triples) + + # 兜底:历史数据里 graph 节点/边与 relations 脱节时,直接从 relations 重建图。 + if mapped <= 0 or not graph_store.has_edge_hash_map(): + nodes = sorted({s for s, _, o, _ in triples} | {o for _, _, o, _ in triples}) + edges = [(s, o) for s, _, o, _ in triples] + hashes = [h for _, _, _, h in triples] + + graph_store.clear() + if nodes: + graph_store.add_nodes(nodes) + if edges: + mapped = graph_store.add_edges(edges, relation_hashes=hashes) + else: + mapped = 0 + graph_step.update( + { + "topology_rebuilt_from_relations": True, + "rebuilt_nodes": len(nodes), + "rebuilt_edges": int(graph_store.num_edges), + } + ) + + graph_store.save() + graph_step.update({"rebuilt": True, "mapped_hashes": int(mapped)}) + else: + graph_step["reason"] = "no_relations" + result["steps"]["graph"] = graph_step + + return result + + +def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: + checks: List[CheckItem] = [] + facts: Dict[str, Any] = { + "config_path": str(config_path), + "data_dir": str(data_dir), + } + + if not config_path.exists(): + checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) + return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} + + config_doc = _read_toml(config_path) + mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() + if mode not in {"forward", "disabled"}: + checks.append(CheckItem("CP-04", "error", f"invalid routing.tool_search_mode: {mode}")) + + summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) + if not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): + checks.append(CheckItem("CP-11", "error", "summarization.model_name must be List[str]")) + summary_knowledge_type = str( + _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" + ).strip().lower() + if summary_knowledge_type not in {item.value for item in KnowledgeType}: + checks.append( + CheckItem("CP-13", "error", f"invalid summarization.default_knowledge_type: {summary_knowledge_type}") + ) + + quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() + if quantization != "int8": + checks.append(CheckItem("UG-07", "error", "embedding.quantization_type must be int8")) + + vectors_dir = data_dir / "vectors" + npy_path = vectors_dir / "vectors.npy" + bin_path = vectors_dir / "vectors.bin" + ids_bin_path = vectors_dir / "vectors_ids.bin" + if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): + checks.append(CheckItem("CP-07", "error", "legacy vectors.npy still exists without bin migration")) + + metadata_dir = data_dir / "metadata" + store = MetadataStore(data_dir=metadata_dir) + try: + store.connect(enforce_schema=True) + schema_version = store.get_schema_version() + facts["schema_version"] = schema_version + if schema_version != SCHEMA_VERSION: + checks.append(CheckItem("CP-08", "error", f"schema version mismatch: {schema_version}")) + + relation_count = int(store.count_relations()) + facts["relations_count"] = relation_count + + conflicts = {} + invalid_knowledge_types: List[str] = [] + db_path = metadata_dir / "metadata.db" + if db_path.exists(): + conn = sqlite3.connect(str(db_path)) + try: + conflicts = _collect_hash_alias_conflicts(conn) + invalid_knowledge_types = _collect_invalid_knowledge_types(conn) + finally: + conn.close() + if conflicts: + checks.append( + CheckItem( + "CP-05", + "error", + "alias conflicts still exist after migration", + {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, + ) + ) + if invalid_knowledge_types: + checks.append( + CheckItem( + "CP-12", + "error", + "invalid paragraph knowledge_type values remain after migration", + {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, + ) + ) + + if relation_count > 0: + graph_dir = data_dir / "graph" + if not (graph_dir / "graph_metadata.pkl").exists(): + checks.append(CheckItem("CP-06", "error", "graph metadata missing while relations exist")) + else: + matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") + graph_store = GraphStore(matrix_format=matrix_format, data_dir=graph_dir) + graph_store.load() + if not graph_store.has_edge_hash_map(): + checks.append(CheckItem("CP-06", "error", "edge_hash_map is empty")) + except Exception as e: + checks.append(CheckItem("CP-08", "error", f"metadata strict connect failed: {e}")) + finally: + try: + store.close() + except Exception: + pass + + has_error = any(c.level == "error" for c in checks) + return { + "ok": not has_error, + "checks": [c.to_dict() for c in checks], + "facts": facts, + } + + +def _print_report(title: str, report: Dict[str, Any]) -> None: + print(f"=== {title} ===") + print(f"ok: {bool(report.get('ok', True))}") + facts = report.get("facts", {}) + if facts: + print("facts:") + for k in sorted(facts.keys()): + print(f" - {k}: {facts[k]}") + checks = report.get("checks", []) + if checks: + print("checks:") + for item in checks: + print(f" - [{item.get('level')}] {item.get('code')}: {item.get('message')}") + else: + print("checks: none") + + +def _write_json_if_needed(path: str, payload: Dict[str, Any]) -> None: + if not path: + return + out = Path(path).expanduser().resolve() + out.parent.mkdir(parents=True, exist_ok=True) + out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") + print(f"json_out: {out}") + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + config_path = Path(args.config).expanduser().resolve() + if not config_path.exists(): + print(f"❌ config not found: {config_path}") + return 2 + config_doc = _read_toml(config_path) + data_dir = _resolve_data_dir(config_doc, args.data_dir) + + if args.command == "preflight": + report = _preflight_impl(config_path, data_dir) + _print_report("vNext Preflight", report) + _write_json_if_needed(args.json_out, report) + has_error = any(item.get("level") == "error" for item in report.get("checks", [])) + if args.strict and has_error: + return 1 + return 0 + + if args.command == "migrate": + payload = _migrate_impl(config_path, data_dir, dry_run=bool(args.dry_run)) + print("=== vNext Migrate ===") + print(json.dumps(payload, ensure_ascii=False, indent=2)) + + verify_report = None + if args.verify_after and not args.dry_run: + verify_report = _verify_impl(config_path, data_dir) + _print_report("vNext Verify (after migrate)", verify_report) + payload["verify_after"] = verify_report + + _write_json_if_needed(args.json_out, payload) + if verify_report is not None: + has_error = any(item.get("level") == "error" for item in verify_report.get("checks", [])) + if has_error: + return 1 + return 0 + + if args.command == "verify": + report = _verify_impl(config_path, data_dir) + _print_report("vNext Verify", report) + _write_json_if_needed(args.json_out, report) + has_error = any(item.get("level") == "error" for item in report.get("checks", [])) + if args.strict and has_error: + return 1 + return 0 + + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/runtime_self_check.py b/plugins/A_memorix/scripts/runtime_self_check.py new file mode 100644 index 00000000..70c423ac --- /dev/null +++ b/plugins/A_memorix/scripts/runtime_self_check.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +"""Run A_Memorix runtime self-check against real embedding/runtime configuration.""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import sys +import tempfile +from pathlib import Path +from typing import Any + +import tomlkit + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +PROJECT_ROOT = PLUGIN_ROOT.parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PLUGIN_ROOT)) + + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="A_Memorix runtime self-check") + parser.add_argument( + "--config", + default=str(PLUGIN_ROOT / "config.toml"), + help="config.toml path (default: plugins/A_memorix/config.toml)", + ) + parser.add_argument( + "--data-dir", + default="", + help="optional data dir override; default resolved from config.storage.data_dir", + ) + parser.add_argument( + "--use-config-data-dir", + action="store_true", + help="use config.storage.data_dir directly instead of an isolated temp dir", + ) + parser.add_argument( + "--sample-text", + default="A_Memorix runtime self check", + help="sample text used for real embedding probe", + ) + parser.add_argument("--json", action="store_true", help="print JSON report") + return parser + + +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + raise SystemExit(0) + +from core.runtime.lifecycle_orchestrator import initialize_storage_async +from core.utils.runtime_self_check import run_embedding_runtime_self_check + + +def _load_config(path: Path) -> dict[str, Any]: + with open(path, "r", encoding="utf-8") as f: + raw = tomlkit.load(f) + return dict(raw) if isinstance(raw, dict) else {} + + +def _nested_get(config: dict[str, Any], key: str, default: Any = None) -> Any: + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +class _PluginStub: + def __init__(self, config: dict[str, Any]): + self.config = config + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.sparse_index = None + self.relation_write_service = None + + def get_config(self, key: str, default: Any = None) -> Any: + return _nested_get(self.config, key, default) + + +async def _main_async(args: argparse.Namespace) -> int: + config_path = Path(args.config).resolve() + if not config_path.exists(): + print(f"❌ 配置文件不存在: {config_path}") + return 2 + + config = _load_config(config_path) + temp_dir_ctx = None + if args.data_dir: + storage_dir = str(Path(args.data_dir).resolve()) + elif args.use_config_data_dir: + raw_data_dir = str(_nested_get(config, "storage.data_dir", "./data") or "./data").strip() + if raw_data_dir.startswith("."): + storage_dir = str((config_path.parent / raw_data_dir).resolve()) + else: + storage_dir = str(Path(raw_data_dir).resolve()) + else: + temp_dir_ctx = tempfile.TemporaryDirectory(prefix="memorix-runtime-self-check-") + storage_dir = temp_dir_ctx.name + + storage_cfg = config.setdefault("storage", {}) + storage_cfg["data_dir"] = storage_dir + + plugin = _PluginStub(config) + try: + await initialize_storage_async(plugin) + report = await run_embedding_runtime_self_check( + config=config, + vector_store=plugin.vector_store, + embedding_manager=plugin.embedding_manager, + sample_text=str(args.sample_text or "A_Memorix runtime self check"), + ) + report["data_dir"] = storage_dir + report["isolated_data_dir"] = temp_dir_ctx is not None + if args.json: + print(json.dumps(report, ensure_ascii=False, indent=2)) + else: + print("A_Memorix Runtime Self-Check") + print(f"ok: {report.get('ok')}") + print(f"code: {report.get('code')}") + print(f"message: {report.get('message')}") + print(f"configured_dimension: {report.get('configured_dimension')}") + print(f"vector_store_dimension: {report.get('vector_store_dimension')}") + print(f"detected_dimension: {report.get('detected_dimension')}") + print(f"encoded_dimension: {report.get('encoded_dimension')}") + print(f"elapsed_ms: {float(report.get('elapsed_ms', 0.0)):.2f}") + return 0 if bool(report.get("ok")) else 1 + finally: + if plugin.metadata_store is not None: + try: + plugin.metadata_store.close() + except Exception: + pass + if temp_dir_ctx is not None: + temp_dir_ctx.cleanup() + + +def main() -> int: + parser = _build_arg_parser() + args = parser.parse_args() + return asyncio.run(_main_async(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/pyproject.toml b/pyproject.toml index f6dd6646..d9ce5c5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "python-levenshtein", "quick-algo>=0.1.4", "rich>=14.0.0", + "scipy>=1.7.0", "sqlalchemy>=2.0.40", "sqlmodel>=0.0.24", "structlog>=25.4.0", diff --git a/requirements.txt b/requirements.txt index 6a72e1e4..50a5a746 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,6 +23,7 @@ python-multipart>=0.0.20 python-levenshtein quick-algo>=0.1.4 rich>=14.0.0 +scipy>=1.7.0 sqlalchemy>=2.0.40 sqlmodel>=0.0.24 structlog>=25.4.0 diff --git a/src/main.py b/src/main.py index c28b6025..1bfa91b0 100644 --- a/src/main.py +++ b/src/main.py @@ -167,6 +167,7 @@ async def main() -> None: system.schedule_tasks(), ) finally: + emoji_manager.shutdown() await memory_automation_service.shutdown() await get_plugin_runtime_manager().bridge_event("on_stop") await get_plugin_runtime_manager().stop() From a1540d7e1703bef48566906ec7a3bebc904deec7 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:38:36 +0800 Subject: [PATCH 04/14] =?UTF-8?q?feat:=20A=5FMemorix=EF=BC=9A=E5=8A=A0?= =?UTF-8?q?=E5=BC=BA=E4=B8=A5=E6=A0=BC=E6=A8=A1=E5=BC=8F=E3=80=81=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E4=B8=8E=E5=88=A0=E9=99=A4=E8=AF=AD=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在 A_Memorix 中强制更严格的检索语义,并改进错误传播与删除结果报告。 强制/校验受支持的搜索模式(search/time/hybrid/episode/aggregate);移除 semantic 模式,并对不支持的模式返回明确错误。将 kernel 和 plugin 构造函数中的默认值从 hybrid 改为 search。(plugins/A_memorix/core/runtime/sdk_memory_kernel.py, plugins/A_memorix/plugin.py) 对 time/hybrid 模式要求必须提供 time_start/time_end,并在文档、快速开始和 README 中体现该语义。(plugins/A_memorix/QUICK_START.md, plugins/A_memorix/README.md) 改进删除预览/执行语义:跟踪“请求的来源”与“匹配的来源”,基于匹配/删除项计算成功状态,并返回详细计数(requested_source_count、matched_source_count、deleted_paragraph_count、error)。修复来源删除逻辑,使其基于匹配到的来源执行删除。(plugins/A_memorix/core/runtime/sdk_memory_kernel.py) 在搜索执行中移除遗留的 semantic 映射,并规范化 query_type 处理。(plugins/A_memorix/core/utils/search_execution_service.py) 向调用方传播后端搜索错误:为 MemorySearchResult 增加 success/error 字段,兼容多种运行时响应封装,并在异常时返回失败结果。更新调用方以处理并报告搜索失败。(src/services/memory_service.py, src/plugin_runtime/capabilities/data.py, src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py, src/memory_system/retrieval_tools/query_long_term_memory.py) --- .gitignore | 1 + plugins/A_memorix/QUICK_START.md | 6 ++ plugins/A_memorix/README.md | 14 +++ .../core/runtime/sdk_memory_kernel.py | 88 +++++++++++++------ .../core/utils/search_execution_service.py | 7 +- plugins/A_memorix/plugin.py | 2 +- .../brain_chat/PFC/pfc_KnowledgeFetcher.py | 10 +++ .../retrieval_tools/query_long_term_memory.py | 5 +- src/plugin_runtime/capabilities/data.py | 2 + src/services/memory_service.py | 37 +++++++- 10 files changed, 135 insertions(+), 37 deletions(-) diff --git a/.gitignore b/.gitignore index 36853b8a..16ee78ca 100644 --- a/.gitignore +++ b/.gitignore @@ -361,3 +361,4 @@ packages/ ## Claude Code and OMC data .claude/ .omc/ +/.venv312 diff --git a/plugins/A_memorix/QUICK_START.md b/plugins/A_memorix/QUICK_START.md index 76750453..7159a35b 100644 --- a/plugins/A_memorix/QUICK_START.md +++ b/plugins/A_memorix/QUICK_START.md @@ -109,6 +109,11 @@ python plugins/A_memorix/scripts/migrate_person_memory_points.py --help `mode` 支持:`search/time/hybrid/episode/aggregate` +严格语义说明: + +- `semantic` 模式已移除,传入会返回参数错误。 +- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误(不会再当作“未命中”)。 + ### 5.2 写入摘要 ```json @@ -190,6 +195,7 @@ python plugins/A_memorix/scripts/migrate_person_memory_points.py --help 1. 先看 `memory_stats` 是否有段落/关系 2. 检查 `chat_id`、`person_id` 过滤条件是否过严 3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误 +4. 若返回包含 `error` 字段,优先按错误提示修正 mode/时间参数 ### Q2: 启动时报向量维度不一致 diff --git a/plugins/A_memorix/README.md b/plugins/A_memorix/README.md index 1afb1b5f..2c59629a 100644 --- a/plugins/A_memorix/README.md +++ b/plugins/A_memorix/README.md @@ -59,6 +59,20 @@ A_Memorix 是面向 MaiBot SDK 的 `memory_provider` 插件。 | `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` | | `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` | +### 检索模式语义(严格) + +- `search_memory.mode` 仅支持:`search/time/hybrid/episode/aggregate`。 +- `semantic` 模式已移除,传入将返回参数错误。 +- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误,不再静默按“未命中”处理。 + +### 删除返回语义(source 模式) + +- `requested_source_count`:请求删除的 source 数。 +- `matched_source_count`:实际命中的 source 数(存在活跃段落)。 +- `deleted_paragraph_count`:实际删除段落数。 +- `deleted_count`:与实际删除对象一致;在 `source` 模式下等于 `deleted_paragraph_count`。 +- `success`:基于实际命中与实际删除判定,未命中 source 时返回 `false`。 + ## 调用示例 ```json diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py index 439afd3d..93c11bf7 100644 --- a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py +++ b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py @@ -36,7 +36,7 @@ logger = get_logger("A_Memorix.SDKMemoryKernel") class KernelSearchRequest: query: str = "" limit: int = 5 - mode: str = "hybrid" + mode: str = "search" chat_id: str = "" person_id: str = "" time_start: Optional[str | float] = None @@ -722,9 +722,19 @@ class SDKMemoryKernel: assert self.episode_retriever is not None assert self.aggregate_query_service is not None - mode = str(request.mode or "hybrid").strip().lower() or "hybrid" + mode = str(request.mode or "search").strip().lower() or "search" query = str(request.query or "").strip() limit = max(1, int(request.limit or 5)) + supported_modes = {"search", "time", "hybrid", "episode", "aggregate"} + if mode not in supported_modes: + return { + "summary": "", + "hits": [], + "error": ( + f"不支持的检索模式: {mode}(仅支持 search/time/hybrid/episode/aggregate," + "semantic 已移除)" + ), + } try: time_window = self._normalize_search_time_window(request.time_start, request.time_end) except ValueError as exc: @@ -760,7 +770,7 @@ class SDKMemoryKernel: filtered = self._filter_hits(hits, request.person_id) return {"summary": self._summary(filtered), "hits": filtered} - query_type = "search" if mode in {"search", "semantic"} else mode + query_type = mode runtime_config = self._build_runtime_config() result = await SearchExecutionService.execute( retriever=self.retriever, @@ -2691,7 +2701,13 @@ class SDKMemoryKernel: counts = {"relations": 0, "paragraphs": 0, "entities": 0, "sources": 0} vector_ids: List[str] = [] sources: List[str] = [] - target_hashes: Dict[str, List[str]] = {"relations": [], "paragraphs": [], "entities": [], "sources": []} + target_hashes: Dict[str, List[str]] = { + "relations": [], + "paragraphs": [], + "entities": [], + "sources": [], + "matched_sources": [], + } if act_mode == "relation": relation_rows = [row for row in (self.metadata_store.get_relation(hash_value) for hash_value in self._resolve_relation_hashes(str(normalized_selector.get("query", "") or ""))) if row] @@ -2721,21 +2737,26 @@ class SDKMemoryKernel: if act_mode == "source": source_tokens = self._resolve_source_targets(normalized_selector) target_hashes["sources"] = source_tokens - counts["sources"] = len(source_tokens) + counts["requested_sources"] = len(source_tokens) + matched_source_tokens: List[str] = [] for source in source_tokens: - sources.append(source) - paragraph_rows.extend( - self.metadata_store.query( - """ - SELECT * - FROM paragraphs - WHERE source = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY created_at ASC - """, - (source,), - ) + source_rows = self.metadata_store.query( + """ + SELECT * + FROM paragraphs + WHERE source = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY created_at ASC + """, + (source,), ) + if source_rows: + matched_source_tokens.append(source) + sources.append(source) + paragraph_rows.extend(source_rows) + target_hashes["matched_sources"] = matched_source_tokens + counts["sources"] = len(matched_source_tokens) + counts["matched_sources"] = len(matched_source_tokens) else: paragraph_rows = self._resolve_paragraph_targets(normalized_selector, include_deleted=False) paragraph_hashes = self._tokens([row.get("hash", "") for row in paragraph_rows]) @@ -2797,9 +2818,14 @@ class SDKMemoryKernel: sources = self._tokens(sources) vector_ids = self._tokens(vector_ids) - primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("sources", 0) + primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("matched_sources", 0) + success = ( + primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0 + if act_mode != "source" + else (counts.get("matched_sources", 0) > 0 and counts.get("paragraphs", 0) > 0) + ) return { - "success": primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0, + "success": success, "mode": act_mode, "selector": normalized_selector, "items": items, @@ -2807,7 +2833,9 @@ class SDKMemoryKernel: "vector_ids": vector_ids, "sources": sources, "target_hashes": target_hashes, - "error": "" if (primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0) else "未命中可删除内容", + "requested_source_count": counts.get("requested_sources", 0) if act_mode == "source" else 0, + "matched_source_count": counts.get("matched_sources", 0) if act_mode == "source" else 0, + "error": "" if success else "未命中可删除内容", } async def _preview_delete_action(self, *, mode: str, selector: Any) -> Dict[str, Any]: @@ -2826,6 +2854,8 @@ class SDKMemoryKernel: "mode": plan.get("mode"), "selector": plan.get("selector"), "counts": plan.get("counts", {}), + "requested_source_count": int(plan.get("requested_source_count", 0) or 0), + "matched_source_count": int(plan.get("matched_source_count", 0) or 0), "sources": plan.get("sources", []), "vector_ids": plan.get("vector_ids", []), "items": preview_items, @@ -2852,7 +2882,8 @@ class SDKMemoryKernel: paragraph_hashes = self._tokens((plan.get("target_hashes") or {}).get("paragraphs")) entity_hashes = self._tokens((plan.get("target_hashes") or {}).get("entities")) relation_hashes = self._tokens((plan.get("target_hashes") or {}).get("relations")) - source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) + requested_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) + matched_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("matched_sources")) try: if paragraph_hashes: @@ -2866,8 +2897,8 @@ class SDKMemoryKernel: tuple(paragraph_hashes), ) self.metadata_store.delete_external_memory_refs_by_paragraphs(paragraph_hashes) - if act_mode == "source" and source_tokens: - for source in source_tokens: + if act_mode == "source" and matched_source_tokens: + for source in matched_source_tokens: self.metadata_store.replace_episodes_for_source(source, []) if entity_hashes: @@ -2903,7 +2934,7 @@ class SDKMemoryKernel: self._rebuild_graph_from_metadata() self._persist() deleted_count = ( - len(source_tokens) + len(paragraph_hashes) if act_mode == "source" else len(paragraph_hashes) if act_mode == "paragraph" @@ -2911,8 +2942,9 @@ class SDKMemoryKernel: if act_mode == "entity" else len(relation_hashes) ) + success = bool(deleted_count > 0) result = { - "success": True, + "success": success, "mode": act_mode, "operation_id": operation.get("operation_id", ""), "counts": plan.get("counts", {}), @@ -2922,8 +2954,12 @@ class SDKMemoryKernel: "deleted_relation_count": len(relation_hashes), } if act_mode == "source": - result["deleted_source_count"] = len(source_tokens) + result["requested_source_count"] = len(requested_source_tokens) + result["matched_source_count"] = len(matched_source_tokens) + result["deleted_source_count"] = len(matched_source_tokens) result["deleted_paragraph_count"] = len(paragraph_hashes) + if not success: + result["error"] = "未命中可删除内容" return result except Exception as exc: conn.rollback() diff --git a/plugins/A_memorix/core/utils/search_execution_service.py b/plugins/A_memorix/core/utils/search_execution_service.py index efb2093f..7df243af 100644 --- a/plugins/A_memorix/core/utils/search_execution_service.py +++ b/plugins/A_memorix/core/utils/search_execution_service.py @@ -48,7 +48,7 @@ class SearchExecutionRequest: stream_id: Optional[str] = None group_id: Optional[str] = None user_id: Optional[str] = None - query_type: str = "search" # search|semantic|time|hybrid + query_type: str = "search" # search|time|hybrid query: str = "" top_k: Optional[int] = None time_from: Optional[str] = None @@ -100,10 +100,7 @@ class SearchExecutionService: @staticmethod def _normalize_query_type(raw_query_type: str) -> str: - query_type = _sanitize_text(raw_query_type).lower() or "search" - if query_type == "semantic": - return "search" - return query_type + return _sanitize_text(raw_query_type).lower() or "search" @staticmethod def _resolve_runtime_component( diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py index 390515f5..841106a4 100644 --- a/plugins/A_memorix/plugin.py +++ b/plugins/A_memorix/plugin.py @@ -75,7 +75,7 @@ class AMemorixPlugin(MaiBotPlugin): self, query: str = "", limit: int = 5, - mode: str = "hybrid", + mode: str = "search", chat_id: str = "", person_id: str = "", time_start: str | float | None = None, diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index 4d47f609..3136f8be 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -71,11 +71,21 @@ class KnowledgeFetcher: "respect_filter": True, } result = await memory_service.search(query, **search_kwargs) + if not result.success: + logger.warning( + f"[私聊][{self.private_name}]长期记忆查询失败: {result.error or '未知错误'}" + ) + return f"长期记忆检索失败:{result.error or '未知错误'}" if not result.filtered and not result.hits and search_kwargs["person_id"]: fallback_kwargs = dict(search_kwargs) fallback_kwargs["person_id"] = "" logger.debug(f"[私聊][{self.private_name}]人物过滤未命中,退回仅按会话检索长期记忆") result = await memory_service.search(query, **fallback_kwargs) + if not result.success: + logger.warning( + f"[私聊][{self.private_name}]长期记忆回退查询失败: {result.error or '未知错误'}" + ) + return f"长期记忆检索失败:{result.error or '未知错误'}" knowledge_info = result.to_text(limit=5) if result.filtered: logger.debug(f"[私聊][{self.private_name}]长期记忆查询被聊天过滤策略跳过") diff --git a/src/memory_system/retrieval_tools/query_long_term_memory.py b/src/memory_system/retrieval_tools/query_long_term_memory.py index 57202f34..bf39c0cd 100644 --- a/src/memory_system/retrieval_tools/query_long_term_memory.py +++ b/src/memory_system/retrieval_tools/query_long_term_memory.py @@ -169,6 +169,9 @@ def _format_tool_result( query: str, time_range_text: str = "", ) -> str: + if not result.success: + return f"长期记忆查询失败:{result.error or '未知错误'}" + if not result.hits: if mode == "time": return f"在指定时间范围内未找到相关的长期记忆{time_range_text}" @@ -225,7 +228,7 @@ async def query_long_term_memory( return str(exc) time_range_text = f"(时间范围:{time_start_text} 至 {time_end_text})" - backend_mode = "hybrid" if normalized_mode == "search" else normalized_mode + backend_mode = normalized_mode try: result = await memory_service.search( diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index 06ddf5de..c8139c16 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -675,6 +675,8 @@ class RuntimeDataCapabilityMixin: from src.services.memory_service import memory_service result = await memory_service.search(query, limit=limit_value) + if not result.success: + return {"success": False, "error": result.error or "长期记忆检索失败"} knowledge_info = result.to_text(limit=limit_value) content = f"你知道这些知识: {knowledge_info}" if knowledge_info else f"你不太了解有关{query}的知识" return {"success": True, "content": content} diff --git a/src/services/memory_service.py b/src/services/memory_service.py index 6cbecd63..04f08ff6 100644 --- a/src/services/memory_service.py +++ b/src/services/memory_service.py @@ -41,6 +41,8 @@ class MemorySearchResult: summary: str = "" hits: List[MemoryHit] = field(default_factory=list) filtered: bool = False + success: bool = True + error: str = "" def to_text(self, limit: int = 5) -> str: if not self.hits: @@ -55,6 +57,8 @@ class MemorySearchResult: def to_dict(self) -> Dict[str, Any]: return { + "success": self.success, + "error": self.error, "summary": self.summary, "hits": [item.to_dict() for item in self.hits], "filtered": self.filtered, @@ -92,13 +96,33 @@ class MemoryService: runtime = get_plugin_runtime_manager() if not runtime.is_running: raise RuntimeError("plugin_runtime 未启动") - return await runtime.invoke_plugin( + response = await runtime.invoke_plugin( method="plugin.invoke_tool", plugin_id=PLUGIN_ID, component_name=component_name, args=args or {}, timeout_ms=max(1000, int(timeout_ms or 30000)), ) + # 兼容新旧运行时返回: + # - 旧版: 直接返回工具结果(dict) + # - 新版: 返回 Envelope,工具结果在 payload.result 中 + if isinstance(response, dict): + return response + payload = getattr(response, "payload", None) + if isinstance(payload, dict): + if isinstance(payload.get("result"), dict): + return payload["result"] + return payload + model_dump = getattr(response, "model_dump", None) + if callable(model_dump): + dumped = model_dump() + if isinstance(dumped, dict): + inner_payload = dumped.get("payload") + if isinstance(inner_payload, dict): + if isinstance(inner_payload.get("result"), dict): + return inner_payload["result"] + return inner_payload + return response async def _invoke_admin( self, @@ -134,7 +158,7 @@ class MemoryService: @staticmethod def _coerce_search_result(payload: Any) -> MemorySearchResult: if not isinstance(payload, dict): - return MemorySearchResult() + return MemorySearchResult(success=False, error="invalid_payload") hits: List[MemoryHit] = [] for item in payload.get("hits", []) or []: if not isinstance(item, dict): @@ -158,10 +182,15 @@ class MemoryService: title=str(item.get("title", "") or ""), ) ) + success_raw = payload.get("success") + error = str(payload.get("error", "") or "") + success = (not bool(error)) if success_raw is None else bool(success_raw) return MemorySearchResult( summary=str(payload.get("summary", "") or ""), hits=hits, filtered=bool(payload.get("filtered", False)), + success=success, + error=error, ) @staticmethod @@ -179,7 +208,7 @@ class MemoryService: query: str, *, limit: int = 5, - mode: str = "hybrid", + mode: str = "search", chat_id: str = "", person_id: str = "", time_start: str | float | None = None, @@ -212,7 +241,7 @@ class MemoryService: return self._coerce_search_result(payload) except Exception as exc: logger.warning("长期记忆搜索失败: %s", exc) - return MemorySearchResult() + return MemorySearchResult(success=False, error=str(exc)) async def ingest_summary( self, From 9d1977446b8dbfe4b0eda1c9d0d72f007aa36608 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 24 Mar 2026 13:45:12 +0800 Subject: [PATCH 05/14] =?UTF-8?q?feat:=E5=B0=86A=5Fmemorix=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=E4=BF=AE=E6=94=B9=E4=B8=BAsubmodule=E7=9A=84=E5=AF=BC?= =?UTF-8?q?=E5=85=A5=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/crowdin-bootstrap.yml | 1 + .github/workflows/crowdin-sync.yml | 3 + .github/workflows/docker-image-dev.yml | 24 +++ .github/workflows/docker-image-main.yml | 24 +++ .github/workflows/i18n-validate.yml | 2 + .github/workflows/precheck.yml | 58 ++++++ .github/workflows/publish-webui-dist.yml | 2 + .github/workflows/ruff-pr.yml | 1 + .github/workflows/ruff.yml | 1 + .gitmodules | 4 + docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md | 38 ++++ plugins/A_memorix/.gitattributes | 2 + plugins/A_memorix/.gitignore | 245 +++++++++++++++++++++++ scripts/run.sh | 8 + 14 files changed, 413 insertions(+) create mode 100644 .gitmodules create mode 100644 docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md create mode 100644 plugins/A_memorix/.gitattributes create mode 100644 plugins/A_memorix/.gitignore diff --git a/.github/workflows/crowdin-bootstrap.yml b/.github/workflows/crowdin-bootstrap.yml index 99e6895e..bb7226cc 100644 --- a/.github/workflows/crowdin-bootstrap.yml +++ b/.github/workflows/crowdin-bootstrap.yml @@ -36,6 +36,7 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ inputs.base_branch }} + submodules: recursive - name: Bootstrap committed target translations into Crowdin uses: crowdin/github-action@v2 diff --git a/.github/workflows/crowdin-sync.yml b/.github/workflows/crowdin-sync.yml index 126cfd9e..c01321e5 100644 --- a/.github/workflows/crowdin-sync.yml +++ b/.github/workflows/crowdin-sync.yml @@ -25,6 +25,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: Sync translations with Crowdin uses: crowdin/github-action@v2 with: @@ -57,6 +59,7 @@ jobs: - uses: actions/checkout@v4 with: ref: ${{ matrix.base_branch }} + submodules: recursive - name: Sync scheduled translations with Crowdin uses: crowdin/github-action@v2 with: diff --git a/.github/workflows/docker-image-dev.yml b/.github/workflows/docker-image-dev.yml index 63fdafee..da43d3a2 100644 --- a/.github/workflows/docker-image-dev.yml +++ b/.github/workflows/docker-image-dev.yml @@ -25,6 +25,7 @@ jobs: with: ref: dev fetch-depth: 0 + submodules: recursive # Clone required dependencies # - name: Clone maim_message @@ -33,6 +34,17 @@ jobs: - name: Clone lpmm run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM + - name: Verify A_Memorix submodule alignment + run: | + set -euo pipefail + git -C plugins/A_memorix fetch origin MaiBot_branch --depth=1 + LOCAL_SHA=$(git -C plugins/A_memorix rev-parse HEAD) + REMOTE_SHA=$(git -C plugins/A_memorix rev-parse FETCH_HEAD) + if [ "${LOCAL_SHA}" != "${REMOTE_SHA}" ]; then + echo "plugins/A_memorix is stale: local=${LOCAL_SHA}, remote=${REMOTE_SHA}" >&2 + exit 1 + fi + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: @@ -79,6 +91,7 @@ jobs: with: ref: dev fetch-depth: 0 + submodules: recursive # Clone required dependencies # - name: Clone maim_message @@ -87,6 +100,17 @@ jobs: - name: Clone lpmm run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM + - name: Verify A_Memorix submodule alignment + run: | + set -euo pipefail + git -C plugins/A_memorix fetch origin MaiBot_branch --depth=1 + LOCAL_SHA=$(git -C plugins/A_memorix rev-parse HEAD) + REMOTE_SHA=$(git -C plugins/A_memorix rev-parse FETCH_HEAD) + if [ "${LOCAL_SHA}" != "${REMOTE_SHA}" ]; then + echo "plugins/A_memorix is stale: local=${LOCAL_SHA}, remote=${REMOTE_SHA}" >&2 + exit 1 + fi + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: diff --git a/.github/workflows/docker-image-main.yml b/.github/workflows/docker-image-main.yml index 3d9b14ab..d7c21bd5 100644 --- a/.github/workflows/docker-image-main.yml +++ b/.github/workflows/docker-image-main.yml @@ -29,6 +29,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + submodules: recursive # Clone required dependencies # - name: Clone maim_message @@ -37,6 +38,17 @@ jobs: - name: Clone lpmm run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM + - name: Verify A_Memorix submodule alignment + run: | + set -euo pipefail + git -C plugins/A_memorix fetch origin MaiBot_branch --depth=1 + LOCAL_SHA=$(git -C plugins/A_memorix rev-parse HEAD) + REMOTE_SHA=$(git -C plugins/A_memorix rev-parse FETCH_HEAD) + if [ "${LOCAL_SHA}" != "${REMOTE_SHA}" ]; then + echo "plugins/A_memorix is stale: local=${LOCAL_SHA}, remote=${REMOTE_SHA}" >&2 + exit 1 + fi + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: @@ -82,6 +94,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + submodules: recursive # Clone required dependencies # - name: Clone maim_message @@ -90,6 +103,17 @@ jobs: - name: Clone lpmm run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM + - name: Verify A_Memorix submodule alignment + run: | + set -euo pipefail + git -C plugins/A_memorix fetch origin MaiBot_branch --depth=1 + LOCAL_SHA=$(git -C plugins/A_memorix rev-parse HEAD) + REMOTE_SHA=$(git -C plugins/A_memorix rev-parse FETCH_HEAD) + if [ "${LOCAL_SHA}" != "${REMOTE_SHA}" ]; then + echo "plugins/A_memorix is stale: local=${LOCAL_SHA}, remote=${REMOTE_SHA}" >&2 + exit 1 + fi + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 with: diff --git a/.github/workflows/i18n-validate.yml b/.github/workflows/i18n-validate.yml index 577d7808..b737eb2f 100644 --- a/.github/workflows/i18n-validate.yml +++ b/.github/workflows/i18n-validate.yml @@ -31,6 +31,8 @@ jobs: steps: - uses: actions/checkout@v4 + with: + submodules: recursive - uses: actions/setup-python@v5 with: python-version: "3.12" diff --git a/.github/workflows/precheck.yml b/.github/workflows/precheck.yml index bf6f9529..ba5c4647 100644 --- a/.github/workflows/precheck.yml +++ b/.github/workflows/precheck.yml @@ -7,6 +7,63 @@ permissions: issues: write jobs: + submodule-alignment-check: + runs-on: ubuntu-24.04 + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha }} + submodules: recursive + - name: Validate A_Memorix submodule strict alignment + env: + SUBMODULE_PATH: plugins/A_memorix + SUBMODULE_URL: https://github.com/A-Dawn/A_memorix.git + SUBMODULE_BRANCH: MaiBot_branch + run: | + set -euo pipefail + + if [ ! -f .gitmodules ]; then + echo "::error::.gitmodules is missing." + exit 1 + fi + + actual_path=$(git config -f .gitmodules --get submodule.plugins/A_memorix.path || true) + actual_url=$(git config -f .gitmodules --get submodule.plugins/A_memorix.url || true) + actual_branch=$(git config -f .gitmodules --get submodule.plugins/A_memorix.branch || true) + + if [ "${actual_path}" != "${SUBMODULE_PATH}" ]; then + echo "::error::submodule path mismatch: expected ${SUBMODULE_PATH}, got ${actual_path:-}" + exit 1 + fi + + if [ "${actual_url}" != "${SUBMODULE_URL}" ]; then + echo "::error::submodule url mismatch: expected ${SUBMODULE_URL}, got ${actual_url:-}" + exit 1 + fi + + if [ "${actual_branch}" != "${SUBMODULE_BRANCH}" ]; then + echo "::error::submodule branch mismatch: expected ${SUBMODULE_BRANCH}, got ${actual_branch:-}" + exit 1 + fi + + if [ ! -f "${SUBMODULE_PATH}/_manifest.json" ]; then + echo "::error::${SUBMODULE_PATH}/_manifest.json is missing. Run: git submodule update --init --recursive" + exit 1 + fi + + git -C "${SUBMODULE_PATH}" remote set-url origin "${SUBMODULE_URL}" + git -C "${SUBMODULE_PATH}" fetch origin "${SUBMODULE_BRANCH}" --depth=1 + + local_sha=$(git -C "${SUBMODULE_PATH}" rev-parse HEAD) + remote_sha=$(git -C "${SUBMODULE_PATH}" rev-parse FETCH_HEAD) + + if [ "${local_sha}" != "${remote_sha}" ]; then + echo "::error::submodule ${SUBMODULE_PATH} must match origin/${SUBMODULE_BRANCH} HEAD." + echo "local=${local_sha} remote=${remote_sha}" + echo "Please run: git submodule update --remote --recursive ${SUBMODULE_PATH} && git add ${SUBMODULE_PATH} && git commit" + exit 1 + fi conflict-check: runs-on: ubuntu-24.04 outputs: @@ -16,6 +73,7 @@ jobs: with: fetch-depth: 0 ref: ${{ github.event.pull_request.head.sha }} + submodules: recursive - name: Check Conflicts id: check-conflicts env: diff --git a/.github/workflows/publish-webui-dist.yml b/.github/workflows/publish-webui-dist.yml index 57424fb3..38526297 100644 --- a/.github/workflows/publish-webui-dist.yml +++ b/.github/workflows/publish-webui-dist.yml @@ -19,6 +19,8 @@ jobs: environment: webui steps: - uses: actions/checkout@v4 + with: + submodules: recursive - name: Setup Bun uses: oven-sh/setup-bun@v2 diff --git a/.github/workflows/ruff-pr.yml b/.github/workflows/ruff-pr.yml index 1176eb0c..06d1fc59 100644 --- a/.github/workflows/ruff-pr.yml +++ b/.github/workflows/ruff-pr.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + submodules: recursive - name: Install Ruff and Run Checks uses: astral-sh/ruff-action@v3 with: diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 22401da3..742c410d 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -23,6 +23,7 @@ jobs: with: fetch-depth: 0 ref: ${{ github.head_ref || github.ref_name }} + submodules: recursive - name: Install Ruff and Run Checks uses: astral-sh/ruff-action@v3 with: diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..3ddc6a14 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "plugins/A_memorix"] + path = plugins/A_memorix + url = https://github.com/A-Dawn/A_memorix.git + branch = MaiBot_branch diff --git a/docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md b/docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md new file mode 100644 index 00000000..ad0d7552 --- /dev/null +++ b/docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md @@ -0,0 +1,38 @@ +# A_Memorix 子模块维护说明(维护者内部文档) + +> 本文档用于维护者,不面向普通用户。 + +## 1. 基本约束 +- 子模块路径固定:`plugins/A_memorix` +- 子模块仓库固定:`https://github.com/A-Dawn/A_memorix.git` +- 子模块分支固定:`MaiBot_branch` +- 强约束:主仓内 `plugins/A_memorix` 指针必须等于远端 `origin/MaiBot_branch` 最新 HEAD + +## 2. 首次拉取/恢复子模块 +```bash +git submodule update --init --recursive +``` + +若目录为空或缺少 `_manifest.json`,先执行上面的命令再排查其他问题。 + +## 3. 维护者更新流程 +1. 先在外部仓 `MaiBot_branch` 完成目标功能合入。 +2. 在主仓执行: +```bash +git submodule update --remote --recursive plugins/A_memorix +git add plugins/A_memorix .gitmodules +git commit -m "chore(submodule): bump A_memorix" +``` + +## 4. CI 严格校验说明 +- PR Precheck 会校验: + - `.gitmodules` 的 path/url/branch 必须匹配固定值 + - 子模块指针必须等于远端 `MaiBot_branch` 最新 HEAD +- Docker 构建工作流在构建前也会执行同样的 fail-fast 对齐检查 + +## 5. 回滚策略 +- 回滚主仓提交会同时回滚子模块指针。 +- 但若回滚后的指针不再是远端 `MaiBot_branch` 最新 HEAD,CI 会阻断。 +- 处理方式: + - 先在外部仓移动/回滚 `MaiBot_branch` 到目标提交,再重跑; + - 或按团队流程申请一次性 CI 豁免。 diff --git a/plugins/A_memorix/.gitattributes b/plugins/A_memorix/.gitattributes new file mode 100644 index 00000000..dfe07704 --- /dev/null +++ b/plugins/A_memorix/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/plugins/A_memorix/.gitignore b/plugins/A_memorix/.gitignore new file mode 100644 index 00000000..bb349827 --- /dev/null +++ b/plugins/A_memorix/.gitignore @@ -0,0 +1,245 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore + +# Python +__pycache__/ +*.pyc +*.pyo +*.pyd +*.egg-info/ + +# Data & Storage (Privacy & Runtime) +data/ +logs/ + +# Deprecated / Cleanup (Avoid uploading junk) +deprecated/ + +# OS / System +.DS_Store +Thumbs.db +ehthumbs.db + +# IDE settings +.idea/ +.vscode/ + +# Temporary Verification Scripts +verify_*.py +config.toml + +# Test Artifacts & Generated Files +MagicMock/ +benchmark_output.txt +e2e_debug.log +e2e_error.log +full_diff.txt + +# Large Test Data Files +机娘导论-openie.json +scripts/机娘导论-openie.json + +# A_memorix recall/tuning generated artifacts +artifacts/ +scripts/run_arc_light_recall_pipeline.py + +# Compressed Data Archives +data.zip +scripts/full_feature_smoke_test.py +ACL2026_DEMO_EVAL.md +.probe_write +tests/ +temp_verify_v5_data/metadata/metadata.db +sql2/t.db +sql2/t.db-journal +scripts/test.json +scripts/test1.json +scripts/test-sample.json +USAGE_ARCHITECTURE.md +scripts/test_conversion.py +scripts/debug_graph_vis.py +/.tmp_feature_e2e_real +/.tmp_sparse_tests +/.tmp_test_probe +/.tmp_test_sqlite +/.tmp_testdata +/scripts/tmp diff --git a/scripts/run.sh b/scripts/run.sh index 7862e2ab..954d2392 100755 --- a/scripts/run.sh +++ b/scripts/run.sh @@ -830,6 +830,14 @@ run_installation() { echo -e "${RED}克隆MaiCore仓库失败!${RESET}" exit 1 } + echo -e "${GREEN}初始化MaiCore子模块...${RESET}" + # 使用与主仓一致的 GitHub 加速前缀,避免子模块直连 github.com 失败 + git -C MaiBot config submodule.plugins/A_memorix.url "$GITHUB_REPO/A-Dawn/A_memorix.git" + git -C MaiBot submodule sync --recursive + git -C MaiBot submodule update --init --recursive || { + echo -e "${RED}初始化MaiCore子模块失败!${RESET}" + exit 1 + } echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}" git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || { From 308448171c36001a32a64f91ecad6e2ca7ce9d79 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:02:08 +0800 Subject: [PATCH 06/14] fix: convert A_memorix tree entries to submodule gitlink --- plugins/A_memorix | 1 + plugins/A_memorix/.gitattributes | 2 - plugins/A_memorix/.gitignore | 245 - plugins/A_memorix/CHANGELOG.md | 718 -- plugins/A_memorix/CONFIG_REFERENCE.md | 292 - plugins/A_memorix/IMPORT_GUIDE.md | 335 - plugins/A_memorix/LICENSE | 661 -- plugins/A_memorix/LICENSE-MAIBOT-GPL.md | 22 - plugins/A_memorix/QUICK_START.md | 216 - plugins/A_memorix/README.md | 230 - plugins/A_memorix/__init__.py | 12 - plugins/A_memorix/_manifest.json | 107 - plugins/A_memorix/core/__init__.py | 84 - plugins/A_memorix/core/embedding/__init__.py | 18 - .../A_memorix/core/embedding/api_adapter.py | 368 -- plugins/A_memorix/core/embedding/manager.py | 510 -- plugins/A_memorix/core/embedding/presets.py | 72 - plugins/A_memorix/core/retrieval/__init__.py | 54 - plugins/A_memorix/core/retrieval/dual_path.py | 1796 ----- .../core/retrieval/graph_relation_recall.py | 272 - plugins/A_memorix/core/retrieval/pagerank.py | 482 -- .../A_memorix/core/retrieval/sparse_bm25.py | 401 -- plugins/A_memorix/core/retrieval/threshold.py | 450 -- plugins/A_memorix/core/runtime/__init__.py | 16 - .../core/runtime/lifecycle_orchestrator.py | 268 - .../core/runtime/sdk_memory_kernel.py | 3162 --------- .../runtime/search_runtime_initializer.py | 240 - plugins/A_memorix/core/storage/__init__.py | 53 - plugins/A_memorix/core/storage/graph_store.py | 1448 ----- .../A_memorix/core/storage/knowledge_types.py | 183 - .../A_memorix/core/storage/metadata_store.py | 5748 ----------------- .../A_memorix/core/storage/type_detection.py | 137 - .../A_memorix/core/storage/vector_store.py | 776 --- plugins/A_memorix/core/strategies/__init__.py | 0 plugins/A_memorix/core/strategies/base.py | 89 - plugins/A_memorix/core/strategies/factual.py | 98 - .../A_memorix/core/strategies/narrative.py | 126 - plugins/A_memorix/core/strategies/quote.py | 52 - plugins/A_memorix/core/utils/__init__.py | 33 - .../core/utils/aggregate_query_service.py | 360 -- .../core/utils/episode_retrieval_service.py | 182 - .../utils/episode_segmentation_service.py | 304 - .../A_memorix/core/utils/episode_service.py | 558 -- plugins/A_memorix/core/utils/hash.py | 129 - .../A_memorix/core/utils/import_payloads.py | 110 - plugins/A_memorix/core/utils/io.py | 84 - plugins/A_memorix/core/utils/matcher.py | 89 - plugins/A_memorix/core/utils/monitor.py | 189 - .../core/utils/path_fallback_service.py | 165 - .../core/utils/person_profile_service.py | 554 -- .../A_memorix/core/utils/plugin_id_policy.py | 27 - plugins/A_memorix/core/utils/quantization.py | 344 - .../A_memorix/core/utils/relation_query.py | 121 - .../core/utils/relation_write_service.py | 166 - .../core/utils/retrieval_tuning_manager.py | 1857 ------ .../core/utils/runtime_self_check.py | 218 - .../core/utils/search_execution_service.py | 439 -- .../core/utils/search_postprocess.py | 90 - .../A_memorix/core/utils/summary_importer.py | 425 -- plugins/A_memorix/core/utils/time_parser.py | 170 - .../core/utils/web_import_manager.py | 3522 ---------- plugins/A_memorix/plugin.py | 273 - plugins/A_memorix/requirements.txt | 52 - .../scripts/audit_vector_consistency.py | 213 - .../scripts/backfill_relation_vectors.py | 270 - .../scripts/backfill_temporal_metadata.py | 73 - plugins/A_memorix/scripts/convert_lpmm.py | 540 -- plugins/A_memorix/scripts/import_lpmm_json.py | 172 - .../A_memorix/scripts/migrate_chat_history.py | 110 - .../scripts/migrate_maibot_memory.py | 1714 ----- .../scripts/migrate_person_memory_points.py | 120 - .../A_memorix/scripts/process_knowledge.py | 728 --- plugins/A_memorix/scripts/rebuild_episodes.py | 127 - .../scripts/release_vnext_migrate.py | 731 --- .../A_memorix/scripts/runtime_self_check.py | 152 - 75 files changed, 1 insertion(+), 35154 deletions(-) create mode 160000 plugins/A_memorix delete mode 100644 plugins/A_memorix/.gitattributes delete mode 100644 plugins/A_memorix/.gitignore delete mode 100644 plugins/A_memorix/CHANGELOG.md delete mode 100644 plugins/A_memorix/CONFIG_REFERENCE.md delete mode 100644 plugins/A_memorix/IMPORT_GUIDE.md delete mode 100644 plugins/A_memorix/LICENSE delete mode 100644 plugins/A_memorix/LICENSE-MAIBOT-GPL.md delete mode 100644 plugins/A_memorix/QUICK_START.md delete mode 100644 plugins/A_memorix/README.md delete mode 100644 plugins/A_memorix/__init__.py delete mode 100644 plugins/A_memorix/_manifest.json delete mode 100644 plugins/A_memorix/core/__init__.py delete mode 100644 plugins/A_memorix/core/embedding/__init__.py delete mode 100644 plugins/A_memorix/core/embedding/api_adapter.py delete mode 100644 plugins/A_memorix/core/embedding/manager.py delete mode 100644 plugins/A_memorix/core/embedding/presets.py delete mode 100644 plugins/A_memorix/core/retrieval/__init__.py delete mode 100644 plugins/A_memorix/core/retrieval/dual_path.py delete mode 100644 plugins/A_memorix/core/retrieval/graph_relation_recall.py delete mode 100644 plugins/A_memorix/core/retrieval/pagerank.py delete mode 100644 plugins/A_memorix/core/retrieval/sparse_bm25.py delete mode 100644 plugins/A_memorix/core/retrieval/threshold.py delete mode 100644 plugins/A_memorix/core/runtime/__init__.py delete mode 100644 plugins/A_memorix/core/runtime/lifecycle_orchestrator.py delete mode 100644 plugins/A_memorix/core/runtime/sdk_memory_kernel.py delete mode 100644 plugins/A_memorix/core/runtime/search_runtime_initializer.py delete mode 100644 plugins/A_memorix/core/storage/__init__.py delete mode 100644 plugins/A_memorix/core/storage/graph_store.py delete mode 100644 plugins/A_memorix/core/storage/knowledge_types.py delete mode 100644 plugins/A_memorix/core/storage/metadata_store.py delete mode 100644 plugins/A_memorix/core/storage/type_detection.py delete mode 100644 plugins/A_memorix/core/storage/vector_store.py delete mode 100644 plugins/A_memorix/core/strategies/__init__.py delete mode 100644 plugins/A_memorix/core/strategies/base.py delete mode 100644 plugins/A_memorix/core/strategies/factual.py delete mode 100644 plugins/A_memorix/core/strategies/narrative.py delete mode 100644 plugins/A_memorix/core/strategies/quote.py delete mode 100644 plugins/A_memorix/core/utils/__init__.py delete mode 100644 plugins/A_memorix/core/utils/aggregate_query_service.py delete mode 100644 plugins/A_memorix/core/utils/episode_retrieval_service.py delete mode 100644 plugins/A_memorix/core/utils/episode_segmentation_service.py delete mode 100644 plugins/A_memorix/core/utils/episode_service.py delete mode 100644 plugins/A_memorix/core/utils/hash.py delete mode 100644 plugins/A_memorix/core/utils/import_payloads.py delete mode 100644 plugins/A_memorix/core/utils/io.py delete mode 100644 plugins/A_memorix/core/utils/matcher.py delete mode 100644 plugins/A_memorix/core/utils/monitor.py delete mode 100644 plugins/A_memorix/core/utils/path_fallback_service.py delete mode 100644 plugins/A_memorix/core/utils/person_profile_service.py delete mode 100644 plugins/A_memorix/core/utils/plugin_id_policy.py delete mode 100644 plugins/A_memorix/core/utils/quantization.py delete mode 100644 plugins/A_memorix/core/utils/relation_query.py delete mode 100644 plugins/A_memorix/core/utils/relation_write_service.py delete mode 100644 plugins/A_memorix/core/utils/retrieval_tuning_manager.py delete mode 100644 plugins/A_memorix/core/utils/runtime_self_check.py delete mode 100644 plugins/A_memorix/core/utils/search_execution_service.py delete mode 100644 plugins/A_memorix/core/utils/search_postprocess.py delete mode 100644 plugins/A_memorix/core/utils/summary_importer.py delete mode 100644 plugins/A_memorix/core/utils/time_parser.py delete mode 100644 plugins/A_memorix/core/utils/web_import_manager.py delete mode 100644 plugins/A_memorix/plugin.py delete mode 100644 plugins/A_memorix/requirements.txt delete mode 100644 plugins/A_memorix/scripts/audit_vector_consistency.py delete mode 100644 plugins/A_memorix/scripts/backfill_relation_vectors.py delete mode 100644 plugins/A_memorix/scripts/backfill_temporal_metadata.py delete mode 100644 plugins/A_memorix/scripts/convert_lpmm.py delete mode 100644 plugins/A_memorix/scripts/import_lpmm_json.py delete mode 100644 plugins/A_memorix/scripts/migrate_chat_history.py delete mode 100644 plugins/A_memorix/scripts/migrate_maibot_memory.py delete mode 100644 plugins/A_memorix/scripts/migrate_person_memory_points.py delete mode 100644 plugins/A_memorix/scripts/process_knowledge.py delete mode 100644 plugins/A_memorix/scripts/rebuild_episodes.py delete mode 100644 plugins/A_memorix/scripts/release_vnext_migrate.py delete mode 100644 plugins/A_memorix/scripts/runtime_self_check.py diff --git a/plugins/A_memorix b/plugins/A_memorix new file mode 160000 index 00000000..5fc5026a --- /dev/null +++ b/plugins/A_memorix @@ -0,0 +1 @@ +Subproject commit 5fc5026a540c1cfd55a7b824b43aaeef867e3228 diff --git a/plugins/A_memorix/.gitattributes b/plugins/A_memorix/.gitattributes deleted file mode 100644 index dfe07704..00000000 --- a/plugins/A_memorix/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -# Auto detect text files and perform LF normalization -* text=auto diff --git a/plugins/A_memorix/.gitignore b/plugins/A_memorix/.gitignore deleted file mode 100644 index bb349827..00000000 --- a/plugins/A_memorix/.gitignore +++ /dev/null @@ -1,245 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - -# Cursor -# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to -# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data -# refer to https://docs.cursor.com/context/ignore-files -.cursorignore -.cursorindexingignore - -# Python -__pycache__/ -*.pyc -*.pyo -*.pyd -*.egg-info/ - -# Data & Storage (Privacy & Runtime) -data/ -logs/ - -# Deprecated / Cleanup (Avoid uploading junk) -deprecated/ - -# OS / System -.DS_Store -Thumbs.db -ehthumbs.db - -# IDE settings -.idea/ -.vscode/ - -# Temporary Verification Scripts -verify_*.py -config.toml - -# Test Artifacts & Generated Files -MagicMock/ -benchmark_output.txt -e2e_debug.log -e2e_error.log -full_diff.txt - -# Large Test Data Files -机娘导论-openie.json -scripts/机娘导论-openie.json - -# A_memorix recall/tuning generated artifacts -artifacts/ -scripts/run_arc_light_recall_pipeline.py - -# Compressed Data Archives -data.zip -scripts/full_feature_smoke_test.py -ACL2026_DEMO_EVAL.md -.probe_write -tests/ -temp_verify_v5_data/metadata/metadata.db -sql2/t.db -sql2/t.db-journal -scripts/test.json -scripts/test1.json -scripts/test-sample.json -USAGE_ARCHITECTURE.md -scripts/test_conversion.py -scripts/debug_graph_vis.py -/.tmp_feature_e2e_real -/.tmp_sparse_tests -/.tmp_test_probe -/.tmp_test_sqlite -/.tmp_testdata -/scripts/tmp diff --git a/plugins/A_memorix/CHANGELOG.md b/plugins/A_memorix/CHANGELOG.md deleted file mode 100644 index 772cff46..00000000 --- a/plugins/A_memorix/CHANGELOG.md +++ /dev/null @@ -1,718 +0,0 @@ -# 更新日志 (Changelog) - -## [2.0.0] - 2026-03-18 - -本次 `2.0.0` 为架构收敛版本,主线是 **SDK Tool 接口统一**、**管理工具能力补齐**、**元数据 schema 升级到 v8** 与 **文档口径同步到 2.0.0**。 - -### 🔖 版本信息 - -- 插件版本:`1.0.1` → `2.0.0` -- 元数据 schema:`7` → `8` - -### 🚀 重点能力 - -- Tool 接口统一: - - `plugin.py` 统一通过 `SDKMemoryKernel` 对外提供 Tool 能力。 - - 保留基础工具:`search_memory / ingest_summary / ingest_text / get_person_profile / maintain_memory / memory_stats`。 - - 新增管理工具:`memory_graph_admin / memory_source_admin / memory_episode_admin / memory_profile_admin / memory_runtime_admin / memory_import_admin / memory_tuning_admin / memory_v5_admin / memory_delete_admin`。 -- 检索与写入治理增强: - - 检索/写入链路支持 `respect_filter + user_id/group_id` 的聊天过滤语义。 - - `maintain_memory` 支持 `freeze` 与 `recycle_bin`,并统一到内核维护流程。 -- 导入与调优能力收敛: - - `memory_import_admin` 提供任务化导入能力(上传、粘贴、扫描、OpenIE、LPMM 转换、时序回填、MaiBot 迁移)。 - - `memory_tuning_admin` 提供检索调优任务(创建、轮次查看、回滚、apply_best、报告导出)。 -- V5 与删除运维: - - 新增 `memory_v5_admin`(`reinforce/weaken/remember_forever/forget/restore/status`)。 - - 新增 `memory_delete_admin`(`preview/execute/restore/list/get/purge`),支持操作审计与恢复。 - -### 🛠️ 存储与运行时 - -- `metadata_store` 升级到 `SCHEMA_VERSION = 8`。 -- 新增/完善外部引用与运维记录能力(包括 `external_memory_refs`、`memory_v5_operations`、`delete_operations` 相关数据结构)。 -- `SDKMemoryKernel` 增加统一后台任务编排(自动保存、Episode pending 处理、画像刷新、记忆维护)。 - -### 📚 文档同步 - -- `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 已切换到 `2.0.0` 口径。 -- 文档主入口统一为 SDK Tool 工作流,不再以旧版 slash 命令作为主说明路径。 - -## [1.0.1] - 2026-03-07 - -本次 `1.0.1` 为 `1.0.0` 发布后的热修复版本,主线是 **图谱 WebUI 取数稳定性修复**、**大图过滤性能修复** 与 **真实检索调优链路稳定性修复**。 - -### 🔖 版本信息 - -- 插件版本:`1.0.0` → `1.0.1` -- 配置版本:`4.1.0`(不变) - -### 🛠️ 代码修复 - -- 图谱接口稳定性: - - 修复 `/api/graph` 在“磁盘已有图文件但运行时尚未装载入内存”场景下返回空图的问题,接口现在会自动补加载持久化图数据。 - - 修复问题数据集下 WebUI 打开图谱页时看似“没有任何节点”的现象;根因不是图数据消失,而是后端过滤路径过慢。 -- 图谱过滤性能: - - 优化 `/api/graph?exclude_leaf=true` 的叶子过滤逻辑,改为预计算 hub 邻接关系,不再对每个节点反复做高成本边权查询。 - - 优化 `GraphStore.get_neighbors()` 并补充入邻居访问能力,避免稠密矩阵展开导致的大图性能退化。 -- 检索调优稳定性: - - 修复真实调优任务在构建运行时配置时深拷贝 `plugin.config`,误复制注入的存储实例并触发 `cannot pickle '_thread.RLock' object` 的问题。 - - 调优评估改为跳过顶层运行时实例键,仅保留纯配置字段后再附加运行时依赖,真实 WebUI 调优任务可正常启动。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CHANGELOG.md`、`CONFIG_REFERENCE.md` 与版本元数据(`plugin.py`、`__init__.py`、`_manifest.json`)。 -- README 新增 `v1.0.1` 修复说明,并补充“调优前先做 runtime self-check”的建议。 - -## [1.0.0] - 2026-03-06 - -本次 `1.0.0` 为主版本升级,主线是 **运行时架构模块化**、**Episode 情景记忆闭环**、**聚合检索与图召回增强**、**离线迁移 / 运行时自检 / 检索调优中心**。 - -### 🔖 版本信息 - -- 插件版本:`0.7.0` → `1.0.0` -- 配置版本:`4.1.0`(不变) - -### 🚀 重点能力 - -- 运行时重构: - - `plugin.py` 大幅瘦身,生命周期、后台任务、请求路由、检索运行时初始化拆分到 `core/runtime/*`。 - - 配置 schema 抽离到 `core/config/plugin_config_schema.py`,`_manifest.json` 同步扩展新配置项。 -- 检索与查询增强: - - `KnowledgeQueryTool` 拆分为 query mode + orchestrator,新增长 `aggregate` / `episode` 查询模式。 - - 新增图辅助关系召回、统一 forward/runtime 构建与请求去重桥接。 -- Episode / 运维能力: - - `metadata_store` schema 升级到 `SCHEMA_VERSION = 7`,新增 `episodes` / `episode_paragraphs` / rebuild queue 等结构。 - - 新增 `release_vnext_migrate.py`、`runtime_self_check.py`、`rebuild_episodes.py` 与 Web 检索调优页 `web/tuning.html`。 - -### 📚 文档同步 - -- 版本号同步到 `plugin.py`、`__init__.py`、`_manifest.json`、`README.md` 与 `CONFIG_REFERENCE.md`。 -- 新增 `RELEASE_SUMMARY_1.0.0.md` - -## [0.7.0] - 2026-03-04 - -本次 `0.7.0` 为中版本升级,主线是 **关系向量化闭环(写入 + 状态机 + 回填 + 审计)**、**检索/命令链路增强** 与 **导入任务能力补齐**。 - -### 🔖 版本信息 - -- 插件版本:`0.6.1` → `0.7.0` -- 配置版本:`4.1.0`(不变) - -### 🚀 重点能力 - -- 关系向量化闭环: - - 新增统一关系写入服务 `RelationWriteService`(metadata 先写、向量后写,失败进入状态机而非回滚主数据)。 - - `relations` 侧补齐 `vector_state/retry_count/last_error/updated_at` 等状态字段,支持 `none/pending/ready/failed` 统一治理。 - - 插件新增后台回填循环与统计接口,可持续修复关系向量缺失并暴露覆盖率指标。 -- 检索与命令链路增强: - - 检索主链继续收敛到 `search/time` forward 路由,`legacy` 仅保留兼容别名。 - - relation 查询规格解析收口,结构化查询与语义回退边界更清晰。 - - `/query stats` 与 tool stats 补充关系向量化统计输出。 -- 导入与运维增强: - - Web Import 新增 `temporal_backfill` 任务入口与编排处理。 - - 新增一致性审计与离线回填脚本,支持灰度修复历史数据。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志版本信息。 -- `README.md` 新增关系向量审计/回填脚本使用说明,并更新 `convert_lpmm.py` 的关系向量重建行为描述。 - -## [0.6.1] - 2026-03-03 - -本次 `0.6.1` 为热修复小版本,重点修复 WebUI 插件配置接口在 A_Memorix 场景下的 `tomlkit` 节点序列化兼容问题。 - -### 🔖 版本信息 - -- 插件版本:`0.6.0` → `0.6.1` -- 配置版本:`4.1.0`(不变) - -### 🛠️ 代码修复 - -- 新增运行时补丁 `_patch_webui_a_memorix_routes_for_tomlkit_serialization()`: - - 仅包裹 `/api/webui/plugins/config/{plugin_id}` 及其 schema 的 `GET` 路由。 - - 仅在 `plugin_id == "A_Memorix"` 时,将返回中的 `config/schema` 通过 `to_builtin_data` 原生化。 - - 保持 `/api/webui/config/*` 全局接口行为不变,避免对其他插件或核心配置路径产生副作用。 -- 在插件初始化时执行该补丁,确保 WebUI 读取插件配置时返回结构可稳定序列化。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与本日志中的版本信息及修复说明。 - -## [0.6.0] - 2026-03-02 - -本次 `0.6.0` 为中版本升级,主线是 **Web Import 导入中心上线与脚本能力对齐**、**失败重试机制升级**、**删除后 manifest 同步** 与 **导入链路稳定性增强**。 - -### 🔖 版本信息 - -- 插件版本:`0.5.1` → `0.6.0` -- 配置版本:`4.0.1` → `4.1.0` - -### 🚀 重点能力 - -- 新增 Web Import 导入中心(`/import`): - - 上传/粘贴/本地扫描/LPMM OpenIE/LPMM 转换/时序回填/MaiBot 迁移。 - - 任务/文件/分块三级状态展示,支持取消与失败重试。 - - 导入文档弹窗读取(远程优先,失败回退本地)。 -- 失败重试升级为“分块优先 + 文件回退”: - - `POST /api/import/tasks/{task_id}/retry_failed` 保持原路径,语义升级。 - - 支持对 `extracting` 失败分块进行子集重试。 - - `writing`/JSON 解析失败自动回退为文件级重试。 -- 删除后 manifest 同步失效: - - 覆盖 `/api/source/batch_delete` 与 `/api/source`。 - - 返回 `manifest_cleanup` 明细,避免误命中去重跳过重导入。 - -### 📂 变更文件清单(本次发布) - -新增文件: - -- `core/utils/web_import_manager.py` -- `scripts/migrate_maibot_memory.py` -- `web/import.html` - -修改文件: - -- `CHANGELOG.md` -- `CONFIG_REFERENCE.md` -- `IMPORT_GUIDE.md` -- `QUICK_START.md` -- `README.md` -- `__init__.py` -- `_manifest.json` -- `components/commands/debug_server_command.py` -- `core/embedding/api_adapter.py` -- `core/storage/graph_store.py` -- `core/utils/summary_importer.py` -- `plugin.py` -- `requirements.txt` -- `server.py` -- `web/index.html` - -删除文件: - -- 无 - -### 📚 文档同步 - -- 同步更新 `README.md`、`QUICK_START.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md` 与本日志。 -- `IMPORT_GUIDE.md` 新增 “Web Import 导入中心” 专区,统一说明能力范围、状态语义与安全边界。 - -## [0.5.1] - 2026-02-23 - -本次 `0.5.1` 为热修订小版本,重点修复“随主程序启动的后台任务拉起”“空名单过滤语义”以及“知识抽取模型选择”。 - -### 🔖 版本信息 - -- 插件版本:`0.5.0` → `0.5.1` -- 配置版本:`4.0.0` → `4.0.1` - -### 🛠️ 代码修复 - -- 生命周期接入主程序事件: - - 新增 `a_memorix_start_handler`(`ON_START`)调用 `plugin.on_enable()`; - - 新增 `a_memorix_stop_handler`(`ON_STOP`)调用 `plugin.on_disable()`; - - 解决仅注册插件但未触发生命周期时,定时导入任务不启动的问题。 -- 聊天过滤空列表策略调整: - - `whitelist + []`:全部拒绝; - - `blacklist + []`:全部放行。 -- 知识抽取模型选择逻辑调整(`import_command._select_model`): - - `advanced.extraction_model` 现在支持三种语义:任务名 / 模型名 / `auto`; - - `auto` 优先抽取相关任务(`lpmm_entity_extract`、`lpmm_rdf_build` 等),并避免误落到 `embedding`; - - 当配置无法识别时输出告警并回退自动选择,提高导入阶段的模型选择可预期性。 - -### 📚 文档同步 - -- 同步更新 `README.md`、`CONFIG_REFERENCE.md` 与 `CHANGELOG.md`。 -- 同步修正文档中的空名单过滤行为描述,保持与当前代码一致。 - -## [0.5.0] - 2026-02-15 - -本次 `0.5.0` 以提交 `66ddc1b98547df3c866b19a3f5dc96e1c8eb7731` 为核心,主线是“人物画像能力上线 + 工具/命令接入 + 版本与文档同步”。 - -### 🔖 版本信息 - -- 插件版本:`0.4.0` → `0.5.0` -- 配置版本:`3.1.0` → `4.0.0` - -### 🚀 人物画像主特性(核心) - -- 新增人物画像服务:`core/utils/person_profile_service.py` - - 支持 `person_id/姓名/别名` 解析。 - - 聚合图关系证据 + 向量证据,生成画像文本并版本化快照。 - - 支持手工覆盖(override)与 TTL 快照复用。 -- 存储层新增人物画像相关表与 API:`core/storage/metadata_store.py` - - `person_profile_switches` - - `person_profile_snapshots` - - `person_profile_active_persons` - - `person_profile_overrides` -- 新增命令:`/person_profile on|off|status` - - 文件:`components/commands/person_profile_command.py` - - 作用:按 `stream_id + user_id` 控制自动注入开关(opt-in 模式)。 -- 查询链路接入人物画像: - - `knowledge_query_tool` 新增 `query_type=person`,支持 `person_id` 或别名查询。 - - `/query person` 与 `/query p` 接入画像查询输出。 -- 插件生命周期接入画像刷新任务: - - 启动/停止统一管理 `person_profile_refresh` 后台任务。 - - 按活跃窗口自动刷新画像快照。 - -### 🛠️ 版本与 schema 同步 - -- `plugin.py`:`plugin_version` 更新为 `0.5.0`。 -- `plugin.py`:`plugin.config_version` 默认值更新为 `4.0.0`。 -- `config.toml`:`config_version` 基线同步为 `4.0.0`(本地配置文件)。 -- `__init__.py`:`__version__` 更新为 `0.5.0`。 -- `_manifest.json`:`version` 更新为 `0.5.0`,`manifest_version` 保持 `1` 。 -- `manifest_utils.py`:仓库内已兼容更高 manifest 版本;但插件发布默认保持 `manifest_version=1` 。 - -### 📚 文档同步 - -- 更新 `README.md`、`CONFIG_REFERENCE.md`、`QUICK_START.md`、`USAGE_ARCHITECTURE.md`。 -- 0.5.0 文档主线改为“人物画像能力 + 版本升级 + 检索链路补充说明”。 - -## [0.4.0] - 2026-02-13 - -本次 `0.4.0` 版本整合了时序检索增强与后续检索链路增强、稳定性修复和文档同步。 - -### 🔖 版本信息 - -- 插件版本:`0.3.3` → `0.4.0` -- 配置版本:`3.0.0` → `3.1.0` - -### 🚀 新增 - -- 新增 `core/retrieval/sparse_bm25.py` - - `SparseBM25Config` / `SparseBM25Index` - - FTS5 + BM25 稀疏检索 - - 支持 `jieba/mixed/char_2gram` 分词与懒加载 - - 支持 ngram 倒排回退与可选 LIKE 兜底 -- `DualPathRetriever` 新增 sparse/fusion 配置注入: - - embedding 不可用时自动 sparse 回退; - - `hybrid` 模式支持向量路 + sparse 路并行候选; - - 新增 `FusionConfig` 与 `weighted_rrf` 融合。 -- `MetadataStore` 新增 FTS/倒排能力: - - `paragraphs_fts`、`relations_fts` schema 与回填; - - `paragraph_ngrams` 倒排索引与回填; - - `fts_search_bm25` / `fts_search_relations_bm25` / `ngram_search_paragraphs`。 - -### 🛠️ 组件链路同步 - -- `plugin.py` - - 新增 `[retrieval.sparse]`、`[retrieval.fusion]` 默认配置; - - 初始化并向组件注入 `sparse_index`; - - `on_disable` 支持按配置卸载 sparse 连接并释放缓存。 -- `knowledge_search_action.py` / `query_command.py` / `knowledge_query_tool.py` - - 统一接入 sparse/fusion 配置; - - 统一注入 `sparse_index`; - - `stats` 输出新增 sparse 状态观测。 -- `requirements.txt` - - 新增 `jieba>=0.42.1`(未安装时自动回退 char n-gram)。 - -### 🧯 修复与行为调整 - -- 修复 `retrieval.ppr_concurrency_limit` 不生效问题: - - `DualPathRetriever` 使用配置值初始化 `_ppr_semaphore`,不再被固定值覆盖。 -- 修复 `char_2gram` 召回失效场景: - - FTS miss 时增加 `_fallback_substring_search`,优先 ngram 倒排回退,按配置可选 LIKE 兜底。 -- 提升可观测性与兼容性: - - `get_statistics()` 对向量规模字段兼容读取 `size -> num_vectors -> 0`,避免属性缺失导致异常。 - - `/query stats` 与 `knowledge_query` 输出包含 sparse 状态(enabled/loaded/tokenizer/doc_count)。 - -### 📚 文档 - -- `README.md` - - 新增检索增强说明、稀疏行为说明、时序回填脚本入口。 -- `CONFIG_REFERENCE.md` - - 补齐 sparse/fusion 参数与触发规则、回退链路、融合实现细节。 - -### ⏱️ 时序检索与导入增强 - -#### 时序检索能力(分钟级) - -- 新增统一时序查询入口: - - `/query time`(别名 `/query t`) - - `knowledge_query(query_type=time)` - - `knowledge_search(query_type=time|hybrid)` -- 查询时间参数统一支持: - - `YYYY/MM/DD` - - `YYYY/MM/DD HH:mm` -- 日期参数自动展开边界: - - `from/time_from` -> `00:00` - - `to/time_to` -> `23:59` -- 查询结果统一回传 `metadata.time_meta`,包含命中时间窗口与命中依据(事件时间或 `created_at` 回退)。 - -#### 存储与检索链路 - -- 段落存储层支持时序字段: - - `event_time` - - `event_time_start` - - `event_time_end` - - `time_granularity` - - `time_confidence` -- 时序命中采用区间相交逻辑,并遵循“双层时间语义”: - - 优先 `event_time/event_time_range` - - 缺失时回退 `created_at`(可配置关闭) -- 检索排序规则保持:语义优先,时间次排序(新到旧)。 -- `process_knowledge.py` 新增 `--chat-log` 参数: - - 启用后强制使用 `narrative` 策略; - - 使用 LLM 对聊天文本进行语义时间抽取(支持相对时间转绝对时间),写入 `event_time/event_time_start/event_time_end`。 - - 新增 `--chat-reference-time`,用于指定相对时间语义解析的参考时间点。 - -#### Schema 与文档同步 - -- `_manifest.json` 同步补齐 `retrieval.temporal` 配置 schema。 -- 配置 schema 版本升级:`config_version` 从 `3.0.0` 提升到 `3.1.0`(`plugin.py` / `config.toml` / 配置文档同步)。 -- 更新 `README.md`、`CONFIG_REFERENCE.md`、`IMPORT_GUIDE.md`,补充时序检索入口、参数格式与导入时间字段说明。 - -## [0.3.3] - 2026-02-11 - -本次更新为 **语言一致性补丁版本**,重点收敛知识抽取时的语言漂移问题,要求输出严格贴合原文语言,不做翻译改写。 - -### 🛠️ 关键修复 - -#### 抽取语言约束 - -- `BaseStrategy`: - - 移除按 `zh/en/mixed` 分支的语言类型判定逻辑; - - 统一为单一约束:抽取值保持原文语言、保留原始术语、禁止翻译。 -- `NarrativeStrategy` / `FactualStrategy`: - - 抽取提示词统一接入上述语言约束; - - 明确要求 JSON 键名固定、抽取值遵循原文语言表达。 - -#### 导入链路一致性 - -- `ImportCommand` 的 LLM 抽取提示词同步强化“优先原文语言、不要翻译”要求,避免脚本与指令导入行为不一致。 - -#### 测试与文档 - -- 更新 `test_strategies.py`,将语言判定测试调整为统一语言约束测试,并验证提示词中包含禁止翻译约束。 -- 同步更新注释与文档描述,确保实现与说明一致。 - -### 🔖 版本信息 - -- 插件版本:`0.3.2` → `0.3.3` - -## [0.3.2] - 2026-02-11 - -本次更新为 **V5 稳定性与兼容性修复版本**,在保持原有业务设计(强化→衰减→冷冻→修剪→回收)的前提下,修复关键链路断裂与误判问题。 - -### 🛠️ 关键修复 - -#### V5 记忆系统契约与链路 - -- `MetadataStore`: - - 统一 `mark_relations_inactive(hashes, inactive_since=None)` 调用契约,兼容不同调用方; - - 补充 `has_table(table_name)`; - - 增加 `restore_relation(hash)` 兼容别名,修复服务层恢复调用断裂; - - 修正 `get_entity_gc_candidates` 对孤立节点参数的处理(支持节点名映射到实体 hash)。 -- `GraphStore`: - - 清理 `deactivate_edges` 重复定义并统一返回冻结数量,保证上层日志与断言稳定。 -- `server.py`: - - 修复 `/api/memory/restore` relation 恢复链路; - - 清理不可达分支并统一异常路径; - - 回收站查询在表检测场景下不再出现错误退空。 - -#### 命令与模型选择 - -- `/memory` 命令修复 hash 长度判定:以 64 位 `sha256` 为标准,同时兼容历史 32 位输入。 -- 总结模型选择修复: - - 解决 `summarization.model_name = auto` 误命中 `embedding` 问题; - - 支持数组与选择器语法(`task:model` / task / model); - - 兼容逗号分隔字符串写法(如 `"utils:model1","utils:model2",replyer`)。 - -#### 生命周期与脚本稳定性 - -- `plugin.py` 修复后台任务生命周期管理: - - 增加 `_scheduled_import_task` / `_auto_save_task` / `_memory_maintenance_task` 句柄; - - 避免重复启动; - - 插件停用时统一 cancel + await 收敛。 -- `process_knowledge.py` 修复 tenacity 重试日志级别类型错误(`"WARNING"` → `logging.WARNING`),避免 `KeyError: 'WARNING'`。 - -### 🔖 版本信息 - -- 插件版本:`0.3.1` → `0.3.2` - -## [0.3.1] - 2026-02-07 - -本次更新为 **稳定性补丁版本**,主要修复脚本导入链路、删除安全性与 LPMM 转换一致性问题。 - -### 🛠️ 关键修复 - -#### 新增功能 - -- 新增 `scripts/convert_lpmm.py`: - - 支持将 LPMM 的 `parquet + graph` 数据直接转换为 A_Memorix 存储结构; - - 提供 LPMM ID 到 A_Memorix ID 的映射能力,用于图节点/边重写; - - 当前实现优先保证检索一致性,关系向量采用安全策略(不直接导入)。 - -#### 导入链路 - -- 修复 `import_lpmm_json.py` 依赖的 `AutoImporter.import_json_data` 公共入口缺失/不稳定问题,确保外部脚本可稳定调用 JSON 直导入流程。 - -#### 删除安全 - -- 修复按来源删除时“同一 `(subject, object)` 存在多关系”场景下的误删风险: - - `MetadataStore.delete_paragraph_atomic` 新增 `relation_prune_ops`; - - 仅在无兄弟关系时才回退删除整条边。 -- `delete_knowledge.py` 新增保守孤儿实体清理(仅对本次候选实体执行,且需同时满足无段落引用、无关系引用、图无邻居)。 -- `delete_knowledge.py` 改为读取向量元数据中的真实维度,避免 `dimension=1` 写回污染。 - -#### LPMM 转换修复 - -- 修复 `convert_lpmm.py` 中向量 ID 与 `MetadataStore` 哈希不一致导致的检索反查失败问题。 -- 为避免脏召回,转换阶段暂时跳过 `relation.parquet` 的直接向量导入(待关系元数据一一映射能力完善后再恢复)。 - -### 🔖 版本信息 - -- 插件版本:`0.3.0` → `0.3.1` - -## [0.3.0] - 2026-01-30 - -本次更新引入了 **V5 动态记忆系统**,实现了符合生物学特性的记忆衰减、强化与全声明周期管理,并提供了配套的指令与工具。 - -### 🧠 记忆系统 (V5) - -#### 核心机制 - -- **记忆衰减 (Decay)**: 引入"遗忘曲线",随时间推移自动降低图谱连接权重。 -- **访问强化 (Reinforcement)**: "越用越强",每次检索命中都会刷新记忆活跃度并增强权重。 -- **生命周期 (Lifecycle)**: - - **活跃 (Active)**: 正常参与计算与检索。 - - **冷冻 (Inactive)**: 权重过低被冻结,不再参与 PPR 计算,但保留语义映射 (Mapping)。 - - **修剪 (Prune)**: 过期且无保护的冷冻记忆将被移入回收站。 -- **多重保护**: 支持 **永久锁定 (Pin)** 与 **限时保护 (TTL)**,防止关键记忆被误删。 - -#### GraphStore - -- **多关系映射**: 实现 `(u,v) -> Set[Hash]` 映射,确保同一通道下的多重语义关系互不干扰。 -- **原子化操作**: 新增 `decay`, `deactivate_edges` (软删), `prune_relation_hashes` (硬删) 等原子操作。 - -### 🛠️ 指令与工具 - -#### Memory Command (`/memory`) - -新增全套记忆维护指令: - -- `/memory status`: 查看记忆系统健康状态(活跃/冷冻/回收站计数)。 -- `/memory protect [hours]`: 保护记忆。不填时间为永久锁定(Pin),填时间为临时保护(TTL)。 -- `/memory reinforce `: 手动强化记忆(绕过冷却时间)。 -- `/memory restore `: 从回收站恢复误删记忆(仅当节点存在时重建连接)。 - -#### MemoryModifierTool - -- **LLM 能力增强**: 更新工具逻辑,支持 LLM 自主触发 `reinforce`, `weaken`, `remember_forever`, `forget` 操作,并自动映射到 V5 底层逻辑。 - -### ⚙️ 配置 (`config.toml`) - -新增 `[memory]` 配置节: - -- `half_life_hours`: 记忆半衰期 (默认 24h)。 -- `enable_auto_reinforce`: 是否开启检索自动强化。 -- `prune_threshold`: 冷冻/修剪阈值 (默认 0.1)。 - -### 💻 WebUI (v1.4) - -实现了与 V5 记忆系统深度集成的全生命周期管理界面: - -- **可视化增强**: - - **冷冻状态**: 非活跃记忆以 **虚线 + 灰色 (Slate-300)** 显示。 - - **保护状态**: 被 Pin 或保护的记忆带有 **金色 (Amber) 光晕**。 -- **交互升级**: - - **记忆回收站**: 新增 Dock 入口与专用面板,支持浏览删除记录并一键恢复。 - - **快捷操作**: 边属性面板新增 **强化 (Reinforce)**、**保护 (Protect/Pin)**、**冷冻 (Freeze)** 按钮。 - - **实时反馈**: 操作后自动刷新图谱布局与样式。 - ---- - -## [0.2.3] - 2026-01-30 - -本次更新主要集中在 **WebUI 交互体验优化** 与 **文档/配置的规范化**。 - -### 🎨 WebUI (v1.3) - -#### 加载与同步体验升级 - -- **沉浸式加载**: 全新设计的加载遮罩,采用磨砂玻璃背景 (`backdrop-filter`) 与呼吸灯文字动效,提升视觉质感。 -- **精准状态反馈**: 优化加载逻辑,明确区分“网络同步”与“拓扑计算”阶段,解决数据加载时的闪烁问题。 -- **新手引导**: 在加载界面新增基础操作提示,降低新用户上手门槛。 - -#### 全功能帮助面板 - -- **操作指南重构**: 全面翻新“操作指南”面板,新增 Dock 栏功能详解、编辑管理操作及视图配置说明。 - -### 🛠️ 工程与规范 - -#### plugin.py - -- **配置描述补全**: 修复了 `config_section_descriptions` 中缺失 `summarization`, `schedule`, `filter` 节导致的问题。 -- **版本号**: `0.2.2` → `0.2.3` - -### ⚙️ 核心与服务 - -#### Core - -- **量化逻辑修正**: 修正了 `_scalar_quantize_int8` 函数,确保向量值正确映射到 `[-128, 127]` 区间,提高量化精度。 - -#### Server - -- **缓存一致性**: 在执行删除节点/边等修改操作后,显式清除 `_relation_cache`,确保前端获取的关系数据实时更新。 - -### 🤖 脚本与数据处理 - -#### process_knowledge.py - -- **策略模式重构**: 引入了 `Strategy-Aware` 架构,支持通过 `Narrative` (叙事), `Factual` (事实), `Quote` (引用) 三种策略差异化处理文本(准确说是确认实装)(默认采用 Narrative模式)。 -- **智能分块纠错**: 新增“分块拯救” (`Chunk Rescue`) 机制,可在长叙事文本中自动识别并提取内嵌的歌词或诗句。 - -#### import_lpmm_json.py - -- **LPMM 迁移工具**: 增加了对 LPMM OpenIE JSON 格式的完整支持,能够自动计算 Hash 并迁移实体/关系数据,确保与 A_Memorix 存储格式兼容。 - -#### Project - -- **构建清理**: 优化 `.gitignore` 规则 - ---- - -## [0.2.2] - 2026-01-27 - -本次更新专注于提高 **网络请求的鲁棒性**,特别是针对嵌入服务的调用。 - -### 🛠️ 稳定性与工程改进 - -#### EmbeddingAPI - -- **可配置重试机制**: 新增 `[embedding.retry]` 配置项,允许自定义最大重试次数和等待时间。默认重试次数从 3 次增加到 10 次,以更好应对网络波动。 -- **配置项**: - - `max_attempts`: 最大重试次数 (默认: 10) - - `max_wait_seconds`: 最大等待时间 (默认: 30s) - - `min_wait_seconds`: 最小等待时间 (默认: 2s) - -#### plugin.py - -- **版本号**: `0.2.1` → `0.2.2` - ---- - -## [0.2.1] - 2026-01-26 - -本次更新重点在于 **可视化交互的全方位重构** 以及 **底层鲁棒性的进一步增强**。 - -### 🎨 可视化与交互重构 - -#### WebUI (Glassmorphism) - -- **全新视觉设计**: 采用深色磨砂玻璃 (Glassmorphism) 风格,配合动态渐变背景。 -- **Dock 菜单栏**: 底部新增 macOS 风格 Dock 栏,聚合所有常用功能。 -- **显著性视图 (Saliency View)**: 基于 **PageRank** 算法的“信息密度”滑块,支持以此过滤叶子节点,仅展示核心骨干或全量细节。 -- **功能面板**: - - **❓ 操作指南**: 内置交互说明与特性介绍。 - - **🔍 悬浮搜索**: 支持按拼音/ID 实时过滤节点。 - - **📂 记忆溯源**: 支持按源文件批量查看和删除记忆数据。 - - **📖 内容字典**: 列表化展示所有实体与关系,支持排序与筛选。 - -### 🛠️ 稳定性与工程改进 - -#### EmbeddingAPI - -- **鲁棒性增强**: 引入 `tenacity` 实现指数退避重试机制。 -- **错误处理**: 失败时返回 `NaN` 向量而非零向量,允许上层逻辑安全跳过。 - -#### MetadataStore - -- **自动修复**: 自动检测并修复 `vector_index` 列错位(文件名误存)的历史数据问题。 -- **数据统计**: 新增 `get_all_sources` 接口支持来源统计。 - -#### 脚本与工具 - -- **用户体验**: 引入 `rich` 库优化终端输出进度条与状态显示。 -- **接口开放**: `process_knowledge.py` 新增 `import_json_data` 供外部调用。 -- **LPMM 迁移**: 新增 `import_lpmm_json.py`,支持导入符合 LPMM 规范的 OpenIE JSON 数据。 - -#### plugin.py - -- **版本号**: `0.2.0` → `0.2.1` - ---- - -## [0.2.0] - 2026-01-22 - -> [!CAUTION] -> **不完全兼容变更**:v0.2.0 版本重构了底层存储架构。由于数据结构的重大调整,**旧版本的导入数据无法在新版本中完全无损兼容**。 -> 虽然部分组件支持自动迁移,但为确保数据一致性和检索质量,**强烈建议在升级后重新使用 `process_knowledge.py` 导入原始数据**。 - -本次更新为**重大版本升级**,包含向量存储架构重写、检索逻辑强化及多项稳定性改进。 - -### 🚀 核心架构重写 - -#### VectorStore: SQ8 量化 + Append-Only 存储 - -- **全新存储格式**: 从 `.npy` 迁移至 `vectors.bin`(float16 增量追加)和 `vectors_ids.bin`,大幅减少内存占用。 -- **原生 SQ8 量化**: 使用 Faiss `IndexScalarQuantizer(QT_8bit)`,替代手动 int8 量化逻辑。 -- **L2 Normalization 强制化**: 所有向量在存储和检索时统一执行 L2 归一化,确保 Inner Product 等价于 Cosine 相似度。 -- **Fallback 索引机制**: 新增 `IndexFlatIP` 回退索引,在 SQ8 训练完成前提供检索能力,避免冷启动无结果问题。 -- **Reservoir Sampling 训练采样**: 使用蓄水池采样收集训练数据(上限 10k),保证小数据集和流式导入场景下的训练样本多样性。 -- **线程安全**: 新增 `threading.RLock` 保护并发读写操作。 -- **自动迁移**: 支持从旧版 `.npy` 格式自动迁移至新 `.bin` 格式。 - -### ✨ 检索功能增强 - -#### KnowledgeQueryTool: 智能回退与多跳路径搜索 - -- **Smart Fallback (智能回退)**: 当向量检索置信度低于阈值 (默认 0.6) 时,自动尝试提取查询中的实体进行多跳路径搜索(`_path_search`),增强对间接关系的召回能力。 -- **结果去重 (`_deduplicate_results`)**: 新增基于内容相似度的安全去重逻辑,防止冗余结果污染 LLM 上下文,同时确保至少保留一条结果。 -- **语义关系检索 (`_semantic_search_relation`)**: 支持自然语言查询关系(无需 `S|P|O` 格式),内部使用 `REL_ONLY` 策略进行向量检索。 -- **路径搜索 (`_path_search`)**: 新增 `GraphStore.find_paths` 调用,支持查找两个实体间的间接连接路径(最大深度 3,最多 5 条路径)。 -- **Clean Output**: LLM 上下文中不再包含原始相似度分数,避免模型偏见。 - -#### DualPathRetriever: 并发控制与调试模式 - -- **PPR 并发限制 (`ppr_concurrency_limit`)**: 新增 Semaphore 控制 PageRank 计算并发数,防止 CPU 峰值过载。 -- **Debug 模式**: 新增 `debug` 配置项,启用时打印检索结果原文到日志。 -- **Entity-Pivot 关系检索**: 优化 `_retrieve_relations_only` 策略,通过检索实体后扩展其关联关系,替代直接检索关系向量。 - -### ⚙️ 配置与 Schema 扩展 - -#### plugin.py - -- **版本号**: `0.1.3` → `0.2.0` -- **默认配置版本**: `config_version` 默认值更新为 `2.0.0` -- **新增配置项**: - - `retrieval.relation_semantic_fallback` (bool): 是否启用关系查询的语义回退。 - - `retrieval.relation_fallback_min_score` (float): 语义回退的最小相似度阈值。 -- **相对路径支持**: `storage.data_dir` 现在支持相对路径(相对于插件目录),默认值改为 `./data`。 -- **全局实例获取**: 新增 `A_MemorixPlugin.get_global_instance()` 静态方法,供组件可靠获取插件实例。 - -#### config.toml / \_manifest.json - -- **新增 `ppr_concurrency_limit`**: 控制 PPR 算法并发数。 -- **新增训练阈值配置**: `embedding.min_train_threshold` 控制触发 SQ8 训练的最小样本数。 - -### 🛠️ 稳定性与工程改进 - -#### GraphStore - -- **`find_paths` 方法**: 新增多跳路径查找功能,支持 BFS 搜索指定深度内的实体间路径。 -- **`find_node` 方法**: 新增大小写不敏感的节点查找。 - -#### MetadataStore - -- **Schema 迁移**: 自动添加缺失的 `is_permanent`, `last_accessed`, `access_count` 字段。 - -#### 脚本与工具 - -- **新增脚本**: - - `scripts/diagnose_relations_source.py`: 诊断关系溯源问题。 - - `scripts/verify_search_robustness.py`: 验证检索鲁棒性。 - - `scripts/run_stress_test.py`, `stress_test_data.py`: 压力测试套件。 - - `scripts/migrate_canonicalization.py`, `migrate_paragraph_relations.py`: 数据迁移工具。 -- **目录整理**: 将大量旧版测试脚本移动至 `deprecated/` 目录。 - -### 🗑️ 移除与废弃 - -- 废弃 `vectors.npy` 存储格式(自动迁移至 `.bin`)。 - ---- - -## [0.1.3] - 上一个稳定版本 - -- 初始发布,包含基础双路检索功能。 -- 手动 Int8 向量量化。 -- 基于 `.npy` 的向量存储。 diff --git a/plugins/A_memorix/CONFIG_REFERENCE.md b/plugins/A_memorix/CONFIG_REFERENCE.md deleted file mode 100644 index ada8aec5..00000000 --- a/plugins/A_memorix/CONFIG_REFERENCE.md +++ /dev/null @@ -1,292 +0,0 @@ -# A_Memorix 配置参考 (v2.0.0) - -本文档对应当前仓库代码(`__version__ = 2.0.0`、`SCHEMA_VERSION = 8`)。 - -说明: - -- 本文只覆盖 **当前运行时实际读取** 的配置键。 -- 旧版 `/query`、`/memory`、`/visualize` 命令体系相关配置,不再作为主路径说明。 -- 未配置的键会回退到代码默认值。 - -## 最小可用配置 - -```toml -[plugin] -enabled = true - -[storage] -data_dir = "./data" - -[embedding] -model_name = "auto" -dimension = 1024 -batch_size = 32 -max_concurrent = 5 -enable_cache = false -quantization_type = "int8" - -[retrieval] -top_k_paragraphs = 20 -top_k_relations = 10 -top_k_final = 10 -alpha = 0.5 -enable_ppr = true -ppr_alpha = 0.85 -ppr_timeout_seconds = 1.5 -ppr_concurrency_limit = 4 -enable_parallel = true - -[retrieval.sparse] -enabled = true - -[episode] -enabled = true -generation_enabled = true -pending_batch_size = 20 -pending_max_retry = 3 - -[person_profile] -enabled = true - -[memory] -enabled = true -half_life_hours = 24.0 -prune_threshold = 0.1 - -[advanced] -enable_auto_save = true -auto_save_interval_minutes = 5 - -[web.import] -enabled = true - -[web.tuning] -enabled = true -``` - -## 1. 存储与嵌入 - -### `storage` - -- `storage.data_dir` (默认 `./data`) -: 数据目录。相对路径按插件目录解析。 - -### `embedding` - -- `embedding.model_name` (默认 `auto`) -: embedding 模型选择。 -- `embedding.dimension` (默认 `1024`) -: 期望维度(运行时会做真实探测并校验)。 -- `embedding.batch_size` (默认 `32`) -- `embedding.max_concurrent` (默认 `5`) -- `embedding.enable_cache` (默认 `false`) -- `embedding.retry` (默认 `{}`) -: embedding 调用重试策略。 -- `embedding.quantization_type` -: 当前主路径仅建议 `int8`。 - -## 2. 检索 - -### `retrieval` 主键 - -- `retrieval.top_k_paragraphs` (默认 `20`) -- `retrieval.top_k_relations` (默认 `10`) -- `retrieval.top_k_final` (默认 `10`) -- `retrieval.alpha` (默认 `0.5`) -- `retrieval.enable_ppr` (默认 `true`) -- `retrieval.ppr_alpha` (默认 `0.85`) -- `retrieval.ppr_timeout_seconds` (默认 `1.5`) -- `retrieval.ppr_concurrency_limit` (默认 `4`) -- `retrieval.enable_parallel` (默认 `true`) -- `retrieval.relation_vectorization.enabled` (默认 `false`) - -### `retrieval.sparse` (`SparseBM25Config`) - -常用键(默认值): - -- `enabled = true` -- `backend = "fts5"` -- `lazy_load = true` -- `mode = "auto"` (`auto`/`fallback_only`/`hybrid`) -- `tokenizer_mode = "jieba"` (`jieba`/`mixed`/`char_2gram`) -- `char_ngram_n = 2` -- `candidate_k = 80` -- `relation_candidate_k = 60` -- `enable_ngram_fallback_index = true` -- `enable_relation_sparse_fallback = true` - -### `retrieval.fusion` (`FusionConfig`) - -- `method` (默认 `weighted_rrf`) -- `rrf_k` (默认 `60`) -- `vector_weight` (默认 `0.7`) -- `bm25_weight` (默认 `0.3`) -- `normalize_score` (默认 `true`) -- `normalize_method` (默认 `minmax`) - -### `retrieval.search.relation_intent` (`RelationIntentConfig`) - -- `enabled` (默认 `true`) -- `alpha_override` (默认 `0.35`) -- `relation_candidate_multiplier` (默认 `4`) -- `preserve_top_relations` (默认 `3`) -- `force_relation_sparse` (默认 `true`) -- `pair_predicate_rerank_enabled` (默认 `true`) -- `pair_predicate_limit` (默认 `3`) - -### `retrieval.search.graph_recall` (`GraphRelationRecallConfig`) - -- `enabled` (默认 `true`) -- `candidate_k` (默认 `24`) -- `max_hop` (默认 `1`) -- `allow_two_hop_pair` (默认 `true`) -- `max_paths` (默认 `4`) - -### `retrieval.aggregate` - -- `retrieval.aggregate.rrf_k` -- `retrieval.aggregate.weights` - -用于聚合检索阶段混合策略;未配置时走代码默认行为。 - -## 3. 阈值过滤 - -### `threshold` (`ThresholdConfig`) - -- `threshold.min_threshold` (默认 `0.3`) -- `threshold.max_threshold` (默认 `0.95`) -- `threshold.percentile` (默认 `75.0`) -- `threshold.std_multiplier` (默认 `1.5`) -- `threshold.min_results` (默认 `3`) -- `threshold.enable_auto_adjust` (默认 `true`) - -## 4. 聊天过滤 - -### `filter` - -用于 `respect_filter=true` 场景(检索和写入都支持)。 - -```toml -[filter] -enabled = true -mode = "blacklist" # blacklist / whitelist -chats = ["group:123", "user:456", "stream:abc"] -``` - -规则: - -- `blacklist`:命中列表即拒绝 -- `whitelist`:仅列表内允许 -- 列表为空时: - - `blacklist` => 全允许 - - `whitelist` => 全拒绝 - -## 5. Episode - -### `episode` - -- `episode.enabled` (默认 `true`) -- `episode.generation_enabled` (默认 `true`) -- `episode.pending_batch_size` (默认 `20`,部分路径默认 `12`) -- `episode.pending_max_retry` (默认 `3`) -- `episode.max_paragraphs_per_call` (默认 `20`) -- `episode.max_chars_per_call` (默认 `6000`) -- `episode.source_time_window_hours` (默认 `24`) -- `episode.segmentation_model` (默认 `auto`) - -## 6. 人物画像 - -### `person_profile` - -- `person_profile.enabled` (默认 `true`) -- `person_profile.refresh_interval_minutes` (默认 `30`) -- `person_profile.active_window_hours` (默认 `72`) -- `person_profile.max_refresh_per_cycle` (默认 `50`) -- `person_profile.top_k_evidence` (默认 `12`) - -## 7. 记忆演化与回收 - -### `memory` - -- `memory.enabled` (默认 `true`) -- `memory.half_life_hours` (默认 `24.0`) -- `memory.base_decay_interval_hours` (默认 `1.0`) -- `memory.prune_threshold` (默认 `0.1`) -- `memory.freeze_duration_hours` (默认 `24.0`) - -### `memory.orphan` - -- `enable_soft_delete` (默认 `true`) -- `entity_retention_days` (默认 `7.0`) -- `paragraph_retention_days` (默认 `7.0`) -- `sweep_grace_hours` (默认 `24.0`) - -## 8. 高级运行时 - -### `advanced` - -- `advanced.enable_auto_save` (默认 `true`) -- `advanced.auto_save_interval_minutes` (默认 `5`) -- `advanced.debug` (默认 `false`) -- `advanced.extraction_model` (默认 `auto`) - -## 9. 导入中心 (`web.import`) - -### 开关与限流 - -- `web.import.enabled` (默认 `true`) -- `web.import.max_queue_size` (默认 `20`) -- `web.import.max_files_per_task` (默认 `200`) -- `web.import.max_file_size_mb` (默认 `20`) -- `web.import.max_paste_chars` (默认 `200000`) -- `web.import.default_file_concurrency` (默认 `2`) -- `web.import.default_chunk_concurrency` (默认 `4`) -- `web.import.max_file_concurrency` (默认 `6`) -- `web.import.max_chunk_concurrency` (默认 `12`) -- `web.import.poll_interval_ms` (默认 `1000`) - -### 重试与路径 - -- `web.import.llm_retry.max_attempts` (默认 `4`) -- `web.import.llm_retry.min_wait_seconds` (默认 `3`) -- `web.import.llm_retry.max_wait_seconds` (默认 `40`) -- `web.import.llm_retry.backoff_multiplier` (默认 `3`) -- `web.import.path_aliases` (默认内置 `raw/lpmm/plugin_data`) - -### 转换阶段 - -- `web.import.convert.enable_staging_switch` (默认 `true`) -- `web.import.convert.keep_backup_count` (默认 `3`) - -## 10. 调优中心 (`web.tuning`) - -- `web.tuning.enabled` (默认 `true`) -- `web.tuning.max_queue_size` (默认 `8`) -- `web.tuning.poll_interval_ms` (默认 `1200`) -- `web.tuning.eval_query_timeout_seconds` (默认 `10.0`) -- `web.tuning.default_intensity` (默认 `standard`) -- `web.tuning.default_objective` (默认 `precision_priority`) -- `web.tuning.default_top_k_eval` (默认 `20`) -- `web.tuning.default_sample_size` (默认 `24`) -- `web.tuning.llm_retry.max_attempts` (默认 `3`) -- `web.tuning.llm_retry.min_wait_seconds` (默认 `2`) -- `web.tuning.llm_retry.max_wait_seconds` (默认 `20`) -- `web.tuning.llm_retry.backoff_multiplier` (默认 `2`) - -## 11. 兼容性提示 - -- 若你从 `1.x` 升级,请优先运行: - -```bash -python plugins/A_memorix/scripts/release_vnext_migrate.py preflight --strict -python plugins/A_memorix/scripts/release_vnext_migrate.py migrate --verify-after -python plugins/A_memorix/scripts/release_vnext_migrate.py verify --strict -``` - -- 启动前再执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -以避免 embedding 维度与向量库不匹配导致运行时异常。 diff --git a/plugins/A_memorix/IMPORT_GUIDE.md b/plugins/A_memorix/IMPORT_GUIDE.md deleted file mode 100644 index 618690e0..00000000 --- a/plugins/A_memorix/IMPORT_GUIDE.md +++ /dev/null @@ -1,335 +0,0 @@ -# A_Memorix 导入指南 (v2.0.0) - -本文档对应当前 `2.0.0` 代码路径,覆盖两类导入方式: - -1. 脚本导入(离线批处理) -2. `memory_import_admin` 任务导入(在线任务化) - -## 1. 导入前检查 - -建议先执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -再确认: - -- `storage.data_dir` 路径可写 -- embedding 配置可用 -- 若是升级项目,先完成迁移脚本 - -## 2. 方式 A:脚本导入(推荐起步) - -## 2.1 原始文本导入 - -将 `.txt` 文件放入: - -```text -plugins/A_memorix/data/raw/ -``` - -执行: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -常用参数: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py --force -python plugins/A_memorix/scripts/process_knowledge.py --chat-log -python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" -``` - -## 2.2 OpenIE JSON 导入 - -```bash -python plugins/A_memorix/scripts/import_lpmm_json.py -``` - -## 2.3 LPMM 数据转换 - -```bash -python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data -``` - -## 2.4 历史数据迁移 - -```bash -python plugins/A_memorix/scripts/migrate_chat_history.py --help -python plugins/A_memorix/scripts/migrate_maibot_memory.py --help -python plugins/A_memorix/scripts/migrate_person_memory_points.py --help -``` - -## 2.5 导入后修复与重建 - -```bash -python plugins/A_memorix/scripts/backfill_temporal_metadata.py --dry-run -python plugins/A_memorix/scripts/backfill_relation_vectors.py --limit 1000 -python plugins/A_memorix/scripts/rebuild_episodes.py --all --wait -python plugins/A_memorix/scripts/audit_vector_consistency.py --json -``` - -## 3. 方式 B:`memory_import_admin` 任务导入 - -`memory_import_admin` 是在线任务化导入入口,适合宿主侧面板或自动化管道。 - -### 3.1 常用 action - -- `settings` / `get_settings` / `get_guide` -- `path_aliases` / `get_path_aliases` -- `resolve_path` -- `create_upload` -- `create_paste` -- `create_raw_scan` -- `create_lpmm_openie` -- `create_lpmm_convert` -- `create_temporal_backfill` -- `create_maibot_migration` -- `list` -- `get` -- `chunks` / `get_chunks` -- `cancel` -- `retry_failed` - -### 3.2 调用示例 - -查看运行时设置: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "settings" - } -} -``` - -创建粘贴导入任务: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "create_paste", - "content": "今天完成了检索调优回归。", - "input_mode": "plain_text", - "source": "manual:worklog" - } -} -``` - -查询任务列表: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "list", - "limit": 20 - } -} -``` - -查看任务详情: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "get", - "task_id": "", - "include_chunks": true - } -} -``` - -重试失败任务: - -```json -{ - "tool": "memory_import_admin", - "arguments": { - "action": "retry_failed", - "task_id": "" - } -} -``` - -## 4. 直接写入 Tool(非任务化) - -若你不需要任务编排,也可以直接调用: - -- `ingest_summary` -- `ingest_text` - -示例: - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "新的召回阈值方案已通过评审", - "chat_id": "group:dev", - "tags": ["worklog", "review"] - } -} -``` - -`external_id` 建议全局唯一,用于幂等去重。 - -## 5. 时间字段建议 - -可用时间字段(按常见优先级): - -- `timestamp` -- `time_start` -- `time_end` - -建议: - -- 事件类记录优先写 `time_start/time_end` -- 仅有单点时间时写 `timestamp` -- 历史数据可先导入,再用 `backfill_temporal_metadata.py` 回填 - -## 6. source_type 建议 - -常见值: - -- `chat_summary` -- `note` -- `person_fact` -- `lpmm_openie` -- `migration` - -建议保持稳定枚举,便于后续按来源治理与重建 Episode。 - -## 7. 导入完成后的验证 - -建议执行以下顺序: - -1. `memory_stats` 看总量是否增长 -2. `search_memory`(`mode=search`/`aggregate`)抽检召回 -3. `memory_episode_admin` 的 `status`/`query` 检查 Episode 生成 -4. `memory_runtime_admin` 的 `self_check` 再确认运行时健康 - -## 8. 常见问题 - -### Q1: 导入任务创建成功但无写入 - -- 检查聊天过滤配置 `filter`(若 `respect_filter=true` 可能被过滤) -- 检查任务详情中的失败原因与分块状态 - -### Q2: 任务反复失败 - -- 检查 embedding 与 LLM 可用性 -- 降低并发(`web.import.default_*_concurrency`) -- 调整重试参数(`web.import.llm_retry.*`) - -### Q3: 导入后检索效果差 - -- 先做 `runtime_self_check` -- 检查 `retrieval.sparse` 是否启用 -- 使用 `memory_tuning_admin` 创建调优任务做参数回归 - -## 9. 相关文档 - -- [QUICK_START.md](QUICK_START.md) -- [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) -- [README.md](README.md) -- [CHANGELOG.md](CHANGELOG.md) - -## 10. 附录:策略模式参考 - -A_Memorix 导入链路仍然遵循策略模式(Strategy-Aware)。`process_knowledge.py` 会自动识别文本类型,也支持手动指定。 - -| 策略类型 | 适用场景 | 核心逻辑 | 自动识别特征 | -| :-- | :-- | :-- | :-- | -| `Narrative` (叙事) | 小说、同人文、剧本、长篇故事 | 按场景/章节切分,使用滑动窗口;提取事件与角色关系 | `#`、`Chapter`、`***` 等章节标记 | -| `Factual` (事实) | 设定集、百科、说明书 | 按语义块切分,保留列表/定义结构;提取 SPO 三元组 | 列表符号、`术语: 解释` | -| `Quote` (引用) | 歌词、诗歌、名言、台词 | 按双换行切分,原文即知识,不做概括 | 平均行长短、行数多 | - -## 11. 附录:参考用例(已恢复) - -以下样例可直接复制保存为文件测试,或作为 LLM few-shot 示例。 - -### 11.1 叙事文本 (`plugins/A_memorix/data/raw/story_demo.txt`) - -```text -# 第一章:星之子 - -艾瑞克在废墟中醒来,手中的星盘发出微弱的蓝光。他并不记得自己是如何来到这里的,只依稀记得莉莉丝最后的警告:“千万不要回头。” - -远处传来了机械守卫的轰鸣声。艾瑞克迅速收起星盘,向着北方的废弃都市奔去。他知道,那里有反抗军唯一的据点。 - -*** - -# 第二章:重逢 - -在反抗军的地下掩体中,艾瑞克见到了那个熟悉的身影。莉莉丝正站在全息地图前,眉头紧锁。 - -“你还是来了。”莉莉丝没有回头,但声音中带着一丝颤抖。 -“我必须来,”艾瑞克握紧了拳头,“为了解开星盘的秘密,也为了你。” -``` - -### 11.2 事实文本 (`plugins/A_memorix/data/raw/rules_demo.txt`) - -```text -# 联邦安全协议 v2.0 - -## 核心法则 -1. **第一公理**:任何人工智能不得伤害人类个体,或因不作为而使人类个体受到伤害。 -2. **第二公理**:人工智能必须服从人类的命令,除非该命令与第一公理冲突。 - -## 术语定义 -- **以太网络**:覆盖全联邦的高速量子通讯网络。 -- **黑色障壁**:用于隔离高危 AI 的物理防火墙设施。 -``` - -### 11.3 引用文本 (`plugins/A_memorix/data/raw/poem_demo.txt`) - -```text -致橡树 - -我如果爱你—— -绝不像攀援的凌霄花, -借你的高枝炫耀自己; - -我如果爱你—— -绝不学痴情的鸟儿, -为绿荫重复单调的歌曲; - -也不止像泉源, -常年送来清凉的慰籍; -也不止像险峰, -增加你的高度,衬托你的威仪。 -``` - -### 11.4 LPMM JSON (`lpmm_data-openie.json`) - -```json -{ - "docs": [ - { - "passage": "艾瑞克手中的星盘是打开遗迹的唯一钥匙。", - "extracted_triples": [ - ["星盘", "是", "唯一的钥匙"], - ["星盘", "属于", "艾瑞克"], - ["钥匙", "用于", "遗迹"] - ], - "extracted_entities": ["星盘", "艾瑞克", "遗迹", "钥匙"] - }, - { - "passage": "莉莉丝是反抗军的现任领袖。", - "extracted_triples": [ - ["莉莉丝", "是", "领袖"], - ["领袖", "所属", "反抗军"] - ] - } - ] -} -``` diff --git a/plugins/A_memorix/LICENSE b/plugins/A_memorix/LICENSE deleted file mode 100644 index e20b431b..00000000 --- a/plugins/A_memorix/LICENSE +++ /dev/null @@ -1,661 +0,0 @@ -GNU AFFERO GENERAL PUBLIC LICENSE - Version 3, 19 November 2007 - - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU Affero General Public License is a free, copyleft license for -software and other kinds of works, specifically designed to ensure -cooperation with the community in the case of network server software. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -our General Public Licenses are intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - Developers that use our General Public Licenses protect your rights -with two steps: (1) assert copyright on the software, and (2) offer -you this License which gives you legal permission to copy, distribute -and/or modify the software. - - A secondary benefit of defending all users' freedom is that -improvements made in alternate versions of the program, if they -receive widespread use, become available for other developers to -incorporate. Many developers of free software are heartened and -encouraged by the resulting cooperation. However, in the case of -software used on network servers, this result may fail to come about. -The GNU General Public License permits making a modified version and -letting the public access it on a server without ever releasing its -source code to the public. - - The GNU Affero General Public License is designed specifically to -ensure that, in such cases, the modified source code becomes available -to the community. It requires the operator of a network server to -provide the source code of the modified version running there to the -users of that server. Therefore, public use of a modified version, on -a publicly accessible server, gives the public access to the source -code of the modified version. - - An older license, called the Affero General Public License and -published by Affero, was designed to accomplish similar goals. This is -a different license, not a version of the Affero GPL, but Affero has -released a new version of the Affero GPL which permits relicensing under -this license. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU Affero General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Remote Network Interaction; Use with the GNU General Public License. - - Notwithstanding any other provision of this License, if you modify the -Program, your modified version must prominently offer all users -interacting with it remotely through a computer network (if your version -supports such interaction) an opportunity to receive the Corresponding -Source of your version by providing access to the Corresponding Source -from a network server at no charge, through some standard or customary -means of facilitating copying of software. This Corresponding Source -shall include the Corresponding Source for any work covered by version 3 -of the GNU General Public License that is incorporated pursuant to the -following paragraph. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the work with which it is combined will remain governed by version -3 of the GNU General Public License. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU Affero General Public License from time to time. Such new versions -will be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU Affero General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU Affero General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU Affero General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published - by the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If your software can interact with users remotely through a computer -network, you should also make sure that it provides a way for users to -get its source. For example, if your program is a web application, its -interface could display a "Source" link that leads users to an archive -of the code. There are many ways you could offer source, and different -solutions will be better for different programs; see section 13 for the -specific requirements. - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU AGPL, see -. diff --git a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md b/plugins/A_memorix/LICENSE-MAIBOT-GPL.md deleted file mode 100644 index 83108097..00000000 --- a/plugins/A_memorix/LICENSE-MAIBOT-GPL.md +++ /dev/null @@ -1,22 +0,0 @@ -Special GPL License Grant for MaiBot - -Licensor -- A_Dawn - -Effective date -- 2026-03-18 - -Default license -- This repository is licensed under AGPL-3.0 by default (see `LICENSE`). - -Additional grant for MaiBot -- The copyright holder(s) of this repository grant an additional, non-exclusive permission to - the project at `https://github.com/Mai-with-u/MaiBot` (including its maintainers and contributors) - to use, modify, and redistribute code from this repository under GPL-3.0. - -Scope -- This additional GPL grant is intended for use in the MaiBot project context. -- For all other uses not covered by the grant above, AGPL-3.0 remains the applicable license. - -No warranty -- This grant is provided without warranty, consistent with AGPL-3.0 and GPL-3.0. diff --git a/plugins/A_memorix/QUICK_START.md b/plugins/A_memorix/QUICK_START.md deleted file mode 100644 index 7159a35b..00000000 --- a/plugins/A_memorix/QUICK_START.md +++ /dev/null @@ -1,216 +0,0 @@ -# A_Memorix Quick Start (v2.0.0) - -本文档面向当前 `2.0.0` 架构(SDK Tool 接口)。 - -## 0. 版本与接口变更 - -- 当前插件版本:`2.0.0` -- 接口形态:`memory_provider` + Tool 调用 -- 旧版 slash 命令(如 `/query`、`/memory`、`/visualize`)不再作为本分支主文档入口 - -## 1. 环境准备 - -- Python 3.10+ -- 与 MaiBot 主程序相同的运行环境 -- 可访问你配置的 embedding 服务 - -安装依赖: - -```bash -pip install -r plugins/A_memorix/requirements.txt --upgrade -``` - -如果当前目录就是插件目录,也可以: - -```bash -pip install -r requirements.txt --upgrade -``` - -## 2. 启用插件 - -在主程序插件配置中启用 `A_Memorix`。 - -若你使用 `plugins/A_memorix/config.toml` 方式,最小示例: - -```toml -[plugin] -enabled = true - -[storage] -data_dir = "./data" - -[embedding] -model_name = "auto" -dimension = 1024 -batch_size = 32 -max_concurrent = 5 -quantization_type = "int8" -``` - -## 3. 运行时自检(强烈建议) - -先确认 embedding 实际输出维度与向量库兼容: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -如果结果 `ok=false`,先修复 embedding 配置或向量库,再继续导入。 - -## 4. 导入数据 - -### 4.1 文本批量导入 - -把文本放到: - -```text -plugins/A_memorix/data/raw/ -``` - -执行: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -常用参数: - -```bash -python plugins/A_memorix/scripts/process_knowledge.py --force -python plugins/A_memorix/scripts/process_knowledge.py --chat-log -python plugins/A_memorix/scripts/process_knowledge.py --chat-log --chat-reference-time "2026/02/12 10:30" -``` - -### 4.2 其他导入脚本 - -```bash -python plugins/A_memorix/scripts/import_lpmm_json.py -python plugins/A_memorix/scripts/convert_lpmm.py -i -o plugins/A_memorix/data -python plugins/A_memorix/scripts/migrate_chat_history.py --help -python plugins/A_memorix/scripts/migrate_maibot_memory.py --help -python plugins/A_memorix/scripts/migrate_person_memory_points.py --help -``` - -## 5. 核心 Tool 调用 - -### 5.1 检索 - -```json -{ - "tool": "search_memory", - "arguments": { - "query": "项目复盘", - "mode": "aggregate", - "limit": 5, - "chat_id": "group:dev" - } -} -``` - -`mode` 支持:`search/time/hybrid/episode/aggregate` - -严格语义说明: - -- `semantic` 模式已移除,传入会返回参数错误。 -- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误(不会再当作“未命中”)。 - -### 5.2 写入摘要 - -```json -{ - "tool": "ingest_summary", - "arguments": { - "external_id": "chat_summary:group-dev:2026-03-18", - "chat_id": "group:dev", - "text": "今天完成了检索调优评审" - } -} -``` - -### 5.3 写入普通记忆 - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "模型切换后召回质量更稳定", - "chat_id": "group:dev", - "tags": ["worklog"] - } -} -``` - -### 5.4 画像与维护 - -```json -{ - "tool": "get_person_profile", - "arguments": { - "person_id": "Alice", - "limit": 8 - } -} -``` - -```json -{ - "tool": "maintain_memory", - "arguments": { - "action": "protect", - "target": "模型切换后召回质量更稳定", - "hours": 24 - } -} -``` - -```json -{ - "tool": "memory_stats", - "arguments": {} -} -``` - -## 6. 管理 Tool(进阶) - -`2.0.0` 提供完整管理工具: - -- `memory_graph_admin` -- `memory_source_admin` -- `memory_episode_admin` -- `memory_profile_admin` -- `memory_runtime_admin` -- `memory_import_admin` -- `memory_tuning_admin` -- `memory_v5_admin` -- `memory_delete_admin` - -可先用 `action=list` / `action=status` 等只读动作验证链路。 - -## 7. 常见问题 - -### Q1: 检索为空 - -1. 先看 `memory_stats` 是否有段落/关系 -2. 检查 `chat_id`、`person_id` 过滤条件是否过严 -3. 运行 `runtime_self_check.py --json` 确认 embedding 维度无误 -4. 若返回包含 `error` 字段,优先按错误提示修正 mode/时间参数 - -### Q2: 启动时报向量维度不一致 - -- 原因:现有向量库维度与当前 embedding 输出不一致 -- 处理:恢复原配置或重建向量数据后再启动 - -### Q3: Web 页面打不开 - -本分支不内置独立 `server.py`。 - -- `web/index.html`、`web/import.html`、`web/tuning.html` 由宿主侧路由/API 集成暴露 -- 请检查宿主是否已映射对应静态页与 `/api/*` 接口 - -## 8. 下一步 - -- 配置细节见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md) -- 导入细节见 [IMPORT_GUIDE.md](IMPORT_GUIDE.md) -- 版本历史见 [CHANGELOG.md](CHANGELOG.md) diff --git a/plugins/A_memorix/README.md b/plugins/A_memorix/README.md deleted file mode 100644 index 2c59629a..00000000 --- a/plugins/A_memorix/README.md +++ /dev/null @@ -1,230 +0,0 @@ -# A_Memorix - -**长期记忆与认知增强插件** (v2.0.0) - -> 消えていかない感覚 , まだまだ足りてないみたい ! - -A_Memorix 是面向 MaiBot SDK 的 `memory_provider` 插件。 -它把文本、关系、Episode、人物画像和检索调优统一在一套运行时里,适合长期运行的 Agent 记忆场景。 - -## 快速导航 - -- [快速入门](QUICK_START.md) -- [配置参数详解](CONFIG_REFERENCE.md) -- [导入指南与最佳实践](IMPORT_GUIDE.md) -- [更新日志](CHANGELOG.md) - -## 2.0.0 版本定位 - -`v2.0.0` 是一次架构收敛版本,当前分支以 **SDK Tool 接口** 为主: - -- 旧 `components/commands/*`、`components/tools/*` 与 `server.py` 已移除。 -- 统一入口为 [`plugin.py`](plugin.py) + [`core/runtime/sdk_memory_kernel.py`](core/runtime/sdk_memory_kernel.py)。 -- 元数据 schema 为 `v8`,新增外部引用与运维操作记录(如 `external_memory_refs`、`memory_v5_operations`、`delete_operations`)。 - -如果你还在使用旧版 slash 命令(如 `/query`、`/memory`、`/visualize`),需要按本文的 Tool 接口迁移。 - -## 核心能力 - -- 双路检索:向量 + 图谱关系联合召回,支持 `search/time/hybrid/episode/aggregate`。 -- 写入与去重:`external_id` 幂等、段落/关系联合写入、Episode pending 队列处理。 -- Episode 能力:按 source 重建、状态查询、批处理 pending。 -- 人物画像:自动快照 + 手动 override。 -- 管理能力:图谱、来源、Episode、画像、导入、调优、V5 运维、删除恢复全套管理工具。 - -## Tool 接口 (v2.0.0) - -### 基础工具 - -| Tool | 说明 | 关键参数 | -| --- | --- | --- | -| `search_memory` | 检索长期记忆 | `query` `mode` `limit` `chat_id` `person_id` `time_start` `time_end` | -| `ingest_summary` | 写入聊天摘要 | `external_id` `chat_id` `text` | -| `ingest_text` | 写入普通文本记忆 | `external_id` `source_type` `text` | -| `get_person_profile` | 获取人物画像 | `person_id` `chat_id` `limit` | -| `maintain_memory` | 维护关系状态 | `action=reinforce/protect/restore/freeze/recycle_bin` | -| `memory_stats` | 获取统计信息 | 无 | - -### 管理工具 - -| Tool | 常用 action | -| --- | --- | -| `memory_graph_admin` | `get_graph/create_node/delete_node/rename_node/create_edge/delete_edge/update_edge_weight` | -| `memory_source_admin` | `list/delete/batch_delete` | -| `memory_episode_admin` | `query/list/get/status/rebuild/process_pending` | -| `memory_profile_admin` | `query/list/set_override/delete_override` | -| `memory_runtime_admin` | `save/get_config/self_check/refresh_self_check/set_auto_save` | -| `memory_import_admin` | `settings/get_guide/create_upload/create_paste/create_raw_scan/create_lpmm_openie/create_lpmm_convert/create_temporal_backfill/create_maibot_migration/list/get/chunks/cancel/retry_failed` | -| `memory_tuning_admin` | `settings/get_profile/apply_profile/rollback_profile/export_profile/create_task/list_tasks/get_task/get_rounds/cancel/apply_best/get_report` | -| `memory_v5_admin` | `status/recycle_bin/restore/reinforce/weaken/remember_forever/forget` | -| `memory_delete_admin` | `preview/execute/restore/get_operation/list_operations/purge` | - -### 检索模式语义(严格) - -- `search_memory.mode` 仅支持:`search/time/hybrid/episode/aggregate`。 -- `semantic` 模式已移除,传入将返回参数错误。 -- `time/hybrid` 模式必须提供 `time_start` 或 `time_end`,否则返回错误,不再静默按“未命中”处理。 - -### 删除返回语义(source 模式) - -- `requested_source_count`:请求删除的 source 数。 -- `matched_source_count`:实际命中的 source 数(存在活跃段落)。 -- `deleted_paragraph_count`:实际删除段落数。 -- `deleted_count`:与实际删除对象一致;在 `source` 模式下等于 `deleted_paragraph_count`。 -- `success`:基于实际命中与实际删除判定,未命中 source 时返回 `false`。 - -## 调用示例 - -```json -{ - "tool": "search_memory", - "arguments": { - "query": "项目复盘", - "mode": "aggregate", - "limit": 5, - "chat_id": "group:dev" - } -} -``` - -```json -{ - "tool": "ingest_text", - "arguments": { - "external_id": "note:2026-03-18:001", - "source_type": "note", - "text": "今天完成了检索调优评审", - "chat_id": "group:dev", - "tags": ["worklog"] - } -} -``` - -```json -{ - "tool": "maintain_memory", - "arguments": { - "action": "protect", - "target": "完成了 检索调优评审", - "hours": 72 - } -} -``` - -## 快速开始 - -### 1. 安装依赖 - -在 MaiBot 主程序使用的同一个 Python 环境中执行: - -```bash -pip install -r plugins/A_memorix/requirements.txt --upgrade -``` - -如果当前目录已经是插件目录,也可以执行: - -```bash -pip install -r requirements.txt --upgrade -``` - -### 2. 启用插件 - -在 `config.toml` 中启用插件(路径取决于你的宿主部署): - -```toml -[plugin] -enabled = true -``` - -### 3. 先做运行时自检 - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -### 4. 导入文本并验证统计 - -```bash -python plugins/A_memorix/scripts/process_knowledge.py -``` - -然后调用 `memory_stats` 或 `search_memory` 检查是否有数据。 - -## Web 页面说明 - -仓库内保留了 Web 静态页面: - -- `web/index.html`(图谱与记忆管理) -- `web/import.html`(导入中心) -- `web/tuning.html`(检索调优) - -当前分支不再内置独立 `server.py`,页面路由与 API 暴露由宿主侧集成负责。 - -## 常用脚本 - -| 脚本 | 用途 | -| --- | --- | -| `process_knowledge.py` | 批量导入原始文本(策略感知) | -| `import_lpmm_json.py` | 导入 OpenIE JSON | -| `convert_lpmm.py` | 转换 LPMM 数据 | -| `migrate_chat_history.py` | 迁移 chat_history | -| `migrate_maibot_memory.py` | 迁移 MaiBot 历史记忆 | -| `migrate_person_memory_points.py` | 迁移 person memory points | -| `backfill_temporal_metadata.py` | 回填时间元数据 | -| `audit_vector_consistency.py` | 审计向量一致性 | -| `backfill_relation_vectors.py` | 回填关系向量 | -| `rebuild_episodes.py` | 按 source 重建 Episode | -| `release_vnext_migrate.py` | 升级预检/迁移/校验 | -| `runtime_self_check.py` | 真实 embedding 运行时自检 | - -## 配置重点 - -完整配置见 [CONFIG_REFERENCE.md](CONFIG_REFERENCE.md)。 - -高频配置项: - -- `storage.data_dir` -- `embedding.dimension` -- `embedding.quantization_type`(当前仅支持 `int8`) -- `retrieval.*` -- `retrieval.sparse.*` -- `episode.*` -- `person_profile.*` -- `memory.*` -- `web.import.*` -- `web.tuning.*` - -## Troubleshooting - -### SQLite 无 FTS5 - -如果环境中的 SQLite 未启用 `FTS5`,可关闭稀疏检索: - -```toml -[retrieval.sparse] -enabled = false -``` - -### 向量维度不一致 - -若日志提示当前 embedding 输出维度与既有向量库不一致,请先执行: - -```bash -python plugins/A_memorix/scripts/runtime_self_check.py --json -``` - -必要时重建向量或调整 embedding 配置后再启动插件。 - -## 许可证 - -默认许可证为 [AGPL-3.0](https://www.gnu.org/licenses/agpl-3.0)(见 `LICENSE`)。 - -针对 `Mai-with-u/MaiBot` 项目的 GPL 额外授权见 `LICENSE-MAIBOT-GPL.md`。 - -除上述额外授权外,其他使用场景仍适用 AGPL-3.0。 - -## 贡献说明 - -当前不接受 PR,只接受 issue。 - -**作者**: `A_Dawn` diff --git a/plugins/A_memorix/__init__.py b/plugins/A_memorix/__init__.py deleted file mode 100644 index d23a5bd5..00000000 --- a/plugins/A_memorix/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -A_Memorix - 轻量级知识库插件 - -完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。 -""" - -__version__ = "2.0.0" -__author__ = "A_Dawn" - -from .plugin import AMemorixPlugin - -__all__ = ["AMemorixPlugin"] diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json deleted file mode 100644 index e4217fdd..00000000 --- a/plugins/A_memorix/_manifest.json +++ /dev/null @@ -1,107 +0,0 @@ -{ - "manifest_version": 1, - "name": "A_Memorix", - "version": "2.0.0", - "description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。", - "author": { - "name": "A_Dawn" - }, - "license": "AGPL-3.0", - "repository_url": "https://github.com/A-Dawn/A_memorix/", - "host_application": { - "min_version": "1.0.0" - }, - "keywords": [ - "memory", - "knowledge", - "retrieval", - "profile", - "episode" - ], - "categories": [ - "Memory", - "Data" - ], - "plugin_info": { - "is_built_in": false, - "plugin_type": "memory_provider", - "components": [ - { - "type": "tool", - "name": "search_memory", - "description": "搜索长期记忆" - }, - { - "type": "tool", - "name": "ingest_summary", - "description": "写入聊天摘要" - }, - { - "type": "tool", - "name": "ingest_text", - "description": "写入普通长期记忆文本" - }, - { - "type": "tool", - "name": "get_person_profile", - "description": "查询人物画像" - }, - { - "type": "tool", - "name": "maintain_memory", - "description": "维护记忆关系" - }, - { - "type": "tool", - "name": "memory_stats", - "description": "查询记忆统计" - }, - { - "type": "tool", - "name": "memory_graph_admin", - "description": "图谱管理接口" - }, - { - "type": "tool", - "name": "memory_source_admin", - "description": "来源管理接口" - }, - { - "type": "tool", - "name": "memory_episode_admin", - "description": "Episode 管理接口" - }, - { - "type": "tool", - "name": "memory_profile_admin", - "description": "画像管理接口" - }, - { - "type": "tool", - "name": "memory_runtime_admin", - "description": "运行时管理接口" - }, - { - "type": "tool", - "name": "memory_import_admin", - "description": "导入管理接口" - }, - { - "type": "tool", - "name": "memory_tuning_admin", - "description": "调优管理接口" - }, - { - "type": "tool", - "name": "memory_v5_admin", - "description": "V5 记忆管理接口" - }, - { - "type": "tool", - "name": "memory_delete_admin", - "description": "删除管理接口" - } - ] - }, - "capabilities": [] -} diff --git a/plugins/A_memorix/core/__init__.py b/plugins/A_memorix/core/__init__.py deleted file mode 100644 index 3f87929c..00000000 --- a/plugins/A_memorix/core/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -"""核心模块 - 存储、嵌入、检索引擎""" - -# 存储模块(已实现) -from .storage import ( - VectorStore, - GraphStore, - MetadataStore, - ImportStrategy, - KnowledgeType, - parse_import_strategy, - resolve_stored_knowledge_type, - detect_knowledge_type, - select_import_strategy, - should_extract_relations, - get_type_display_name, -) - -# 嵌入模块(使用主程序 API) -from .embedding import ( - EmbeddingAPIAdapter, - create_embedding_api_adapter, -) - -# 检索模块(已实现) -from .retrieval import ( - DualPathRetriever, - RetrievalStrategy, - RetrievalResult, - DualPathRetrieverConfig, - TemporalQueryOptions, - FusionConfig, - GraphRelationRecallConfig, - RelationIntentConfig, - PersonalizedPageRank, - PageRankConfig, - create_ppr_from_graph, - DynamicThresholdFilter, - ThresholdMethod, - ThresholdConfig, - SparseBM25Index, - SparseBM25Config, -) -from .utils import ( - RelationWriteService, - RelationWriteResult, -) - -__all__ = [ - # Storage - "VectorStore", - "GraphStore", - "MetadataStore", - "ImportStrategy", - "KnowledgeType", - "parse_import_strategy", - "resolve_stored_knowledge_type", - "detect_knowledge_type", - "select_import_strategy", - "should_extract_relations", - "get_type_display_name", - # Embedding - "EmbeddingAPIAdapter", - "create_embedding_api_adapter", - # Retrieval - "DualPathRetriever", - "RetrievalStrategy", - "RetrievalResult", - "DualPathRetrieverConfig", - "TemporalQueryOptions", - "FusionConfig", - "GraphRelationRecallConfig", - "RelationIntentConfig", - "PersonalizedPageRank", - "PageRankConfig", - "create_ppr_from_graph", - "DynamicThresholdFilter", - "ThresholdMethod", - "ThresholdConfig", - "SparseBM25Index", - "SparseBM25Config", - "RelationWriteService", - "RelationWriteResult", -] - diff --git a/plugins/A_memorix/core/embedding/__init__.py b/plugins/A_memorix/core/embedding/__init__.py deleted file mode 100644 index 11a52db9..00000000 --- a/plugins/A_memorix/core/embedding/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -"""嵌入模块 - 向量生成与量化""" - -# 新的 API 适配器(主程序嵌入 API) -from .api_adapter import ( - EmbeddingAPIAdapter, - create_embedding_api_adapter, -) - -from ..utils.quantization import QuantizationType - -__all__ = [ - # 新的 API 适配器(推荐使用) - "EmbeddingAPIAdapter", - "create_embedding_api_adapter", - # 量化 - "QuantizationType", -] - diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py deleted file mode 100644 index d11e2d05..00000000 --- a/plugins/A_memorix/core/embedding/api_adapter.py +++ /dev/null @@ -1,368 +0,0 @@ -""" -请求式嵌入 API 适配器。 - -恢复 v1.0.1 的真实 embedding 请求语义: -- 通过宿主模型配置探测/请求 embedding -- 支持 dimensions 参数 -- 支持批量与重试 -- 不再提供本地 hash fallback -""" - -from __future__ import annotations - -import asyncio -import time -from typing import Any, List, Optional, Union - -import aiohttp -import numpy as np -import openai - -from src.common.logger import get_logger -from src.config.config import config_manager -from src.config.model_configs import APIProvider, ModelInfo -from src.llm_models.exceptions import NetworkConnectionError -from src.llm_models.model_client.base_client import client_registry - -logger = get_logger("A_Memorix.EmbeddingAPIAdapter") - - -class EmbeddingAPIAdapter: - """适配宿主 embedding 请求接口。""" - - def __init__( - self, - batch_size: int = 32, - max_concurrent: int = 5, - default_dimension: int = 1024, - enable_cache: bool = False, - model_name: str = "auto", - retry_config: Optional[dict] = None, - ) -> None: - self.batch_size = max(1, int(batch_size)) - self.max_concurrent = max(1, int(max_concurrent)) - self.default_dimension = max(1, int(default_dimension)) - self.enable_cache = bool(enable_cache) - self.model_name = str(model_name or "auto") - - self.retry_config = retry_config or {} - self.max_attempts = max(1, int(self.retry_config.get("max_attempts", 5))) - self.max_wait_seconds = max(0.1, float(self.retry_config.get("max_wait_seconds", 40))) - self.min_wait_seconds = max(0.1, float(self.retry_config.get("min_wait_seconds", 3))) - self.backoff_multiplier = max(1.0, float(self.retry_config.get("backoff_multiplier", 3))) - - self._dimension: Optional[int] = None - self._dimension_detected = False - self._total_encoded = 0 - self._total_errors = 0 - self._total_time = 0.0 - - logger.info( - "EmbeddingAPIAdapter 初始化: " - f"batch_size={self.batch_size}, " - f"max_concurrent={self.max_concurrent}, " - f"default_dim={self.default_dimension}, " - f"model={self.model_name}" - ) - - def _get_current_model_config(self): - return config_manager.get_model_config() - - @staticmethod - def _find_model_info(model_name: str) -> ModelInfo: - model_cfg = config_manager.get_model_config() - for item in model_cfg.models: - if item.name == model_name: - return item - raise ValueError(f"未找到 embedding 模型: {model_name}") - - @staticmethod - def _find_provider(provider_name: str) -> APIProvider: - model_cfg = config_manager.get_model_config() - for item in model_cfg.api_providers: - if item.name == provider_name: - return item - raise ValueError(f"未找到 embedding provider: {provider_name}") - - def _resolve_candidate_model_names(self) -> List[str]: - task_config = self._get_current_model_config().model_task_config.embedding - configured = list(getattr(task_config, "model_list", []) or []) - if self.model_name and self.model_name != "auto": - return [self.model_name, *[name for name in configured if name != self.model_name]] - return configured - - @staticmethod - def _validate_embedding_vector(embedding: Any, *, source: str) -> np.ndarray: - array = np.asarray(embedding, dtype=np.float32) - if array.ndim != 1: - raise RuntimeError(f"{source} 返回的 embedding 维度非法: ndim={array.ndim}") - if array.size <= 0: - raise RuntimeError(f"{source} 返回了空 embedding") - if not np.all(np.isfinite(array)): - raise RuntimeError(f"{source} 返回了非有限 embedding 值") - return array - - async def _request_with_retry(self, client, model_info, text: str, extra_params: dict): - retriable_exceptions = ( - openai.APIConnectionError, - openai.APITimeoutError, - aiohttp.ClientError, - asyncio.TimeoutError, - NetworkConnectionError, - ) - - last_exc: Optional[BaseException] = None - for attempt in range(1, self.max_attempts + 1): - try: - return await client.get_embedding( - model_info=model_info, - embedding_input=text, - extra_params=extra_params, - ) - except retriable_exceptions as exc: - last_exc = exc - if attempt >= self.max_attempts: - raise - wait_seconds = min( - self.max_wait_seconds, - self.min_wait_seconds * (self.backoff_multiplier ** (attempt - 1)), - ) - logger.warning( - "Embedding 请求失败,重试 " - f"{attempt}/{max(1, self.max_attempts - 1)}," - f"{wait_seconds:.1f}s 后重试: {exc}" - ) - await asyncio.sleep(wait_seconds) - except Exception: - raise - - if last_exc is not None: - raise last_exc - raise RuntimeError("Embedding 请求失败:未知错误") - - async def _get_embedding_direct(self, text: str, dimensions: Optional[int] = None) -> Optional[List[float]]: - candidate_names = self._resolve_candidate_model_names() - if not candidate_names: - raise RuntimeError("embedding 任务未配置模型") - - last_exc: Optional[BaseException] = None - for candidate_name in candidate_names: - try: - model_info = self._find_model_info(candidate_name) - api_provider = self._find_provider(model_info.api_provider) - client = client_registry.get_client_class_instance(api_provider, force_new=True) - - extra_params = dict(getattr(model_info, "extra_params", {}) or {}) - if dimensions is not None: - extra_params["dimensions"] = int(dimensions) - - response = await self._request_with_retry( - client=client, - model_info=model_info, - text=text, - extra_params=extra_params, - ) - embedding = getattr(response, "embedding", None) - if embedding is None: - raise RuntimeError(f"模型 {candidate_name} 未返回 embedding") - vector = self._validate_embedding_vector( - embedding, - source=f"embedding 模型 {candidate_name}", - ) - return vector.tolist() - except Exception as exc: - last_exc = exc - logger.warning(f"embedding 模型 {candidate_name} 请求失败: {exc}") - - if last_exc is not None: - logger.error(f"通过直接 Client 获取 Embedding 失败: {last_exc}") - return None - - async def _detect_dimension(self) -> int: - if self._dimension_detected and self._dimension is not None: - return self._dimension - - logger.info("正在检测嵌入模型维度...") - try: - target_dim = self.default_dimension - logger.debug(f"尝试请求指定维度: {target_dim}") - test_embedding = await self._get_embedding_direct("test", dimensions=target_dim) - if test_embedding and isinstance(test_embedding, list): - detected_dim = len(test_embedding) - if detected_dim == target_dim: - logger.info(f"嵌入维度检测成功 (匹配配置): {detected_dim}") - else: - logger.warning( - f"请求维度 {target_dim} 但模型返回 {detected_dim},将使用模型自然维度" - ) - self._dimension = detected_dim - self._dimension_detected = True - return detected_dim - except Exception as exc: - logger.debug(f"带维度参数探测失败: {exc},尝试不带参数探测") - - try: - test_embedding = await self._get_embedding_direct("test", dimensions=None) - if test_embedding and isinstance(test_embedding, list): - detected_dim = len(test_embedding) - self._dimension = detected_dim - self._dimension_detected = True - logger.info(f"嵌入维度检测成功 (自然维度): {detected_dim}") - return detected_dim - logger.warning(f"嵌入维度检测失败,使用默认值: {self.default_dimension}") - except Exception as exc: - logger.error(f"嵌入维度检测异常: {exc},使用默认值: {self.default_dimension}") - - self._dimension = self.default_dimension - self._dimension_detected = True - return self.default_dimension - - async def encode( - self, - texts: Union[str, List[str]], - batch_size: Optional[int] = None, - show_progress: bool = False, - normalize: bool = True, - dimensions: Optional[int] = None, - ) -> np.ndarray: - del show_progress - del normalize - - start_time = time.time() - target_dim = int(dimensions) if dimensions is not None else int(await self._detect_dimension()) - - if isinstance(texts, str): - normalized_texts = [texts] - single_input = True - else: - normalized_texts = list(texts or []) - single_input = False - - if not normalized_texts: - empty = np.zeros((0, target_dim), dtype=np.float32) - return empty[0] if single_input else empty - - if batch_size is None: - batch_size = self.batch_size - - try: - embeddings = await self._encode_batch_internal( - normalized_texts, - batch_size=max(1, int(batch_size)), - dimensions=dimensions, - ) - if embeddings.ndim == 1: - embeddings = embeddings.reshape(1, -1) - self._total_encoded += len(normalized_texts) - elapsed = time.time() - start_time - self._total_time += elapsed - logger.debug( - "编码完成: " - f"{len(normalized_texts)} 个文本, " - f"耗时 {elapsed:.2f}s, " - f"平均 {elapsed / max(1, len(normalized_texts)):.3f}s/文本" - ) - return embeddings[0] if single_input else embeddings - except Exception as exc: - self._total_errors += 1 - logger.error(f"编码失败: {exc}") - raise RuntimeError(f"embedding encode failed: {exc}") from exc - - async def _encode_batch_internal( - self, - texts: List[str], - batch_size: int, - dimensions: Optional[int] = None, - ) -> np.ndarray: - all_embeddings: List[np.ndarray] = [] - for offset in range(0, len(texts), batch_size): - batch = texts[offset : offset + batch_size] - semaphore = asyncio.Semaphore(self.max_concurrent) - - async def encode_with_semaphore(text: str, index: int): - async with semaphore: - embedding = await self._get_embedding_direct(text, dimensions=dimensions) - if embedding is None: - raise RuntimeError(f"文本 {index} 编码失败:embedding 返回为空") - vector = self._validate_embedding_vector( - embedding, - source=f"文本 {index}", - ) - return index, vector - - tasks = [ - encode_with_semaphore(text, offset + index) - for index, text in enumerate(batch) - ] - results = await asyncio.gather(*tasks) - results.sort(key=lambda item: item[0]) - all_embeddings.extend(emb for _, emb in results) - - return np.array(all_embeddings, dtype=np.float32) - - async def encode_batch( - self, - texts: List[str], - batch_size: Optional[int] = None, - num_workers: Optional[int] = None, - show_progress: bool = False, - dimensions: Optional[int] = None, - ) -> np.ndarray: - del show_progress - if num_workers is not None: - previous = self.max_concurrent - self.max_concurrent = max(1, int(num_workers)) - try: - return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) - finally: - self.max_concurrent = previous - return await self.encode(texts, batch_size=batch_size, dimensions=dimensions) - - def get_embedding_dimension(self) -> int: - if self._dimension is not None: - return self._dimension - logger.warning(f"维度尚未检测,返回默认值: {self.default_dimension}") - return self.default_dimension - - def get_model_info(self) -> dict: - return { - "model_name": self.model_name, - "dimension": self._dimension or self.default_dimension, - "dimension_detected": self._dimension_detected, - "batch_size": self.batch_size, - "max_concurrent": self.max_concurrent, - "total_encoded": self._total_encoded, - "total_errors": self._total_errors, - "avg_time_per_text": self._total_time / self._total_encoded if self._total_encoded else 0.0, - } - - def get_statistics(self) -> dict: - return self.get_model_info() - - @property - def is_model_loaded(self) -> bool: - return True - - def __repr__(self) -> str: - return ( - f"EmbeddingAPIAdapter(dim={self._dimension or self.default_dimension}, " - f"detected={self._dimension_detected}, encoded={self._total_encoded})" - ) - - -def create_embedding_api_adapter( - batch_size: int = 32, - max_concurrent: int = 5, - default_dimension: int = 1024, - enable_cache: bool = False, - model_name: str = "auto", - retry_config: Optional[dict] = None, -) -> EmbeddingAPIAdapter: - return EmbeddingAPIAdapter( - batch_size=batch_size, - max_concurrent=max_concurrent, - default_dimension=default_dimension, - enable_cache=enable_cache, - model_name=model_name, - retry_config=retry_config, - ) diff --git a/plugins/A_memorix/core/embedding/manager.py b/plugins/A_memorix/core/embedding/manager.py deleted file mode 100644 index d161e23b..00000000 --- a/plugins/A_memorix/core/embedding/manager.py +++ /dev/null @@ -1,510 +0,0 @@ -""" -嵌入管理器 - -负责嵌入模型的加载、缓存和批量生成。 -""" - -import hashlib -import pickle -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Optional, Union, List, Dict, Any, Tuple - -import numpy as np - -try: - from sentence_transformers import SentenceTransformer - HAS_SENTENCE_TRANSFORMERS = True -except ImportError: - HAS_SENTENCE_TRANSFORMERS = False - -from src.common.logger import get_logger -from .presets import ( - EmbeddingModelConfig, - get_custom_config, - validate_config_compatibility, - are_models_compatible, -) -from ..utils.quantization import QuantizationType - -logger = get_logger("A_Memorix.EmbeddingManager") - - -class EmbeddingManager: - """ - 嵌入管理器 - - 功能: - - 模型加载与缓存 - - 批量生成嵌入 - - 多线程/多进程支持 - - 模型一致性检查 - - 智能分批 - - 参数: - config: 模型配置 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - """ - - def __init__( - self, - config: EmbeddingModelConfig, - cache_dir: Optional[Union[str, Path]] = None, - enable_cache: bool = True, - num_workers: int = 1, - ): - """ - 初始化嵌入管理器 - - Args: - config: 模型配置 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - """ - if not HAS_SENTENCE_TRANSFORMERS: - raise ImportError( - "sentence-transformers 未安装,请安装: " - "pip install sentence-transformers" - ) - - self.config = config - self.cache_dir = Path(cache_dir) if cache_dir else None - self.enable_cache = enable_cache - self.num_workers = max(1, num_workers) - - # 模型实例 - self._model: Optional[SentenceTransformer] = None - self._model_lock = threading.Lock() - - # 缓存 - self._embedding_cache: Dict[str, np.ndarray] = {} - self._cache_lock = threading.Lock() - - # 统计 - self._total_encoded = 0 - self._cache_hits = 0 - self._cache_misses = 0 - - logger.info( - f"EmbeddingManager 初始化: model={config.model_name}, " - f"dim={config.dimension}, workers={num_workers}" - ) - - def load_model(self) -> None: - """加载模型(懒加载)""" - if self._model is not None: - return - - with self._model_lock: - # 双重检查 - if self._model is not None: - return - - logger.info(f"正在加载模型: {self.config.model_name}") - - try: - # 构建模型参数 - model_kwargs = {} - if self.config.cache_dir: - model_kwargs["cache_folder"] = self.config.cache_dir - - # 加载模型 - self._model = SentenceTransformer( - self.config.model_path, - **model_kwargs, - ) - - logger.info(f"模型加载成功: {self.config.model_name}") - - except Exception as e: - logger.error(f"模型加载失败: {e}") - raise - - def encode( - self, - texts: Union[str, List[str]], - batch_size: Optional[int] = None, - show_progress: bool = False, - normalize: bool = True, - ) -> np.ndarray: - """ - 生成文本嵌入 - - Args: - texts: 文本或文本列表 - batch_size: 批次大小(默认使用配置值) - show_progress: 是否显示进度条 - normalize: 是否归一化 - - Returns: - 嵌入向量 (N x D) - """ - # 确保模型已加载 - self.load_model() - - # 标准化输入 - if isinstance(texts, str): - texts = [texts] - single_input = True - else: - single_input = False - - if not texts: - return np.zeros((0, self.config.dimension), dtype=np.float32) - - # 使用配置的批次大小 - if batch_size is None: - batch_size = self.config.batch_size - - # 生成嵌入 - try: - embeddings = self._model.encode( - texts, - batch_size=batch_size, - show_progress_bar=show_progress, - normalize_embeddings=normalize and self.config.normalization, - convert_to_numpy=True, - ) - - # 确保是2D数组 - if embeddings.ndim == 1: - embeddings = embeddings.reshape(1, -1) - - self._total_encoded += len(texts) - - # 如果是单个输入,返回1D数组 - if single_input: - return embeddings[0] - - return embeddings - - except Exception as e: - logger.error(f"生成嵌入失败: {e}") - raise - - def encode_batch( - self, - texts: List[str], - batch_size: Optional[int] = None, - num_workers: Optional[int] = None, - show_progress: bool = False, - ) -> np.ndarray: - """ - 批量生成嵌入(多线程优化) - - Args: - texts: 文本列表 - batch_size: 批次大小 - num_workers: 工作线程数(默认使用初始化时的值) - show_progress: 是否显示进度条 - - Returns: - 嵌入向量 (N x D) - """ - if not texts: - return np.zeros((0, self.config.dimension), dtype=np.float32) - - # 单线程模式 - num_workers = num_workers if num_workers is not None else self.num_workers - if num_workers == 1: - return self.encode(texts, batch_size=batch_size, show_progress=show_progress) - - # 多线程模式 - logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入") - - # 分批 - batch_size = batch_size or self.config.batch_size - batches = [ - texts[i:i + batch_size] - for i in range(0, len(texts), batch_size) - ] - - # 多线程生成 - all_embeddings = [] - with ThreadPoolExecutor(max_workers=num_workers) as executor: - # 提交任务 - future_to_batch = { - executor.submit( - self.encode, - batch, - batch_size, - False, # 不显示进度条(多线程时会混乱) - ): i - for i, batch in enumerate(batches) - } - - # 收集结果 - for future in as_completed(future_to_batch): - batch_idx = future_to_batch[future] - try: - embeddings = future.result() - all_embeddings.append((batch_idx, embeddings)) - except Exception as e: - logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}") - raise - - # 按顺序合并 - all_embeddings.sort(key=lambda x: x[0]) - final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0) - - return final_embeddings - - def encode_with_cache( - self, - texts: List[str], - batch_size: Optional[int] = None, - show_progress: bool = False, - ) -> np.ndarray: - """ - 生成嵌入(带缓存) - - Args: - texts: 文本列表 - batch_size: 批次大小 - show_progress: 是否显示进度条 - - Returns: - 嵌入向量 (N x D) - """ - if not self.enable_cache: - return self.encode(texts, batch_size, show_progress) - - # 分离缓存命中和未命中的文本 - cached_embeddings = [] - uncached_texts = [] - uncached_indices = [] - - for i, text in enumerate(texts): - cache_key = self._get_cache_key(text) - - with self._cache_lock: - if cache_key in self._embedding_cache: - cached_embeddings.append((i, self._embedding_cache[cache_key])) - self._cache_hits += 1 - else: - uncached_texts.append(text) - uncached_indices.append(i) - self._cache_misses += 1 - - # 生成未缓存的嵌入 - if uncached_texts: - new_embeddings = self.encode( - uncached_texts, - batch_size, - show_progress, - ) - - # 更新缓存 - with self._cache_lock: - for text, embedding in zip(uncached_texts, new_embeddings): - cache_key = self._get_cache_key(text) - self._embedding_cache[cache_key] = embedding.copy() - - # 合并结果 - for idx, embedding in zip(uncached_indices, new_embeddings): - cached_embeddings.append((idx, embedding)) - - # 按原始顺序排序 - cached_embeddings.sort(key=lambda x: x[0]) - final_embeddings = np.array([emb for _, emb in cached_embeddings]) - - return final_embeddings - - def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: - """ - 保存缓存到磁盘 - - Args: - cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) - """ - if cache_path is None: - if self.cache_dir is None: - raise ValueError("未指定缓存目录") - cache_path = self.cache_dir / "embeddings_cache.pkl" - - cache_path = Path(cache_path) - cache_path.parent.mkdir(parents=True, exist_ok=True) - - with self._cache_lock: - with open(cache_path, "wb") as f: - pickle.dump(self._embedding_cache, f) - - logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)") - - def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: - """ - 从磁盘加载缓存 - - Args: - cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) - """ - if cache_path is None: - if self.cache_dir is None: - raise ValueError("未指定缓存目录") - cache_path = self.cache_dir / "embeddings_cache.pkl" - - cache_path = Path(cache_path) - if not cache_path.exists(): - logger.warning(f"缓存文件不存在: {cache_path}") - return - - with self._cache_lock: - with open(cache_path, "rb") as f: - self._embedding_cache = pickle.load(f) - - logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)") - - def clear_cache(self) -> None: - """清空缓存""" - with self._cache_lock: - count = len(self._embedding_cache) - self._embedding_cache.clear() - logger.info(f"已清空缓存: {count} 条") - - def check_model_consistency( - self, - stored_embeddings: np.ndarray, - sample_texts: List[str] = None, - ) -> Tuple[bool, str]: - """ - 检查模型一致性 - - Args: - stored_embeddings: 存储的嵌入向量 - sample_texts: 样本文本(用于重新生成对比) - - Returns: - (是否一致, 详细信息) - """ - # 检查维度 - if stored_embeddings.shape[1] != self.config.dimension: - return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}" - - # 如果提供了样本文本,重新生成并比较 - if sample_texts: - try: - new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个 - - # 计算相似度 - similarities = np.dot( - stored_embeddings[:5], - new_embeddings.T, - ).diagonal() - - # 检查相似度 - if np.mean(similarities) < 0.95: - return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}" - - return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}" - - except Exception as e: - return False, f"一致性检查失败: {e}" - - return True, "维度匹配" - - def get_model_info(self) -> Dict[str, Any]: - """ - 获取模型信息 - - Returns: - 模型信息字典 - """ - return { - "model_name": self.config.model_name, - "dimension": self.config.dimension, - "max_seq_length": self.config.max_seq_length, - "batch_size": self.config.batch_size, - "normalization": self.config.normalization, - "pooling": self.config.pooling, - "model_loaded": self._model is not None, - "cache_enabled": self.enable_cache, - "cache_size": len(self._embedding_cache), - "total_encoded": self._total_encoded, - "cache_hits": self._cache_hits, - "cache_misses": self._cache_misses, - } - - def get_embedding_dimension(self) -> int: - """获取嵌入维度""" - return self.config.dimension - - def _get_cache_key(self, text: str) -> str: - """ - 生成缓存键 - - Args: - text: 文本内容 - - Returns: - 缓存键(SHA256哈希) - """ - return hashlib.sha256(text.encode("utf-8")).hexdigest() - - @property - def is_model_loaded(self) -> bool: - """模型是否已加载""" - return self._model is not None - - @property - def cache_hit_rate(self) -> float: - """缓存命中率""" - total = self._cache_hits + self._cache_misses - if total == 0: - return 0.0 - return self._cache_hits / total - - def __repr__(self) -> str: - return ( - f"EmbeddingManager(model={self.config.model_name}, " - f"dim={self.config.dimension}, " - f"loaded={self.is_model_loaded}, " - f"cache={len(self._embedding_cache)})" - ) - - - - -def create_embedding_manager_from_config( - model_name: str, - model_path: str, - dimension: int, - cache_dir: Optional[Union[str, Path]] = None, - enable_cache: bool = True, - num_workers: int = 1, - **config_kwargs, -) -> EmbeddingManager: - """ - 从自定义配置创建嵌入管理器 - - Args: - model_name: 模型名称 - model_path: HuggingFace模型路径 - dimension: 输出维度 - cache_dir: 缓存目录 - enable_cache: 是否启用缓存 - num_workers: 工作线程数 - **config_kwargs: 其他配置参数 - - Returns: - 嵌入管理器实例 - """ - # 创建自定义配置 - config = get_custom_config( - model_name=model_name, - model_path=model_path, - dimension=dimension, - cache_dir=cache_dir, - **config_kwargs, - ) - - # 创建管理器 - return EmbeddingManager( - config=config, - cache_dir=cache_dir, - enable_cache=enable_cache, - num_workers=num_workers, - ) diff --git a/plugins/A_memorix/core/embedding/presets.py b/plugins/A_memorix/core/embedding/presets.py deleted file mode 100644 index 54e6f8b4..00000000 --- a/plugins/A_memorix/core/embedding/presets.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -嵌入模型配置模块 -""" - -from dataclasses import dataclass -from typing import Optional, Dict, Any, Union -from pathlib import Path - - -@dataclass -class EmbeddingModelConfig: - """ - 嵌入模型配置 - - 属性: - model_name: 模型描述名称 - model_path: 实际加载路径(Local or HF) - dimension: 嵌入向量维度 - max_seq_length: 最大序列长度 - batch_size: 编码批次大小 - model_size_mb: 估计显存占用 - description: 模型说明 - normalization: 是否自动归一化 - pooling: 池化策略 (mean, cls, max) - cache_dir: 模型缓存目录 - """ - - model_name: str - model_path: str - dimension: int - max_seq_length: int = 512 - batch_size: int = 32 - model_size_mb: int = 100 - description: str = "" - normalization: bool = True - pooling: str = "mean" - cache_dir: Optional[Union[str, Path]] = None - - -def validate_config_compatibility( - config1: EmbeddingModelConfig, config2: EmbeddingModelConfig -) -> bool: - """检查两个配置是否兼容(主要看维度)""" - return config1.dimension == config2.dimension - - -def are_models_compatible( - config1: EmbeddingModelConfig, config2: EmbeddingModelConfig -) -> bool: - """检查模型是否完全相同(用于热切换判断)""" - return ( - config1.model_path == config2.model_path - and config1.dimension == config2.dimension - and config1.pooling == config2.pooling - ) - - -def get_custom_config( - model_name: str, - model_path: str, - dimension: int, - cache_dir: Optional[Union[str, Path]] = None, - **kwargs, -) -> EmbeddingModelConfig: - """创建自定义模型配置""" - return EmbeddingModelConfig( - model_name=model_name, - model_path=model_path, - dimension=dimension, - cache_dir=cache_dir, - **kwargs, - ) diff --git a/plugins/A_memorix/core/retrieval/__init__.py b/plugins/A_memorix/core/retrieval/__init__.py deleted file mode 100644 index 6efce7f6..00000000 --- a/plugins/A_memorix/core/retrieval/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -"""检索模块 - 双路检索与排序""" - -from .dual_path import ( - DualPathRetriever, - RetrievalStrategy, - RetrievalResult, - DualPathRetrieverConfig, - TemporalQueryOptions, - FusionConfig, - RelationIntentConfig, -) -from .pagerank import ( - PersonalizedPageRank, - PageRankConfig, - create_ppr_from_graph, -) -from .threshold import ( - DynamicThresholdFilter, - ThresholdMethod, - ThresholdConfig, -) -from .sparse_bm25 import ( - SparseBM25Index, - SparseBM25Config, -) -from .graph_relation_recall import ( - GraphRelationRecallConfig, - GraphRelationRecallService, -) - -__all__ = [ - # DualPathRetriever - "DualPathRetriever", - "RetrievalStrategy", - "RetrievalResult", - "DualPathRetrieverConfig", - "TemporalQueryOptions", - "FusionConfig", - "RelationIntentConfig", - # PersonalizedPageRank - "PersonalizedPageRank", - "PageRankConfig", - "create_ppr_from_graph", - # DynamicThresholdFilter - "DynamicThresholdFilter", - "ThresholdMethod", - "ThresholdConfig", - # Sparse BM25 - "SparseBM25Index", - "SparseBM25Config", - # Graph relation recall - "GraphRelationRecallConfig", - "GraphRelationRecallService", -] diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py deleted file mode 100644 index 6ed5e71a..00000000 --- a/plugins/A_memorix/core/retrieval/dual_path.py +++ /dev/null @@ -1,1796 +0,0 @@ -""" -双路检索器 - -同时检索关系和段落,实现知识图谱增强的检索。 -""" - -import asyncio -import re -from dataclasses import dataclass, field -from typing import Optional, List, Dict, Any, Tuple, Union -from enum import Enum - -import numpy as np - -from src.common.logger import get_logger -from ..storage import VectorStore, GraphStore, MetadataStore -from ..embedding import EmbeddingAPIAdapter -from ..utils.matcher import AhoCorasick -from ..utils.time_parser import format_timestamp -from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService -from .pagerank import PersonalizedPageRank, PageRankConfig -from .sparse_bm25 import SparseBM25Config, SparseBM25Index - -logger = get_logger("A_Memorix.DualPathRetriever") - - -class RetrievalStrategy(Enum): - """检索策略""" - - PARA_ONLY = "paragraph_only" # 仅段落检索 - REL_ONLY = "relation_only" # 仅关系检索 - DUAL_PATH = "dual_path" # 双路检索(推荐) - - -@dataclass -class RetrievalResult: - """ - 检索结果 - - 属性: - hash_value: 哈希值 - content: 内容(段落或关系) - score: 相似度分数 - result_type: 结果类型(paragraph/relation) - source: 来源(paragraph_search/relation_search/fusion) - metadata: 额外元数据 - """ - - hash_value: str - content: str - score: float - result_type: str # "paragraph" or "relation" - source: str # "paragraph_search", "relation_search", "fusion" - metadata: Dict[str, Any] - - def to_dict(self) -> Dict[str, Any]: - """转换为字典""" - return { - "hash": self.hash_value, - "content": self.content, - "score": self.score, - "type": self.result_type, - "source": self.source, - "metadata": self.metadata, - } - - -@dataclass -class DualPathRetrieverConfig: - """ - 双路检索器配置 - - 属性: - top_k_paragraphs: 段落检索数量 - top_k_relations: 关系检索数量 - top_k_final: 最终返回数量 - alpha: 段落和关系的融合权重(0-1) - - 0: 仅使用关系分数 - - 1: 仅使用段落分数 - - 0.5: 平均融合 - enable_ppr: 是否启用PageRank重排序 - ppr_alpha: PageRank的alpha参数 - ppr_concurrency_limit: PPR计算的最大并发数 - enable_parallel: 是否并行检索 - retrieval_strategy: 检索策略 - debug: 是否启用调试模式(打印搜索结果原文) - """ - - top_k_paragraphs: int = 20 - top_k_relations: int = 10 - top_k_final: int = 10 - alpha: float = 0.5 # 融合权重 - enable_ppr: bool = True - ppr_alpha: float = 0.85 - ppr_timeout_seconds: float = 1.5 - ppr_concurrency_limit: int = 4 - enable_parallel: bool = True - retrieval_strategy: RetrievalStrategy = RetrievalStrategy.DUAL_PATH - debug: bool = False - sparse: SparseBM25Config = field(default_factory=SparseBM25Config) - fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig()) - relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig()) - graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig) - - def __post_init__(self): - """验证配置""" - if isinstance(self.sparse, dict): - self.sparse = SparseBM25Config(**self.sparse) - if isinstance(self.fusion, dict): - self.fusion = FusionConfig(**self.fusion) - if isinstance(self.relation_intent, dict): - self.relation_intent = RelationIntentConfig(**self.relation_intent) - if isinstance(self.graph_recall, dict): - self.graph_recall = GraphRelationRecallConfig(**self.graph_recall) - - if not 0 <= self.alpha <= 1: - raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}") - - if self.top_k_paragraphs <= 0: - raise ValueError(f"top_k_paragraphs必须大于0: {self.top_k_paragraphs}") - - if self.top_k_relations <= 0: - raise ValueError(f"top_k_relations必须大于0: {self.top_k_relations}") - - if self.top_k_final <= 0: - raise ValueError(f"top_k_final必须大于0: {self.top_k_final}") - if self.ppr_timeout_seconds <= 0: - raise ValueError(f"ppr_timeout_seconds必须大于0: {self.ppr_timeout_seconds}") - - -@dataclass -class TemporalQueryOptions: - """时序查询选项。""" - - time_from: Optional[float] = None - time_to: Optional[float] = None - person: Optional[str] = None - source: Optional[str] = None - allow_created_fallback: bool = True - candidate_multiplier: int = 8 - max_scan: int = 1000 - - -@dataclass -class RelationIntentConfig: - """关系意图增强配置。""" - - enabled: bool = True - alpha_override: float = 0.35 - relation_candidate_multiplier: int = 4 - preserve_top_relations: int = 3 - force_relation_sparse: bool = True - pair_predicate_rerank_enabled: bool = True - pair_predicate_limit: int = 3 - - def __post_init__(self): - self.alpha_override = min(1.0, max(0.0, float(self.alpha_override))) - self.relation_candidate_multiplier = max(1, int(self.relation_candidate_multiplier)) - self.preserve_top_relations = max(0, int(self.preserve_top_relations)) - self.force_relation_sparse = bool(self.force_relation_sparse) - self.pair_predicate_rerank_enabled = bool(self.pair_predicate_rerank_enabled) - self.pair_predicate_limit = max(1, int(self.pair_predicate_limit)) - - -@dataclass -class FusionConfig: - """融合配置。""" - - method: str = "weighted_rrf" # weighted_rrf | alpha_legacy - rrf_k: int = 60 - vector_weight: float = 0.7 - bm25_weight: float = 0.3 - normalize_score: bool = True - normalize_method: str = "minmax" - - def __post_init__(self): - self.method = str(self.method or "weighted_rrf").strip().lower() - self.normalize_method = str(self.normalize_method or "minmax").strip().lower() - self.rrf_k = max(1, int(self.rrf_k)) - self.vector_weight = max(0.0, float(self.vector_weight)) - self.bm25_weight = max(0.0, float(self.bm25_weight)) - s = self.vector_weight + self.bm25_weight - if s <= 0: - self.vector_weight = 0.7 - self.bm25_weight = 0.3 - elif abs(s - 1.0) > 1e-8: - self.vector_weight /= s - self.bm25_weight /= s - - -class DualPathRetriever: - """ - 双路检索器 - - 功能: - - 并行检索段落和关系 - - 结果融合与排序 - - PageRank重排序 - - 实体识别与加权 - - 参数: - vector_store: 向量存储 - graph_store: 图存储 - metadata_store: 元数据存储 - embedding_manager: 嵌入管理器 - config: 检索配置 - """ - - def __init__( - self, - vector_store: VectorStore, - graph_store: GraphStore, - metadata_store: MetadataStore, - embedding_manager: EmbeddingAPIAdapter, - sparse_index: Optional[SparseBM25Index] = None, - config: Optional[DualPathRetrieverConfig] = None, - ): - """ - 初始化双路检索器 - - Args: - vector_store: 向量存储 - graph_store: 图存储 - metadata_store: 元数据存储 - embedding_manager: 嵌入管理器 - config: 检索配置 - """ - self.vector_store = vector_store - self.graph_store = graph_store - self.metadata_store = metadata_store - self.embedding_manager = embedding_manager - self.config = config or DualPathRetrieverConfig() - self.sparse_index = sparse_index - - # PageRank计算器 - ppr_config = PageRankConfig(alpha=self.config.ppr_alpha) - self._ppr = PersonalizedPageRank( - graph_store=graph_store, - config=ppr_config, - ) - self._ppr_semaphore = asyncio.Semaphore(self.config.ppr_concurrency_limit) - self._graph_relation_recall = GraphRelationRecallService( - graph_store=graph_store, - metadata_store=metadata_store, - config=self.config.graph_recall, - ) - - logger.info( - f"DualPathRetriever 初始化: " - f"strategy={self.config.retrieval_strategy.value}, " - f"top_k_para={self.config.top_k_paragraphs}, " - f"top_k_rel={self.config.top_k_relations}" - ) - - # 缓存 Aho-Corasick 匹配器 - self._ac_matcher: Optional[AhoCorasick] = None - self._ac_nodes_count = 0 - self._relation_intent_pattern = re.compile( - r"(什么关系|有哪些关系|和.+关系|关联|关系网|subject|predicate|object|" - r"relation|related|between.+and)", - re.IGNORECASE, - ) - - async def retrieve( - self, - query: str, - top_k: Optional[int] = None, - strategy: Optional[RetrievalStrategy] = None, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 执行检索(异步方法) - - Args: - query: 查询文本 - top_k: 返回结果数量(默认使用配置值) - strategy: 检索策略(默认使用配置值) - temporal: 时序查询选项(可选) - - Returns: - 检索结果列表 - """ - top_k = top_k or self.config.top_k_final - strategy = strategy or self.config.retrieval_strategy - relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) - - logger.info( - "执行检索: " - f"query='{query[:50]}...', " - f"strategy={strategy.value}, " - f"relation_intent={relation_intent_ctx.get('enabled', False)}" - ) - - if temporal and not (query or "").strip(): - return self._retrieve_temporal_only(temporal, top_k) - - # 根据策略执行检索 - if strategy == RetrievalStrategy.PARA_ONLY: - results = await self._retrieve_paragraphs_only(query, top_k, temporal=temporal) - elif strategy == RetrievalStrategy.REL_ONLY: - results = await self._retrieve_relations_only(query, top_k, temporal=temporal) - else: # DUAL_PATH - results = await self._retrieve_dual_path( - query, - top_k, - temporal=temporal, - relation_intent=relation_intent_ctx, - ) - - logger.info(f"检索完成: 返回 {len(results)} 条结果") - - # 调试模式:打印结果原文 - if self.config.debug: - logger.info(f"[DEBUG] 检索结果内容原文:") - for i, res in enumerate(results): - logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}") - - return results - - def _is_relation_intent_query(self, query: str) -> bool: - q = str(query or "").strip() - if not q: - return False - if "|" in q or "->" in q: - return True - return self._relation_intent_pattern.search(q) is not None - - def _build_relation_intent_context(self, query: str, top_k: int) -> Dict[str, Any]: - cfg = self.config.relation_intent - enabled = bool(cfg.enabled) and self._is_relation_intent_query(query) - base_relation_k = max(1, int(self.config.top_k_relations)) - relation_top_k = max(base_relation_k, int(top_k)) - if enabled: - relation_top_k = max( - relation_top_k, - relation_top_k * int(cfg.relation_candidate_multiplier), - ) - return { - "enabled": enabled, - "alpha_override": float(cfg.alpha_override) if enabled else None, - "relation_top_k": int(relation_top_k), - "preserve_top_relations": int(cfg.preserve_top_relations) if enabled else 0, - "force_relation_sparse": bool(cfg.force_relation_sparse) if enabled else False, - "pair_predicate_rerank_enabled": bool(cfg.pair_predicate_rerank_enabled) if enabled else False, - "pair_predicate_limit": int(cfg.pair_predicate_limit) if enabled else 0, - } - - def _cap_temporal_scan_k( - self, - candidate_k: int, - temporal: Optional[TemporalQueryOptions], - ) -> int: - """对 temporal 模式候选召回数应用 max_scan 上限。""" - k = max(1, int(candidate_k)) - if temporal and temporal.max_scan and temporal.max_scan > 0: - k = min(k, int(temporal.max_scan)) - return max(1, k) - - def _is_valid_embedding(self, emb: Optional[np.ndarray]) -> bool: - if emb is None: - return False - arr = np.asarray(emb, dtype=np.float32) - if arr.ndim == 0 or arr.size == 0: - return False - return bool(np.all(np.isfinite(arr))) - - def _get_embedding_dim(self, emb: Optional[np.ndarray]) -> Optional[int]: - if emb is None: - return None - arr = np.asarray(emb) - if arr.ndim == 1: - return int(arr.shape[0]) if arr.size > 0 else None - if arr.ndim == 2: - if arr.shape[0] == 0: - return None - return int(arr.shape[1]) - return None - - def _is_embedding_dimension_compatible(self, emb: Optional[np.ndarray]) -> bool: - got_dim = self._get_embedding_dim(emb) - expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) - if got_dim is None or expected_dim <= 0: - return False - return got_dim == expected_dim - - def _is_embedding_ready_for_vector_search( - self, - emb: Optional[np.ndarray], - *, - stage: str, - ) -> bool: - if not self._is_valid_embedding(emb): - return False - if self._is_embedding_dimension_compatible(emb): - return True - - expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) - got_dim = self._get_embedding_dim(emb) - logger.warning( - "metric.embedding_dim_mismatch_fallback_count=1 " - f"stage={stage} expected_dim={expected_dim} got_dim={got_dim}" - ) - return False - - def _should_use_sparse( - self, - embedding_ok: bool, - vector_results: Optional[List[RetrievalResult]] = None, - ) -> bool: - if not self.config.sparse.enabled or self.sparse_index is None: - return False - - mode = self.config.sparse.mode - if mode == "hybrid": - return True - if mode == "fallback_only": - return not embedding_ok - # auto - if not embedding_ok: - return True - if not vector_results: - return True - best = max((float(r.score) for r in vector_results), default=0.0) - return best < 0.45 - - def _should_use_sparse_relations( - self, - embedding_ok: bool, - relation_results: Optional[List[RetrievalResult]] = None, - force_enable: bool = False, - ) -> bool: - if force_enable and self.config.sparse.enabled and self.sparse_index is not None: - return True - if not self.config.sparse.enable_relation_sparse_fallback: - return False - return self._should_use_sparse(embedding_ok, relation_results) - - def _normalize_scores_minmax(self, results: List[RetrievalResult]) -> None: - if not results: - return - vals = [float(r.score) for r in results] - lo = min(vals) - hi = max(vals) - if hi - lo < 1e-12: - for r in results: - r.score = 1.0 - return - for r in results: - r.score = (float(r.score) - lo) / (hi - lo) - - def _build_minmax_score_map(self, results: List[RetrievalResult]) -> Dict[str, float]: - if not results: - return {} - vals = [float(r.score) for r in results] - lo = min(vals) - hi = max(vals) - if hi - lo < 1e-12: - return {r.hash_value: 1.0 for r in results} - return { - r.hash_value: (float(r.score) - lo) / (hi - lo) - for r in results - } - - @staticmethod - def _clone_retrieval_result(item: RetrievalResult) -> RetrievalResult: - return RetrievalResult( - hash_value=item.hash_value, - content=item.content, - score=float(item.score), - result_type=item.result_type, - source=item.source, - metadata=dict(item.metadata or {}), - ) - - def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: - entities = self._extract_entities(query) - if not entities: - return [] - ranked = sorted( - entities.items(), - key=lambda x: (-float(x[1]), -len(str(x[0])), str(x[0]).lower()), - ) - return [str(name) for name, _ in ranked[: max(0, int(limit))]] - - def _search_relations_graph( - self, - query: str, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - service = getattr(self, "_graph_relation_recall", None) - if service is None or not bool(getattr(self.config.graph_recall, "enabled", True)): - return [] - - seed_entities = self._extract_graph_seed_entities(query, limit=2) - if not seed_entities: - return [] - - payloads = service.recall(seed_entities=seed_entities) - results: List[RetrievalResult] = [] - for payload in payloads: - meta = payload.to_payload() - results.append( - RetrievalResult( - hash_value=str(meta["hash"]), - content=str(meta["content"]), - score=0.0, - result_type="relation", - source="graph_relation_recall", - metadata={ - "subject": meta["subject"], - "predicate": meta["predicate"], - "object": meta["object"], - "confidence": float(meta["confidence"]), - "graph_seed_entities": list(meta["graph_seed_entities"]), - "graph_hops": int(meta["graph_hops"]), - "graph_candidate_type": str(meta["graph_candidate_type"]), - "supporting_paragraph_count": int(meta["supporting_paragraph_count"]), - }, - ) - ) - return self._apply_temporal_filter_to_relations(results, temporal) - - def _fuse_ranked_lists_weighted_rrf( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """按 weighted RRF 融合两路段落召回。""" - if not vector_results: - out = sparse_results[:] - if self.config.fusion.normalize_score: - self._normalize_scores_minmax(out) - return out - if not sparse_results: - out = vector_results[:] - if self.config.fusion.normalize_score: - self._normalize_scores_minmax(out) - return out - - k = self.config.fusion.rrf_k - w_vec = self.config.fusion.vector_weight - w_sparse = self.config.fusion.bm25_weight - merged: Dict[str, RetrievalResult] = {} - score_map: Dict[str, float] = {} - - for rank, item in enumerate(vector_results, start=1): - h = item.hash_value - if h not in merged: - merged[h] = item - merged[h].source = "fusion_rrf" - score_map[h] = score_map.get(h, 0.0) + w_vec * (1.0 / (k + rank)) - - for rank, item in enumerate(sparse_results, start=1): - h = item.hash_value - if h not in merged: - merged[h] = item - merged[h].source = "fusion_rrf" - score_map[h] = score_map.get(h, 0.0) + w_sparse * (1.0 / (k + rank)) - - out = list(merged.values()) - for item in out: - item.score = float(score_map.get(item.hash_value, 0.0)) - - out.sort(key=lambda x: x.score, reverse=True) - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(out) - return out - - def _search_paragraphs_sparse( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """BM25 段落召回。""" - if not self.sparse_index or not self.config.sparse.enabled: - return [] - - candidate_k = max(top_k, self.config.sparse.candidate_k) - candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) - sparse_rows = self.sparse_index.search(query=query, k=candidate_k) - results: List[RetrievalResult] = [] - for row in sparse_rows: - hash_value = row["hash"] - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - results.append( - RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(row.get("score", 0.0)), - result_type="paragraph", - source="sparse_bm25", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - "bm25_score": float(row.get("bm25_score", 0.0)), - }, - ) - ) - results = self._apply_temporal_filter_to_paragraphs(results, temporal) - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(results) - return results - - def _search_relations_sparse( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """关系 BM25 召回。""" - if not self.sparse_index or not self.config.sparse.enabled: - return [] - if not self.config.sparse.enable_relation_sparse_fallback: - return [] - - candidate_k = max(top_k, self.config.sparse.relation_candidate_k) - candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) - rows = self.sparse_index.search_relations(query=query, k=candidate_k) - results: List[RetrievalResult] = [] - for row in rows: - hash_value = row["hash"] - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - continue - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) - if relation_time_meta is None: - continue - - content = f"{relation['subject']} {relation['predicate']} {relation['object']}" - results.append( - RetrievalResult( - hash_value=hash_value, - content=content, - score=float(row.get("score", 0.0)), - result_type="relation", - source="sparse_relation_bm25", - metadata={ - "subject": relation["subject"], - "predicate": relation["predicate"], - "object": relation["object"], - "confidence": relation.get("confidence", 1.0), - "time_meta": relation_time_meta, - "bm25_score": float(row.get("bm25_score", 0.0)), - }, - ) - ) - - if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": - self._normalize_scores_minmax(results) - return self._apply_temporal_filter_to_relations(results, temporal) - - def _merge_relation_results( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """合并关系候选,按 hash 去重并保留更高分。""" - merged: Dict[str, RetrievalResult] = {} - for item in vector_results: - merged[item.hash_value] = item - for item in sparse_results: - old = merged.get(item.hash_value) - if old is None or float(item.score) > float(old.score): - merged[item.hash_value] = item - elif old is not None and old.source != item.source: - old.source = "relation_fusion" - out = list(merged.values()) - out.sort(key=lambda x: x.score, reverse=True) - return out - - def _merge_relation_results_graph_enhanced( - self, - vector_results: List[RetrievalResult], - sparse_results: List[RetrievalResult], - graph_results: List[RetrievalResult], - ) -> List[RetrievalResult]: - """Graph-aware relation fusion with semantic + graph + evidence scoring.""" - vector_norm = self._build_minmax_score_map(vector_results) - sparse_norm = self._build_minmax_score_map(sparse_results) - graph_score_map = { - "direct_pair": 1.0, - "one_hop_seed": 0.75, - "two_hop_pair": 0.55, - } - - merged: Dict[str, RetrievalResult] = {} - source_sets: Dict[str, set[str]] = {} - support_cache: Dict[str, int] = {} - - for group in (vector_results, sparse_results, graph_results): - for item in group: - existing = merged.get(item.hash_value) - if existing is None: - existing = self._clone_retrieval_result(item) - merged[item.hash_value] = existing - else: - for key, value in dict(item.metadata or {}).items(): - if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): - existing.metadata[key] = value - source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") - - out = list(merged.values()) - for item in out: - meta = item.metadata if isinstance(item.metadata, dict) else {} - semantic_norm = max( - float(vector_norm.get(item.hash_value, 0.0)), - float(sparse_norm.get(item.hash_value, 0.0)), - ) - graph_candidate_type = str(meta.get("graph_candidate_type", "") or "") - graph_score = float(graph_score_map.get(graph_candidate_type, 0.0)) - - if item.hash_value not in support_cache: - cached = meta.get("supporting_paragraph_count") - if cached is None: - support_cache[item.hash_value] = len( - self.metadata_store.get_paragraphs_by_relation(item.hash_value) - ) - else: - support_cache[item.hash_value] = max(0, int(cached)) - supporting_paragraph_count = support_cache[item.hash_value] - evidence_score = min(1.0, supporting_paragraph_count / 3.0) - - meta["supporting_paragraph_count"] = supporting_paragraph_count - meta["graph_seed_entities"] = list(meta.get("graph_seed_entities") or []) - if "graph_hops" in meta: - meta["graph_hops"] = int(meta.get("graph_hops") or 0) - item.score = 0.60 * semantic_norm + 0.30 * graph_score + 0.10 * evidence_score - - sources = source_sets.get(item.hash_value, set()) - if len(sources) > 1: - item.source = "relation_fusion" - elif sources: - item.source = next(iter(sources)) - - out.sort(key=lambda x: x.score, reverse=True) - return out - - async def _retrieve_paragraphs_only( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 仅检索段落(异步方法) - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 检索结果列表 - """ - query_emb = None - embedding_ok = False - vector_results: List[RetrievalResult] = [] - - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="paragraph_only", - ) - except Exception as e: - logger.warning(f"段落检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - if embedding_ok: - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * 2 * multiplier, temporal) - para_ids, para_scores = self.vector_store.search( - query_emb, # type: ignore[arg-type] - k=candidate_k, - ) - - for hash_value, score in zip(para_ids, para_scores): - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - vector_results.append( - RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(score), - result_type="paragraph", - source="paragraph_search", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - }, - ) - ) - vector_results = self._apply_temporal_filter_to_paragraphs(vector_results, temporal) - - sparse_results: List[RetrievalResult] = [] - if self._should_use_sparse(embedding_ok, vector_results): - sparse_results = self._search_paragraphs_sparse(query, top_k, temporal=temporal) - - if self.config.fusion.method == "weighted_rrf" and (vector_results and sparse_results): - results = self._fuse_ranked_lists_weighted_rrf(vector_results, sparse_results) - elif vector_results and sparse_results: - results = vector_results + sparse_results - results.sort(key=lambda x: x.score, reverse=True) - else: - results = vector_results if vector_results else sparse_results - - return results[:top_k] - - async def _retrieve_relations_only( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 仅检索关系 (通过实体枢纽 Entity-Pivot) - - 策略: - 1. 检索向量库中的 Top-K 实体 (Entity) - 2. 通过图结构/元数据扩展出与实体关联的关系 (Relation) - 3. 以实体相似度作为基础分返回关系 - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 检索结果列表 - """ - query_emb = None - embedding_ok = False - vector_results: List[RetrievalResult] = [] - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="relation_only", - ) - except Exception as e: - logger.warning(f"关系检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - if embedding_ok: - # 1. 检索向量 (混合了段落和实体,所以扩大检索范围以召回足够多实体) - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * 3 * multiplier, temporal) - ids, scores = self.vector_store.search( - query_emb, # type: ignore[arg-type] - k=candidate_k, - ) - - seen_relations = set() - for hash_value, score in zip(ids, scores): - entity = self.metadata_store.get_entity(hash_value) - if not entity: - continue - entity_name = entity["name"] - - related_rels = [] - related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) - related_rels.extend(self.metadata_store.get_relations(object=entity_name)) - - for rel in related_rels: - if rel["hash"] in seen_relations: - continue - seen_relations.add(rel["hash"]) - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(rel["hash"], temporal) - if relation_time_meta is None: - continue - - content = f"{rel['subject']} {rel['predicate']} {rel['object']}" - vector_results.append( - RetrievalResult( - hash_value=rel["hash"], - content=content, - score=float(score), - result_type="relation", - source="relation_search (via entity)", - metadata={ - "subject": rel["subject"], - "predicate": rel["predicate"], - "object": rel["object"], - "confidence": rel.get("confidence", 1.0), - "pivot_entity": entity_name, - "time_meta": relation_time_meta, - }, - ) - ) - - vector_results = self._apply_temporal_filter_to_relations(vector_results, temporal) - - sparse_results: List[RetrievalResult] = [] - if self._should_use_sparse_relations(embedding_ok, vector_results): - sparse_results = self._search_relations_sparse(query=query, top_k=top_k, temporal=temporal) - - graph_results = self._search_relations_graph(query=query, temporal=temporal) - if graph_results: - results = self._merge_relation_results_graph_enhanced( - vector_results, - sparse_results, - graph_results, - ) - elif vector_results and sparse_results: - results = self._merge_relation_results(vector_results, sparse_results) - else: - results = vector_results if vector_results else sparse_results - - return results[:top_k] - - async def _retrieve_dual_path( - self, - query: str, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - relation_intent: Optional[Dict[str, Any]] = None, - ) -> List[RetrievalResult]: - """ - 双路检索(段落+关系)(异步方法) - - Args: - query: 查询文本 - top_k: 返回数量 - - Returns: - 融合后的检索结果列表 - """ - query_emb = None - embedding_ok = False - relation_intent = relation_intent or {} - relation_top_k = max( - 1, - int(relation_intent.get("relation_top_k", self.config.top_k_relations)), - ) - force_relation_sparse = bool(relation_intent.get("force_relation_sparse", False)) - preserve_top_relations = max( - 0, - int(relation_intent.get("preserve_top_relations", 0)), - ) - pair_predicate_rerank_enabled = bool( - relation_intent.get("pair_predicate_rerank_enabled", False) - ) - pair_predicate_limit = max( - 1, - int( - relation_intent.get( - "pair_predicate_limit", - self.config.relation_intent.pair_predicate_limit, - ) - ), - ) - alpha_override = relation_intent.get("alpha_override") - try: - query_emb = await self.embedding_manager.encode(query) - embedding_ok = self._is_embedding_ready_for_vector_search( - query_emb, - stage="dual_path", - ) - except Exception as e: - logger.warning(f"双路检索 embedding 生成失败,将尝试 sparse 回退: {e}") - - para_results: List[RetrievalResult] = [] - rel_results: List[RetrievalResult] = [] - if embedding_ok: - # 并行检索(使用 asyncio) - if self.config.enable_parallel: - para_results, rel_results = await self._parallel_retrieve( - query_emb, - temporal=temporal, - relation_top_k=relation_top_k, - ) # type: ignore[arg-type] - else: - para_results, rel_results = self._sequential_retrieve( - query_emb, - temporal=temporal, - relation_top_k=relation_top_k, - ) # type: ignore[arg-type] - else: - logger.warning("embedding 不可用,跳过向量段落/关系召回") - - sparse_para_results: List[RetrievalResult] = [] - if self._should_use_sparse(embedding_ok, para_results): - sparse_para_results = self._search_paragraphs_sparse( - query=query, - top_k=max(top_k * 2, self.config.sparse.candidate_k), - temporal=temporal, - ) - sparse_rel_results: List[RetrievalResult] = [] - if self._should_use_sparse_relations( - embedding_ok, - rel_results, - force_enable=force_relation_sparse, - ): - sparse_rel_results = self._search_relations_sparse( - query=query, - top_k=max( - top_k, - self.config.sparse.relation_candidate_k, - relation_top_k, - ), - temporal=temporal, - ) - - graph_rel_results: List[RetrievalResult] = [] - if bool(relation_intent.get("enabled", False)): - graph_rel_results = self._search_relations_graph(query=query, temporal=temporal) - - if self.config.fusion.method == "weighted_rrf" and para_results and sparse_para_results: - para_results = self._fuse_ranked_lists_weighted_rrf(para_results, sparse_para_results) - elif para_results and sparse_para_results: - para_results = para_results + sparse_para_results - para_results.sort(key=lambda x: x.score, reverse=True) - elif sparse_para_results and (not para_results or not embedding_ok): - para_results = sparse_para_results - - if graph_rel_results: - rel_results = self._merge_relation_results_graph_enhanced( - rel_results, - sparse_rel_results, - graph_rel_results, - ) - elif rel_results and sparse_rel_results: - rel_results = self._merge_relation_results(rel_results, sparse_rel_results) - elif sparse_rel_results and (not rel_results or not embedding_ok): - rel_results = sparse_rel_results - - # 融合结果 - fused_results = self._fuse_results( - para_results, - rel_results, - query_emb, - alpha_override=alpha_override, - preserve_top_relations=preserve_top_relations, - ) - - # PageRank重排序 - if self.config.enable_ppr: - fused_results = await self._rerank_with_ppr( - fused_results, - query, - ) - - if temporal: - fused_results = self._sort_results_with_temporal(fused_results, temporal) - - fused_results = self._apply_relation_intent_pair_rerank( - fused_results, - enabled=bool(relation_intent.get("enabled", False)), - pair_rerank_enabled=pair_predicate_rerank_enabled, - pair_limit=pair_predicate_limit, - ) - - return fused_results[:top_k] - - async def _parallel_retrieve( - self, - query_emb: np.ndarray, - temporal: Optional[TemporalQueryOptions] = None, - relation_top_k: Optional[int] = None, - ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: - """ - 并行检索段落和关系(异步方法) - - Args: - query_emb: 查询嵌入 - - Returns: - (段落结果, 关系结果) - """ - # 使用 asyncio.gather 并发执行两个搜索任务 - # 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数, - # 使用 asyncio.to_thread 在线程池中执行 - try: - para_task = asyncio.to_thread( - self._search_paragraphs, - query_emb, - self.config.top_k_paragraphs, - temporal, - ) - rel_task = asyncio.to_thread( - self._search_relations, - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) - - para_results, rel_results = await asyncio.gather( - para_task, rel_task, return_exceptions=True - ) - - # 处理异常 - if isinstance(para_results, Exception): - logger.error(f"段落检索失败: {para_results}") - para_results = [] - if isinstance(rel_results, Exception): - logger.error(f"关系检索失败: {rel_results}") - rel_results = [] - - return para_results, rel_results - - except Exception as e: - logger.error(f"并行检索失败: {e}") - return [], [] - - def _sequential_retrieve( - self, - query_emb: np.ndarray, - temporal: Optional[TemporalQueryOptions] = None, - relation_top_k: Optional[int] = None, - ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: - """ - 顺序检索段落和关系 - - Args: - query_emb: 查询嵌入 - - Returns: - (段落结果, 关系结果) - """ - para_results = self._search_paragraphs( - query_emb, - self.config.top_k_paragraphs, - temporal, - ) - - rel_results = self._search_relations( - query_emb, - relation_top_k if relation_top_k is not None else self.config.top_k_relations, - temporal, - ) - - return para_results, rel_results - - def _search_paragraphs( - self, - query_emb: np.ndarray, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 搜索段落 - - Args: - query_emb: 查询嵌入 - top_k: 返回数量 - - Returns: - 段落结果列表 - """ - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) - para_ids, para_scores = self.vector_store.search(query_emb, k=candidate_k) - - results = [] - for hash_value, score in zip(para_ids, para_scores): - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - continue - - time_meta = self._build_time_meta_from_paragraph( - paragraph, - temporal=temporal, - ) - results.append(RetrievalResult( - hash_value=hash_value, - content=paragraph["content"], - score=float(score), - result_type="paragraph", - source="paragraph_search", - metadata={ - "word_count": paragraph.get("word_count", 0), - "time_meta": time_meta, - }, - )) - - return self._apply_temporal_filter_to_paragraphs(results, temporal) - - def _search_relations( - self, - query_emb: np.ndarray, - top_k: int, - temporal: Optional[TemporalQueryOptions] = None, - ) -> List[RetrievalResult]: - """ - 搜索关系 - - Args: - query_emb: 查询嵌入 - top_k: 返回数量 - - Returns: - 关系结果列表 - """ - multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 - candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) - rel_ids, rel_scores = self.vector_store.search(query_emb, k=candidate_k) - - results = [] - for hash_value, score in zip(rel_ids, rel_scores): - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - continue - - relation_time_meta = None - if temporal: - relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) - if relation_time_meta is None: - continue - - content = f"{relation['subject']} {relation['predicate']} {relation['object']}" - - results.append(RetrievalResult( - hash_value=hash_value, - content=content, - score=float(score), - result_type="relation", - source="relation_search", - metadata={ - "subject": relation["subject"], - "predicate": relation["predicate"], - "object": relation["object"], - "confidence": relation.get("confidence", 1.0), - "time_meta": relation_time_meta, - }, - )) - - return self._apply_temporal_filter_to_relations(results, temporal) - - def _fuse_results( - self, - para_results: List[RetrievalResult], - rel_results: List[RetrievalResult], - query_emb: Optional[np.ndarray] = None, - alpha_override: Optional[float] = None, - preserve_top_relations: int = 0, - ) -> List[RetrievalResult]: - """ - 融合段落和关系结果 - - 融合策略: - 1. 计算加权分数 - 2. 去重(基于段落和关系的关联) - 3. 排序 - - Args: - para_results: 段落结果 - rel_results: 关系结果 - query_emb: 查询嵌入(兼容参数,当前未使用) - - Returns: - 融合后的结果列表 - """ - del query_emb # 参数保留用于兼容 - alpha = float(alpha_override) if alpha_override is not None else self.config.alpha - - # 为段落结果计算加权分数 - for result in para_results: - result.score = result.score * alpha - result.source = "fusion" - - # 为关系结果计算加权分数 - for result in rel_results: - result.score = result.score * (1 - alpha) - result.source = "fusion" - - preserve_top_relations = max(0, int(preserve_top_relations)) - preserved_relation_hashes = set() - if preserve_top_relations > 0 and rel_results: - rel_ranked = sorted(rel_results, key=lambda x: x.score, reverse=True) - preserved_relation_hashes = { - item.hash_value for item in rel_ranked[:preserve_top_relations] - } - - # 合并结果 - all_results = para_results + rel_results - all_results.sort(key=lambda x: x.score, reverse=True) - - # 去重:如果段落有关联的关系,只保留分数更高的 - seen_paragraphs = set() - seen_items = set() - deduplicated_results = [] - - for result in all_results: - if result.hash_value in seen_items: - continue - if result.result_type == "paragraph": - hash_val = result.hash_value - if hash_val not in seen_paragraphs: - seen_paragraphs.add(hash_val) - seen_items.add(hash_val) - deduplicated_results.append(result) - else: # relation - if result.hash_value in preserved_relation_hashes: - seen_items.add(result.hash_value) - deduplicated_results.append(result) - continue - # 检查关系关联的段落是否已存在 - relation = self.metadata_store.get_relation(result.hash_value) - if relation: - # 获取关联的段落 - para_rels = self.metadata_store.query(""" - SELECT paragraph_hash FROM paragraph_relations - WHERE relation_hash = ? - """, (result.hash_value,)) - - if para_rels: - # 检查段落是否已在结果中 - for para_rel in para_rels: - if para_rel["paragraph_hash"] in seen_paragraphs: - # 段落已存在,跳过此关系 - break - else: - # 所有段落都不存在,添加关系 - seen_items.add(result.hash_value) - deduplicated_results.append(result) - else: - # 没有关联段落,直接添加 - seen_items.add(result.hash_value) - deduplicated_results.append(result) - else: - seen_items.add(result.hash_value) - deduplicated_results.append(result) - - # 按分数排序 - deduplicated_results.sort(key=lambda x: x.score, reverse=True) - - return deduplicated_results - - def _apply_relation_intent_pair_rerank( - self, - results: List[RetrievalResult], - *, - enabled: bool, - pair_rerank_enabled: bool, - pair_limit: int, - ) -> List[RetrievalResult]: - """仅在 relation-intent 下对关系项执行同主客体多谓词重排。""" - if not enabled or not pair_rerank_enabled: - return results - return self._rerank_relation_items_by_pair(results, pair_limit=pair_limit) - - def _rerank_relation_items_by_pair( - self, - results: List[RetrievalResult], - pair_limit: int, - ) -> List[RetrievalResult]: - """ - 同主客体多谓词重排: - 1. 关系项按 (subject, object) 分组 - 2. 组内按分数降序 + 原始位置升序 - 3. 组间按组最高分降序 + 组最早位置升序 - 4. 先拼接每组前 N 条,再拼接每组 overflow 条目 - 5. 回填到原关系槽位,段落槽位不变 - """ - if len(results) <= 1: - return results - - relation_positions: List[int] = [] - relation_items: List[Tuple[int, RetrievalResult]] = [] - for idx, item in enumerate(results): - if item.result_type == "relation": - relation_positions.append(idx) - relation_items.append((idx, item)) - - if len(relation_items) <= 1: - return results - - pair_limit = max(1, int(pair_limit)) - - grouped: Dict[Tuple[str, str], List[Tuple[int, RetrievalResult]]] = {} - for original_idx, item in relation_items: - metadata = item.metadata if isinstance(item.metadata, dict) else {} - subject = str(metadata.get("subject", "")).strip().lower() - obj = str(metadata.get("object", "")).strip().lower() - if subject and obj: - key = (subject, obj) - else: - key = ("__missing__", item.hash_value) - grouped.setdefault(key, []).append((original_idx, item)) - - for grouped_items in grouped.values(): - grouped_items.sort(key=lambda x: (-float(x[1].score), x[0])) - - ordered_groups = sorted( - grouped.values(), - key=lambda grouped_items: ( - -float(grouped_items[0][1].score), - grouped_items[0][0], - ), - ) - - prioritized: List[RetrievalResult] = [] - overflow: List[RetrievalResult] = [] - for grouped_items in ordered_groups: - prioritized.extend([item for _, item in grouped_items[:pair_limit]]) - overflow.extend([item for _, item in grouped_items[pair_limit:]]) - - reordered_relations = prioritized + overflow - if len(reordered_relations) != len(relation_items): - return results - - logger.debug( - "relation_rerank_applied=1 " - f"relation_pair_groups={len(ordered_groups)} " - f"relation_pair_overflow_count={len(overflow)} " - f"relation_pair_limit={pair_limit}" - ) - - rebuilt = list(results) - for slot_idx, relation_item in zip(relation_positions, reordered_relations): - rebuilt[slot_idx] = relation_item - return rebuilt - - async def _rerank_with_ppr( - self, - results: List[RetrievalResult], - query: str, - ) -> List[RetrievalResult]: - """ - 使用PageRank重排序结果 (异步 + 线程池) - - Args: - results: 检索结果 - query: 查询文本 - - Returns: - 重排序后的结果 - """ - # 从查询中提取实体 - entities = self._extract_entities(query) - - if not entities: - logger.debug("未识别到实体,跳过PPR重排序") - return results - - # 计算PPR分数 (放入线程池运行,避免阻塞主循环) - ppr_timeout_s = max(0.1, float(getattr(self.config, "ppr_timeout_seconds", 1.5) or 1.5)) - try: - async with self._ppr_semaphore: - ppr_scores = await asyncio.wait_for( - asyncio.to_thread( - self._ppr.compute, - personalization=entities, - normalize=True, - ), - timeout=ppr_timeout_s, - ) - except asyncio.TimeoutError: - logger.warning( - "metric.ppr_timeout_skip_count=1 " - f"timeout_s={ppr_timeout_s} " - f"entities={len(entities)}" - ) - return results - except Exception as e: - logger.warning(f"PPR 重排序失败,回退原排序: {e}") - return results - - # 调整结果分数 - ppr_scores_by_name = { - str(name).strip().lower(): float(score) - for name, score in ppr_scores.items() - } - for result in results: - if result.result_type == "paragraph": - # 获取段落的实体 - para_entities = self.metadata_store.get_paragraph_entities( - result.hash_value - ) - - # 计算实体的平均PPR分数 - if para_entities: - entity_scores = [] - for ent in para_entities: - ent_name = str(ent.get("name", "")).strip().lower() - if ent_name in ppr_scores_by_name: - entity_scores.append(ppr_scores_by_name[ent_name]) - - if entity_scores: - avg_ppr = np.mean(entity_scores) - # 融合原始分数和PPR分数 - result.score = result.score * 0.7 + avg_ppr * 0.3 - - # 重新排序 - results.sort(key=lambda x: x.score, reverse=True) - - return results - - def _retrieve_temporal_only( - self, - temporal: TemporalQueryOptions, - top_k: int, - ) -> List[RetrievalResult]: - """无语义 query 时,直接走时序索引查询。""" - limit = self._cap_temporal_scan_k( - top_k * max(1, temporal.candidate_multiplier), - temporal, - ) - paragraphs = self.metadata_store.query_paragraphs_temporal( - start_ts=temporal.time_from, - end_ts=temporal.time_to, - person=temporal.person, - source=temporal.source, - limit=limit, - allow_created_fallback=temporal.allow_created_fallback, - ) - results: List[RetrievalResult] = [] - for para in paragraphs: - time_meta = self._build_time_meta_from_paragraph(para, temporal=temporal) - results.append( - RetrievalResult( - hash_value=para["hash"], - content=para["content"], - score=1.0, - result_type="paragraph", - source="temporal_scan", - metadata={ - "word_count": para.get("word_count", 0), - "time_meta": time_meta, - }, - ) - ) - - results = self._sort_results_with_temporal(results, temporal) - return results[:top_k] - - def _extract_effective_time( - self, - paragraph: Dict[str, Any], - temporal: Optional[TemporalQueryOptions] = None, - ) -> Tuple[Optional[float], Optional[float], Optional[str]]: - """提取段落有效时间区间与命中依据。""" - event_time = paragraph.get("event_time") - event_start = paragraph.get("event_time_start") - event_end = paragraph.get("event_time_end") - - if event_start is not None or event_end is not None: - effective_start = event_start if event_start is not None else ( - event_time if event_time is not None else event_end - ) - effective_end = event_end if event_end is not None else ( - event_time if event_time is not None else event_start - ) - return effective_start, effective_end, "event_time_range" - - if event_time is not None: - return event_time, event_time, "event_time" - - allow_fallback = True - if temporal is not None: - allow_fallback = temporal.allow_created_fallback - - created_at = paragraph.get("created_at") - if allow_fallback and created_at is not None: - return created_at, created_at, "created_at_fallback" - - return None, None, None - - def _build_time_meta_from_paragraph( - self, - paragraph: Dict[str, Any], - temporal: Optional[TemporalQueryOptions] = None, - ) -> Dict[str, Any]: - """构建统一 time_meta 结构。""" - effective_start, effective_end, match_basis = self._extract_effective_time( - paragraph, - temporal=temporal, - ) - return { - "event_time": paragraph.get("event_time"), - "event_time_start": paragraph.get("event_time_start"), - "event_time_end": paragraph.get("event_time_end"), - "ingest_time": paragraph.get("created_at"), - "time_granularity": paragraph.get("time_granularity"), - "time_confidence": paragraph.get("time_confidence", 1.0), - "effective_start": effective_start, - "effective_end": effective_end, - "effective_start_text": format_timestamp(effective_start), - "effective_end_text": format_timestamp(effective_end), - "match_basis": match_basis or "none", - } - - def _matches_person_filter(self, paragraph_hash: str, person: Optional[str]) -> bool: - if not person: - return True - target = person.strip().lower() - if not target: - return True - para_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) - for ent in para_entities: - name = str(ent.get("name", "")).strip().lower() - if target in name: - return True - return False - - def _is_temporal_match( - self, - paragraph: Dict[str, Any], - temporal: TemporalQueryOptions, - ) -> bool: - """判断段落是否命中时序筛选。""" - if temporal.source and paragraph.get("source") != temporal.source: - return False - - if not self._matches_person_filter(paragraph.get("hash", ""), temporal.person): - return False - - effective_start, effective_end, _ = self._extract_effective_time(paragraph, temporal=temporal) - if effective_start is None or effective_end is None: - return False - - if temporal.time_from is not None and temporal.time_to is not None: - return effective_end >= temporal.time_from and effective_start <= temporal.time_to - if temporal.time_from is not None: - return effective_end >= temporal.time_from - if temporal.time_to is not None: - return effective_start <= temporal.time_to - return True - - def _apply_temporal_filter_to_paragraphs( - self, - results: List[RetrievalResult], - temporal: Optional[TemporalQueryOptions], - ) -> List[RetrievalResult]: - if not temporal: - return results - - filtered: List[RetrievalResult] = [] - for result in results: - paragraph = self.metadata_store.get_paragraph(result.hash_value) - if not paragraph: - continue - if not self._is_temporal_match(paragraph, temporal): - continue - result.metadata["time_meta"] = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) - filtered.append(result) - - return self._sort_results_with_temporal(filtered, temporal) - - def _best_supporting_time_meta( - self, - relation_hash: str, - temporal: TemporalQueryOptions, - ) -> Optional[Dict[str, Any]]: - """获取关系在时序窗口内最优支撑段落的 time_meta。""" - supports = self.metadata_store.get_paragraphs_by_relation(relation_hash) - if not supports: - return None - - best_meta: Optional[Dict[str, Any]] = None - best_time = float("-inf") - for para in supports: - if not self._is_temporal_match(para, temporal): - continue - meta = self._build_time_meta_from_paragraph(para, temporal=temporal) - eff = meta.get("effective_end") - score = float(eff) if eff is not None else float("-inf") - if score >= best_time: - best_time = score - best_meta = meta - - return best_meta - - def _apply_temporal_filter_to_relations( - self, - results: List[RetrievalResult], - temporal: Optional[TemporalQueryOptions], - ) -> List[RetrievalResult]: - if not temporal: - return results - - filtered: List[RetrievalResult] = [] - for result in results: - meta = result.metadata.get("time_meta") - if meta is None: - meta = self._best_supporting_time_meta(result.hash_value, temporal) - if meta is None: - continue - result.metadata["time_meta"] = meta - filtered.append(result) - - return self._sort_results_with_temporal(filtered, temporal) - - def _sort_results_with_temporal( - self, - results: List[RetrievalResult], - temporal: TemporalQueryOptions, - ) -> List[RetrievalResult]: - """语义优先,时间次排序(新到旧)。""" - del temporal # temporal 保留给未来扩展,目前只使用结果内 time_meta - - def _temporal_key(item: RetrievalResult) -> float: - time_meta = item.metadata.get("time_meta", {}) - effective = time_meta.get("effective_end") - if effective is None: - effective = time_meta.get("effective_start") - if effective is None: - return float("-inf") - return float(effective) - - results.sort(key=lambda x: (x.score, _temporal_key(x)), reverse=True) - return results - - def _extract_entities(self, text: str) -> Dict[str, float]: - """ - 从文本中提取实体(简化版本) - - Args: - text: 输入文本 - - Returns: - 实体字典 {实体名: 权重} - """ - # 获取所有实体 - all_entities = self.graph_store.get_nodes() - if not all_entities: - return {} - - # 检查是否需要更新 Aho-Corasick 匹配器 - if self._ac_matcher is None or self._ac_nodes_count != len(all_entities): - self._ac_matcher = AhoCorasick() - for entity in all_entities: - self._ac_matcher.add_pattern(entity.lower()) - self._ac_matcher.build() - self._ac_nodes_count = len(all_entities) - - # 执行匹配 - text_lower = text.lower() - stats = self._ac_matcher.find_all(text_lower) - - # 映射回原始名称并使用出现次数作为权重 - node_map = {node.lower(): node for node in all_entities} - entities = {node_map[low_name]: float(count) for low_name, count in stats.items()} - - return entities - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取检索统计信息 - - Returns: - 统计信息字典 - """ - vector_size = getattr(self.vector_store, "size", None) - if vector_size is None: - vector_size = getattr(self.vector_store, "num_vectors", 0) - - return { - "config": { - "top_k_paragraphs": self.config.top_k_paragraphs, - "top_k_relations": self.config.top_k_relations, - "top_k_final": self.config.top_k_final, - "alpha": self.config.alpha, - "enable_ppr": self.config.enable_ppr, - "enable_parallel": self.config.enable_parallel, - "strategy": self.config.retrieval_strategy.value, - "sparse_mode": self.config.sparse.mode, - "fusion_method": self.config.fusion.method, - "relation_intent_enabled": self.config.relation_intent.enabled, - "relation_intent_alpha_override": self.config.relation_intent.alpha_override, - "relation_intent_candidate_multiplier": self.config.relation_intent.relation_candidate_multiplier, - "relation_intent_preserve_top_relations": self.config.relation_intent.preserve_top_relations, - "relation_intent_force_sparse": self.config.relation_intent.force_relation_sparse, - "relation_intent_pair_rerank_enabled": self.config.relation_intent.pair_predicate_rerank_enabled, - "relation_intent_pair_predicate_limit": self.config.relation_intent.pair_predicate_limit, - "graph_recall_enabled": self.config.graph_recall.enabled, - "graph_recall_candidate_k": self.config.graph_recall.candidate_k, - "graph_recall_allow_two_hop_pair": self.config.graph_recall.allow_two_hop_pair, - "graph_recall_max_paths": self.config.graph_recall.max_paths, - }, - "vector_store": { - "size": int(vector_size), - }, - "graph_store": { - "num_nodes": self.graph_store.num_nodes, - "num_edges": self.graph_store.num_edges, - }, - "metadata_store": self.metadata_store.get_statistics(), - "sparse": self.sparse_index.stats() if self.sparse_index else None, - } - - def __repr__(self) -> str: - return ( - f"DualPathRetriever(" - f"strategy={self.config.retrieval_strategy.value}, " - f"para_k={self.config.top_k_paragraphs}, " - f"rel_k={self.config.top_k_relations})" - ) diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py deleted file mode 100644 index 9af862f3..00000000 --- a/plugins/A_memorix/core/retrieval/graph_relation_recall.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Graph-assisted relation candidate recall for relation-oriented queries.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Sequence, Set - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.GraphRelationRecall") - - -@dataclass -class GraphRelationRecallConfig: - """Configuration for controlled graph relation recall.""" - - enabled: bool = True - candidate_k: int = 24 - max_hop: int = 1 - allow_two_hop_pair: bool = True - max_paths: int = 4 - - def __post_init__(self) -> None: - self.enabled = bool(self.enabled) - self.candidate_k = max(1, int(self.candidate_k)) - self.max_hop = max(1, int(self.max_hop)) - self.allow_two_hop_pair = bool(self.allow_two_hop_pair) - self.max_paths = max(1, int(self.max_paths)) - - -@dataclass -class GraphRelationCandidate: - """A graph-derived relation candidate before retriever-side fusion.""" - - hash_value: str - subject: str - predicate: str - object: str - confidence: float - graph_seed_entities: List[str] - graph_hops: int - graph_candidate_type: str - supporting_paragraph_count: int - - def to_payload(self) -> Dict[str, Any]: - content = f"{self.subject} {self.predicate} {self.object}" - return { - "hash": self.hash_value, - "content": content, - "subject": self.subject, - "predicate": self.predicate, - "object": self.object, - "confidence": self.confidence, - "graph_seed_entities": list(self.graph_seed_entities), - "graph_hops": int(self.graph_hops), - "graph_candidate_type": self.graph_candidate_type, - "supporting_paragraph_count": int(self.supporting_paragraph_count), - } - - -class GraphRelationRecallService: - """Collect relation candidates from the entity graph in a controlled way.""" - - def __init__( - self, - *, - graph_store: Any, - metadata_store: Any, - config: Optional[GraphRelationRecallConfig] = None, - ) -> None: - self.graph_store = graph_store - self.metadata_store = metadata_store - self.config = config or GraphRelationRecallConfig() - - def recall( - self, - *, - seed_entities: Sequence[str], - ) -> List[GraphRelationCandidate]: - if not self.config.enabled: - return [] - if self.graph_store is None or self.metadata_store is None: - return [] - - seeds = self._normalize_seed_entities(seed_entities) - if not seeds: - return [] - - seen_hashes: Set[str] = set() - candidates: List[GraphRelationCandidate] = [] - - if len(seeds) >= 2: - self._collect_direct_pair_candidates( - seed_a=seeds[0], - seed_b=seeds[1], - seen_hashes=seen_hashes, - out=candidates, - ) - if ( - len(candidates) < 3 - and self.config.allow_two_hop_pair - and len(candidates) < self.config.candidate_k - ): - self._collect_two_hop_pair_candidates( - seed_a=seeds[0], - seed_b=seeds[1], - seen_hashes=seen_hashes, - out=candidates, - ) - else: - self._collect_one_hop_seed_candidates( - seed=seeds[0], - seen_hashes=seen_hashes, - out=candidates, - ) - - return candidates[: self.config.candidate_k] - - def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]: - out: List[str] = [] - seen = set() - for raw in list(seed_entities)[:2]: - resolved = None - try: - resolved = self.graph_store.find_node(str(raw), ignore_case=True) - except Exception: - resolved = None - if not resolved: - continue - canon = str(resolved).strip().lower() - if not canon or canon in seen: - continue - seen.add(canon) - out.append(str(resolved)) - return out - - def _collect_direct_pair_candidates( - self, - *, - seed_a: str, - seed_b: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - relation_hashes = [] - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b)) - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a)) - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="direct_pair", - graph_hops=1, - graph_seed_entities=[seed_a, seed_b], - ) - - def _collect_two_hop_pair_candidates( - self, - *, - seed_a: str, - seed_b: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - try: - paths = self.graph_store.find_paths( - seed_a, - seed_b, - max_depth=2, - max_paths=self.config.max_paths, - ) - except Exception as e: - logger.debug(f"graph two-hop recall skipped: {e}") - return - - for path_nodes in paths: - if len(out) >= self.config.candidate_k: - break - if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3: - continue - if len(path_nodes) != 3: - continue - for idx in range(len(path_nodes) - 1): - if len(out) >= self.config.candidate_k: - break - u = str(path_nodes[idx]) - v = str(path_nodes[idx + 1]) - relation_hashes = [] - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v)) - relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u)) - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="two_hop_pair", - graph_hops=2, - graph_seed_entities=[seed_a, seed_b], - ) - - def _collect_one_hop_seed_candidates( - self, - *, - seed: str, - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - ) -> None: - try: - relation_hashes = self.graph_store.get_incident_relation_hashes( - seed, - limit=self.config.candidate_k, - ) - except Exception as e: - logger.debug(f"graph one-hop recall skipped: {e}") - return - self._append_relation_hashes( - relation_hashes=relation_hashes, - seen_hashes=seen_hashes, - out=out, - candidate_type="one_hop_seed", - graph_hops=min(1, self.config.max_hop), - graph_seed_entities=[seed], - ) - - def _append_relation_hashes( - self, - *, - relation_hashes: Sequence[str], - seen_hashes: Set[str], - out: List[GraphRelationCandidate], - candidate_type: str, - graph_hops: int, - graph_seed_entities: Sequence[str], - ) -> None: - for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}): - if len(out) >= self.config.candidate_k: - break - if relation_hash in seen_hashes: - continue - candidate = self._build_candidate( - relation_hash=relation_hash, - candidate_type=candidate_type, - graph_hops=graph_hops, - graph_seed_entities=graph_seed_entities, - ) - if candidate is None: - continue - seen_hashes.add(relation_hash) - out.append(candidate) - - def _build_candidate( - self, - *, - relation_hash: str, - candidate_type: str, - graph_hops: int, - graph_seed_entities: Sequence[str], - ) -> Optional[GraphRelationCandidate]: - relation = self.metadata_store.get_relation(relation_hash) - if relation is None: - return None - supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) - return GraphRelationCandidate( - hash_value=relation_hash, - subject=str(relation.get("subject", "")), - predicate=str(relation.get("predicate", "")), - object=str(relation.get("object", "")), - confidence=float(relation.get("confidence", 1.0) or 1.0), - graph_seed_entities=[str(x) for x in graph_seed_entities], - graph_hops=int(graph_hops), - graph_candidate_type=str(candidate_type), - supporting_paragraph_count=len(supporting_paragraphs), - ) diff --git a/plugins/A_memorix/core/retrieval/pagerank.py b/plugins/A_memorix/core/retrieval/pagerank.py deleted file mode 100644 index c8ee48bb..00000000 --- a/plugins/A_memorix/core/retrieval/pagerank.py +++ /dev/null @@ -1,482 +0,0 @@ -""" -Personalized PageRank实现 - -提供个性化的图节点排序功能。 -""" - -from typing import Dict, List, Optional, Tuple, Union, Any -from dataclasses import dataclass -import numpy as np - -from src.common.logger import get_logger -from ..storage import GraphStore -from ..utils.matcher import AhoCorasick - -logger = get_logger("A_Memorix.PersonalizedPageRank") - - -@dataclass -class PageRankConfig: - """ - PageRank配置 - - 属性: - alpha: 阻尼系数(0-1之间) - max_iter: 最大迭代次数 - tol: 收敛阈值 - normalize: 是否归一化结果 - min_iterations: 最小迭代次数 - """ - - alpha: float = 0.85 - max_iter: int = 100 - tol: float = 1e-6 - normalize: bool = True - min_iterations: int = 20 - - def __post_init__(self): - """验证配置""" - if not 0 <= self.alpha < 1: - raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}") - - if self.max_iter <= 0: - raise ValueError(f"max_iter必须大于0: {self.max_iter}") - - if self.tol <= 0: - raise ValueError(f"tol必须大于0: {self.tol}") - - if self.min_iterations < 0: - raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}") - - if self.min_iterations >= self.max_iter: - raise ValueError(f"min_iterations必须小于max_iter") - - -class PersonalizedPageRank: - """ - Personalized PageRank计算器 - - 功能: - - 个性化向量支持 - - 快速收敛检测 - - 结果归一化 - - 批量计算 - - 统计信息 - - 参数: - graph_store: 图存储 - config: PageRank配置 - """ - - def __init__( - self, - graph_store: GraphStore, - config: Optional[PageRankConfig] = None, - ): - """ - 初始化PPR计算器 - - Args: - graph_store: 图存储 - config: PageRank配置 - """ - self.graph_store = graph_store - self.config = config or PageRankConfig() - - # 统计信息 - self._total_computations = 0 - self._total_iterations = 0 - self._convergence_history: List[int] = [] - - logger.info( - f"PersonalizedPageRank 初始化: " - f"alpha={self.config.alpha}, " - f"max_iter={self.config.max_iter}" - ) - - # 缓存 Aho-Corasick 匹配器 - self._ac_matcher: Optional[AhoCorasick] = None - self._ac_nodes_count = 0 - - def compute( - self, - personalization: Optional[Dict[str, float]] = None, - alpha: Optional[float] = None, - max_iter: Optional[int] = None, - normalize: Optional[bool] = None, - ) -> Dict[str, float]: - """ - 计算Personalized PageRank - - Args: - personalization: 个性化向量 {节点名: 权重} - alpha: 阻尼系数(覆盖配置值) - max_iter: 最大迭代次数(覆盖配置值) - normalize: 是否归一化(覆盖配置值) - - Returns: - 节点PageRank值字典 {节点名: 分数} - """ - # 使用覆盖值或配置值 - alpha = alpha if alpha is not None else self.config.alpha - max_iter = max_iter if max_iter is not None else self.config.max_iter - normalize = normalize if normalize is not None else self.config.normalize - - # 调用GraphStore的compute_pagerank - scores = self.graph_store.compute_pagerank( - personalization=personalization, - alpha=alpha, - max_iter=max_iter, - tol=self.config.tol, - ) - - # 归一化(如果需要) - if normalize and scores: - total = sum(scores.values()) - if total > 0: - scores = {node: score / total for node, score in scores.items()} - - # 更新统计 - self._total_computations += 1 - - logger.debug( - f"PPR计算完成: {len(scores)} 个节点, " - f"personalization_nodes={len(personalization) if personalization else 0}" - ) - - return scores - - def compute_batch( - self, - personalization_list: List[Dict[str, float]], - normalize: bool = True, - ) -> List[Dict[str, float]]: - """ - 批量计算PPR - - Args: - personalization_list: 个性化向量列表 - normalize: 是否归一化 - - Returns: - PageRank值字典列表 - """ - results = [] - - for i, personalization in enumerate(personalization_list): - logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR") - scores = self.compute(personalization=personalization, normalize=normalize) - results.append(scores) - - return results - - def compute_for_entities( - self, - entities: List[str], - weights: Optional[List[float]] = None, - normalize: bool = True, - ) -> Dict[str, float]: - """ - 为实体列表计算PPR - - Args: - entities: 实体列表 - weights: 权重列表(默认均匀权重) - normalize: 是否归一化 - - Returns: - PageRank值字典 - """ - if not entities: - logger.warning("实体列表为空,返回均匀PPR") - return self.compute(personalization=None, normalize=normalize) - - # 构建个性化向量 - if weights is None: - weights = [1.0] * len(entities) - - if len(weights) != len(entities): - raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}") - - personalization = {entity: weight for entity, weight in zip(entities, weights)} - - return self.compute(personalization=personalization, normalize=normalize) - - def compute_for_query( - self, - query: str, - entity_extractor: Optional[callable] = None, - normalize: bool = True, - ) -> Dict[str, float]: - """ - 为查询计算PPR - - Args: - query: 查询文本 - entity_extractor: 实体提取函数(可选) - normalize: 是否归一化 - - Returns: - PageRank值字典 - """ - # 提取实体 - if entity_extractor is not None: - entities = entity_extractor(query) - else: - # 简单实现:基于图中的节点匹配 - entities = self._extract_entities_from_query(query) - - if not entities: - logger.debug(f"未从查询中提取到实体: '{query}'") - return self.compute(personalization=None, normalize=normalize) - - # 计算PPR - return self.compute_for_entities(entities, normalize=normalize) - - def rank_nodes( - self, - scores: Dict[str, float], - top_k: Optional[int] = None, - min_score: float = 0.0, - ) -> List[Tuple[str, float]]: - """ - 对节点排序 - - Args: - scores: PageRank分数字典 - top_k: 返回前k个节点(None表示全部) - min_score: 最小分数阈值 - - Returns: - 排序后的节点列表 [(节点名, 分数), ...] - """ - # 过滤低分节点 - filtered = [(node, score) for node, score in scores.items() if score >= min_score] - - # 按分数降序排序 - sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True) - - # 返回top_k - if top_k is not None: - sorted_nodes = sorted_nodes[:top_k] - - return sorted_nodes - - def get_personalization_vector( - self, - nodes: List[str], - method: str = "uniform", - ) -> Dict[str, float]: - """ - 生成个性化向量 - - Args: - nodes: 节点列表 - method: 生成方法 - - "uniform": 均匀权重 - - "degree": 按度数加权 - - "inverse_degree": 按度数反比加权 - - Returns: - 个性化向量 {节点名: 权重} - """ - if not nodes: - return {} - - if method == "uniform": - # 均匀权重 - weight = 1.0 / len(nodes) - return {node: weight for node in nodes} - - elif method == "degree": - # 按度数加权 - node_degrees = {} - for node in nodes: - neighbors = self.graph_store.get_neighbors(node) - node_degrees[node] = len(neighbors) - - total_degree = sum(node_degrees.values()) - if total_degree > 0: - return {node: degree / total_degree for node, degree in node_degrees.items()} - else: - return {node: 1.0 / len(nodes) for node in nodes} - - elif method == "inverse_degree": - # 按度数反比加权 - node_degrees = {} - for node in nodes: - neighbors = self.graph_store.get_neighbors(node) - node_degrees[node] = len(neighbors) - - # 反度数 - inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()} - total_inv = sum(inv_degrees.values()) - - if total_inv > 0: - return {node: inv / total_inv for node, inv in inv_degrees.items()} - else: - return {node: 1.0 / len(nodes) for node in nodes} - - else: - raise ValueError(f"不支持的个性化向量生成方法: {method}") - - def compare_scores( - self, - scores1: Dict[str, float], - scores2: Dict[str, float], - ) -> Dict[str, Dict[str, float]]: - """ - 比较两组PPR分数 - - Args: - scores1: 第一组分数 - scores2: 第二组分数 - - Returns: - 比较结果 { - "common_nodes": {节点: (score1, score2)}, - "only_in_1": {节点: score1}, - "only_in_2": {节点: score2}, - } - """ - common_nodes = {} - only_in_1 = {} - only_in_2 = {} - - all_nodes = set(scores1.keys()) | set(scores2.keys()) - - for node in all_nodes: - if node in scores1 and node in scores2: - common_nodes[node] = (scores1[node], scores2[node]) - elif node in scores1: - only_in_1[node] = scores1[node] - else: - only_in_2[node] = scores2[node] - - return { - "common_nodes": common_nodes, - "only_in_1": only_in_1, - "only_in_2": only_in_2, - } - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - avg_iterations = ( - self._total_iterations / self._total_computations - if self._total_computations > 0 - else 0 - ) - - return { - "config": { - "alpha": self.config.alpha, - "max_iter": self.config.max_iter, - "tol": self.config.tol, - "normalize": self.config.normalize, - "min_iterations": self.config.min_iterations, - }, - "statistics": { - "total_computations": self._total_computations, - "total_iterations": self._total_iterations, - "avg_iterations": avg_iterations, - "convergence_history": self._convergence_history.copy(), - }, - "graph": { - "num_nodes": self.graph_store.num_nodes, - "num_edges": self.graph_store.num_edges, - }, - } - - def reset_statistics(self) -> None: - """重置统计信息""" - self._total_computations = 0 - self._total_iterations = 0 - self._convergence_history.clear() - logger.info("统计信息已重置") - - def _extract_entities_from_query(self, query: str) -> List[str]: - """ - 从查询中提取实体(简化实现) - - Args: - query: 查询文本 - - Returns: - 实体列表 - """ - # 获取所有节点 - all_nodes = self.graph_store.get_nodes() - if not all_nodes: - return [] - - # 检查是否需要更新 Aho-Corasick 匹配器 - if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes): - self._ac_matcher = AhoCorasick() - for node in all_nodes: - # 统一转为小写进行不区分大小写匹配 - self._ac_matcher.add_pattern(node.lower()) - self._ac_matcher.build() - self._ac_nodes_count = len(all_nodes) - - # 执行匹配 - query_lower = query.lower() - stats = self._ac_matcher.find_all(query_lower) - - # 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值) - # 为了简单,AC 中 add_pattern 存的是小写 - # 我们需要一个映射:小写 -> 原始 - node_map = {node.lower(): node for node in all_nodes} - entities = [node_map[low_name] for low_name in stats.keys()] - - return entities - - @property - def num_computations(self) -> int: - """计算次数""" - return self._total_computations - - @property - def avg_iterations(self) -> float: - """平均迭代次数""" - if self._total_computations == 0: - return 0.0 - return self._total_iterations / self._total_computations - - def __repr__(self) -> str: - return ( - f"PersonalizedPageRank(" - f"alpha={self.config.alpha}, " - f"computations={self._total_computations})" - ) - - -def create_ppr_from_graph( - graph_store: GraphStore, - alpha: float = 0.85, - max_iter: int = 100, -) -> PersonalizedPageRank: - """ - 从图存储创建PPR计算器 - - Args: - graph_store: 图存储 - alpha: 阻尼系数 - max_iter: 最大迭代次数 - - Returns: - PPR计算器实例 - """ - config = PageRankConfig( - alpha=alpha, - max_iter=max_iter, - ) - - return PersonalizedPageRank( - graph_store=graph_store, - config=config, - ) diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py deleted file mode 100644 index 1fef9f80..00000000 --- a/plugins/A_memorix/core/retrieval/sparse_bm25.py +++ /dev/null @@ -1,401 +0,0 @@ -""" -稀疏检索组件(FTS5 + BM25) - -支持: -- 懒加载索引连接 -- jieba / char n-gram 分词 -- 可卸载并收缩 SQLite 内存缓存 -""" - -from __future__ import annotations - -import re -import sqlite3 -from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -from src.common.logger import get_logger -from ..storage import MetadataStore - -logger = get_logger("A_Memorix.SparseBM25") - -try: - import jieba # type: ignore - - HAS_JIEBA = True -except Exception: - HAS_JIEBA = False - jieba = None - - -@dataclass -class SparseBM25Config: - """BM25 稀疏检索配置。""" - - enabled: bool = True - backend: str = "fts5" - lazy_load: bool = True - mode: str = "auto" # auto | fallback_only | hybrid - tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram - jieba_user_dict: str = "" - char_ngram_n: int = 2 - candidate_k: int = 80 - max_doc_len: int = 2000 - enable_ngram_fallback_index: bool = True - enable_like_fallback: bool = False - enable_relation_sparse_fallback: bool = True - relation_candidate_k: int = 60 - relation_max_doc_len: int = 512 - unload_on_disable: bool = True - shrink_memory_on_unload: bool = True - - def __post_init__(self) -> None: - self.backend = str(self.backend or "fts5").strip().lower() - self.mode = str(self.mode or "auto").strip().lower() - self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower() - self.char_ngram_n = max(1, int(self.char_ngram_n)) - self.candidate_k = max(1, int(self.candidate_k)) - self.max_doc_len = max(0, int(self.max_doc_len)) - self.relation_candidate_k = max(1, int(self.relation_candidate_k)) - self.relation_max_doc_len = max(0, int(self.relation_max_doc_len)) - if self.backend != "fts5": - raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}") - if self.mode not in {"auto", "fallback_only", "hybrid"}: - raise ValueError(f"sparse.mode 非法: {self.mode}") - if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}: - raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}") - - -class SparseBM25Index: - """ - 基于 SQLite FTS5 的 BM25 检索适配层。 - """ - - def __init__( - self, - metadata_store: MetadataStore, - config: Optional[SparseBM25Config] = None, - ): - self.metadata_store = metadata_store - self.config = config or SparseBM25Config() - self._conn: Optional[sqlite3.Connection] = None - self._loaded: bool = False - self._jieba_dict_loaded: bool = False - - @property - def loaded(self) -> bool: - return self._loaded and self._conn is not None - - def ensure_loaded(self) -> bool: - """按需加载 FTS 连接与索引。""" - if not self.config.enabled: - return False - if self.loaded: - return True - - db_path = self.metadata_store.get_db_path() - conn = sqlite3.connect( - str(db_path), - check_same_thread=False, - timeout=30.0, - ) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - conn.execute("PRAGMA temp_store=MEMORY") - - if not self.metadata_store.ensure_fts_schema(conn=conn): - conn.close() - return False - self.metadata_store.ensure_fts_backfilled(conn=conn) - # 关系稀疏检索按独立开关加载,避免不必要的初始化开销。 - if self.config.enable_relation_sparse_fallback: - self.metadata_store.ensure_relations_fts_schema(conn=conn) - self.metadata_store.ensure_relations_fts_backfilled(conn=conn) - if self.config.enable_ngram_fallback_index: - self.metadata_store.ensure_paragraph_ngram_schema(conn=conn) - self.metadata_store.ensure_paragraph_ngram_backfilled( - n=self.config.char_ngram_n, - conn=conn, - ) - - self._conn = conn - self._loaded = True - self._prepare_tokenizer() - logger.info( - "SparseBM25Index loaded: " - f"backend=fts5, tokenizer={self.config.tokenizer_mode}, mode={self.config.mode}" - ) - return True - - def _prepare_tokenizer(self) -> None: - if self._jieba_dict_loaded: - return - if self.config.tokenizer_mode not in {"jieba", "mixed"}: - return - if not HAS_JIEBA: - logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram") - return - user_dict = str(self.config.jieba_user_dict or "").strip() - if user_dict: - try: - jieba.load_userdict(user_dict) # type: ignore[union-attr] - logger.info(f"已加载 jieba 用户词典: {user_dict}") - except Exception as e: - logger.warning(f"加载 jieba 用户词典失败: {e}") - self._jieba_dict_loaded = True - - def _tokenize_jieba(self, text: str) -> List[str]: - if not HAS_JIEBA: - return [] - try: - tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr] - return [t.strip().lower() for t in tokens if t and t.strip()] - except Exception: - return [] - - def _tokenize_char_ngram(self, text: str, n: int) -> List[str]: - compact = re.sub(r"\s+", "", text.lower()) - if not compact: - return [] - if len(compact) < n: - return [compact] - return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] - - def _tokenize(self, text: str) -> List[str]: - text = str(text or "").strip() - if not text: - return [] - - mode = self.config.tokenizer_mode - if mode == "jieba": - tokens = self._tokenize_jieba(text) - if tokens: - return list(dict.fromkeys(tokens)) - return self._tokenize_char_ngram(text, self.config.char_ngram_n) - - if mode == "mixed": - toks = self._tokenize_jieba(text) - toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n)) - return list(dict.fromkeys([t for t in toks if t])) - - return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n))) - - def _build_match_query(self, tokens: List[str]) -> str: - safe_tokens: List[str] = [] - for token in tokens: - t = token.replace('"', '""').strip() - if not t: - continue - safe_tokens.append(f'"{t}"') - if not safe_tokens: - return "" - # 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。 - return " OR ".join(safe_tokens[:64]) - - def _fallback_substring_search( - self, - tokens: List[str], - limit: int, - ) -> List[Dict[str, Any]]: - """ - 当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。 - - 说明: - - FTS 索引当前采用 unicode61 tokenizer。 - - 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。 - - 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。 - """ - if not tokens: - return [] - - # 去重并裁剪 token 数量,避免生成超长 SQL。 - uniq_tokens = [t for t in dict.fromkeys(tokens) if t] - uniq_tokens = uniq_tokens[:32] - if not uniq_tokens: - return [] - - if self.config.enable_ngram_fallback_index: - try: - # 允许运行时切换开关后按需补齐 schema/回填。 - self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn) - self.metadata_store.ensure_paragraph_ngram_backfilled( - n=self.config.char_ngram_n, - conn=self._conn, - ) - rows = self.metadata_store.ngram_search_paragraphs( - tokens=uniq_tokens, - limit=limit, - max_doc_len=self.config.max_doc_len, - conn=self._conn, - ) - if rows: - return rows - except Exception as e: - logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}") - - if not self.config.enable_like_fallback: - return [] - - conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens)) - params: List[Any] = [f"%{tok}%" for tok in uniq_tokens] - scan_limit = max(int(limit) * 8, 200) - params.append(scan_limit) - - sql = f""" - SELECT p.hash, p.content - FROM paragraphs p - WHERE (p.is_deleted IS NULL OR p.is_deleted = 0) - AND ({conditions}) - LIMIT ? - """ - rows = self.metadata_store.query(sql, tuple(params)) - if not rows: - return [] - - scored: List[Dict[str, Any]] = [] - token_count = max(1, len(uniq_tokens)) - for row in rows: - content = str(row.get("content") or "") - content_low = content.lower() - matched = [tok for tok in uniq_tokens if tok in content_low] - if not matched: - continue - coverage = len(matched) / token_count - length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low)) - # 兜底路径使用相对分,保持与上层接口兼容。 - fallback_score = coverage * 0.8 + length_bonus * 0.2 - scored.append( - { - "hash": row["hash"], - "content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content, - "bm25_score": -float(fallback_score), - "fallback_score": float(fallback_score), - } - ) - - scored.sort(key=lambda x: x["fallback_score"], reverse=True) - return scored[:limit] - - def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]: - """执行 BM25 检索。""" - if not self.config.enabled: - return [] - if self.config.lazy_load and not self.loaded: - if not self.ensure_loaded(): - return [] - if not self.loaded: - return [] - # 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。 - self.metadata_store.ensure_relations_fts_schema(conn=self._conn) - self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn) - - tokens = self._tokenize(query) - match_query = self._build_match_query(tokens) - if not match_query: - return [] - - limit = max(1, int(k)) - rows = self.metadata_store.fts_search_bm25( - match_query=match_query, - limit=limit, - max_doc_len=self.config.max_doc_len, - conn=self._conn, - ) - if not rows: - rows = self._fallback_substring_search(tokens=tokens, limit=limit) - - results: List[Dict[str, Any]] = [] - for rank, row in enumerate(rows, start=1): - bm25_score = float(row.get("bm25_score", 0.0)) - results.append( - { - "hash": row["hash"], - "content": row["content"], - "rank": rank, - "bm25_score": bm25_score, - "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 - } - ) - return results - - def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]: - """执行关系稀疏检索(FTS5 + BM25)。""" - if not self.config.enabled or not self.config.enable_relation_sparse_fallback: - return [] - if self.config.lazy_load and not self.loaded: - if not self.ensure_loaded(): - return [] - if not self.loaded: - return [] - - tokens = self._tokenize(query) - match_query = self._build_match_query(tokens) - if not match_query: - return [] - - rows = self.metadata_store.fts_search_relations_bm25( - match_query=match_query, - limit=max(1, int(k)), - max_doc_len=self.config.relation_max_doc_len, - conn=self._conn, - ) - out: List[Dict[str, Any]] = [] - for rank, row in enumerate(rows, start=1): - bm25_score = float(row.get("bm25_score", 0.0)) - out.append( - { - "hash": row["hash"], - "subject": row["subject"], - "predicate": row["predicate"], - "object": row["object"], - "content": row["content"], - "rank": rank, - "bm25_score": bm25_score, - "score": -bm25_score, - } - ) - return out - - def upsert_paragraph(self, paragraph_hash: str) -> bool: - if not self.loaded: - return False - return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn) - - def delete_paragraph(self, paragraph_hash: str) -> bool: - if not self.loaded: - return False - return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn) - - def unload(self) -> None: - """卸载 BM25 连接并尽量释放内存。""" - if self._conn is not None: - try: - if self.config.shrink_memory_on_unload: - self.metadata_store.shrink_memory(conn=self._conn) - except Exception: - pass - try: - self._conn.close() - except Exception: - pass - self._conn = None - self._loaded = False - logger.info("SparseBM25Index unloaded") - - def stats(self) -> Dict[str, Any]: - doc_count = 0 - if self.loaded: - doc_count = self.metadata_store.fts_doc_count(conn=self._conn) - return { - "enabled": self.config.enabled, - "backend": self.config.backend, - "mode": self.config.mode, - "tokenizer_mode": self.config.tokenizer_mode, - "enable_ngram_fallback_index": self.config.enable_ngram_fallback_index, - "enable_like_fallback": self.config.enable_like_fallback, - "enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback, - "loaded": self.loaded, - "has_jieba": HAS_JIEBA, - "doc_count": doc_count, - } diff --git a/plugins/A_memorix/core/retrieval/threshold.py b/plugins/A_memorix/core/retrieval/threshold.py deleted file mode 100644 index 87a0094b..00000000 --- a/plugins/A_memorix/core/retrieval/threshold.py +++ /dev/null @@ -1,450 +0,0 @@ -""" -动态阈值过滤器 - -根据检索结果的分布特征自适应调整过滤阈值。 -""" - -import numpy as np -from typing import List, Dict, Any, Optional, Tuple, Union -from dataclasses import dataclass -from enum import Enum - -from src.common.logger import get_logger -from .dual_path import RetrievalResult - -logger = get_logger("A_Memorix.DynamicThresholdFilter") - - -class ThresholdMethod(Enum): - """阈值计算方法""" - - PERCENTILE = "percentile" # 百分位数 - STD_DEV = "std_dev" # 标准差 - GAP_DETECTION = "gap_detection" # 跳变检测 - ADAPTIVE = "adaptive" # 自适应(综合多种方法) - - -@dataclass -class ThresholdConfig: - """ - 阈值配置 - - 属性: - method: 阈值计算方法 - min_threshold: 最小阈值(绝对值) - max_threshold: 最大阈值(绝对值) - percentile: 百分位数(用于percentile方法) - std_multiplier: 标准差倍数(用于std_dev方法) - min_results: 最少保留结果数 - enable_auto_adjust: 是否自动调整参数 - """ - - method: ThresholdMethod = ThresholdMethod.ADAPTIVE - min_threshold: float = 0.3 - max_threshold: float = 0.95 - percentile: float = 75.0 # 百分位数 - std_multiplier: float = 1.5 # 标准差倍数 - min_results: int = 3 # 最少保留结果数 - enable_auto_adjust: bool = True - - def __post_init__(self): - """验证配置""" - if not 0 <= self.min_threshold <= 1: - raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}") - - if not 0 <= self.max_threshold <= 1: - raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}") - - if self.min_threshold >= self.max_threshold: - raise ValueError(f"min_threshold必须小于max_threshold") - - if not 0 <= self.percentile <= 100: - raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}") - - if self.std_multiplier <= 0: - raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}") - - if self.min_results < 0: - raise ValueError(f"min_results必须大于等于0: {self.min_results}") - - -class DynamicThresholdFilter: - """ - 动态阈值过滤器 - - 功能: - - 基于结果分布自适应计算阈值 - - 多种阈值计算方法 - - 自动参数调整 - - 统计信息收集 - - 参数: - config: 阈值配置 - """ - - def __init__( - self, - config: Optional[ThresholdConfig] = None, - ): - """ - 初始化动态阈值过滤器 - - Args: - config: 阈值配置 - """ - self.config = config or ThresholdConfig() - - # 统计信息 - self._total_filtered = 0 - self._total_processed = 0 - self._threshold_history: List[float] = [] - - logger.info( - f"DynamicThresholdFilter 初始化: " - f"method={self.config.method.value}, " - f"min_threshold={self.config.min_threshold}" - ) - - def filter( - self, - results: List[RetrievalResult], - return_threshold: bool = False, - ) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]: - """ - 过滤检索结果 - - Args: - results: 检索结果列表 - return_threshold: 是否返回使用的阈值 - - Returns: - 过滤后的结果列表,或 (结果列表, 阈值) 元组 - """ - if not results: - logger.debug("结果列表为空,无需过滤") - return ([], 0.0) if return_threshold else [] - - self._total_processed += len(results) - - # 提取分数 - scores = np.array([r.score for r in results]) - - # 计算阈值 - threshold = self._compute_threshold(scores, results) - - # 记录阈值 - self._threshold_history.append(threshold) - - # 应用阈值过滤 - filtered_results = [ - r for r in results - if r.score >= threshold - ] - - # 确保至少保留min_results个结果 - if len(filtered_results) < self.config.min_results: - # 按分数排序,取前min_results个 - sorted_results = sorted(results, key=lambda x: x.score, reverse=True) - filtered_results = sorted_results[:self.config.min_results] - threshold = filtered_results[-1].score if filtered_results else 0.0 - - self._total_filtered += len(results) - len(filtered_results) - - logger.info( - f"过滤完成: {len(results)} -> {len(filtered_results)} " - f"(threshold={threshold:.3f})" - ) - - if return_threshold: - return filtered_results, threshold - return filtered_results - - def _compute_threshold( - self, - scores: np.ndarray, - results: List[RetrievalResult], - ) -> float: - """ - 计算阈值 - - Args: - scores: 分数数组 - results: 检索结果列表 - - Returns: - 阈值 - """ - if self.config.method == ThresholdMethod.PERCENTILE: - threshold = self._percentile_threshold(scores) - elif self.config.method == ThresholdMethod.STD_DEV: - threshold = self._std_dev_threshold(scores) - elif self.config.method == ThresholdMethod.GAP_DETECTION: - threshold = self._gap_detection_threshold(scores) - else: # ADAPTIVE - # 自适应方法:综合多种方法 - thresholds = [ - self._percentile_threshold(scores), - self._std_dev_threshold(scores), - self._gap_detection_threshold(scores), - ] - # 使用中位数作为最终阈值 - threshold = float(np.median(thresholds)) - - # 限制在[min_threshold, max_threshold]范围内 - threshold = np.clip( - threshold, - self.config.min_threshold, - self.config.max_threshold, - ) - - # 自动调整 - if self.config.enable_auto_adjust: - threshold = self._auto_adjust_threshold(threshold, scores) - - return float(threshold) - - def _percentile_threshold(self, scores: np.ndarray) -> float: - """ - 基于百分位数计算阈值 - - Args: - scores: 分数数组 - - Returns: - 阈值 - """ - percentile = self.config.percentile - threshold = float(np.percentile(scores, percentile)) - - logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})") - return threshold - - def _std_dev_threshold(self, scores: np.ndarray) -> float: - """ - 基于标准差计算阈值 - - threshold = mean - std_multiplier * std - - Args: - scores: 分数数组 - - Returns: - 阈值 - """ - mean = float(np.mean(scores)) - std = float(np.std(scores)) - multiplier = self.config.std_multiplier - - threshold = mean - multiplier * std - - logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})") - return threshold - - def _gap_detection_threshold(self, scores: np.ndarray) -> float: - """ - 基于跳变检测计算阈值 - - 找到分数分布中最大的"跳变"位置,以此为阈值 - - Args: - scores: 分数数组(降序排列) - - Returns: - 阈值 - """ - # 降序排列 - sorted_scores = np.sort(scores)[::-1] - - if len(sorted_scores) < 2: - return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0 - - # 计算相邻分数的差值 - gaps = np.diff(sorted_scores) - - # 找到最大的跳变位置 - max_gap_idx = int(np.argmax(gaps)) - - # 阈值为跳变后的分数 - threshold = float(sorted_scores[max_gap_idx + 1]) - - logger.debug( - f"跳变检测阈值: {threshold:.3f} " - f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})" - ) - return threshold - - def _auto_adjust_threshold( - self, - threshold: float, - scores: np.ndarray, - ) -> float: - """ - 自动调整阈值 - - 基于历史阈值和当前分数分布调整 - - Args: - threshold: 当前阈值 - scores: 分数数组 - - Returns: - 调整后的阈值 - """ - if not self._threshold_history: - return threshold - - # 计算历史阈值的移动平均 - recent_thresholds = self._threshold_history[-10:] # 最近10次 - avg_threshold = float(np.mean(recent_thresholds)) - - # 当前阈值与历史平均的差异 - diff = threshold - avg_threshold - - # 如果差异过大(>0.2),向历史平均靠拢 - if abs(diff) > 0.2: - adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50% - logger.debug( - f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} " - f"(历史平均={avg_threshold:.3f})" - ) - return adjusted_threshold - - return threshold - - def filter_by_confidence( - self, - results: List[RetrievalResult], - min_confidence: float = 0.5, - ) -> List[RetrievalResult]: - """ - 基于置信度过滤结果 - - Args: - results: 检索结果列表 - min_confidence: 最小置信度 - - Returns: - 过滤后的结果列表 - """ - filtered = [] - for result in results: - # 对于关系结果,使用confidence字段 - if result.result_type == "relation": - confidence = result.metadata.get("confidence", 1.0) - if confidence >= min_confidence: - filtered.append(result) - else: - # 对于段落结果,直接使用分数 - if result.score >= min_confidence: - filtered.append(result) - - logger.info( - f"置信度过滤: {len(results)} -> {len(filtered)} " - f"(min_confidence={min_confidence})" - ) - - return filtered - - def filter_by_diversity( - self, - results: List[RetrievalResult], - similarity_threshold: float = 0.9, - top_k: int = 10, - ) -> List[RetrievalResult]: - """ - 基于多样性过滤结果(去除重复) - - Args: - results: 检索结果列表 - similarity_threshold: 相似度阈值(高于此值视为重复) - top_k: 最多保留结果数 - - Returns: - 过滤后的结果列表 - """ - if not results: - return [] - - # 按分数排序 - sorted_results = sorted(results, key=lambda x: x.score, reverse=True) - - # 贪心选择:选择与已选结果相似度低的结果 - selected = [] - selected_hashes = [] - - for result in sorted_results: - if len(selected) >= top_k: - break - - # 检查与已选结果的相似度 - is_duplicate = False - for selected_hash in selected_hashes: - # 简单判断:基于hash的前缀 - if result.hash_value[:8] == selected_hash[:8]: - is_duplicate = True - break - - if not is_duplicate: - selected.append(result) - selected_hashes.append(result.hash_value) - - logger.info( - f"多样性过滤: {len(results)} -> {len(selected)} " - f"(similarity_threshold={similarity_threshold})" - ) - - return selected - - def get_statistics(self) -> Dict[str, Any]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - filter_rate = ( - self._total_filtered / self._total_processed - if self._total_processed > 0 - else 0.0 - ) - - stats = { - "config": { - "method": self.config.method.value, - "min_threshold": self.config.min_threshold, - "max_threshold": self.config.max_threshold, - "percentile": self.config.percentile, - "std_multiplier": self.config.std_multiplier, - "min_results": self.config.min_results, - "enable_auto_adjust": self.config.enable_auto_adjust, - }, - "statistics": { - "total_processed": self._total_processed, - "total_filtered": self._total_filtered, - "filter_rate": filter_rate, - "avg_threshold": float(np.mean(self._threshold_history)) - if self._threshold_history else 0.0, - "threshold_count": len(self._threshold_history), - }, - } - - if self._threshold_history: - stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history)) - stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history)) - - return stats - - def reset_statistics(self) -> None: - """重置统计信息""" - self._total_filtered = 0 - self._total_processed = 0 - self._threshold_history.clear() - logger.info("统计信息已重置") - - def __repr__(self) -> str: - return ( - f"DynamicThresholdFilter(" - f"method={self.config.method.value}, " - f"min_threshold={self.config.min_threshold}, " - f"filtered={self._total_filtered}/{self._total_processed})" - ) diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py deleted file mode 100644 index eece6d21..00000000 --- a/plugins/A_memorix/core/runtime/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -"""SDK runtime exports for A_Memorix.""" - -from .search_runtime_initializer import ( - SearchRuntimeBundle, - SearchRuntimeInitializer, - build_search_runtime, -) -from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel - -__all__ = [ - "SearchRuntimeBundle", - "SearchRuntimeInitializer", - "build_search_runtime", - "KernelSearchRequest", - "SDKMemoryKernel", -] diff --git a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py b/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py deleted file mode 100644 index 423b55c4..00000000 --- a/plugins/A_memorix/core/runtime/lifecycle_orchestrator.py +++ /dev/null @@ -1,268 +0,0 @@ -"""Lifecycle bootstrap/teardown helpers extracted from plugin.py.""" - -from __future__ import annotations - -import asyncio -from pathlib import Path -from typing import Any - -from src.common.logger import get_logger - -from ..embedding import create_embedding_api_adapter -from ..retrieval import SparseBM25Config, SparseBM25Index -from ..storage import ( - GraphStore, - MetadataStore, - QuantizationType, - SparseMatrixFormat, - VectorStore, -) -from ..utils.runtime_self_check import ensure_runtime_self_check -from ..utils.relation_write_service import RelationWriteService - -logger = get_logger("A_Memorix.LifecycleOrchestrator") - - -async def ensure_initialized(plugin: Any) -> None: - if plugin._initialized: - plugin._runtime_ready = plugin._check_storage_ready() - return - - async with plugin._init_lock: - if plugin._initialized: - plugin._runtime_ready = plugin._check_storage_ready() - return - - logger.info("A_Memorix 插件正在异步初始化存储组件...") - plugin._validate_runtime_config() - await initialize_storage_async(plugin) - report = await ensure_runtime_self_check(plugin, force=True) - if not bool(report.get("ok", False)): - logger.error( - "A_Memorix runtime self-check failed: " - f"{report.get('message', 'unknown')}; " - "建议执行 python plugins/A_memorix/scripts/runtime_self_check.py --json" - ) - - if plugin.graph_store and plugin.metadata_store: - relation_count = plugin.metadata_store.count_relations() - if relation_count > 0 and not plugin.graph_store.has_edge_hash_map(): - raise RuntimeError( - "检测到 relations 数据存在但 edge-hash-map 为空。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - plugin._initialized = True - plugin._runtime_ready = plugin._check_storage_ready() - plugin._update_plugin_config() - logger.info("A_Memorix 插件异步初始化成功") - - -def start_background_tasks(plugin: Any) -> None: - """Start background tasks idempotently.""" - if not hasattr(plugin, "_episode_generation_task"): - plugin._episode_generation_task = None - - if ( - plugin.get_config("summarization.enabled", True) - and plugin.get_config("schedule.enabled", True) - and (plugin._scheduled_import_task is None or plugin._scheduled_import_task.done()) - ): - plugin._scheduled_import_task = asyncio.create_task(plugin._scheduled_import_loop()) - - if ( - plugin.get_config("advanced.enable_auto_save", True) - and (plugin._auto_save_task is None or plugin._auto_save_task.done()) - ): - plugin._auto_save_task = asyncio.create_task(plugin._auto_save_loop()) - - if ( - plugin.get_config("person_profile.enabled", True) - and (plugin._person_profile_refresh_task is None or plugin._person_profile_refresh_task.done()) - ): - plugin._person_profile_refresh_task = asyncio.create_task(plugin._person_profile_refresh_loop()) - - if plugin._memory_maintenance_task is None or plugin._memory_maintenance_task.done(): - plugin._memory_maintenance_task = asyncio.create_task(plugin._memory_maintenance_loop()) - - rv_cfg = plugin.get_config("retrieval.relation_vectorization", {}) or {} - if isinstance(rv_cfg, dict): - rv_enabled = bool(rv_cfg.get("enabled", False)) - rv_backfill = bool(rv_cfg.get("backfill_enabled", False)) - else: - rv_enabled = False - rv_backfill = False - if rv_enabled and rv_backfill and ( - plugin._relation_vector_backfill_task is None or plugin._relation_vector_backfill_task.done() - ): - plugin._relation_vector_backfill_task = asyncio.create_task(plugin._relation_vector_backfill_loop()) - - episode_task = getattr(plugin, "_episode_generation_task", None) - episode_loop = getattr(plugin, "_episode_generation_loop", None) - if ( - callable(episode_loop) - and bool(plugin.get_config("episode.enabled", True)) - and bool(plugin.get_config("episode.generation_enabled", True)) - and (episode_task is None or episode_task.done()) - ): - plugin._episode_generation_task = asyncio.create_task(episode_loop()) - - -async def cancel_background_tasks(plugin: Any) -> None: - """Cancel all background tasks and wait for cleanup.""" - tasks = [ - ("scheduled_import", plugin._scheduled_import_task), - ("auto_save", plugin._auto_save_task), - ("person_profile_refresh", plugin._person_profile_refresh_task), - ("memory_maintenance", plugin._memory_maintenance_task), - ("relation_vector_backfill", plugin._relation_vector_backfill_task), - ("episode_generation", getattr(plugin, "_episode_generation_task", None)), - ] - for _, task in tasks: - if task and not task.done(): - task.cancel() - - for name, task in tasks: - if not task: - continue - try: - await task - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning(f"后台任务 {name} 退出异常: {e}") - - plugin._scheduled_import_task = None - plugin._auto_save_task = None - plugin._person_profile_refresh_task = None - plugin._memory_maintenance_task = None - plugin._relation_vector_backfill_task = None - plugin._episode_generation_task = None - - -async def initialize_storage_async(plugin: Any) -> None: - """Initialize storage components asynchronously.""" - data_dir_str = plugin.get_config("storage.data_dir", "./data") - if data_dir_str.startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir_str).resolve() - else: - data_dir = Path(data_dir_str) - - logger.info(f"A_Memorix 数据存储路径: {data_dir}") - data_dir.mkdir(parents=True, exist_ok=True) - - plugin.embedding_manager = create_embedding_api_adapter( - batch_size=plugin.get_config("embedding.batch_size", 32), - max_concurrent=plugin.get_config("embedding.max_concurrent", 5), - default_dimension=plugin.get_config("embedding.dimension", 1024), - model_name=plugin.get_config("embedding.model_name", "auto"), - retry_config=plugin.get_config("embedding.retry", {}), - ) - logger.info("嵌入 API 适配器初始化完成") - - try: - detected_dimension = await plugin.embedding_manager._detect_dimension() - logger.info(f"嵌入维度检测成功: {detected_dimension}") - except Exception as e: - logger.warning(f"嵌入维度检测失败: {e},使用默认值") - detected_dimension = plugin.embedding_manager.default_dimension - - quantization_str = plugin.get_config("embedding.quantization_type", "int8") - if str(quantization_str or "").strip().lower() != "int8": - raise ValueError("embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。") - quantization_type = QuantizationType.INT8 - - plugin.vector_store = VectorStore( - dimension=detected_dimension, - quantization_type=quantization_type, - data_dir=data_dir / "vectors", - ) - plugin.vector_store.min_train_threshold = plugin.get_config("embedding.min_train_threshold", 40) - logger.info( - "向量存储初始化完成(" - f"维度: {detected_dimension}, " - f"训练阈值: {plugin.vector_store.min_train_threshold})" - ) - - matrix_format_str = plugin.get_config("graph.sparse_matrix_format", "csr") - matrix_format_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - matrix_format = matrix_format_map.get(matrix_format_str, SparseMatrixFormat.CSR) - - plugin.graph_store = GraphStore( - matrix_format=matrix_format, - data_dir=data_dir / "graph", - ) - logger.info("图存储初始化完成") - - plugin.metadata_store = MetadataStore(data_dir=data_dir / "metadata") - plugin.metadata_store.connect() - logger.info("元数据存储初始化完成") - - plugin.relation_write_service = RelationWriteService( - metadata_store=plugin.metadata_store, - graph_store=plugin.graph_store, - vector_store=plugin.vector_store, - embedding_manager=plugin.embedding_manager, - ) - logger.info("关系写入服务初始化完成") - - sparse_cfg_raw = plugin.get_config("retrieval.sparse", {}) or {} - if not isinstance(sparse_cfg_raw, dict): - sparse_cfg_raw = {} - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as e: - logger.warning(f"sparse 配置非法,回退默认配置: {e}") - sparse_cfg = SparseBM25Config() - plugin.sparse_index = SparseBM25Index( - metadata_store=plugin.metadata_store, - config=sparse_cfg, - ) - logger.info( - "稀疏检索组件初始化完成: " - f"enabled={sparse_cfg.enabled}, " - f"lazy_load={sparse_cfg.lazy_load}, " - f"mode={sparse_cfg.mode}, " - f"tokenizer={sparse_cfg.tokenizer_mode}" - ) - if sparse_cfg.enabled and not sparse_cfg.lazy_load: - plugin.sparse_index.ensure_loaded() - - if plugin.vector_store.has_data(): - try: - plugin.vector_store.load() - logger.info(f"向量数据已加载,共 {plugin.vector_store.num_vectors} 个向量") - except Exception as e: - logger.warning(f"加载向量数据失败: {e}") - - try: - warmup_summary = plugin.vector_store.warmup_index(force_train=True) - if warmup_summary.get("ok"): - logger.info( - "向量索引预热完成: " - f"trained={warmup_summary.get('trained')}, " - f"index_ntotal={warmup_summary.get('index_ntotal')}, " - f"fallback_ntotal={warmup_summary.get('fallback_ntotal')}, " - f"bin_count={warmup_summary.get('bin_count')}, " - f"duration_ms={float(warmup_summary.get('duration_ms', 0.0)):.2f}" - ) - else: - logger.warning( - "向量索引预热失败,继续启用 sparse 降级路径: " - f"{warmup_summary.get('error', 'unknown')}" - ) - except Exception as e: - logger.warning(f"向量索引预热异常,继续启用 sparse 降级路径: {e}") - - if plugin.graph_store.has_data(): - try: - plugin.graph_store.load() - logger.info(f"图数据已加载,共 {plugin.graph_store.num_nodes} 个节点") - except Exception as e: - logger.warning(f"加载图数据失败: {e}") - - logger.info(f"知识库数据目录: {data_dir}") diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py deleted file mode 100644 index 93c11bf7..00000000 --- a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py +++ /dev/null @@ -1,3162 +0,0 @@ -from __future__ import annotations - -import asyncio -import json -import pickle -import time -import uuid -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Awaitable, Callable, Dict, Iterable, List, Optional, Sequence - -from src.common.logger import get_logger - -from ..embedding import create_embedding_api_adapter -from ..retrieval import RetrievalResult, SparseBM25Config, SparseBM25Index, TemporalQueryOptions -from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore -from ..utils.aggregate_query_service import AggregateQueryService -from ..utils.episode_retrieval_service import EpisodeRetrievalService -from ..utils.episode_segmentation_service import EpisodeSegmentationService -from ..utils.episode_service import EpisodeService -from ..utils.hash import compute_hash, normalize_text -from ..utils.person_profile_service import PersonProfileService -from ..utils.relation_write_service import RelationWriteService -from ..utils.retrieval_tuning_manager import RetrievalTuningManager -from ..utils.runtime_self_check import run_embedding_runtime_self_check -from ..utils.search_execution_service import SearchExecutionRequest, SearchExecutionService -from ..utils.summary_importer import SummaryImporter -from ..utils.time_parser import format_timestamp, parse_query_datetime_to_timestamp -from ..utils.web_import_manager import ImportTaskManager -from .search_runtime_initializer import SearchRuntimeBundle, build_search_runtime - -logger = get_logger("A_Memorix.SDKMemoryKernel") - - -@dataclass -class KernelSearchRequest: - query: str = "" - limit: int = 5 - mode: str = "search" - chat_id: str = "" - person_id: str = "" - time_start: Optional[str | float] = None - time_end: Optional[str | float] = None - respect_filter: bool = True - user_id: str = "" - group_id: str = "" - - -@dataclass -class _NormalizedSearchTimeWindow: - numeric_start: Optional[float] = None - numeric_end: Optional[float] = None - query_start: Optional[str] = None - query_end: Optional[str] = None - - -class _KernelRuntimeFacade: - def __init__(self, kernel: "SDKMemoryKernel") -> None: - self._kernel = kernel - self.config = kernel.config - self._plugin_config = kernel.config - self._runtime_self_check_report: Dict[str, Any] = {} - - def get_config(self, key: str, default: Any = None) -> Any: - return self._kernel._cfg(key, default) - - def is_runtime_ready(self) -> bool: - return self._kernel.is_runtime_ready() - - def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: - return self._kernel.is_chat_enabled(stream_id=stream_id, group_id=group_id, user_id=user_id) - - async def reinforce_access(self, relation_hashes: Sequence[str]) -> None: - if self._kernel.metadata_store is None: - return - hashes = [str(item or "").strip() for item in relation_hashes if str(item or "").strip()] - if not hashes: - return - self._kernel.metadata_store.reinforce_relations(hashes) - self._kernel._last_maintenance_at = time.time() - - async def execute_request_with_dedup( - self, - request_key: str, - executor: Callable[[], Awaitable[Dict[str, Any]]], - ) -> tuple[bool, Dict[str, Any]]: - return await self._kernel.execute_request_with_dedup(request_key, executor) - - @property - def vector_store(self) -> Optional[VectorStore]: - return self._kernel.vector_store - - @property - def graph_store(self) -> Optional[GraphStore]: - return self._kernel.graph_store - - @property - def metadata_store(self) -> Optional[MetadataStore]: - return self._kernel.metadata_store - - @property - def embedding_manager(self): - return self._kernel.embedding_manager - - @property - def sparse_index(self): - return self._kernel.sparse_index - - @property - def relation_write_service(self) -> Optional[RelationWriteService]: - return self._kernel.relation_write_service - - -class SDKMemoryKernel: - def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None: - self.plugin_root = Path(plugin_root).resolve() - self.config = config or {} - storage_cfg = self._cfg("storage", {}) or {} - data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") - self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) - self.embedding_dimension = max(1, int(self._cfg("embedding.dimension", 1024))) - self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) - - self.embedding_manager = None - self.vector_store: Optional[VectorStore] = None - self.graph_store: Optional[GraphStore] = None - self.metadata_store: Optional[MetadataStore] = None - self.relation_write_service: Optional[RelationWriteService] = None - self.sparse_index: Optional[SparseBM25Index] = None - self.retriever = None - self.threshold_filter = None - self.episode_retriever: Optional[EpisodeRetrievalService] = None - self.aggregate_query_service: Optional[AggregateQueryService] = None - self.person_profile_service: Optional[PersonProfileService] = None - self.episode_segmentation_service: Optional[EpisodeSegmentationService] = None - self.episode_service: Optional[EpisodeService] = None - self.summary_importer: Optional[SummaryImporter] = None - self.import_task_manager: Optional[ImportTaskManager] = None - self.retrieval_tuning_manager: Optional[RetrievalTuningManager] = None - self._runtime_bundle: Optional[SearchRuntimeBundle] = None - self._runtime_facade = _KernelRuntimeFacade(self) - self._initialized = False - self._last_maintenance_at: Optional[float] = None - self._request_dedup_tasks: Dict[str, asyncio.Task] = {} - self._background_tasks: Dict[str, asyncio.Task] = {} - self._background_lock = asyncio.Lock() - self._background_stopping = False - self._active_person_timestamps: Dict[str, float] = {} - - def _cfg(self, key: str, default: Any = None) -> Any: - current: Any = self.config - if key in {"storage", "embedding", "retrieval", "graph", "episode", "web", "advanced", "threshold", "summarization"} and isinstance(current, dict): - return current.get(key, default) - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - def _set_cfg(self, key: str, value: Any) -> None: - current: Dict[str, Any] = self.config - parts = [part for part in str(key or "").split(".") if part] - if not parts: - return - for part in parts[:-1]: - next_value = current.get(part) - if not isinstance(next_value, dict): - next_value = {} - current[part] = next_value - current = next_value - current[parts[-1]] = value - - def _build_runtime_config(self) -> Dict[str, Any]: - runtime_config = dict(self.config) - runtime_config.update( - { - "vector_store": self.vector_store, - "graph_store": self.graph_store, - "metadata_store": self.metadata_store, - "embedding_manager": self.embedding_manager, - "sparse_index": self.sparse_index, - "relation_write_service": self.relation_write_service, - "plugin_instance": self._runtime_facade, - } - ) - return runtime_config - - def is_runtime_ready(self) -> bool: - return bool( - self._initialized - and self.vector_store is not None - and self.graph_store is not None - and self.metadata_store is not None - and self.embedding_manager is not None - and self.retriever is not None - ) - - def is_chat_enabled(self, stream_id: str, group_id: str | None = None, user_id: str | None = None) -> bool: - filter_config = self._cfg("filter", {}) or {} - if not isinstance(filter_config, dict) or not filter_config: - return True - - if not bool(filter_config.get("enabled", True)): - return True - - mode = str(filter_config.get("mode", "blacklist") or "blacklist").strip().lower() - patterns = filter_config.get("chats") or [] - if not isinstance(patterns, list): - patterns = [] - - if not patterns: - return mode == "blacklist" - - stream_token = str(stream_id or "").strip() - group_token = str(group_id or "").strip() - user_token = str(user_id or "").strip() - candidates = {token for token in (stream_token, group_token, user_token) if token} - - matched = False - for raw_pattern in patterns: - pattern = str(raw_pattern or "").strip() - if not pattern: - continue - if ":" in pattern: - prefix, value = pattern.split(":", 1) - prefix = prefix.strip().lower() - value = value.strip() - if prefix == "group" and value and value == group_token: - matched = True - elif prefix in {"user", "private"} and value and value == user_token: - matched = True - elif prefix == "stream" and value and value == stream_token: - matched = True - elif pattern in candidates: - matched = True - - if matched: - break - - if mode == "blacklist": - return not matched - return matched - - def _is_chat_filtered( - self, - *, - respect_filter: bool, - stream_id: str = "", - group_id: str = "", - user_id: str = "", - ) -> bool: - if not bool(respect_filter): - return False - - stream_token = str(stream_id or "").strip() - group_token = str(group_id or "").strip() - user_token = str(user_id or "").strip() - if not (stream_token or group_token or user_token): - return False - return not self.is_chat_enabled(stream_token, group_token, user_token) - - def _stored_vector_dimension(self) -> Optional[int]: - meta_path = self.data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - return None - try: - with open(meta_path, "rb") as handle: - meta = pickle.load(handle) - except Exception as exc: - logger.warning(f"读取向量元数据失败,将回退到 runtime self-check: {exc}") - return None - try: - value = int(meta.get("dimension") or 0) - except Exception: - return None - return value if value > 0 else None - - def _vector_mismatch_error(self, *, stored_dimension: int, detected_dimension: int) -> str: - return ( - "检测到现有向量库与当前 embedding 输出维度不一致:" - f"stored={stored_dimension}, encoded={detected_dimension}。" - " 当前版本不会兼容 hash 时代或其他维度的旧向量,请改回原 embedding 配置," - "或执行重嵌入/重建向量。" - ) - - async def initialize(self) -> None: - if self._initialized: - await self._start_background_tasks() - return - - self.data_dir.mkdir(parents=True, exist_ok=True) - self.embedding_manager = create_embedding_api_adapter( - batch_size=int(self._cfg("embedding.batch_size", 32)), - max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), - default_dimension=self.embedding_dimension, - enable_cache=bool(self._cfg("embedding.enable_cache", False)), - model_name=str(self._cfg("embedding.model_name", "auto") or "auto"), - retry_config=self._cfg("embedding.retry", {}) or {}, - ) - detected_dimension = int(await self.embedding_manager._detect_dimension()) - self.embedding_dimension = detected_dimension - - stored_dimension = self._stored_vector_dimension() - if stored_dimension is not None and stored_dimension != detected_dimension: - raise RuntimeError( - self._vector_mismatch_error( - stored_dimension=stored_dimension, - detected_dimension=detected_dimension, - ) - ) - - matrix_format = str(self._cfg("graph.sparse_matrix_format", "csr") or "csr").strip().lower() - graph_format = SparseMatrixFormat.CSC if matrix_format == "csc" else SparseMatrixFormat.CSR - - self.vector_store = VectorStore( - dimension=detected_dimension, - quantization_type=QuantizationType.INT8, - data_dir=self.data_dir / "vectors", - ) - self.graph_store = GraphStore(matrix_format=graph_format, data_dir=self.data_dir / "graph") - self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") - self.metadata_store.connect() - - if self.vector_store.has_data(): - self.vector_store.load() - self.vector_store.warmup_index(force_train=True) - if self.graph_store.has_data(): - self.graph_store.load() - - sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as exc: - logger.warning(f"sparse 配置非法,回退默认: {exc}") - sparse_cfg = SparseBM25Config() - self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=sparse_cfg) - if getattr(self.sparse_index.config, "enabled", False): - self.sparse_index.ensure_loaded() - - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - runtime_config = self._build_runtime_config() - self._runtime_bundle = build_search_runtime( - plugin_config=runtime_config, - logger_obj=logger, - owner_tag="sdk_kernel", - log_prefix="[sdk]", - ) - if not self._runtime_bundle.ready: - raise RuntimeError(self._runtime_bundle.error or "检索运行时初始化失败") - - self.retriever = self._runtime_bundle.retriever - self.threshold_filter = self._runtime_bundle.threshold_filter - self.sparse_index = self._runtime_bundle.sparse_index or self.sparse_index - - runtime_config = self._build_runtime_config() - self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) - self.aggregate_query_service = AggregateQueryService(plugin_config=runtime_config) - self.person_profile_service = PersonProfileService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - plugin_config=runtime_config, - retriever=self.retriever, - ) - self.episode_segmentation_service = EpisodeSegmentationService(plugin_config=runtime_config) - self.episode_service = EpisodeService( - metadata_store=self.metadata_store, - plugin_config=runtime_config, - segmentation_service=self.episode_segmentation_service, - ) - self.summary_importer = SummaryImporter( - vector_store=self.vector_store, - graph_store=self.graph_store, - metadata_store=self.metadata_store, - embedding_manager=self.embedding_manager, - plugin_config=runtime_config, - ) - self.import_task_manager = ImportTaskManager(self._runtime_facade) - self.retrieval_tuning_manager = RetrievalTuningManager( - self._runtime_facade, - import_write_blocked_provider=self.import_task_manager.is_write_blocked, - ) - - report = await run_embedding_runtime_self_check( - config=runtime_config, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sample_text="A_Memorix runtime self check", - ) - self._runtime_facade._runtime_self_check_report = dict(report) - if not bool(report.get("ok", False)): - message = str(report.get("message", "runtime self-check failed") or "runtime self-check failed") - raise RuntimeError(f"{message};请改回原 embedding 配置,或执行重嵌入/重建向量。") - - self._initialized = True - await self._start_background_tasks() - - async def shutdown(self) -> None: - await self._stop_background_tasks() - if self.import_task_manager is not None: - try: - await self.import_task_manager.shutdown() - except Exception as exc: - logger.warning(f"关闭导入任务管理器失败: {exc}") - if self.retrieval_tuning_manager is not None: - try: - await self.retrieval_tuning_manager.shutdown() - except Exception as exc: - logger.warning(f"关闭调优任务管理器失败: {exc}") - self.close() - - def close(self) -> None: - try: - self._persist() - finally: - if self.metadata_store is not None: - self.metadata_store.close() - self._initialized = False - self._request_dedup_tasks.clear() - self._runtime_facade._runtime_self_check_report = {} - self._background_tasks.clear() - self._active_person_timestamps.clear() - - async def execute_request_with_dedup( - self, - request_key: str, - executor: Callable[[], Awaitable[Dict[str, Any]]], - ) -> tuple[bool, Dict[str, Any]]: - token = str(request_key or "").strip() - if not token: - return False, await executor() - - existing = self._request_dedup_tasks.get(token) - if existing is not None: - return True, await existing - - task = asyncio.create_task(executor()) - self._request_dedup_tasks[token] = task - try: - payload = await task - return False, payload - finally: - current = self._request_dedup_tasks.get(token) - if current is task: - self._request_dedup_tasks.pop(token, None) - - async def summarize_chat_stream( - self, - *, - chat_id: str, - context_length: Optional[int] = None, - include_personality: Optional[bool] = None, - ) -> Dict[str, Any]: - await self.initialize() - assert self.summary_importer - success, detail = await self.summary_importer.import_from_stream( - stream_id=str(chat_id or "").strip(), - context_length=context_length, - include_personality=include_personality, - ) - if success: - await self.rebuild_episodes_for_sources([self._build_source("chat_summary", chat_id, [])]) - self._persist() - return {"success": bool(success), "detail": detail} - - async def ingest_summary( - self, - *, - external_id: str, - chat_id: str, - text: str, - participants: Optional[Sequence[str]] = None, - time_start: Optional[float] = None, - time_end: Optional[float] = None, - tags: Optional[Sequence[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - user_id: str = "", - group_id: str = "", - ) -> Dict[str, Any]: - external_token = str(external_id or "").strip() or compute_hash(f"chat_summary:{chat_id}:{text}") - if self._is_chat_filtered( - respect_filter=respect_filter, - stream_id=chat_id, - group_id=group_id, - user_id=user_id, - ): - return { - "success": True, - "stored_ids": [], - "skipped_ids": [external_token], - "detail": "chat_filtered", - } - - summary_meta = dict(metadata or {}) - summary_meta.setdefault("kind", "chat_summary") - if not str(text or "").strip() or bool(summary_meta.get("generate_from_chat", False)): - result = await self.summarize_chat_stream( - chat_id=chat_id, - context_length=self._optional_int(summary_meta.get("context_length")), - include_personality=summary_meta.get("include_personality"), - ) - result.setdefault("external_id", external_id) - result.setdefault("chat_id", chat_id) - return result - return await self.ingest_text( - external_id=external_id, - source_type="chat_summary", - text=text, - chat_id=chat_id, - participants=participants, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=summary_meta, - respect_filter=respect_filter, - user_id=user_id, - group_id=group_id, - ) - - async def ingest_text( - self, - *, - external_id: str, - source_type: str, - text: str, - chat_id: str = "", - person_ids: Optional[Sequence[str]] = None, - participants: Optional[Sequence[str]] = None, - timestamp: Optional[float] = None, - time_start: Optional[float] = None, - time_end: Optional[float] = None, - tags: Optional[Sequence[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - entities: Optional[Sequence[str]] = None, - relations: Optional[Sequence[Dict[str, Any]]] = None, - respect_filter: bool = True, - user_id: str = "", - group_id: str = "", - ) -> Dict[str, Any]: - content = normalize_text(text) - external_token = str(external_id or "").strip() or compute_hash(f"{source_type}:{chat_id}:{content}") - if self._is_chat_filtered( - respect_filter=respect_filter, - stream_id=chat_id, - group_id=group_id, - user_id=user_id, - ): - return { - "success": True, - "stored_ids": [], - "skipped_ids": [external_token], - "detail": "chat_filtered", - } - - await self.initialize() - assert self.metadata_store is not None - assert self.vector_store is not None - assert self.graph_store is not None - assert self.embedding_manager is not None - assert self.relation_write_service is not None - - if not content: - return {"stored_ids": [], "skipped_ids": [external_token], "reason": "empty_text"} - - existing_ref = self.metadata_store.get_external_memory_ref(external_token) - if existing_ref: - return { - "stored_ids": [], - "skipped_ids": [str(existing_ref.get("paragraph_hash", "") or "")], - "reason": "exists", - } - - person_tokens = self._tokens(person_ids) - participant_tokens = self._tokens(participants) - entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) - source = self._build_source(source_type, chat_id, person_tokens) - paragraph_meta = dict(metadata or {}) - paragraph_meta.update( - { - "external_id": external_token, - "source_type": str(source_type or "").strip(), - "chat_id": str(chat_id or "").strip(), - "person_ids": person_tokens, - "participants": participant_tokens, - "tags": self._tokens(tags), - } - ) - - paragraph_hash = self.metadata_store.add_paragraph( - content=content, - source=source, - metadata=paragraph_meta, - knowledge_type=self._resolve_knowledge_type(source_type), - time_meta=self._time_meta(timestamp, time_start, time_end), - ) - embedding = await self.embedding_manager.encode(content) - self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) - - for name in entity_tokens: - self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) - - stored_relations: List[str] = [] - for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: - subject = str(row.get("subject", "") or "").strip() - predicate = str(row.get("predicate", "") or "").strip() - obj = str(row.get("object", "") or "").strip() - if not (subject and predicate and obj): - continue - result = await self.relation_write_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(row.get("confidence", 1.0) or 1.0), - source_paragraph=paragraph_hash, - metadata=row.get("metadata") if isinstance(row.get("metadata"), dict) else {"external_id": external_token, "source_type": source_type}, - write_vector=self.relation_vectors_enabled, - ) - self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) - stored_relations.append(result.hash_value) - - self.metadata_store.upsert_external_memory_ref( - external_id=external_token, - paragraph_hash=paragraph_hash, - source_type=source_type, - metadata={"chat_id": chat_id, "person_ids": person_tokens}, - ) - self.metadata_store.enqueue_episode_pending(paragraph_hash, source=source) - self._persist() - await self.process_episode_pending_batch( - limit=max(1, int(self._cfg("episode.pending_batch_size", 12))), - max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3))), - ) - for person_id in person_tokens: - self._mark_person_active(person_id) - await self.refresh_person_profile(person_id) - return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} - - async def process_episode_pending_batch(self, *, limit: int = 20, max_retry: int = 3) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.episode_service is not None - - pending_rows = self.metadata_store.fetch_episode_pending_batch(limit=max(1, int(limit)), max_retry=max(1, int(max_retry))) - if not pending_rows: - return {"processed": 0, "episode_count": 0, "fallback_count": 0, "failed": 0} - - source_to_hashes: Dict[str, List[str]] = {} - pending_hashes = [str(row.get("paragraph_hash", "") or "").strip() for row in pending_rows if str(row.get("paragraph_hash", "") or "").strip()] - for row in pending_rows: - paragraph_hash = str(row.get("paragraph_hash", "") or "").strip() - source = str(row.get("source", "") or "").strip() - if not paragraph_hash or not source: - continue - source_to_hashes.setdefault(source, []).append(paragraph_hash) - - if pending_hashes: - self.metadata_store.mark_episode_pending_running(pending_hashes) - - result = await self.episode_service.process_pending_rows(pending_rows) - done_hashes = [str(item or "").strip() for item in result.get("done_hashes", []) if str(item or "").strip()] - failed_hashes = { - str(hash_value or "").strip(): str(error or "").strip() - for hash_value, error in (result.get("failed_hashes", {}) or {}).items() - if str(hash_value or "").strip() - } - - if done_hashes: - self.metadata_store.mark_episode_pending_done(done_hashes) - for hash_value, error in failed_hashes.items(): - self.metadata_store.mark_episode_pending_failed(hash_value, error) - - untouched = [hash_value for hash_value in pending_hashes if hash_value not in set(done_hashes) and hash_value not in failed_hashes] - for hash_value in untouched: - self.metadata_store.mark_episode_pending_failed(hash_value, "episode processing finished without explicit status") - - for source, paragraph_hashes in source_to_hashes.items(): - counts = self.metadata_store.get_episode_pending_status_counts(source) - if counts.get("failed", 0) > 0: - source_error = next( - ( - failed_hashes.get(hash_value) - for hash_value in paragraph_hashes - if failed_hashes.get(hash_value) - ), - "episode pending source contains failed rows", - ) - self.metadata_store.mark_episode_source_failed(source, str(source_error or "episode pending source contains failed rows")) - elif counts.get("pending", 0) == 0 and counts.get("running", 0) == 0: - self.metadata_store.mark_episode_source_done(source) - - self._persist() - return { - "processed": len(done_hashes) + len(failed_hashes), - "episode_count": int(result.get("episode_count") or 0), - "fallback_count": int(result.get("fallback_count") or 0), - "failed": len(failed_hashes) + len(untouched), - "group_count": int(result.get("group_count") or 0), - "missing_count": int(result.get("missing_count") or 0), - } - - async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: - if self._is_chat_filtered( - respect_filter=request.respect_filter, - stream_id=request.chat_id, - group_id=request.group_id, - user_id=request.user_id, - ): - return {"summary": "", "hits": [], "filtered": True} - - await self.initialize() - assert self.retriever is not None - assert self.episode_retriever is not None - assert self.aggregate_query_service is not None - - mode = str(request.mode or "search").strip().lower() or "search" - query = str(request.query or "").strip() - limit = max(1, int(request.limit or 5)) - supported_modes = {"search", "time", "hybrid", "episode", "aggregate"} - if mode not in supported_modes: - return { - "summary": "", - "hits": [], - "error": ( - f"不支持的检索模式: {mode}(仅支持 search/time/hybrid/episode/aggregate," - "semantic 已移除)" - ), - } - try: - time_window = self._normalize_search_time_window(request.time_start, request.time_end) - except ValueError as exc: - return {"summary": "", "hits": [], "error": str(exc)} - - if mode == "episode": - rows = await self.episode_retriever.query( - query=query, - top_k=limit, - time_from=time_window.numeric_start, - time_to=time_window.numeric_end, - person=request.person_id or None, - source=self._chat_source(request.chat_id), - ) - hits = [self._episode_hit(row) for row in rows] - return {"summary": self._summary(hits), "hits": hits} - - if mode == "aggregate": - payload = await self.aggregate_query_service.execute( - query=query, - top_k=limit, - mix=True, - mix_top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - search_runner=lambda: self._aggregate_search(query, limit, request), - time_runner=lambda: self._aggregate_time(query, limit, request, time_window), - episode_runner=lambda: self._aggregate_episode(query, limit, request, time_window), - ) - hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] - for item in hits: - item.setdefault("metadata", {}) - filtered = self._filter_hits(hits, request.person_id) - return {"summary": self._summary(filtered), "hits": filtered} - - query_type = mode - runtime_config = self._build_runtime_config() - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=runtime_config, - request=SearchExecutionRequest( - caller="sdk_memory_kernel", - stream_id=str(request.chat_id or "") or None, - group_id=str(request.group_id or "") or None, - user_id=str(request.user_id or "") or None, - query_type=query_type, - query=query, - top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=bool(request.respect_filter), - reinforce_access=True, - ) - if not result.success: - return {"summary": "", "hits": [], "error": result.error} - if result.chat_filtered: - return {"summary": "", "hits": [], "filtered": True} - - hits = [self._retrieval_result_hit(item) for item in result.results] - filtered = self._filter_hits(hits, request.person_id) - return {"summary": self._summary(filtered), "hits": filtered} - - async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: - del chat_id - await self.initialize() - assert self.metadata_store is not None - assert self.person_profile_service is not None - self._mark_person_active(person_id) - profile = await self.person_profile_service.query_person_profile( - person_id=person_id, - top_k=max(4, int(limit or 10)), - source_note="sdk_memory_kernel.get_person_profile", - ) - if not profile.get("success"): - return {"summary": "", "traits": [], "evidence": []} - - evidence = [] - for hash_value in profile.get("evidence_ids", [])[: max(1, int(limit))]: - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is not None: - evidence.append( - { - "hash": hash_value, - "content": str(paragraph.get("content", "") or "")[:220], - "metadata": paragraph.get("metadata", {}) or {}, - "type": "paragraph", - } - ) - continue - - relation = self.metadata_store.get_relation(hash_value) - if relation is not None: - evidence.append( - { - "hash": hash_value, - "content": " ".join( - [ - str(relation.get("subject", "") or "").strip(), - str(relation.get("predicate", "") or "").strip(), - str(relation.get("object", "") or "").strip(), - ] - ).strip(), - "metadata": { - "confidence": relation.get("confidence"), - "source_paragraph": relation.get("source_paragraph"), - }, - "type": "relation", - } - ) - - text = str(profile.get("profile_text", "") or "").strip() - traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] - return { - "summary": text, - "traits": traits, - "evidence": evidence, - "person_id": str(profile.get("person_id", "") or person_id), - "person_name": str(profile.get("person_name", "") or ""), - "profile_source": str(profile.get("profile_source", "") or "auto_snapshot"), - "has_manual_override": bool(profile.get("has_manual_override", False)), - } - - async def refresh_person_profile(self, person_id: str, limit: int = 10, *, mark_active: bool = True) -> Dict[str, Any]: - await self.initialize() - assert self.person_profile_service - if mark_active: - self._mark_person_active(person_id) - profile = await self.person_profile_service.query_person_profile( - person_id=person_id, - top_k=max(4, int(limit or 10)), - force_refresh=True, - source_note="sdk_memory_kernel.refresh_person_profile", - ) - return profile if isinstance(profile, dict) else {} - - async def maintain_memory( - self, - *, - action: str, - target: str = "", - hours: Optional[float] = None, - reason: str = "", - limit: int = 50, - ) -> Dict[str, Any]: - del reason - await self.initialize() - assert self.metadata_store - act = str(action or "").strip().lower() - if act == "recycle_bin": - items = self.metadata_store.get_deleted_relations(limit=max(1, int(limit or 50))) - return {"success": True, "items": items, "count": len(items)} - - hashes = self._resolve_deleted_relation_hashes(target) if act == "restore" else self._resolve_relation_hashes(target) - if not hashes: - return {"success": False, "detail": "未命中可维护关系"} - - if act == "reinforce": - self.metadata_store.reinforce_relations(hashes) - elif act == "freeze": - self.metadata_store.mark_relations_inactive(hashes) - self._rebuild_graph_from_metadata() - elif act == "protect": - ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 - self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) - elif act == "restore": - restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) - if restored <= 0: - return {"success": False, "detail": "未恢复任何关系"} - self._rebuild_graph_from_metadata() - else: - return {"success": False, "detail": f"不支持的维护动作: {act}"} - - self._last_maintenance_at = time.time() - self._persist() - return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} - - async def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.episode_service is not None - - items: List[Dict[str, Any]] = [] - failures: List[Dict[str, str]] = [] - for source in self._tokens(sources): - self.metadata_store.mark_episode_source_running(source) - try: - result = await self.episode_service.rebuild_source(source) - self.metadata_store.mark_episode_source_done(source) - items.append(result) - except Exception as exc: - err = str(exc)[:500] - self.metadata_store.mark_episode_source_failed(source, err) - failures.append({"source": source, "error": err}) - self._persist() - return { - "rebuilt": len(items), - "items": items, - "failures": failures, - "sources": [str(item.get("source", "") or "") for item in items] or self._tokens(sources), - } - - def memory_stats(self) -> Dict[str, Any]: - assert self.metadata_store - stats = self.metadata_store.get_statistics() - episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] - profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] - pending = self.metadata_store.query( - "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" - )[0]["c"] - return { - "paragraphs": int(stats.get("paragraph_count", 0) or 0), - "relations": int(stats.get("relation_count", 0) or 0), - "episodes": int(episodes or 0), - "profiles": int(profiles or 0), - "episode_pending": int(pending or 0), - "last_maintenance_at": self._last_maintenance_at, - } - - async def memory_graph_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.graph_store is not None - - act = str(action or "").strip().lower() - if act == "get_graph": - return {"success": True, **self._serialize_graph(limit=max(1, int(kwargs.get("limit", 200) or 200)))} - - if act == "create_node": - name = str(kwargs.get("name", "") or kwargs.get("node", "") or "").strip() - if not name: - return {"success": False, "error": "node name 不能为空"} - entity_hash = self.metadata_store.add_entity(name=name, metadata=kwargs.get("metadata") or {}) - self._rebuild_graph_from_metadata() - self._persist() - return {"success": True, "node": {"name": name, "hash": entity_hash}} - - if act == "delete_node": - name = str(kwargs.get("name", "") or kwargs.get("node", "") or kwargs.get("hash_or_name", "") or "").strip() - if not name: - return {"success": False, "error": "node name 不能为空"} - result = await self._execute_delete_action( - mode="entity", - selector={"query": name}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_node"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": bool(result.get("deleted_count", 0)), - "node": name, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - if act == "rename_node": - old_name = str(kwargs.get("name", "") or kwargs.get("old_name", "") or kwargs.get("node", "") or "").strip() - new_name = str(kwargs.get("new_name", "") or kwargs.get("target_name", "") or "").strip() - return self._rename_node(old_name, new_name) - - if act == "create_edge": - subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() - predicate = str(kwargs.get("predicate", "") or kwargs.get("label", "") or "").strip() - obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() - if not all([subject, predicate, obj]): - return {"success": False, "error": "subject/predicate/object 不能为空"} - if self.relation_write_service is not None: - result = await self.relation_write_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(kwargs.get("confidence", 1.0) or 1.0), - source_paragraph=str(kwargs.get("source_paragraph", "") or "") or None, - metadata=kwargs.get("metadata") or {}, - write_vector=self.relation_vectors_enabled, - ) - relation_hash = result.hash_value - else: - relation_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - confidence=float(kwargs.get("confidence", 1.0) or 1.0), - source_paragraph=kwargs.get("source_paragraph"), - metadata=kwargs.get("metadata") or {}, - ) - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "edge": { - "hash": relation_hash, - "subject": subject, - "predicate": predicate, - "object": obj, - "weight": float(kwargs.get("confidence", 1.0) or 1.0), - }, - } - - if act == "delete_edge": - relation_hash = str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip() - if relation_hash: - result = await self._execute_delete_action( - mode="relation", - selector={"query": relation_hash}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_edge"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": int(result.get("deleted_count", 0)), - "hash": relation_hash, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - subject = str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip() - obj = str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip() - deleted_hashes = [ - str(row.get("hash", "") or "") - for row in self.metadata_store.get_relations(subject=subject) - if str(row.get("object", "") or "").strip() == obj - ] - result = await self._execute_delete_action( - mode="relation", - selector={"hashes": deleted_hashes, "subject": subject, "object": obj}, - requested_by=str(kwargs.get("requested_by", "") or "memory_graph_admin"), - reason=str(kwargs.get("reason", "") or "graph_delete_edge"), - ) - return { - "success": bool(result.get("success", False)), - "deleted": int(result.get("deleted_count", 0)), - "subject": subject, - "object": obj, - "operation_id": result.get("operation_id", ""), - "counts": result.get("counts", {}), - "error": result.get("error", ""), - } - - if act == "update_edge_weight": - return self._update_edge_weight( - relation_hash=str(kwargs.get("hash", "") or kwargs.get("relation_hash", "") or "").strip(), - subject=str(kwargs.get("subject", "") or kwargs.get("source", "") or "").strip(), - obj=str(kwargs.get("object", "") or kwargs.get("target", "") or "").strip(), - weight=float(kwargs.get("weight", kwargs.get("confidence", 1.0)) or 1.0), - ) - - return {"success": False, "error": f"不支持的 graph action: {act}"} - - async def memory_source_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - if act == "list": - sources = self.metadata_store.get_all_sources() - items = [] - for row in sources: - source_name = str(row.get("source", "") or "").strip() - items.append( - { - **row, - "episode_rebuild_blocked": self.metadata_store.is_episode_source_query_blocked(source_name), - } - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "delete": - source = str(kwargs.get("source", "") or "").strip() - return await self._execute_delete_action( - mode="source", - selector={"sources": [source]}, - requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), - reason=str(kwargs.get("reason", "") or "source_delete"), - ) - - if act == "batch_delete": - return await self._execute_delete_action( - mode="source", - selector={"sources": list(kwargs.get("sources") or [])}, - requested_by=str(kwargs.get("requested_by", "") or "memory_source_admin"), - reason=str(kwargs.get("reason", "") or "source_batch_delete"), - ) - - return {"success": False, "error": f"不支持的 source action: {act}"} - - async def memory_episode_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - if act in {"query", "list"}: - items = self.metadata_store.query_episodes( - query=str(kwargs.get("query", "") or "").strip(), - time_from=self._optional_float(kwargs.get("time_start", kwargs.get("time_from"))), - time_to=self._optional_float(kwargs.get("time_end", kwargs.get("time_to"))), - person=str(kwargs.get("person_id", "") or kwargs.get("person", "") or "").strip() or None, - source=str(kwargs.get("source", "") or "").strip() or None, - limit=max(1, int(kwargs.get("limit", 20) or 20)), - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "get": - episode_id = str(kwargs.get("episode_id", "") or "").strip() - if not episode_id: - return {"success": False, "error": "episode_id 不能为空"} - episode = self.metadata_store.get_episode_by_id(episode_id) - if episode is None: - return {"success": False, "error": "episode 不存在"} - episode["paragraphs"] = self.metadata_store.get_episode_paragraphs( - episode_id, - limit=max(1, int(kwargs.get("paragraph_limit", 100) or 100)), - ) - return {"success": True, "episode": episode} - - if act == "status": - summary = self.metadata_store.get_episode_source_rebuild_summary( - failed_limit=max(1, int(kwargs.get("limit", 20) or 20)) - ) - summary["pending_queue"] = self.metadata_store.query( - "SELECT COUNT(*) AS c FROM episode_pending_paragraphs WHERE status IN ('pending', 'running', 'failed')" - )[0]["c"] - return {"success": True, **summary} - - if act == "rebuild": - sources = self._tokens(kwargs.get("sources")) - if not sources: - source = str(kwargs.get("source", "") or "").strip() - if source: - sources = [source] - if not sources and bool(kwargs.get("all", False)): - sources = self.metadata_store.list_episode_sources_for_rebuild() - if not sources: - sources = [str(row.get("source", "") or "").strip() for row in self.metadata_store.get_all_sources()] - if not sources: - return {"success": False, "error": "未提供可重建的 source"} - result = await self.rebuild_episodes_for_sources(sources) - return {"success": len(result.get("failures", [])) == 0, **result} - - if act == "process_pending": - result = await self.process_episode_pending_batch( - limit=max(1, int(kwargs.get("limit", 20) or 20)), - max_retry=max(1, int(kwargs.get("max_retry", 3) or 3)), - ) - return {"success": True, **result} - - return {"success": False, "error": f"不支持的 episode action: {act}"} - - async def memory_profile_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store is not None - assert self.person_profile_service is not None - - act = str(action or "").strip().lower() - if act == "query": - profile = await self.person_profile_service.query_person_profile( - person_id=str(kwargs.get("person_id", "") or "").strip(), - person_keyword=str(kwargs.get("person_keyword", "") or kwargs.get("keyword", "") or "").strip(), - top_k=max(1, int(kwargs.get("limit", kwargs.get("top_k", 12)) or 12)), - force_refresh=bool(kwargs.get("force_refresh", False)), - source_note="sdk_memory_kernel.memory_profile_admin.query", - ) - return profile if isinstance(profile, dict) else {"success": False, "error": "invalid profile payload"} - - if act == "list": - limit = max(1, int(kwargs.get("limit", 50) or 50)) - rows = self.metadata_store.query( - """ - SELECT s.person_id, s.profile_version, s.profile_text, s.updated_at, s.expires_at, s.source_note - FROM person_profile_snapshots s - JOIN ( - SELECT person_id, MAX(profile_version) AS max_version - FROM person_profile_snapshots - GROUP BY person_id - ) latest - ON latest.person_id = s.person_id - AND latest.max_version = s.profile_version - ORDER BY s.updated_at DESC - LIMIT ? - """, - (limit,), - ) - items = [] - for row in rows: - person_id = str(row.get("person_id", "") or "").strip() - override = self.metadata_store.get_person_profile_override(person_id) - items.append( - { - "person_id": person_id, - "profile_version": int(row.get("profile_version", 0) or 0), - "profile_text": str(row.get("profile_text", "") or ""), - "updated_at": row.get("updated_at"), - "expires_at": row.get("expires_at"), - "source_note": str(row.get("source_note", "") or ""), - "has_manual_override": bool(override), - "manual_override": override, - } - ) - return {"success": True, "items": items, "count": len(items)} - - if act == "set_override": - person_id = str(kwargs.get("person_id", "") or "").strip() - override = self.metadata_store.set_person_profile_override( - person_id=person_id, - override_text=str(kwargs.get("override_text", "") or kwargs.get("text", "") or ""), - updated_by=str(kwargs.get("updated_by", "") or ""), - source=str(kwargs.get("source", "") or "memory_profile_admin"), - ) - return {"success": True, "override": override} - - if act == "delete_override": - person_id = str(kwargs.get("person_id", "") or "").strip() - deleted = self.metadata_store.delete_person_profile_override(person_id) - return {"success": bool(deleted), "deleted": bool(deleted), "person_id": person_id} - - return {"success": False, "error": f"不支持的 profile action: {act}"} - - async def memory_runtime_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - act = str(action or "").strip().lower() - - if act == "save": - self._persist() - return {"success": True, "saved": True, "data_dir": str(self.data_dir)} - - if act == "get_config": - return { - "success": True, - "config": self.config, - "data_dir": str(self.data_dir), - "embedding_dimension": int(self.embedding_dimension), - "auto_save": bool(self._cfg("advanced.enable_auto_save", True)), - "relation_vectors_enabled": bool(self.relation_vectors_enabled), - "runtime_ready": self.is_runtime_ready(), - } - - if act in {"self_check", "refresh_self_check"}: - report = await run_embedding_runtime_self_check( - config=self._build_runtime_config(), - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - sample_text=str(kwargs.get("sample_text", "") or "A_Memorix runtime self check"), - ) - self._runtime_facade._runtime_self_check_report = dict(report) - return {"success": bool(report.get("ok", False)), "report": report} - - if act == "set_auto_save": - enabled = bool(kwargs.get("enabled", False)) - self._set_cfg("advanced.enable_auto_save", enabled) - return {"success": True, "auto_save": enabled} - - return {"success": False, "error": f"不支持的 runtime action: {act}"} - - async def memory_import_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - manager = self.import_task_manager - if manager is None: - return {"success": False, "error": "import manager 未初始化"} - - act = str(action or "").strip().lower() - if act in {"settings", "get_settings", "get_guide"}: - return {"success": True, "settings": await manager.get_runtime_settings()} - if act in {"path_aliases", "get_path_aliases"}: - return {"success": True, "path_aliases": manager.get_path_aliases()} - if act in {"resolve_path", "resolve"}: - return await manager.resolve_path_request(kwargs) - if act == "create_upload": - task = await manager.create_upload_task( - list(kwargs.get("staged_files") or kwargs.get("files") or kwargs.get("uploads") or []), - kwargs, - ) - return {"success": True, "task": task} - if act == "create_paste": - return {"success": True, "task": await manager.create_paste_task(kwargs)} - if act == "create_raw_scan": - return {"success": True, "task": await manager.create_raw_scan_task(kwargs)} - if act == "create_lpmm_openie": - return {"success": True, "task": await manager.create_lpmm_openie_task(kwargs)} - if act == "create_lpmm_convert": - return {"success": True, "task": await manager.create_lpmm_convert_task(kwargs)} - if act == "create_temporal_backfill": - return {"success": True, "task": await manager.create_temporal_backfill_task(kwargs)} - if act == "create_maibot_migration": - return {"success": True, "task": await manager.create_maibot_migration_task(kwargs)} - if act == "list": - items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) - return {"success": True, "items": items, "count": len(items)} - if act == "get": - task = await manager.get_task( - str(kwargs.get("task_id", "") or ""), - include_chunks=bool(kwargs.get("include_chunks", False)), - ) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act in {"chunks", "get_chunks"}: - payload = await manager.get_chunks( - str(kwargs.get("task_id", "") or ""), - str(kwargs.get("file_id", "") or ""), - offset=max(0, int(kwargs.get("offset", 0) or 0)), - limit=max(1, int(kwargs.get("limit", 50) or 50)), - ) - return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务或文件不存在"} - if act == "cancel": - task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "retry_failed": - overrides = kwargs.get("overrides") if isinstance(kwargs.get("overrides"), dict) else kwargs - task = await manager.retry_failed(str(kwargs.get("task_id", "") or ""), overrides=overrides) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - return {"success": False, "error": f"不支持的 import action: {act}"} - - async def memory_tuning_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - manager = self.retrieval_tuning_manager - if manager is None: - return {"success": False, "error": "tuning manager 未初始化"} - - act = str(action or "").strip().lower() - if act in {"settings", "get_settings"}: - return {"success": True, "settings": manager.get_runtime_settings()} - if act == "get_profile": - profile = manager.get_profile_snapshot() - return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} - if act == "apply_profile": - profile = kwargs.get("profile") if isinstance(kwargs.get("profile"), dict) else kwargs - return {"success": True, **await manager.apply_profile(profile, reason=str(kwargs.get("reason", "manual") or "manual"))} - if act == "rollback_profile": - return {"success": True, **await manager.rollback_profile()} - if act == "export_profile": - profile = manager.get_profile_snapshot() - return {"success": True, "profile": profile, "toml": manager.export_toml_snippet(profile)} - if act == "create_task": - payload = kwargs.get("payload") if isinstance(kwargs.get("payload"), dict) else kwargs - return {"success": True, "task": await manager.create_task(payload)} - if act == "list_tasks": - items = await manager.list_tasks(limit=max(1, int(kwargs.get("limit", 50) or 50))) - return {"success": True, "items": items, "count": len(items)} - if act == "get_task": - task = await manager.get_task( - str(kwargs.get("task_id", "") or ""), - include_rounds=bool(kwargs.get("include_rounds", False)), - ) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "get_rounds": - payload = await manager.get_rounds( - str(kwargs.get("task_id", "") or ""), - offset=max(0, int(kwargs.get("offset", 0) or 0)), - limit=max(1, int(kwargs.get("limit", 50) or 50)), - ) - return {"success": payload is not None, **(payload or {}), "error": "" if payload is not None else "任务不存在"} - if act == "cancel": - task = await manager.cancel_task(str(kwargs.get("task_id", "") or "")) - return {"success": task is not None, "task": task, "error": "" if task is not None else "任务不存在"} - if act == "apply_best": - return {"success": True, **await manager.apply_best(str(kwargs.get("task_id", "") or ""))} - if act == "get_report": - report = await manager.get_report(str(kwargs.get("task_id", "") or ""), fmt=str(kwargs.get("format", "md") or "md")) - return {"success": report is not None, "report": report, "error": "" if report is not None else "任务不存在"} - return {"success": False, "error": f"不支持的 tuning action: {act}"} - - async def memory_v5_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - assert self.metadata_store - - act = str(action or "").strip().lower() - target = str(kwargs.get("target", "") or kwargs.get("query", "") or "").strip() - reason = str(kwargs.get("reason", "") or "").strip() - updated_by = str(kwargs.get("updated_by", "") or kwargs.get("requested_by", "") or "").strip() - limit = max(1, int(kwargs.get("limit", 50) or 50)) - - if act == "recycle_bin": - items = self.metadata_store.get_deleted_relations(limit=limit) - return {"success": True, "items": items, "count": len(items)} - - if act == "status": - return self._memory_v5_status(target=target, limit=limit) - - if act == "restore": - hashes = self._resolve_deleted_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可恢复关系"} - result = await self._restore_relation_hashes(hashes) - operation = self.metadata_store.record_v5_operation( - action=act, - target=target, - resolved_hashes=hashes, - reason=reason, - updated_by=updated_by, - result=result, - ) - return {"success": bool(result.get("restored_count", 0) > 0), "operation": operation, **result} - - hashes = self._resolve_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可维护关系"} - - result = self._apply_v5_relation_action( - action=act, - hashes=hashes, - strength=float(kwargs.get("strength", 1.0) or 1.0), - ) - operation = self.metadata_store.record_v5_operation( - action=act, - target=target, - resolved_hashes=hashes, - reason=reason, - updated_by=updated_by, - result=result, - ) - return {"success": bool(result.get("success", False)), "operation": operation, **result} - - async def memory_delete_admin(self, *, action: str, **kwargs) -> Dict[str, Any]: - await self.initialize() - act = str(action or "").strip().lower() - mode = str(kwargs.get("mode", "") or "").strip().lower() - selector = kwargs.get("selector") - if selector is None: - selector = { - key: value - for key, value in kwargs.items() - if key - not in { - "action", - "mode", - "dry_run", - "cascade", - "operation_id", - "reason", - "requested_by", - } - } - reason = str(kwargs.get("reason", "") or "").strip() - requested_by = str(kwargs.get("requested_by", "") or "").strip() - - if act == "preview": - return await self._preview_delete_action(mode=mode, selector=selector) - if act == "execute": - return await self._execute_delete_action( - mode=mode, - selector=selector, - requested_by=requested_by, - reason=reason, - ) - if act == "restore": - return await self._restore_delete_action( - mode=mode, - selector=selector, - operation_id=str(kwargs.get("operation_id", "") or "").strip(), - requested_by=requested_by, - reason=reason, - ) - if act == "get_operation": - operation = self.metadata_store.get_delete_operation(str(kwargs.get("operation_id", "") or "").strip()) - return {"success": operation is not None, "operation": operation, "error": "" if operation is not None else "operation 不存在"} - if act == "list_operations": - items = self.metadata_store.list_delete_operations( - limit=max(1, int(kwargs.get("limit", 50) or 50)), - mode=mode, - ) - return {"success": True, "items": items, "count": len(items)} - if act == "purge": - return await self._purge_deleted_memory( - grace_hours=self._optional_float(kwargs.get("grace_hours")), - limit=max(1, int(kwargs.get("limit", 1000) or 1000)), - ) - return {"success": False, "error": f"不支持的 delete action: {act}"} - - def get_import_task_manager(self) -> Optional[ImportTaskManager]: - return self.import_task_manager - - def get_retrieval_tuning_manager(self) -> Optional[RetrievalTuningManager]: - return self.retrieval_tuning_manager - - async def _aggregate_search(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=self._build_runtime_config(), - request=SearchExecutionRequest( - caller="sdk_memory_kernel.aggregate", - stream_id=str(request.chat_id or "") or None, - query_type="search", - query=query, - top_k=limit, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=False, - reinforce_access=True, - ) - hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] - return {"success": result.success, "results": hits, "count": len(hits), "query_type": "search", "error": result.error} - - async def _aggregate_time( - self, - query: str, - limit: int, - request: KernelSearchRequest, - time_window: _NormalizedSearchTimeWindow, - ) -> Dict[str, Any]: - result = await SearchExecutionService.execute( - retriever=self.retriever, - threshold_filter=self.threshold_filter, - plugin_config=self._build_runtime_config(), - request=SearchExecutionRequest( - caller="sdk_memory_kernel.aggregate", - stream_id=str(request.chat_id or "") or None, - query_type="time", - query=query, - top_k=limit, - time_from=time_window.query_start, - time_to=time_window.query_end, - person=str(request.person_id or "") or None, - source=self._chat_source(request.chat_id), - use_threshold=True, - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ), - enforce_chat_filter=False, - reinforce_access=True, - ) - hits = [self._retrieval_result_hit(item) for item in result.results] if result.success else [] - return {"success": result.success, "results": hits, "count": len(hits), "query_type": "time", "error": result.error} - - async def _aggregate_episode( - self, - query: str, - limit: int, - request: KernelSearchRequest, - time_window: _NormalizedSearchTimeWindow, - ) -> Dict[str, Any]: - assert self.episode_retriever - rows = await self.episode_retriever.query( - query=query, - top_k=limit, - time_from=time_window.numeric_start, - time_to=time_window.numeric_end, - person=request.person_id or None, - source=self._chat_source(request.chat_id), - ) - hits = [self._episode_hit(row) for row in rows] - return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} - - def _persist(self) -> None: - if self.vector_store is not None: - self.vector_store.save() - if self.graph_store is not None: - self.graph_store.save() - if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): - self.sparse_index.ensure_loaded() - - async def _start_background_tasks(self) -> None: - async with self._background_lock: - self._background_stopping = False - self._ensure_background_task("auto_save", self._auto_save_loop) - self._ensure_background_task("episode_pending", self._episode_pending_loop) - self._ensure_background_task("memory_maintenance", self._memory_maintenance_loop) - self._ensure_background_task("person_profile_refresh", self._person_profile_refresh_loop) - - def _ensure_background_task(self, name: str, factory: Callable[[], Awaitable[None]]) -> None: - task = self._background_tasks.get(name) - if task is not None and not task.done(): - return - self._background_tasks[name] = asyncio.create_task(factory(), name=f"A_Memorix.{name}") - - async def _stop_background_tasks(self) -> None: - async with self._background_lock: - self._background_stopping = True - tasks = [task for task in self._background_tasks.values() if task is not None and not task.done()] - for task in tasks: - task.cancel() - for task in tasks: - try: - await task - except asyncio.CancelledError: - pass - except Exception as exc: - logger.warning(f"后台任务退出异常: {exc}") - self._background_tasks.clear() - - async def _auto_save_loop(self) -> None: - try: - while not self._background_stopping: - interval_minutes = max(1.0, float(self._cfg("advanced.auto_save_interval_minutes", 5) or 5)) - await asyncio.sleep(interval_minutes * 60.0) - if self._background_stopping: - break - if bool(self._cfg("advanced.enable_auto_save", True)): - self._persist() - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"auto_save loop 异常: {exc}") - - async def _episode_pending_loop(self) -> None: - try: - while not self._background_stopping: - await asyncio.sleep(60.0) - if self._background_stopping: - break - if not bool(self._cfg("episode.enabled", True)): - continue - if not bool(self._cfg("episode.generation_enabled", True)): - continue - await self.process_episode_pending_batch( - limit=max(1, int(self._cfg("episode.pending_batch_size", 20) or 20)), - max_retry=max(1, int(self._cfg("episode.pending_max_retry", 3) or 3)), - ) - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"episode_pending loop 异常: {exc}") - - async def _person_profile_refresh_loop(self) -> None: - try: - while not self._background_stopping: - interval_minutes = max(1.0, float(self._cfg("person_profile.refresh_interval_minutes", 30) or 30)) - await asyncio.sleep(max(60.0, interval_minutes * 60.0)) - if self._background_stopping: - break - if not bool(self._cfg("person_profile.enabled", True)): - continue - active_window_hours = max(1.0, float(self._cfg("person_profile.active_window_hours", 72.0) or 72.0)) - max_refresh = max(1, int(self._cfg("person_profile.max_refresh_per_cycle", 50) or 50)) - cutoff = time.time() - active_window_hours * 3600.0 - candidates = [ - person_id - for person_id, seen_at in sorted( - self._active_person_timestamps.items(), - key=lambda item: item[1], - reverse=True, - ) - if seen_at >= cutoff - ][:max_refresh] - for person_id in candidates: - try: - await self.refresh_person_profile(person_id, limit=max(4, int(self._cfg("person_profile.top_k_evidence", 12) or 12)), mark_active=False) - except Exception as exc: - logger.warning(f"刷新人物画像失败: {exc}") - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"person_profile_refresh loop 异常: {exc}") - - async def _memory_maintenance_loop(self) -> None: - try: - while not self._background_stopping: - interval_hours = max(1.0 / 60.0, float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0)) - await asyncio.sleep(max(60.0, interval_hours * 3600.0)) - if self._background_stopping: - break - if not bool(self._cfg("memory.enabled", True)): - continue - await self._run_memory_maintenance_cycle(interval_hours=interval_hours) - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning(f"memory_maintenance loop 异常: {exc}") - - async def _run_memory_maintenance_cycle(self, *, interval_hours: float) -> None: - assert self.graph_store is not None - assert self.metadata_store is not None - half_life = float(self._cfg("memory.half_life_hours", 24.0) or 24.0) - if half_life > 0: - factor = 0.5 ** (float(interval_hours) / half_life) - self.graph_store.decay(factor) - - await self._process_freeze_and_prune() - await self._orphan_gc_phase() - self._last_maintenance_at = time.time() - self._persist() - - async def _process_freeze_and_prune(self) -> None: - assert self.metadata_store is not None - assert self.graph_store is not None - prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) - freeze_duration = max(0.0, float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0)) * 3600.0 - now = time.time() - - low_edges = self.graph_store.get_low_weight_edges(prune_threshold) - hashes_to_freeze: List[str] = [] - edges_to_deactivate: List[tuple[str, str]] = [] - for src, tgt in low_edges: - relation_hashes = list(self.graph_store.get_relation_hashes_for_edge(src, tgt)) - if not relation_hashes: - continue - statuses = self.metadata_store.get_relation_status_batch(relation_hashes) - current_hashes: List[str] = [] - protected = False - for hash_value, status in statuses.items(): - if bool(status.get("is_pinned")) or float(status.get("protected_until") or 0.0) > now: - protected = True - break - current_hashes.append(hash_value) - if protected or not current_hashes: - continue - hashes_to_freeze.extend(current_hashes) - edges_to_deactivate.append((src, tgt)) - - if hashes_to_freeze: - self.metadata_store.mark_relations_inactive(hashes_to_freeze, inactive_since=now) - self.graph_store.deactivate_edges(edges_to_deactivate) - - cutoff = now - freeze_duration - expired_hashes = self.metadata_store.get_prune_candidates(cutoff) - if not expired_hashes: - return - relation_info = self.metadata_store.get_relations_subject_object_map(expired_hashes) - operations = [(src, tgt, hash_value) for hash_value, (src, tgt) in relation_info.items()] - if operations: - self.graph_store.prune_relation_hashes(operations) - deleted_hashes = [hash_value for hash_value in expired_hashes if hash_value in relation_info] - if deleted_hashes: - self.metadata_store.backup_and_delete_relations(deleted_hashes) - if self.vector_store is not None: - self.vector_store.delete(deleted_hashes) - - async def _orphan_gc_phase(self) -> None: - assert self.metadata_store is not None - assert self.graph_store is not None - orphan_cfg = self._cfg("memory.orphan", {}) or {} - if not bool(orphan_cfg.get("enable_soft_delete", True)): - return - entity_retention = max(0.0, float(orphan_cfg.get("entity_retention_days", 7.0) or 7.0)) * 86400.0 - paragraph_retention = max(0.0, float(orphan_cfg.get("paragraph_retention_days", 7.0) or 7.0)) * 86400.0 - grace_period = max(0.0, float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0)) * 3600.0 - - isolated = self.graph_store.get_isolated_nodes(include_inactive=True) - if isolated: - entity_hashes = self.metadata_store.get_entity_gc_candidates(isolated, retention_seconds=entity_retention) - if entity_hashes: - self.metadata_store.mark_as_deleted(entity_hashes, "entity") - - paragraph_hashes = self.metadata_store.get_paragraph_gc_candidates(retention_seconds=paragraph_retention) - if paragraph_hashes: - self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") - - dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace_period) - if dead_paragraphs: - hashes = [str(item[0] or "").strip() for item in dead_paragraphs if item and str(item[0] or "").strip()] - if hashes: - self.metadata_store.physically_delete_paragraphs(hashes) - if self.vector_store is not None: - self.vector_store.delete(hashes) - - dead_entities = self.metadata_store.sweep_deleted_items("entity", grace_period) - if dead_entities: - entity_hashes = [str(item[0] or "").strip() for item in dead_entities if item and str(item[0] or "").strip()] - entity_names = [str(item[1] or "").strip() for item in dead_entities if item and str(item[1] or "").strip()] - if entity_names: - self.graph_store.delete_nodes(entity_names) - if entity_hashes: - self.metadata_store.physically_delete_entities(entity_hashes) - if self.vector_store is not None: - self.vector_store.delete(entity_hashes) - - def _mark_person_active(self, person_id: str) -> None: - token = str(person_id or "").strip() - if not token: - return - self._active_person_timestamps[token] = time.time() - - def _serialize_graph(self, *, limit: int = 200) -> Dict[str, Any]: - assert self.graph_store is not None - assert self.metadata_store is not None - nodes = self.graph_store.get_nodes() - if limit > 0: - nodes = nodes[:limit] - node_set = set(nodes) - node_payload = [] - for name in nodes: - attrs = self.graph_store.get_node_attributes(name) or {} - node_payload.append({"id": name, "name": name, "attributes": attrs}) - - edge_payload = [] - for source, target, relation_hashes in self.graph_store.iter_edge_hash_entries(): - if source not in node_set or target not in node_set: - continue - edge_payload.append( - { - "source": source, - "target": target, - "weight": float(self.graph_store.get_edge_weight(source, target)), - "relation_hashes": sorted(str(item) for item in relation_hashes if str(item).strip()), - } - ) - return { - "nodes": node_payload, - "edges": edge_payload, - "total_nodes": int(self.graph_store.num_nodes), - "total_edges": int(self.graph_store.num_edges), - } - - def _delete_sources(self, sources: Iterable[Any]) -> Dict[str, Any]: - assert self.metadata_store - source_tokens = self._tokens(sources) - if not source_tokens: - return {"success": False, "error": "source 不能为空"} - - deleted_paragraphs = 0 - deleted_sources: List[str] = [] - for source in source_tokens: - paragraphs = self.metadata_store.get_paragraphs_by_source(source) - if not paragraphs: - self.metadata_store.replace_episodes_for_source(source, []) - continue - for row in paragraphs: - paragraph_hash = str(row.get("hash", "") or "").strip() - if not paragraph_hash: - continue - cleanup = self.metadata_store.delete_paragraph_atomic(paragraph_hash) - self._apply_cleanup_plan(cleanup) - deleted_paragraphs += 1 - self.metadata_store.replace_episodes_for_source(source, []) - deleted_sources.append(source) - - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "sources": deleted_sources, - "deleted_source_count": len(deleted_sources), - "deleted_paragraph_count": deleted_paragraphs, - } - - def _apply_cleanup_plan(self, cleanup: Dict[str, Any]) -> None: - if not isinstance(cleanup, dict): - return - if self.vector_store is not None: - vector_ids: List[str] = [] - paragraph_hash = str(cleanup.get("vector_id_to_remove", "") or "").strip() - if paragraph_hash: - vector_ids.append(paragraph_hash) - for _, _, relation_hash in cleanup.get("relation_prune_ops", []) or []: - token = str(relation_hash or "").strip() - if token: - vector_ids.append(token) - if vector_ids: - self.vector_store.delete(list(dict.fromkeys(vector_ids))) - - def _rebuild_graph_from_metadata(self) -> Dict[str, int]: - assert self.metadata_store is not None - assert self.graph_store is not None - entity_rows = self.metadata_store.query( - """ - SELECT name - FROM entities - WHERE is_deleted IS NULL OR is_deleted = 0 - ORDER BY name ASC - """ - ) - raw_relation_rows = self.metadata_store.query( - """ - SELECT subject, object, confidence, hash - FROM relations - WHERE is_inactive IS NULL OR is_inactive = 0 - """ - ) - relation_rows = [ - row - for row in raw_relation_rows - if str(row.get("subject", "") or "").strip() and str(row.get("object", "") or "").strip() - ] - - names = list( - dict.fromkeys( - [ - str(row.get("name", "") or "").strip() - for row in entity_rows - if str(row.get("name", "") or "").strip() - ] - + [ - str(row.get("subject", "") or "").strip() - for row in relation_rows - if str(row.get("subject", "") or "").strip() - ] - + [ - str(row.get("object", "") or "").strip() - for row in relation_rows - if str(row.get("object", "") or "").strip() - ] - ) - ) - self.graph_store.clear() - if names: - self.graph_store.add_nodes(names) - if relation_rows: - self.graph_store.add_edges( - [ - ( - str(row.get("subject", "") or "").strip(), - str(row.get("object", "") or "").strip(), - ) - for row in relation_rows - ], - weights=[float(row.get("confidence", 1.0) or 1.0) for row in relation_rows], - relation_hashes=[str(row.get("hash", "") or "") for row in relation_rows], - ) - return {"node_count": int(self.graph_store.num_nodes), "edge_count": int(self.graph_store.num_edges)} - - def _rename_node(self, old_name: str, new_name: str) -> Dict[str, Any]: - assert self.metadata_store - source = str(old_name or "").strip() - target = str(new_name or "").strip() - if not source or not target: - return {"success": False, "error": "old_name/new_name 不能为空"} - if source == target: - return {"success": True, "renamed": False, "old_name": source, "new_name": target} - - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - old_hash = compute_hash(source.lower()) - target_hash = compute_hash(target.lower()) - - cursor.execute( - """ - SELECT hash, name, vector_index, appearance_count, created_at, metadata - FROM entities - WHERE hash = ? - OR LOWER(TRIM(name)) = LOWER(TRIM(?)) - LIMIT 1 - """, - (old_hash, source), - ) - old_row = cursor.fetchone() - if old_row is None: - return {"success": False, "error": "原节点不存在"} - - cursor.execute( - """ - SELECT hash, appearance_count - FROM entities - WHERE hash = ? - OR LOWER(TRIM(name)) = LOWER(TRIM(?)) - LIMIT 1 - """, - (target_hash, target), - ) - target_row = cursor.fetchone() - - try: - cursor.execute("BEGIN IMMEDIATE") - if target_row is None: - cursor.execute( - """ - INSERT INTO entities (hash, name, vector_index, appearance_count, created_at, metadata, is_deleted, deleted_at) - VALUES (?, ?, ?, ?, ?, ?, 0, NULL) - """, - ( - target_hash, - target, - old_row["vector_index"], - old_row["appearance_count"], - old_row["created_at"], - old_row["metadata"], - ), - ) - resolved_target_hash = target_hash - else: - resolved_target_hash = str(target_row["hash"] or "").strip() - cursor.execute( - """ - UPDATE entities - SET name = ?, - appearance_count = COALESCE(appearance_count, 0) + ?, - is_deleted = 0, - deleted_at = NULL - WHERE hash = ? - """, - ( - target, - int(old_row["appearance_count"] or 0), - resolved_target_hash, - ), - ) - - cursor.execute( - "UPDATE OR IGNORE paragraph_entities SET entity_hash = ? WHERE entity_hash = ?", - (resolved_target_hash, old_row["hash"]), - ) - cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (old_row["hash"],)) - cursor.execute( - "UPDATE relations SET subject = ? WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?))", - (target, old_row["name"]), - ) - cursor.execute( - "UPDATE relations SET object = ? WHERE LOWER(TRIM(object)) = LOWER(TRIM(?))", - (target, old_row["name"]), - ) - cursor.execute("DELETE FROM entities WHERE hash = ?", (old_row["hash"],)) - conn.commit() - except Exception as exc: - conn.rollback() - return {"success": False, "error": f"rename failed: {exc}"} - - self._rebuild_graph_from_metadata() - self._persist() - return {"success": True, "renamed": True, "old_name": source, "new_name": target} - - def _update_edge_weight( - self, - *, - relation_hash: str, - subject: str, - obj: str, - weight: float, - ) -> Dict[str, Any]: - assert self.metadata_store - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - target_weight = max(0.0, float(weight or 0.0)) - if relation_hash: - cursor.execute("UPDATE relations SET confidence = ? WHERE hash = ?", (target_weight, relation_hash)) - updated = cursor.rowcount - else: - cursor.execute( - """ - UPDATE relations - SET confidence = ? - WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) - AND LOWER(TRIM(object)) = LOWER(TRIM(?)) - """, - (target_weight, subject, obj), - ) - updated = cursor.rowcount - conn.commit() - if updated <= 0: - return {"success": False, "error": "未找到可更新的关系"} - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "updated": int(updated), - "weight": target_weight, - "hash": relation_hash, - "subject": subject, - "object": obj, - } - - @staticmethod - def _tokens(values: Optional[Iterable[Any]]) -> List[str]: - result: List[str] = [] - seen = set() - for item in values or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - result.append(token) - return result - - @classmethod - def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]: - merged: List[str] = [] - seen = set() - for group in groups: - for item in cls._tokens(group): - if item in seen: - continue - seen.add(item) - merged.append(item) - return merged - - @staticmethod - def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str: - clean_type = str(source_type or "").strip() or "memory" - if clean_type == "chat_summary" and chat_id: - return f"chat_summary:{chat_id}" - if clean_type == "person_fact" and person_ids: - return f"person_fact:{person_ids[0]}" - return f"{clean_type}:{chat_id}" if chat_id else clean_type - - @staticmethod - def _chat_source(chat_id: str) -> Optional[str]: - clean = str(chat_id or "").strip() - return f"chat_summary:{clean}" if clean else None - - @staticmethod - def _resolve_knowledge_type(source_type: str) -> str: - clean_type = str(source_type or "").strip().lower() - if clean_type == "person_fact": - return "factual" - if clean_type == "chat_summary": - return "narrative" - return "mixed" - - @staticmethod - def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: - payload: Dict[str, Any] = {} - if timestamp is not None: - payload["event_time"] = float(timestamp) - if time_start is not None: - payload["event_time_start"] = float(time_start) - if time_end is not None: - payload["event_time_end"] = float(time_end) - if payload: - payload["time_granularity"] = "minute" - payload["time_confidence"] = 0.95 - return payload - - @classmethod - def _normalize_search_time_bound(cls, value: Any, *, is_end: bool) -> tuple[Optional[float], Optional[str]]: - if value in {None, ""}: - return None, None - if isinstance(value, (int, float)): - ts = float(value) - return ts, format_timestamp(ts) - - text = str(value or "").strip() - if not text: - return None, None - - numeric = cls._optional_float(text) - if numeric is not None: - return numeric, format_timestamp(numeric) - - try: - ts = parse_query_datetime_to_timestamp(text, is_end=is_end) - except ValueError as exc: - raise ValueError(f"时间参数错误: {exc}") from exc - return ts, text - - @classmethod - def _normalize_search_time_window(cls, time_start: Any, time_end: Any) -> _NormalizedSearchTimeWindow: - numeric_start, query_start = cls._normalize_search_time_bound(time_start, is_end=False) - numeric_end, query_end = cls._normalize_search_time_bound(time_end, is_end=True) - if numeric_start is not None and numeric_end is not None and numeric_start > numeric_end: - raise ValueError("时间参数错误: time_start 不能晚于 time_end") - return _NormalizedSearchTimeWindow( - numeric_start=numeric_start, - numeric_end=numeric_end, - query_start=query_start, - query_end=query_end, - ) - - @staticmethod - def _retrieval_result_hit(item: RetrievalResult) -> Dict[str, Any]: - payload = item.to_dict() - return { - "hash": payload.get("hash", ""), - "content": payload.get("content", ""), - "score": payload.get("score", 0.0), - "type": payload.get("type", ""), - "source": payload.get("source", ""), - "metadata": payload.get("metadata", {}) or {}, - } - - @staticmethod - def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: - return { - "type": "episode", - "episode_id": str(row.get("episode_id", "") or ""), - "title": str(row.get("title", "") or ""), - "content": str(row.get("summary", "") or ""), - "score": float(row.get("lexical_score", 0.0) or 0.0), - "source": "episode", - "metadata": { - "participants": row.get("participants", []) or [], - "keywords": row.get("keywords", []) or [], - "source": row.get("source"), - "event_time_start": row.get("event_time_start"), - "event_time_end": row.get("event_time_end"), - }, - } - - @staticmethod - def _summary(hits: Sequence[Dict[str, Any]]) -> str: - if not hits: - return "" - lines = [] - for index, item in enumerate(hits[:5], start=1): - content = str(item.get("content", "") or "").strip().replace("\n", " ") - lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}") - return "\n".join(lines) - - @staticmethod - def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]: - if not person_id: - return hits - filtered = [] - for item in hits: - metadata = item.get("metadata", {}) or {} - if person_id in (metadata.get("person_ids", []) or []): - filtered.append(item) - continue - if person_id and person_id in str(item.get("content", "") or ""): - filtered.append(item) - return filtered or hits - - def _resolve_relation_hashes(self, target: str) -> List[str]: - assert self.metadata_store - token = str(target or "").strip() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): - return [token] - hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) - if hashes: - return hashes - return [ - str(row.get("hash", "") or "") - for row in self.metadata_store.get_relations(subject=token)[:10] - if str(row.get("hash", "")).strip() - ] - - def _resolve_deleted_relation_hashes(self, target: str) -> List[str]: - assert self.metadata_store - token = str(target or "").strip() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): - return [token] - return self.metadata_store.search_deleted_relation_hashes_by_text(token, limit=10) - - def _memory_v5_status(self, *, target: str = "", limit: int = 50) -> Dict[str, Any]: - assert self.metadata_store - now = time.time() - summary = self.metadata_store.get_memory_status_summary(now) - payload: Dict[str, Any] = { - "success": True, - **summary, - "config": { - "half_life_hours": float(self._cfg("memory.half_life_hours", 24.0) or 24.0), - "base_decay_interval_hours": float(self._cfg("memory.base_decay_interval_hours", 1.0) or 1.0), - "prune_threshold": float(self._cfg("memory.prune_threshold", 0.1) or 0.1), - "freeze_duration_hours": float(self._cfg("memory.freeze_duration_hours", 24.0) or 24.0), - }, - "last_maintenance_at": self._last_maintenance_at, - } - token = str(target or "").strip() - if not token: - return payload - - active_hashes = self._resolve_relation_hashes(token)[:limit] - deleted_hashes = self._resolve_deleted_relation_hashes(token)[:limit] - active_statuses = self.metadata_store.get_relation_status_batch(active_hashes) - items: List[Dict[str, Any]] = [] - for hash_value in active_hashes: - relation = self.metadata_store.get_relation(hash_value) or {} - status = active_statuses.get(hash_value, {}) - items.append( - { - "hash": hash_value, - "subject": str(relation.get("subject", "") or ""), - "predicate": str(relation.get("predicate", "") or ""), - "object": str(relation.get("object", "") or ""), - "state": "inactive" if bool(status.get("is_inactive")) else "active", - "is_pinned": bool(status.get("is_pinned", False)), - "temp_protected": bool(float(status.get("protected_until") or 0.0) > now), - "protected_until": status.get("protected_until"), - "last_reinforced": status.get("last_reinforced"), - "weight": float(status.get("weight", relation.get("confidence", 0.0)) or 0.0), - } - ) - for hash_value in deleted_hashes: - relation = self.metadata_store.get_deleted_relation(hash_value) or {} - items.append( - { - "hash": hash_value, - "subject": str(relation.get("subject", "") or ""), - "predicate": str(relation.get("predicate", "") or ""), - "object": str(relation.get("object", "") or ""), - "state": "deleted", - "is_pinned": bool(relation.get("is_pinned", False)), - "temp_protected": False, - "protected_until": relation.get("protected_until"), - "last_reinforced": relation.get("last_reinforced"), - "weight": float(relation.get("confidence", 0.0) or 0.0), - "deleted_at": relation.get("deleted_at"), - } - ) - payload["items"] = items[:limit] - payload["count"] = len(payload["items"]) - payload["target"] = token - return payload - - def _adjust_relation_confidence(self, hashes: List[str], *, delta: float) -> Dict[str, float]: - assert self.metadata_store - normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] - if not normalized: - return {} - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - chunk_size = 200 - for index in range(0, len(normalized), chunk_size): - chunk = normalized[index : index + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE relations - SET confidence = MAX(0.0, COALESCE(confidence, 0.0) + ?) - WHERE hash IN ({placeholders}) - """, - tuple([float(delta)] + chunk), - ) - conn.commit() - statuses = self.metadata_store.get_relation_status_batch(normalized) - return {hash_value: float((statuses.get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} - - def _apply_v5_relation_action(self, *, action: str, hashes: List[str], strength: float = 1.0) -> Dict[str, Any]: - assert self.metadata_store - act = str(action or "").strip().lower() - normalized = [str(item or "").strip() for item in hashes if str(item or "").strip()] - if not normalized: - return {"success": False, "error": "未命中可维护关系"} - - now = time.time() - strength_value = max(0.1, float(strength or 1.0)) - prune_threshold = max(0.0, float(self._cfg("memory.prune_threshold", 0.1) or 0.1)) - detail = "" - - if act == "reinforce": - weights = self._adjust_relation_confidence(normalized, delta=0.5 * strength_value) - protect_hours = max(1.0, 24.0 * strength_value) - self.metadata_store.reinforce_relations(normalized) - self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) - self.metadata_store.update_relations_protection( - normalized, - protected_until=now + protect_hours * 3600.0, - last_reinforced=now, - ) - detail = f"reinforce {len(normalized)} 条关系" - elif act == "weaken": - weights = self._adjust_relation_confidence(normalized, delta=-0.5 * strength_value) - to_freeze = [hash_value for hash_value, weight in weights.items() if weight <= prune_threshold] - if to_freeze: - self.metadata_store.mark_relations_inactive(to_freeze, inactive_since=now) - detail = f"weaken {len(normalized)} 条关系" - elif act == "remember_forever": - self.metadata_store.mark_relations_active(normalized, boost_weight=max(prune_threshold, 0.1)) - self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=True) - weights = {hash_value: float((self.metadata_store.get_relation_status_batch([hash_value]).get(hash_value) or {}).get("weight", 0.0) or 0.0) for hash_value in normalized} - detail = f"remember_forever {len(normalized)} 条关系" - elif act == "forget": - weights = self._adjust_relation_confidence(normalized, delta=-2.0 * strength_value) - self.metadata_store.update_relations_protection(normalized, protected_until=0.0, is_pinned=False) - self.metadata_store.mark_relations_inactive(normalized, inactive_since=now) - detail = f"forget {len(normalized)} 条关系" - else: - return {"success": False, "error": f"不支持的 V5 动作: {act}"} - - self._rebuild_graph_from_metadata() - self._last_maintenance_at = now - self._persist() - statuses = self.metadata_store.get_relation_status_batch(normalized) - return { - "success": True, - "detail": detail, - "hashes": normalized, - "count": len(normalized), - "weights": weights, - "statuses": statuses, - } - - async def _ensure_vector_for_text(self, *, item_hash: str, text: str) -> bool: - if self.vector_store is None or self.embedding_manager is None: - return False - token = str(item_hash or "").strip() - content = str(text or "").strip() - if not token or not content: - return False - embedding = await self.embedding_manager.encode([content], dimensions=self.embedding_dimension) - if getattr(embedding, "ndim", 1) == 1: - embedding = embedding.reshape(1, -1) - if getattr(embedding, "size", 0) <= 0: - return False - try: - self.vector_store.add(embedding, [token]) - return True - except Exception as exc: - logger.warning(f"重建向量失败: {exc}") - return False - - async def _ensure_relation_vector(self, relation: Dict[str, Any]) -> bool: - if not bool(self.relation_vectors_enabled): - return False - return await self._ensure_vector_for_text( - item_hash=str(relation.get("hash", "") or ""), - text=" ".join( - [ - str(relation.get("subject", "") or "").strip(), - str(relation.get("predicate", "") or "").strip(), - str(relation.get("object", "") or "").strip(), - ] - ).strip(), - ) - - async def _ensure_paragraph_vector(self, paragraph: Dict[str, Any]) -> bool: - return await self._ensure_vector_for_text( - item_hash=str(paragraph.get("hash", "") or ""), - text=str(paragraph.get("content", "") or ""), - ) - - async def _ensure_entity_vector(self, entity: Dict[str, Any]) -> bool: - return await self._ensure_vector_for_text( - item_hash=str(entity.get("hash", "") or ""), - text=str(entity.get("name", "") or ""), - ) - - async def _restore_relation_hashes( - self, - hashes: List[str], - *, - payloads: Optional[Dict[str, Dict[str, Any]]] = None, - rebuild_graph: bool = True, - persist: bool = True, - ) -> Dict[str, Any]: - assert self.metadata_store - restored: List[str] = [] - failures: List[Dict[str, str]] = [] - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - payload_map = payloads or {} - for hash_value in [str(item or "").strip() for item in hashes if str(item or "").strip()]: - relation = self.metadata_store.restore_relation(hash_value) - if relation is None: - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - failures.append({"hash": hash_value, "error": "relation 不存在"}) - continue - payload = payload_map.get(hash_value) if isinstance(payload_map.get(hash_value), dict) else {} - paragraph_hashes = self._tokens(payload.get("paragraph_hashes")) - for paragraph_hash in paragraph_hashes: - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - (paragraph_hash, hash_value), - ) - await self._ensure_relation_vector({**relation, "hash": hash_value}) - restored.append(hash_value) - conn.commit() - if restored and rebuild_graph: - self._rebuild_graph_from_metadata() - if restored and persist: - self._persist() - return {"restored_hashes": restored, "restored_count": len(restored), "failures": failures} - - @staticmethod - def _selector_dict(selector: Any) -> Dict[str, Any]: - if isinstance(selector, dict): - return dict(selector) - if isinstance(selector, (list, tuple)): - return {"items": list(selector)} - token = str(selector or "").strip() - return {"query": token} if token else {} - - def _resolve_paragraph_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: - assert self.metadata_store - raw = self._selector_dict(selector) - rows: List[Dict[str, Any]] = [] - hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) - for hash_value in hashes: - row = self.metadata_store.get_paragraph(hash_value) - if row is None: - continue - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(row) - if rows: - return rows - query = str(raw.get("query", "") or raw.get("content", "") or "").strip() - if not query: - return [] - if len(query) == 64 and all(ch in "0123456789abcdef" for ch in query.lower()): - row = self.metadata_store.get_paragraph(query) - if row is None: - return [] - if not include_deleted and bool(row.get("is_deleted", 0)): - return [] - return [row] - matches = self.metadata_store.search_paragraphs_by_content(query) - return [row for row in matches if include_deleted or not bool(row.get("is_deleted", 0))] - - def _resolve_entity_targets(self, selector: Any, *, include_deleted: bool = False) -> List[Dict[str, Any]]: - assert self.metadata_store - raw = self._selector_dict(selector) - rows: List[Dict[str, Any]] = [] - hashes = self._merge_tokens(raw.get("hashes"), raw.get("items"), [raw.get("hash")]) - for hash_value in hashes: - row = self.metadata_store.get_entity(hash_value) - if row is None: - continue - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(row) - names = self._merge_tokens(raw.get("names"), [raw.get("name")], [raw.get("query")]) - for name in names: - if not name: - continue - matches = self.metadata_store.query( - """ - SELECT * - FROM entities - WHERE LOWER(TRIM(name)) = LOWER(TRIM(?)) - OR hash = ? - ORDER BY appearance_count DESC, created_at ASC - """, - (name, compute_hash(str(name).strip().lower())), - ) - for row in matches: - if not include_deleted and bool(row.get("is_deleted", 0)): - continue - rows.append(self.metadata_store._row_to_dict(row, "entity") if hasattr(self.metadata_store, "_row_to_dict") else row) - dedup: Dict[str, Dict[str, Any]] = {} - for row in rows: - token = str(row.get("hash", "") or "").strip() - if token and token not in dedup: - dedup[token] = row - return list(dedup.values()) - - def _resolve_source_targets(self, selector: Any) -> List[str]: - raw = self._selector_dict(selector) - return self._merge_tokens(raw.get("sources"), [raw.get("source")], [raw.get("query")], raw.get("items")) - - def _snapshot_relation_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - relation = self.metadata_store.get_relation(hash_value) - if relation is None: - relation = self.metadata_store.get_deleted_relation(hash_value) - if relation is None: - return None - paragraph_hashes = [ - str(row.get("paragraph_hash", "") or "").strip() - for row in self.metadata_store.query( - "SELECT paragraph_hash FROM paragraph_relations WHERE relation_hash = ? ORDER BY paragraph_hash ASC", - (hash_value,), - ) - if str(row.get("paragraph_hash", "") or "").strip() - ] - return { - "item_type": "relation", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "relation": relation, - "paragraph_hashes": paragraph_hashes, - }, - } - - def _snapshot_paragraph_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - paragraph = self.metadata_store.get_paragraph(hash_value) - if paragraph is None: - return None - entity_links = [ - { - "paragraph_hash": hash_value, - "entity_hash": str(row.get("entity_hash", "") or ""), - "mention_count": int(row.get("mention_count", 1) or 1), - } - for row in self.metadata_store.query( - """ - SELECT paragraph_hash, entity_hash, mention_count - FROM paragraph_entities - WHERE paragraph_hash = ? - ORDER BY entity_hash ASC - """, - (hash_value,), - ) - ] - relation_hashes = [ - str(row.get("relation_hash", "") or "").strip() - for row in self.metadata_store.query( - """ - SELECT relation_hash - FROM paragraph_relations - WHERE paragraph_hash = ? - ORDER BY relation_hash ASC - """, - (hash_value,), - ) - if str(row.get("relation_hash", "") or "").strip() - ] - return { - "item_type": "paragraph", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "paragraph": paragraph, - "entity_links": entity_links, - "relation_hashes": relation_hashes, - "external_refs": self.metadata_store.list_external_memory_refs_by_paragraphs([hash_value]), - }, - } - - def _snapshot_entity_item(self, hash_value: str) -> Optional[Dict[str, Any]]: - assert self.metadata_store - entity = self.metadata_store.get_entity(hash_value) - if entity is None: - return None - paragraph_links = [ - { - "paragraph_hash": str(row.get("paragraph_hash", "") or ""), - "entity_hash": hash_value, - "mention_count": int(row.get("mention_count", 1) or 1), - } - for row in self.metadata_store.query( - """ - SELECT paragraph_hash, mention_count - FROM paragraph_entities - WHERE entity_hash = ? - ORDER BY paragraph_hash ASC - """, - (hash_value,), - ) - ] - return { - "item_type": "entity", - "item_hash": hash_value, - "item_key": hash_value, - "payload": { - "entity": entity, - "paragraph_links": paragraph_links, - }, - } - - def _relation_has_remaining_paragraphs(self, relation_hash: str, removing_hashes: Sequence[str]) -> bool: - assert self.metadata_store - excluded = [str(item or "").strip() for item in removing_hashes if str(item or "").strip()] - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - if excluded: - placeholders = ",".join(["?"] * len(excluded)) - cursor.execute( - f""" - SELECT 1 - FROM paragraph_relations pr - JOIN paragraphs p ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - AND pr.paragraph_hash NOT IN ({placeholders}) - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - LIMIT 1 - """, - tuple([relation_hash] + excluded), - ) - else: - cursor.execute( - """ - SELECT 1 - FROM paragraph_relations pr - JOIN paragraphs p ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - LIMIT 1 - """, - (relation_hash,), - ) - return cursor.fetchone() is not None - - async def _build_delete_plan(self, *, mode: str, selector: Any) -> Dict[str, Any]: - assert self.metadata_store - act_mode = str(mode or "").strip().lower() - normalized_selector = self._selector_dict(selector) - items: List[Dict[str, Any]] = [] - counts = {"relations": 0, "paragraphs": 0, "entities": 0, "sources": 0} - vector_ids: List[str] = [] - sources: List[str] = [] - target_hashes: Dict[str, List[str]] = { - "relations": [], - "paragraphs": [], - "entities": [], - "sources": [], - "matched_sources": [], - } - - if act_mode == "relation": - relation_rows = [row for row in (self.metadata_store.get_relation(hash_value) for hash_value in self._resolve_relation_hashes(str(normalized_selector.get("query", "") or ""))) if row] - if normalized_selector.get("hashes"): - relation_rows = [ - row - for hash_value in self._tokens(normalized_selector.get("hashes")) - for row in [self.metadata_store.get_relation(hash_value)] - if row is not None - ] - dedup_hashes: List[str] = [] - seen = set() - for row in relation_rows: - hash_value = str(row.get("hash", "") or "").strip() - if hash_value and hash_value not in seen: - seen.add(hash_value) - dedup_hashes.append(hash_value) - snap = self._snapshot_relation_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - counts["relations"] = len(dedup_hashes) - target_hashes["relations"] = dedup_hashes - - elif act_mode in {"paragraph", "source"}: - paragraph_rows: List[Dict[str, Any]] = [] - if act_mode == "source": - source_tokens = self._resolve_source_targets(normalized_selector) - target_hashes["sources"] = source_tokens - counts["requested_sources"] = len(source_tokens) - matched_source_tokens: List[str] = [] - for source in source_tokens: - source_rows = self.metadata_store.query( - """ - SELECT * - FROM paragraphs - WHERE source = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY created_at ASC - """, - (source,), - ) - if source_rows: - matched_source_tokens.append(source) - sources.append(source) - paragraph_rows.extend(source_rows) - target_hashes["matched_sources"] = matched_source_tokens - counts["sources"] = len(matched_source_tokens) - counts["matched_sources"] = len(matched_source_tokens) - else: - paragraph_rows = self._resolve_paragraph_targets(normalized_selector, include_deleted=False) - paragraph_hashes = self._tokens([row.get("hash", "") for row in paragraph_rows]) - target_hashes["paragraphs"] = paragraph_hashes - counts["paragraphs"] = len(paragraph_hashes) - for hash_value in paragraph_hashes: - snap = self._snapshot_paragraph_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - paragraph = snap["payload"].get("paragraph") or {} - source = str(paragraph.get("source", "") or "").strip() - if source: - sources.append(source) - - orphan_relations: List[str] = [] - for item in items: - if item.get("item_type") != "paragraph": - continue - for relation_hash in self._tokens((item.get("payload") or {}).get("relation_hashes")): - if relation_hash in orphan_relations: - continue - if not self._relation_has_remaining_paragraphs(relation_hash, paragraph_hashes): - orphan_relations.append(relation_hash) - for relation_hash in orphan_relations: - snap = self._snapshot_relation_item(relation_hash) - if snap: - items.append(snap) - vector_ids.append(relation_hash) - target_hashes["relations"] = orphan_relations - counts["relations"] = len(orphan_relations) - - elif act_mode == "entity": - entity_rows = self._resolve_entity_targets(normalized_selector, include_deleted=False) - entity_hashes = self._tokens([row.get("hash", "") for row in entity_rows]) - target_hashes["entities"] = entity_hashes - counts["entities"] = len(entity_hashes) - entity_names = [str(row.get("name", "") or "").strip() for row in entity_rows if str(row.get("name", "") or "").strip()] - for hash_value in entity_hashes: - snap = self._snapshot_entity_item(hash_value) - if snap: - items.append(snap) - vector_ids.append(hash_value) - relation_hashes: List[str] = [] - for entity_name in entity_names: - for relation in self.metadata_store.get_relations(subject=entity_name) + self.metadata_store.get_relations(object=entity_name): - hash_value = str(relation.get("hash", "") or "").strip() - if hash_value and hash_value not in relation_hashes: - relation_hashes.append(hash_value) - for relation_hash in relation_hashes: - snap = self._snapshot_relation_item(relation_hash) - if snap: - items.append(snap) - vector_ids.append(relation_hash) - target_hashes["relations"] = relation_hashes - counts["relations"] = len(relation_hashes) - else: - return {"success": False, "error": f"不支持的 delete mode: {act_mode}"} - - sources = self._tokens(sources) - vector_ids = self._tokens(vector_ids) - primary_count = counts.get(f"{act_mode}s", 0) if act_mode != "source" else counts.get("matched_sources", 0) - success = ( - primary_count > 0 or counts.get("paragraphs", 0) > 0 or counts.get("relations", 0) > 0 - if act_mode != "source" - else (counts.get("matched_sources", 0) > 0 and counts.get("paragraphs", 0) > 0) - ) - return { - "success": success, - "mode": act_mode, - "selector": normalized_selector, - "items": items, - "counts": counts, - "vector_ids": vector_ids, - "sources": sources, - "target_hashes": target_hashes, - "requested_source_count": counts.get("requested_sources", 0) if act_mode == "source" else 0, - "matched_source_count": counts.get("matched_sources", 0) if act_mode == "source" else 0, - "error": "" if success else "未命中可删除内容", - } - - async def _preview_delete_action(self, *, mode: str, selector: Any) -> Dict[str, Any]: - plan = await self._build_delete_plan(mode=mode, selector=selector) - if not plan.get("success", False): - return {"success": False, "error": plan.get("error", "未命中可删除内容")} - preview_items = [ - { - "item_type": str(item.get("item_type", "") or ""), - "item_hash": str(item.get("item_hash", "") or ""), - } - for item in plan.get("items", [])[:100] - ] - return { - "success": True, - "mode": plan.get("mode"), - "selector": plan.get("selector"), - "counts": plan.get("counts", {}), - "requested_source_count": int(plan.get("requested_source_count", 0) or 0), - "matched_source_count": int(plan.get("matched_source_count", 0) or 0), - "sources": plan.get("sources", []), - "vector_ids": plan.get("vector_ids", []), - "items": preview_items, - "item_count": len(plan.get("items", [])), - "dry_run": True, - } - - async def _execute_delete_action( - self, - *, - mode: str, - selector: Any, - requested_by: str = "", - reason: str = "", - ) -> Dict[str, Any]: - assert self.metadata_store - plan = await self._build_delete_plan(mode=mode, selector=selector) - if not plan.get("success", False): - return {"success": False, "error": plan.get("error", "未命中可删除内容")} - - act_mode = str(plan.get("mode", "") or "").strip().lower() - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - paragraph_hashes = self._tokens((plan.get("target_hashes") or {}).get("paragraphs")) - entity_hashes = self._tokens((plan.get("target_hashes") or {}).get("entities")) - relation_hashes = self._tokens((plan.get("target_hashes") or {}).get("relations")) - requested_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("sources")) - matched_source_tokens = self._tokens((plan.get("target_hashes") or {}).get("matched_sources")) - - try: - if paragraph_hashes: - self.metadata_store.mark_as_deleted(paragraph_hashes, "paragraph") - cursor.execute( - f"DELETE FROM paragraph_entities WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", - tuple(paragraph_hashes), - ) - cursor.execute( - f"DELETE FROM paragraph_relations WHERE paragraph_hash IN ({','.join(['?'] * len(paragraph_hashes))})", - tuple(paragraph_hashes), - ) - self.metadata_store.delete_external_memory_refs_by_paragraphs(paragraph_hashes) - if act_mode == "source" and matched_source_tokens: - for source in matched_source_tokens: - self.metadata_store.replace_episodes_for_source(source, []) - - if entity_hashes: - self.metadata_store.mark_as_deleted(entity_hashes, "entity") - cursor.execute( - f"DELETE FROM paragraph_entities WHERE entity_hash IN ({','.join(['?'] * len(entity_hashes))})", - tuple(entity_hashes), - ) - - conn.commit() - - deleted_relations = self.metadata_store.backup_and_delete_relations(relation_hashes) - deleted_vectors = 0 - if self.vector_store is not None and plan.get("vector_ids"): - deleted_vectors = self.vector_store.delete(list(plan.get("vector_ids") or [])) - - operation = self.metadata_store.create_delete_operation( - mode=act_mode, - selector=plan.get("selector"), - items=plan.get("items", []), - reason=reason, - requested_by=requested_by, - summary={ - "counts": plan.get("counts", {}), - "sources": plan.get("sources", []), - "vector_ids": plan.get("vector_ids", []), - "deleted_relation_rows": deleted_relations, - }, - ) - - if plan.get("sources"): - self.metadata_store._enqueue_episode_source_rebuilds(list(plan.get("sources") or []), reason="delete_admin_execute") - self._rebuild_graph_from_metadata() - self._persist() - deleted_count = ( - len(paragraph_hashes) - if act_mode == "source" - else len(paragraph_hashes) - if act_mode == "paragraph" - else len(entity_hashes) - if act_mode == "entity" - else len(relation_hashes) - ) - success = bool(deleted_count > 0) - result = { - "success": success, - "mode": act_mode, - "operation_id": operation.get("operation_id", ""), - "counts": plan.get("counts", {}), - "sources": plan.get("sources", []), - "deleted_count": deleted_count, - "deleted_vector_count": int(deleted_vectors or 0), - "deleted_relation_count": len(relation_hashes), - } - if act_mode == "source": - result["requested_source_count"] = len(requested_source_tokens) - result["matched_source_count"] = len(matched_source_tokens) - result["deleted_source_count"] = len(matched_source_tokens) - result["deleted_paragraph_count"] = len(paragraph_hashes) - if not success: - result["error"] = "未命中可删除内容" - return result - except Exception as exc: - conn.rollback() - logger.warning(f"delete_admin execute 失败: {exc}") - return {"success": False, "error": str(exc)} - - async def _restore_delete_action( - self, - *, - mode: str, - selector: Any, - operation_id: str = "", - requested_by: str = "", - reason: str = "", - ) -> Dict[str, Any]: - del requested_by - del reason - assert self.metadata_store - - op_id = str(operation_id or "").strip() - if op_id: - operation = self.metadata_store.get_delete_operation(op_id) - if operation is None: - return {"success": False, "error": "operation 不存在"} - return await self._restore_delete_operation(operation) - - act_mode = str(mode or "").strip().lower() - if act_mode != "relation": - return {"success": False, "error": "paragraph/entity/source 恢复必须提供 operation_id"} - - raw = self._selector_dict(selector) - target = str(raw.get("query", "") or raw.get("target", "") or raw.get("hash", "") or "").strip() - hashes = self._resolve_deleted_relation_hashes(target) - if not hashes: - return {"success": False, "error": "未命中可恢复关系"} - result = await self._restore_relation_hashes(hashes) - return {"success": bool(result.get("restored_count", 0) > 0), **result} - - async def _restore_delete_operation(self, operation: Dict[str, Any]) -> Dict[str, Any]: - assert self.metadata_store - items = operation.get("items") if isinstance(operation.get("items"), list) else [] - entity_payloads: Dict[str, Dict[str, Any]] = {} - paragraph_payloads: Dict[str, Dict[str, Any]] = {} - relation_payloads: Dict[str, Dict[str, Any]] = {} - for item in items: - if not isinstance(item, dict): - continue - item_type = str(item.get("item_type", "") or "").strip() - item_hash = str(item.get("item_hash", "") or "").strip() - payload = item.get("payload") if isinstance(item.get("payload"), dict) else {} - if item_type == "entity" and item_hash: - entity_payloads[item_hash] = payload - elif item_type == "paragraph" and item_hash: - paragraph_payloads[item_hash] = payload - elif item_type == "relation" and item_hash: - relation_payloads[item_hash] = payload - - restored_entities: List[str] = [] - restored_paragraphs: List[str] = [] - for hash_value, payload in entity_payloads.items(): - entity_row = payload.get("entity") if isinstance(payload.get("entity"), dict) else {} - if entity_row: - self.metadata_store.restore_entity_by_hash(hash_value) - await self._ensure_entity_vector(entity_row) - restored_entities.append(hash_value) - for hash_value, payload in paragraph_payloads.items(): - paragraph_row = payload.get("paragraph") if isinstance(payload.get("paragraph"), dict) else {} - if paragraph_row: - self.metadata_store.restore_paragraph_by_hash(hash_value) - await self._ensure_paragraph_vector(paragraph_row) - restored_paragraphs.append(hash_value) - - restored_relations = await self._restore_relation_hashes(list(relation_payloads.keys()), payloads=relation_payloads, rebuild_graph=False, persist=False) - - conn = self.metadata_store.get_connection() - cursor = conn.cursor() - for payload in entity_payloads.values(): - for link in payload.get("paragraph_links") or []: - paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() - entity_hash = str(link.get("entity_hash", "") or "").strip() - mention_count = max(1, int(link.get("mention_count", 1) or 1)) - if not paragraph_hash or not entity_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - (paragraph_hash, entity_hash, mention_count), - ) - for payload in paragraph_payloads.values(): - for link in payload.get("entity_links") or []: - paragraph_hash = str(link.get("paragraph_hash", "") or "").strip() - entity_hash = str(link.get("entity_hash", "") or "").strip() - mention_count = max(1, int(link.get("mention_count", 1) or 1)) - if not paragraph_hash or not entity_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_entities (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - (paragraph_hash, entity_hash, mention_count), - ) - for relation_hash in self._tokens(payload.get("relation_hashes")): - paragraph_hash = str((payload.get("paragraph") or {}).get("hash", "") or "").strip() - if not paragraph_hash or not relation_hash: - continue - cursor.execute( - """ - INSERT OR IGNORE INTO paragraph_relations (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - (paragraph_hash, relation_hash), - ) - self.metadata_store.restore_external_memory_refs(list(payload.get("external_refs") or [])) - conn.commit() - - sources = self._tokens( - [ - str(((payload.get("paragraph") or {}).get("source", "") or "")).strip() - for payload in paragraph_payloads.values() - ] - ) - if sources: - self.metadata_store._enqueue_episode_source_rebuilds(sources, reason="delete_admin_restore") - self._rebuild_graph_from_metadata() - self._persist() - summary = { - "restored_entities": restored_entities, - "restored_paragraphs": restored_paragraphs, - "restored_relations": restored_relations.get("restored_hashes", []), - "sources": sources, - } - self.metadata_store.mark_delete_operation_restored(str(operation.get("operation_id", "") or ""), summary=summary) - return { - "success": True, - "operation_id": str(operation.get("operation_id", "") or ""), - **summary, - "restored_relation_count": restored_relations.get("restored_count", 0), - "relation_failures": restored_relations.get("failures", []), - } - - async def _purge_deleted_memory(self, *, grace_hours: Optional[float], limit: int) -> Dict[str, Any]: - assert self.metadata_store - orphan_cfg = self._cfg("memory.orphan", {}) or {} - grace = float(grace_hours) if grace_hours is not None else max( - 1.0, - float(orphan_cfg.get("sweep_grace_hours", 24.0) or 24.0), - ) - cutoff = time.time() - grace * 3600.0 - deleted_relation_hashes = self.metadata_store.purge_deleted_relations(cutoff_time=cutoff, limit=limit) - dead_paragraphs = self.metadata_store.sweep_deleted_items("paragraph", grace * 3600.0) - paragraph_hashes = [str(item[0] or "").strip() for item in dead_paragraphs if str(item[0] or "").strip()] - dead_entities = self.metadata_store.sweep_deleted_items("entity", grace * 3600.0) - entity_hashes = [str(item[0] or "").strip() for item in dead_entities if str(item[0] or "").strip()] - entity_names = [str(item[1] or "").strip() for item in dead_entities if str(item[1] or "").strip()] - - if paragraph_hashes: - self.metadata_store.physically_delete_paragraphs(paragraph_hashes) - if entity_hashes: - self.metadata_store.physically_delete_entities(entity_hashes) - if entity_names: - self.graph_store.delete_nodes(entity_names) - if self.vector_store is not None: - vector_ids = self._merge_tokens(paragraph_hashes, entity_hashes, deleted_relation_hashes) - if vector_ids: - self.vector_store.delete(vector_ids) - self._rebuild_graph_from_metadata() - self._persist() - return { - "success": True, - "grace_hours": grace, - "purged_deleted_relations": deleted_relation_hashes, - "purged_paragraph_hashes": paragraph_hashes, - "purged_entity_hashes": entity_hashes, - "purged_counts": { - "relations": len(deleted_relation_hashes), - "paragraphs": len(paragraph_hashes), - "entities": len(entity_hashes), - }, - } - - @staticmethod - def _optional_float(value: Any) -> Optional[float]: - if value in {None, ""}: - return None - try: - return float(value) - except Exception: - return None - - @staticmethod - def _optional_int(value: Any) -> Optional[int]: - if value in {None, ""}: - return None - try: - return int(value) - except Exception: - return None diff --git a/plugins/A_memorix/core/runtime/search_runtime_initializer.py b/plugins/A_memorix/core/runtime/search_runtime_initializer.py deleted file mode 100644 index c3c7a81f..00000000 --- a/plugins/A_memorix/core/runtime/search_runtime_initializer.py +++ /dev/null @@ -1,240 +0,0 @@ -"""Shared runtime initializer for Action/Tool/Command retrieval components.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from src.common.logger import get_logger - -from ..retrieval import ( - DualPathRetriever, - DualPathRetrieverConfig, - DynamicThresholdFilter, - FusionConfig, - GraphRelationRecallConfig, - RelationIntentConfig, - RetrievalStrategy, - SparseBM25Config, - ThresholdConfig, - ThresholdMethod, -) - -_logger = get_logger("A_Memorix.SearchRuntimeInitializer") - -_REQUIRED_COMPONENT_KEYS = ( - "vector_store", - "graph_store", - "metadata_store", - "embedding_manager", -) - - -def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: - if not isinstance(config, dict): - return default - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _safe_dict(value: Any) -> Dict[str, Any]: - return value if isinstance(value, dict) else {} - - -def _resolve_debug_enabled(plugin_config: Optional[dict]) -> bool: - advanced = _get_config_value(plugin_config, "advanced", {}) - if isinstance(advanced, dict): - return bool(advanced.get("debug", False)) - return bool(_get_config_value(plugin_config, "debug", False)) - - -@dataclass -class SearchRuntimeBundle: - """Resolved runtime components and initialized retriever/filter.""" - - vector_store: Optional[Any] = None - graph_store: Optional[Any] = None - metadata_store: Optional[Any] = None - embedding_manager: Optional[Any] = None - sparse_index: Optional[Any] = None - retriever: Optional[DualPathRetriever] = None - threshold_filter: Optional[DynamicThresholdFilter] = None - error: str = "" - - @property - def ready(self) -> bool: - return ( - self.retriever is not None - and self.vector_store is not None - and self.graph_store is not None - and self.metadata_store is not None - and self.embedding_manager is not None - ) - - -def _resolve_runtime_components(plugin_config: Optional[dict]) -> SearchRuntimeBundle: - bundle = SearchRuntimeBundle( - vector_store=_get_config_value(plugin_config, "vector_store"), - graph_store=_get_config_value(plugin_config, "graph_store"), - metadata_store=_get_config_value(plugin_config, "metadata_store"), - embedding_manager=_get_config_value(plugin_config, "embedding_manager"), - sparse_index=_get_config_value(plugin_config, "sparse_index"), - ) - - missing_required = any( - getattr(bundle, key) is None for key in _REQUIRED_COMPONENT_KEYS - ) - if not missing_required: - return bundle - - try: - from ...plugin import AMemorixPlugin - - instances = AMemorixPlugin.get_storage_instances() - except Exception: - instances = {} - - if not isinstance(instances, dict) or not instances: - return bundle - - if bundle.vector_store is None: - bundle.vector_store = instances.get("vector_store") - if bundle.graph_store is None: - bundle.graph_store = instances.get("graph_store") - if bundle.metadata_store is None: - bundle.metadata_store = instances.get("metadata_store") - if bundle.embedding_manager is None: - bundle.embedding_manager = instances.get("embedding_manager") - if bundle.sparse_index is None: - bundle.sparse_index = instances.get("sparse_index") - return bundle - - -def build_search_runtime( - plugin_config: Optional[dict], - logger_obj: Optional[Any], - owner_tag: str, - *, - log_prefix: str = "", -) -> SearchRuntimeBundle: - """Build retriever + threshold filter with unified fallback/config parsing.""" - - log = logger_obj or _logger - owner = str(owner_tag or "runtime").strip().lower() or "runtime" - prefix = str(log_prefix or "").strip() - prefix_text = f"{prefix} " if prefix else "" - - runtime = _resolve_runtime_components(plugin_config) - if any(getattr(runtime, key) is None for key in _REQUIRED_COMPONENT_KEYS): - runtime.error = "存储组件未完全初始化" - log.warning(f"{prefix_text}[{owner}] 存储组件未完全初始化,无法使用检索功能") - return runtime - - sparse_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.sparse", {}) or {}) - fusion_cfg_raw = _safe_dict(_get_config_value(plugin_config, "retrieval.fusion", {}) or {}) - relation_intent_cfg_raw = _safe_dict( - _get_config_value(plugin_config, "retrieval.search.relation_intent", {}) or {} - ) - graph_recall_cfg_raw = _safe_dict( - _get_config_value(plugin_config, "retrieval.search.graph_recall", {}) or {} - ) - - try: - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] sparse 配置非法,回退默认: {e}") - sparse_cfg = SparseBM25Config() - - try: - fusion_cfg = FusionConfig(**fusion_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] fusion 配置非法,回退默认: {e}") - fusion_cfg = FusionConfig() - - try: - relation_intent_cfg = RelationIntentConfig(**relation_intent_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] relation_intent 配置非法,回退默认: {e}") - relation_intent_cfg = RelationIntentConfig() - - try: - graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) - except Exception as e: - log.warning(f"{prefix_text}[{owner}] graph_recall 配置非法,回退默认: {e}") - graph_recall_cfg = GraphRelationRecallConfig() - - try: - config = DualPathRetrieverConfig( - top_k_paragraphs=_get_config_value(plugin_config, "retrieval.top_k_paragraphs", 20), - top_k_relations=_get_config_value(plugin_config, "retrieval.top_k_relations", 10), - top_k_final=_get_config_value(plugin_config, "retrieval.top_k_final", 10), - alpha=_get_config_value(plugin_config, "retrieval.alpha", 0.5), - enable_ppr=_get_config_value(plugin_config, "retrieval.enable_ppr", True), - ppr_alpha=_get_config_value(plugin_config, "retrieval.ppr_alpha", 0.85), - ppr_timeout_seconds=_get_config_value( - plugin_config, "retrieval.ppr_timeout_seconds", 1.5 - ), - ppr_concurrency_limit=_get_config_value( - plugin_config, "retrieval.ppr_concurrency_limit", 4 - ), - enable_parallel=_get_config_value(plugin_config, "retrieval.enable_parallel", True), - retrieval_strategy=RetrievalStrategy.DUAL_PATH, - debug=_resolve_debug_enabled(plugin_config), - sparse=sparse_cfg, - fusion=fusion_cfg, - relation_intent=relation_intent_cfg, - graph_recall=graph_recall_cfg, - ) - - runtime.retriever = DualPathRetriever( - vector_store=runtime.vector_store, - graph_store=runtime.graph_store, - metadata_store=runtime.metadata_store, - embedding_manager=runtime.embedding_manager, - sparse_index=runtime.sparse_index, - config=config, - ) - - threshold_config = ThresholdConfig( - method=ThresholdMethod.ADAPTIVE, - min_threshold=_get_config_value(plugin_config, "threshold.min_threshold", 0.3), - max_threshold=_get_config_value(plugin_config, "threshold.max_threshold", 0.95), - percentile=_get_config_value(plugin_config, "threshold.percentile", 75.0), - std_multiplier=_get_config_value(plugin_config, "threshold.std_multiplier", 1.5), - min_results=_get_config_value(plugin_config, "threshold.min_results", 3), - enable_auto_adjust=_get_config_value(plugin_config, "threshold.enable_auto_adjust", True), - ) - runtime.threshold_filter = DynamicThresholdFilter(threshold_config) - runtime.error = "" - log.info(f"{prefix_text}[{owner}] 检索运行时初始化完成") - except Exception as e: - runtime.retriever = None - runtime.threshold_filter = None - runtime.error = str(e) - log.error(f"{prefix_text}[{owner}] 检索运行时初始化失败: {e}") - - return runtime - - -class SearchRuntimeInitializer: - """Compatibility wrapper around the function style initializer.""" - - @staticmethod - def build_search_runtime( - plugin_config: Optional[dict], - logger_obj: Optional[Any], - owner_tag: str, - *, - log_prefix: str = "", - ) -> SearchRuntimeBundle: - return build_search_runtime( - plugin_config=plugin_config, - logger_obj=logger_obj, - owner_tag=owner_tag, - log_prefix=log_prefix, - ) diff --git a/plugins/A_memorix/core/storage/__init__.py b/plugins/A_memorix/core/storage/__init__.py deleted file mode 100644 index d878b8e7..00000000 --- a/plugins/A_memorix/core/storage/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -"""存储层""" - -from .vector_store import VectorStore, QuantizationType -from .graph_store import GraphStore, SparseMatrixFormat -from .metadata_store import MetadataStore -from .knowledge_types import ( - ImportStrategy, - KnowledgeType, - allowed_import_strategy_values, - allowed_knowledge_type_values, - get_knowledge_type_from_string, - get_import_strategy_from_string, - parse_import_strategy, - resolve_stored_knowledge_type, - should_extract_relations, - get_default_chunk_size, - get_type_display_name, - validate_stored_knowledge_type, -) -from .type_detection import ( - detect_knowledge_type, - get_type_from_user_input, - looks_like_factual_text, - looks_like_quote_text, - looks_like_structured_text, - select_import_strategy, -) - -__all__ = [ - "VectorStore", - "GraphStore", - "MetadataStore", - "QuantizationType", - "SparseMatrixFormat", - "ImportStrategy", - "KnowledgeType", - "allowed_import_strategy_values", - "allowed_knowledge_type_values", - "get_knowledge_type_from_string", - "get_import_strategy_from_string", - "parse_import_strategy", - "resolve_stored_knowledge_type", - "should_extract_relations", - "get_default_chunk_size", - "get_type_display_name", - "validate_stored_knowledge_type", - "detect_knowledge_type", - "get_type_from_user_input", - "looks_like_factual_text", - "looks_like_quote_text", - "looks_like_structured_text", - "select_import_strategy", -] diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py deleted file mode 100644 index 0a5fd95d..00000000 --- a/plugins/A_memorix/core/storage/graph_store.py +++ /dev/null @@ -1,1448 +0,0 @@ -""" -图存储模块 - -基于SciPy稀疏矩阵的知识图谱存储与计算。 -""" - -import pickle -from enum import Enum -from pathlib import Path -from typing import Optional, Union, Tuple, List, Dict, Set, Any -from collections import defaultdict -import threading -import asyncio - -import numpy as np - -class SparseMatrixFormat(Enum): - """稀疏矩阵格式""" - CSR = "csr" - CSC = "csc" - -try: - from scipy.sparse import csr_matrix, csc_matrix, triu, save_npz, load_npz, bmat, lil_matrix - from scipy.sparse.linalg import norm - HAS_SCIPY = True -except ImportError: - class _SparseMatrixPlaceholder: - pass - - def _scipy_missing(*args, **kwargs): - raise ImportError("SciPy 未安装,请安装: pip install scipy") - - csr_matrix = _SparseMatrixPlaceholder - csc_matrix = _SparseMatrixPlaceholder - lil_matrix = _SparseMatrixPlaceholder - triu = _scipy_missing - save_npz = _scipy_missing - load_npz = _scipy_missing - bmat = _scipy_missing - norm = _scipy_missing - HAS_SCIPY = False - -import contextlib -from src.common.logger import get_logger -from ..utils.hash import compute_hash -from ..utils.io import atomic_write - -logger = get_logger("A_Memorix.GraphStore") - - -class GraphModificationMode(Enum): - """图修改模式""" - BATCH = "batch" # 批量模式 (默认, 适合一次性加载) - INCREMENTAL = "incremental" # 增量模式 (适合频繁随机写入, 使用LIL) - READ_ONLY = "read_only" # 只读模式 (适合计算, CSR/CSC) - - -class GraphStore: - """ - 图存储类 - - 功能: - - CSR稀疏矩阵存储图结构 - - 节点和边的CRUD操作 - - Personalized PageRank计算 - - 同义词自动连接 - - 图持久化 - - 参数: - matrix_format: 稀疏矩阵格式(csr/csc) - data_dir: 数据目录 - """ - - def __init__( - self, - matrix_format: str = "csr", - data_dir: Optional[Union[str, Path]] = None, - ): - """ - 初始化图存储 - - Args: - matrix_format: 稀疏矩阵格式(csr/csc) - data_dir: 数据目录 - """ - if not HAS_SCIPY: - raise ImportError("SciPy 未安装,请安装: pip install scipy") - - if isinstance(matrix_format, SparseMatrixFormat): - self.matrix_format = matrix_format.value - else: - self.matrix_format = str(matrix_format).lower() - self.data_dir = Path(data_dir) if data_dir else None - - # 节点管理 - self._nodes: List[str] = [] # 节点列表 - self._node_to_idx: Dict[str, int] = {} # 节点名到索引的映射 - self._node_attrs: Dict[str, Dict[str, Any]] = {} # 节点属性 - - # 边管理(邻接矩阵) - self._adjacency: Optional[Union[csr_matrix, csc_matrix]] = None - - # 统计信息 - self._total_nodes_added = 0 - self._total_edges_added = 0 - self._total_nodes_deleted = 0 - self._total_edges_deleted = 0 - - # 状态管理 - self._modification_mode = GraphModificationMode.BATCH - - # 状态管理 - self._adjacency_T: Optional[Union[csr_matrix, csc_matrix]] = None - self._adjacency_dirty: bool = True - self._saliency_cache: Optional[Dict[str, float]] = None - - # V5: 多关系映射 (src_idx, dst_idx) -> Set[relation_hash] - self._edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) - # V5: 简单的异步锁 (实际上 asyncio 环境下单线程主循环可能不需要,但为了安全保留) - self._lock = asyncio.Lock() - - logger.info(f"GraphStore 初始化: format={matrix_format}") - - def _canonicalize(self, node: str) -> str: - """规范化节点名称 (用于去重和内部索引)""" - if not node: - return "" - return str(node).strip().lower() - - @contextlib.contextmanager - def batch_update(self): - """ - 批量更新上下文管理器 - - 进入时切换到 LIL 格式以优化随机/增量更新 - 退出时恢复到 CSR/CSC 格式以优化存储和计算 - """ - original_mode = self._modification_mode - self._switch_mode(GraphModificationMode.INCREMENTAL) - try: - yield - finally: - self._switch_mode(original_mode) - - def _switch_mode(self, new_mode: GraphModificationMode): - """切换修改模式并转换矩阵格式""" - if new_mode == self._modification_mode: - return - - if self._adjacency is None: - self._modification_mode = new_mode - return - - logger.debug(f"切换图模式: {self._modification_mode.value} -> {new_mode.value}") - - # 转换逻辑 - if new_mode == GraphModificationMode.INCREMENTAL: - # 转换为 LIL 格式 - if not isinstance(self._adjacency, lil_matrix): # 粗略检查是否非 lil - try: - self._adjacency = self._adjacency.tolil() - logger.debug("已转换为 LIL 格式") - except Exception as e: - logger.warning(f"转换为 LIL 失败: {e}") - - elif new_mode in [GraphModificationMode.BATCH, GraphModificationMode.READ_ONLY]: - # 转换回配置的格式 (CSR/CSC) - if self.matrix_format == "csr": - self._adjacency = self._adjacency.tocsr() - elif self.matrix_format == "csc": - self._adjacency = self._adjacency.tocsc() - logger.debug(f"已恢复为 {self.matrix_format.upper()} 格式") - - self._modification_mode = new_mode - - def add_nodes( - self, - nodes: List[str], - attributes: Optional[Dict[str, Dict[str, Any]]] = None, - ) -> int: - """ - 添加节点 - - Args: - nodes: 节点名称列表 - attributes: 节点属性字典 {node_name: {attr: value}} - - Returns: - 成功添加的节点数量 - """ - added = 0 - for node in nodes: - canon = self._canonicalize(node) - if canon in self._node_to_idx: - logger.debug(f"节点已存在,跳过: {node}") - continue - - # 添加到节点列表 - idx = len(self._nodes) - self._nodes.append(node) # 存储原始节点名 - self._node_to_idx[canon] = idx # 映射规范化节点名到索引 - self._adjacency_dirty = True - self._saliency_cache = None - - # 添加属性 - if attributes and node in attributes: - self._node_attrs[canon] = attributes[node] - else: - self._node_attrs[canon] = {} - - added += 1 - self._total_nodes_added += 1 - - # 扩展邻接矩阵 - if added > 0: - self._expand_adjacency_matrix(added) - - logger.debug(f"添加 {added} 个节点") - return added - - def add_edges( - self, - edges: List[Tuple[str, str]], - weights: Optional[List[float]] = None, - relation_hashes: Optional[List[str]] = None, # V5: 支持关系哈希映射 (Relation Hash Mapping) - ) -> int: - """ - 添加边 - - Args: - edges: 边列表 [(source, target), ...] - weights: 边权重列表(默认为1.0) - - Returns: - 成功添加的边数量 - """ - if not edges: - return 0 - - # 确保所有节点存在 - nodes_to_add = set() - for src, tgt in edges: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon not in self._node_to_idx: - nodes_to_add.add(src) - if tgt_canon not in self._node_to_idx: - nodes_to_add.add(tgt) - - if nodes_to_add: - self.add_nodes(list(nodes_to_add)) - - # 处理权重 - if weights is None: - weights = [1.0] * len(edges) - - if len(weights) != len(edges): - raise ValueError(f"边数量与权重数量不匹配: {len(edges)} vs {len(weights)}") - - # 如果仅仅是添加边且处于增量模式 (LIL),直接更新 - if self._modification_mode == GraphModificationMode.INCREMENTAL: - if self._adjacency is None: - # 初始化为空 LIL - n = len(self._nodes) - from scipy.sparse import lil_matrix - self._adjacency = lil_matrix((n, n), dtype=np.float32) - - # 尝试直接使用 LIL 索引更新 - try: - # 批量获取索引 - rows = [self._node_to_idx[self._canonicalize(src)] for src, _ in edges] - cols = [self._node_to_idx[self._canonicalize(tgt)] for _, tgt in edges] - - # 确保矩阵足够大 (如果 add_nodes 没有扩展它) - 通常 add_nodes 会处理 - # 这里直接赋值 - self._adjacency[rows, cols] = weights - - self._total_edges_added += len(edges) - - # V5: Update edge hash map - if relation_hashes: - for (src, tgt), r_hash in zip(edges, relation_hashes): - if r_hash: - s_idx = self._node_to_idx[self._canonicalize(src)] - t_idx = self._node_to_idx[self._canonicalize(tgt)] - self._edge_hash_map[(s_idx, t_idx)].add(r_hash) - - logger.debug(f"增量添加 {len(edges)} 条边 (LIL)") - return len(edges) - except Exception as e: - logger.warning(f"LIL 增量更新失败,回退到通用方法: {e}") - # Fallback to general method below - - # 通用方法 (构建 COO 然后合并) - # 构建边的三元组 - row_indices = [] - col_indices = [] - data_values = [] - - for (src, tgt), weight in zip(edges, weights): - src_idx = self._node_to_idx[self._canonicalize(src)] - tgt_idx = self._node_to_idx[self._canonicalize(tgt)] - - row_indices.append(src_idx) - col_indices.append(tgt_idx) - data_values.append(weight) - - # 创建新的边的矩阵 - n = len(self._nodes) - new_edges = csr_matrix( - (data_values, (row_indices, col_indices)), - shape=(n, n), - ) - - # 合并到邻接矩阵 - if self._adjacency is None: - self._adjacency = new_edges - else: - self._adjacency = self._adjacency + new_edges - - # 转换为指定格式 - if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): - self._adjacency = self._adjacency.tocsc() - elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): - self._adjacency = self._adjacency.tocsr() - - self._total_edges_added += len(edges) - self._adjacency_dirty = True # 标记脏位 - - # V5: 更新边哈希映射 (Edge Hash Map) - if relation_hashes: - for (src, tgt), r_hash in zip(edges, relation_hashes): - if r_hash: - try: - s_idx = self._node_to_idx[self._canonicalize(src)] - t_idx = self._node_to_idx[self._canonicalize(tgt)] - self._edge_hash_map[(s_idx, t_idx)].add(r_hash) - except KeyError: - pass # 正常情况下节点已在上方添加,此处仅作防错处理 - - logger.debug(f"添加 {len(edges)} 条边") - return len(edges) - - def update_edge_weight( - self, - source: str, - target: str, - delta: float, - min_weight: float = 0.1, - max_weight: float = 10.0, - ) -> float: - """ - 更新边权重 (增量/强化/弱化) - - Args: - source: 源节点 - target: 目标节点 - delta: 权重变化量 (+/-) - min_weight: 最小权重限制 - max_weight: 最大权重限制 - - Returns: - 更新后的权重 - """ - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - logger.warning(f"节点不存在,无法更新权重: {source} -> {target}") - return 0.0 - - current_weight = self.get_edge_weight(source, target) - if current_weight == 0.0 and delta <= 0: - # 边不存在且试图减少权重,忽略 - return 0.0 - - # 如果边不存在但 delta > 0,相当于添加新边 (默认基础权重0 + delta) - # 但为了逻辑清晰,我们假设只更新存在的边,或者确实添加 - - new_weight = current_weight + delta - new_weight = max(min_weight, min(max_weight, new_weight)) - - # 使用 batch_update 上下文自动处理格式转换 - # 这里我们临时切换到 incremental 模式进行单次更新 - with self.batch_update(): - # add_edges 会覆盖或添加,我们需要覆盖 - self.add_edges([(source, target)], [new_weight]) - - logger.debug(f"更新权重 {source}->{target}: {current_weight:.2f} -> {new_weight:.2f}") - return new_weight - - def delete_nodes(self, nodes: List[str]) -> int: - """ - 删除节点(及相关的边) - - Args: - nodes: 要删除的节点列表 - - Returns: - 成功删除的节点数量 - """ - if not nodes: - return 0 - - # 检查哪些节点存在 - existing_nodes = [node for node in nodes if self._canonicalize(node) in self._node_to_idx] - if not existing_nodes: - logger.warning("所有节点都不存在,无法删除") - return 0 - - # 获取要删除的索引 - indices_to_delete = {self._node_to_idx[self._canonicalize(node)] for node in existing_nodes} - indices_to_keep = [ - i for i in range(len(self._nodes)) - if i not in indices_to_delete - ] - - # 创建索引映射 - old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(indices_to_keep)} - - # 重建节点列表 (存储原始节点名) - self._nodes = [self._nodes[i] for i in indices_to_keep] - # 重建规范化节点名到索引的映射 - self._node_to_idx = {self._canonicalize(self._nodes[new_idx]): new_idx for new_idx in range(len(self._nodes))} - - # 删除并重构节点属性 - new_node_attrs = {} - for idx, node_name in enumerate(self._nodes): - canon = self._canonicalize(node_name) - if canon in self._node_attrs: - new_node_attrs[canon] = self._node_attrs[canon] - self._node_attrs = new_node_attrs - - # 重建邻接矩阵 - if self._adjacency is not None: - # 转换为COO格式以进行切片,然后转换回原始格式 - self._adjacency = self._adjacency.tocoo() - mask_rows = np.isin(self._adjacency.row, list(indices_to_keep)) - mask_cols = np.isin(self._adjacency.col, list(indices_to_keep)) - - # 筛选出保留的行和列 - new_rows = self._adjacency.row[mask_rows & mask_cols] - new_cols = self._adjacency.col[mask_rows & mask_cols] - new_data = self._adjacency.data[mask_rows & mask_cols] - - # 更新索引 - new_rows = np.array([old_to_new[r] for r in new_rows]) - new_cols = np.array([old_to_new[c] for c in new_cols]) - - n = len(self._nodes) - if self.matrix_format == "csr": - self._adjacency = csr_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) - else: # csc - self._adjacency = csc_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) - - # 重建关系哈希映射,移除涉及已删除节点的记录并重映射索引。 - if self._edge_hash_map: - new_edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) - for (old_src, old_tgt), hashes in self._edge_hash_map.items(): - if old_src in indices_to_delete or old_tgt in indices_to_delete: - continue - if old_src in old_to_new and old_tgt in old_to_new and hashes: - new_edge_hash_map[(old_to_new[old_src], old_to_new[old_tgt])] = set(hashes) - self._edge_hash_map = new_edge_hash_map - - deleted_count = len(existing_nodes) - self._total_nodes_deleted += deleted_count - self._adjacency_dirty = True - self._saliency_cache = None - - logger.info(f"删除 {deleted_count} 个节点") - return deleted_count - - def remove_nodes(self, nodes: List[str]) -> int: - """兼容性别名:删除节点""" - return self.delete_nodes(nodes) - - def delete_edges( - self, - edges: List[Tuple[str, str]], - ) -> int: - """ - 删除边 - - Args: - edges: 要删除的边列表 [(source, target), ...] - - Returns: - 成功删除的边数量 - """ - if not edges: - return 0 - - deleted = 0 - # 构建要删除的边的索引集合 - edges_to_delete = set() - for src, tgt in edges: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - edges_to_delete.add((src_idx, tgt_idx)) - - if self._adjacency is not None and edges_to_delete: - # 转换为COO格式便于修改 - adj_coo = self._adjacency.tocoo() - - # 过滤要删除的边 - new_row = [] - new_col = [] - new_data = [] - - for i, j, val in zip(adj_coo.row, adj_coo.col, adj_coo.data): - if (i, j) not in edges_to_delete: - new_row.append(i) - new_col.append(j) - new_data.append(val) - else: - deleted += 1 - - # 重建邻接矩阵 - n = len(self._nodes) - self._adjacency = csr_matrix((new_data, (new_row, new_col)), shape=(n, n)) - - # 转换回指定格式 - if self.matrix_format == "csc": - self._adjacency = self._adjacency.tocsc() - - # delete_edges 是“物理删除”语义,必须同步清理 edge_hash_map。 - if edges_to_delete and self._edge_hash_map: - for key in edges_to_delete: - self._edge_hash_map.pop(key, None) - - self._total_edges_deleted += deleted - self._adjacency_dirty = True - self._saliency_cache = None - logger.info(f"删除 {deleted} 条边") - return deleted - - def remove_edges(self, edges: List[Tuple[str, str]]) -> int: - """兼容性别名:删除边""" - return self.delete_edges(edges) - - def get_nodes(self) -> List[str]: - """ - 获取所有节点 - - Returns: - 节点列表 - """ - return self._nodes.copy() - - def has_node(self, node: str) -> bool: - """ - 检查节点是否存在 - - Args: - node: 节点名称 - """ - return self._canonicalize(node) in self._node_to_idx - - def find_node(self, node: str, ignore_case: bool = False) -> Optional[str]: - """ - 查找节点 (由于底层已统一规范化,ignore_case 始终有效) - - Args: - node: 节点名称 - ignore_case: 是否忽略大小写 (已默认忽略) - - Returns: - 真实节点名称 (如果存在),否则 None - """ - canon = self._canonicalize(node) - if canon in self._node_to_idx: - return self._nodes[self._node_to_idx[canon]] - return None - - def get_node_attributes(self, node: str) -> Optional[Dict[str, Any]]: - """ - 获取节点属性 - - Args: - node: 节点名称 - - Returns: - 节点属性字典,不存在则返回None - """ - canon = self._canonicalize(node) - return self._node_attrs.get(canon) - - def get_neighbors(self, node: str) -> List[str]: - """ - 获取节点的出邻居 - - Args: - node: 节点名称 - - Returns: - 出邻居节点列表 - """ - canon = self._canonicalize(node) - if canon not in self._node_to_idx or self._adjacency is None: - return [] - - idx = self._node_to_idx[canon] - neighbor_indices = self._row_neighbor_indices(self._adjacency, idx) - return [self._nodes[int(i)] for i in neighbor_indices] - - def get_in_neighbors(self, node: str) -> List[str]: - """ - 获取节点的入邻居 - - Args: - node: 节点名称 - - Returns: - 入邻居节点列表 - """ - canon = self._canonicalize(node) - if canon not in self._node_to_idx or self._adjacency is None: - return [] - - self._ensure_adjacency_T() - if self._adjacency_T is None: - return [] - - idx = self._node_to_idx[canon] - neighbor_indices = self._row_neighbor_indices(self._adjacency_T, idx) - return [self._nodes[int(i)] for i in neighbor_indices] - - def get_edge_weight(self, source: str, target: str) -> float: - """ - 获取边的权重 - """ - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - return 0.0 - - if self._adjacency is None: - return 0.0 - - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - - return float(self._adjacency[src_idx, tgt_idx]) - - def canonicalize_node(self, node: str) -> str: - """公开节点规范化接口,避免外部访问私有方法。""" - return self._canonicalize(node) - - def has_edge_hash_map(self) -> bool: - """是否存在 relation-hash 映射。""" - return bool(self._edge_hash_map) - - def get_relation_hashes_for_edge(self, source: str, target: str) -> Set[str]: - """获取边 (source -> target) 关联的关系哈希集合。""" - src_canon = self._canonicalize(source) - tgt_canon = self._canonicalize(target) - if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: - return set() - src_idx = self._node_to_idx[src_canon] - tgt_idx = self._node_to_idx[tgt_canon] - return set(self._edge_hash_map.get((src_idx, tgt_idx), set())) - - def get_incident_relation_hashes(self, node: str, limit: Optional[int] = None) -> List[str]: - """获取与指定节点关联的关系哈希列表(入边 + 出边)。""" - canon = self._canonicalize(node) - if canon not in self._node_to_idx or not self._edge_hash_map: - return [] - - idx = self._node_to_idx[canon] - limit_val = max(1, int(limit)) if limit is not None else None - collected: List[Tuple[str, str, str]] = [] - idx_to_node = self._nodes - - for (src_idx, tgt_idx), hashes in self._edge_hash_map.items(): - if idx not in {src_idx, tgt_idx}: - continue - src_name = idx_to_node[src_idx] if 0 <= src_idx < len(idx_to_node) else "" - tgt_name = idx_to_node[tgt_idx] if 0 <= tgt_idx < len(idx_to_node) else "" - for hash_value in hashes: - hash_text = str(hash_value).strip() - if not hash_text: - continue - collected.append((src_name, tgt_name, hash_text)) - - collected.sort(key=lambda x: (x[0].lower(), x[1].lower(), x[2])) - out: List[str] = [] - seen = set() - for _, _, hash_value in collected: - if hash_value in seen: - continue - seen.add(hash_value) - out.append(hash_value) - if limit_val is not None and len(out) >= limit_val: - break - return out - - def edge_contains_relation_hash(self, source: str, target: str, hash_value: str) -> bool: - """判断边是否包含指定关系哈希。""" - if not hash_value: - return False - return str(hash_value) in self.get_relation_hashes_for_edge(source, target) - - def iter_edge_hash_entries(self) -> List[Tuple[str, str, Set[str]]]: - """以节点名形式遍历 edge-hash-map。""" - out: List[Tuple[str, str, Set[str]]] = [] - if not self._edge_hash_map: - return out - idx_to_node = self._nodes - for (s_idx, t_idx), hashes in self._edge_hash_map.items(): - if not hashes: - continue - if s_idx < 0 or t_idx < 0: - continue - if s_idx >= len(idx_to_node) or t_idx >= len(idx_to_node): - continue - out.append((idx_to_node[s_idx], idx_to_node[t_idx], set(hashes))) - return out - - def deactivate_edges(self, edges: List[Tuple[str, str]]) -> int: - """ - 冻结边 (将权重设为0.0,使其在计算意义上消失,但保留在Map中) - - Args: - edges: [(s1, t1), (s2, t2)...] - """ - if not edges or self._adjacency is None: - return 0 - - deactivated_count = 0 - with self.batch_update(): - # 我们需要 explicit set to 0. - # 使用增量更新模式覆盖 - for s, t in edges: - s_canon = self._canonicalize(s) - t_canon = self._canonicalize(t) - if s_canon in self._node_to_idx and t_canon in self._node_to_idx: - idx_s = self._node_to_idx[s_canon] - idx_t = self._node_to_idx[t_canon] - self._adjacency[idx_s, idx_t] = 0.0 - deactivated_count += 1 - - self._adjacency_dirty = True - return deactivated_count - - def _ensure_adjacency_T(self): - """确保转置邻接矩阵是最新的""" - if self._adjacency is None: - self._adjacency_T = None - return - - if self._adjacency_dirty or self._adjacency_T is None: - # 只有在确实需要时才计算转置 - # find_paths 以“按行读取邻居”为主,因此统一缓存为 CSR,避免 - # CSR->CSC 转置后按行切片读出错误的索引视图。 - self._adjacency_T = self._adjacency.transpose().tocsr() - - self._adjacency_dirty = False - # logger.debug("重建转置邻接矩阵缓存") - - @staticmethod - def _row_neighbor_indices( - matrix: Optional[Union[csr_matrix, csc_matrix]], - row_idx: int, - ) -> np.ndarray: - """返回指定行的非零列索引。""" - if matrix is None: - return np.asarray([], dtype=np.int32) - - if isinstance(matrix, csr_matrix): - return matrix.indices[matrix.indptr[row_idx]:matrix.indptr[row_idx + 1]] - - row = matrix[row_idx, :] - _, indices = row.nonzero() - return np.asarray(indices, dtype=np.int32) - - def find_paths( - self, - start_node: str, - end_node: str, - max_depth: int = 3, - max_paths: int = 5, - max_expansions: int = 20000 - ) -> List[List[str]]: - """ - 查找两个节点之间的路径 (BFS) - 支持有向和无向 (视作双向) 探索 - - Args: - start_node: 起始节点 - end_node: 目标节点 - max_depth: 最大深度 - max_paths: 最大路径数 (找到这么多就停止) - max_expansions: 最大扩展次数 (防止爆炸) - - Returns: - 路径列表 [[n1, n2, n3], ...] - """ - start_canon = self._canonicalize(start_node) - end_canon = self._canonicalize(end_node) - - if start_canon not in self._node_to_idx or end_canon not in self._node_to_idx: - return [] - - if self._adjacency is None: - return [] - - # 确保转置矩阵可用 (用于查找入边) - self._ensure_adjacency_T() - - start_idx = self._node_to_idx[start_canon] - end_idx = self._node_to_idx[end_canon] - - # 队列: (current_idx, path_indices) - queue = [(start_idx, [start_idx])] - found_paths = [] - expansions = 0 - - unique_paths = set() - - while queue: - curr, path = queue.pop(0) - - if len(path) > max_depth + 1: - continue - - if curr == end_idx: - # 找到路径 - # 转换回节点名 - path_names = [self._nodes[i] for i in path] - path_tuple = tuple(path_names) - if path_tuple not in unique_paths: - found_paths.append(path_names) - unique_paths.add(path_tuple) - - if len(found_paths) >= max_paths: - break - continue - - if expansions >= max_expansions: - break - - expansions += 1 - - # 获取邻居 (出边 + 入边) - out_indices = self._row_neighbor_indices(self._adjacency, curr) - - # 2. 入边 (使用转置矩阵) - if self._adjacency_T is not None: - in_indices = self._row_neighbor_indices(self._adjacency_T, curr) - neighbors = np.concatenate((out_indices, in_indices)) - else: - neighbors = out_indices - - # 去重并过滤已在路径中的节点 (防止环) - # 注意: 这里简单去重,可能包含重复的邻居(如果既是出又是入) - seen_in_path = set(path) - queued_neighbors = set() - - for neighbor_idx in neighbors: - neighbor = int(neighbor_idx) - if neighbor not in seen_in_path and neighbor not in queued_neighbors: - # 只有未访问过的才加入 - queue.append((neighbor, path + [neighbor])) - queued_neighbors.add(neighbor) - - return found_paths - - def compute_pagerank( - self, - personalization: Optional[Dict[str, float]] = None, - alpha: float = 0.85, - max_iter: int = 100, - tol: float = 1e-6, - ) -> Dict[str, float]: - """ - 计算Personalized PageRank - - Args: - personalization: 个性化向量 {node: weight},默认为均匀分布 - alpha: 阻尼系数(0-1之间) - max_iter: 最大迭代次数 - tol: 收敛阈值 - - Returns: - 节点PageRank值字典 {node: score} - """ - if self._adjacency is None or len(self._nodes) == 0: - logger.warning("图为空,无法计算PageRank") - return {} - - n = len(self._nodes) - - # 构建列归一化的转移矩阵 - adj = self._adjacency.astype(np.float32) - - # 计算出度 - out_degrees = np.array(adj.sum(axis=1)).flatten() - - # 处理悬挂节点(出度为0) - dangling = (out_degrees == 0) - out_degrees_inv = np.zeros_like(out_degrees) - out_degrees_inv[~dangling] = 1.0 / out_degrees[~dangling] - - # 归一化 (使用稀疏对角阵避免内存溢出) - from scipy.sparse import diags - D_inv = diags(out_degrees_inv) - M = adj.T @ D_inv # 转移矩阵 - - # 初始化个性化向量 - if personalization is None: - # 均匀分布 - p = np.ones(n) / n - else: - # 使用指定的个性化向量 - p = np.zeros(n) - total_weight = sum(personalization.values()) - for node, weight in personalization.items(): - if node in self._node_to_idx: - idx = self._node_to_idx[node] - p[idx] = weight / total_weight - - # 确保和为1 - if p.sum() == 0: - p = np.ones(n) / n - else: - p = p / p.sum() - - # 幂迭代法 - p_orig = p.copy() - for i in range(max_iter): - # p_new = alpha * M * p + (1-alpha) * personalization - p_new = alpha * (M @ p) + (1 - alpha) * p_orig - - # 处理因为悬挂节点导致的概率流失 - current_sum = p_new.sum() - if current_sum < 1.0: - p_new += (1.0 - current_sum) * p_orig - - # 检查收敛 - diff = np.linalg.norm(p_new - p, 1) - if diff < tol: - logger.debug(f"PageRank在 {i+1} 次迭代后收敛") - p = p_new - break - p = p_new - else: - logger.warning(f"PageRank未在 {max_iter} 次迭代内收敛") - - # 转换为真实节点名称字典 - return {self._nodes[idx]: float(val) for idx, val in enumerate(p)} - - def get_saliency_scores(self) -> Dict[str, float]: - """ - 获取节点显著性得分 (带有缓存机制) - """ - if self._saliency_cache is not None and not self._adjacency_dirty: - return self._saliency_cache - - logger.debug("正在计算节点显著性得分 (PageRank)...") - scores = self.compute_pagerank() - self._saliency_cache = scores - # 注意:这里我们不把 _adjacency_dirty 设为 False,因为其它逻辑(如_adjacency_T)也依赖它 - return scores - - def connect_synonyms( - self, - similarity_matrix: np.ndarray, - node_list: List[str], - threshold: float = 0.85, - ) -> int: - """ - 连接相似节点(同义词) - - Args: - similarity_matrix: 相似度矩阵 (N x N) - node_list: 对应的节点列表(长度为N) - threshold: 相似度阈值 - - Returns: - 添加的边数量 - """ - if len(node_list) != similarity_matrix.shape[0]: - raise ValueError( - f"节点列表长度与相似度矩阵维度不匹配: " - f"{len(node_list)} vs {similarity_matrix.shape[0]}" - ) - - # 找到相似的节点对(上三角,排除对角线) - similar_pairs = np.argwhere( - (triu(similarity_matrix, k=1) >= threshold) & - (triu(similarity_matrix, k=1) < 1.0) # 排除完全相同的 - ) - - # 添加边 - edges = [] - for i, j in similar_pairs: - if i < len(node_list) and j < len(node_list): - src = node_list[i] - tgt = node_list[j] - # 使用相似度作为权重 - weight = float(similarity_matrix[i, j]) - edges.append((src, tgt, weight)) - - if edges: - edge_pairs = [(src, tgt) for src, tgt, _ in edges] - weights = [w for _, _, w in edges] - count = self.add_edges(edge_pairs, weights) - logger.info(f"连接 {count} 对相似节点(阈值={threshold})") - return count - return 0 - - - # ========================================================================= - # V5 Memory System Methods (Graph Level) - # ========================================================================= - - def decay(self, factor: float, min_active_weight: float = 0.0) -> None: - """ - 全图衰减 (Atomic Decay) - - Args: - factor: 衰减因子 (0.0 < factor < 1.0) - min_active_weight: 最小活跃权重 (低于此值可能被视为无效,但在物理修剪前仍保留) - """ - if self._adjacency is None or factor >= 1.0 or factor <= 0.0: - return - - logger.debug(f"正在执行全图衰减,因子: {factor}") - - # 直接矩阵乘法,SciPy CSR/CSC 非常高效 - self._adjacency *= factor - - # 如果需要处理极小值 (可选,防止下溢,但通常浮点数足够小) - # if min_active_weight > 0: - # ... (复杂操作,暂不需要,由 prune 逻辑处理) - - self._adjacency_dirty = True - - def prune_relation_hashes(self, operations: List[Tuple[str, str, str]]) -> None: - """ - 修剪特定关系哈希 (从 _edge_hash_map 移除; 如果边变空则从矩阵移除) - - Args: - operations: List[(src, tgt, relation_hash)] - """ - if not operations: - return - - edges_to_check_removal = set() - - # 1. 更新映射 (Update Map) - for src, tgt, h in operations: - src_canon = self._canonicalize(src) - tgt_canon = self._canonicalize(tgt) - if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: - s_idx = self._node_to_idx[src_canon] - t_idx = self._node_to_idx[tgt_canon] - - key = (s_idx, t_idx) - if key in self._edge_hash_map: - if h in self._edge_hash_map[key]: - self._edge_hash_map[key].remove(h) - - if not self._edge_hash_map[key]: - del self._edge_hash_map[key] - edges_to_check_removal.add((src, tgt)) - - # 2. 从矩阵中移除空边 (Remove Empty Edges from Matrix) - if edges_to_check_removal: - self.deactivate_edges(list(edges_to_check_removal)) - self._total_edges_deleted += len(edges_to_check_removal) - - def get_low_weight_edges(self, threshold: float) -> List[Tuple[str, str]]: - """ - 获取低于阈值的边 (candidates for pruning/freezing) - - Args: - threshold: 权重阈值 - - Returns: - List[(src, tgt)]: 边列表 - """ - if self._adjacency is None: - return [] - - # 获取所有非零元素 - rows, cols = self._adjacency.nonzero() - data = self._adjacency.data - - low_weight_indices = np.where(data < threshold)[0] - - results = [] - for idx in low_weight_indices: - r = rows[idx] - c = cols[idx] - src = self._nodes[r] - tgt = self._nodes[c] - results.append((src, tgt)) - - return results - - def get_isolated_nodes(self, include_inactive: bool = True) -> List[str]: - """ - 获取孤儿节点 (Active Degree = 0) - - Args: - include_inactive: 是否包含参与了inactive边(冻结边)的节点。 - 如果 True (默认推荐): 排除掉虽然active degree=0但存在于_edge_hash_map(冻结边)中的节点。 - 如果 False: 只要在 active matrix 里度为0就返回 (可能会误删冻结节点)。 - - Returns: - 孤儿节点名称列表 - """ - if self._adjacency is None: - # 如果全空,则所有节点都是孤儿 - return self._nodes.copy() - - n = len(self._nodes) - - # 计算 Active Degree (In + Out) - # 用 sum(axis) 会得到 dense matrix/array - active_adj = self._adjacency - out_degrees = np.array(active_adj.sum(axis=1)).flatten() - in_degrees = np.array(active_adj.sum(axis=0)).flatten() - - # 处理悬挂节点 (dangling node check not really needed here, just sum) - total_degrees = out_degrees + in_degrees - - # 找到 active degree = 0 的索引 - isolated_indices = np.where(total_degrees == 0)[0] - - if len(isolated_indices) == 0: - return [] - - isolated_nodes_set = {self._nodes[i] for i in isolated_indices} - - # 如果需要排除 Inactive 参与者 - if include_inactive and self._edge_hash_map: - # 收集所有在冻结边中的 unique 节点索引 - frozen_participant_indices = set() - for (u_idx, v_idx), hashes in self._edge_hash_map.items(): - if hashes: # 只要有 hash 存在 (哪怕 inactive) - frozen_participant_indices.add(u_idx) - frozen_participant_indices.add(v_idx) - - # 过滤 - final_isolated = [] - for idx in isolated_indices: - if idx not in frozen_participant_indices: - final_isolated.append(self._nodes[idx]) - return final_isolated - - else: - return list(isolated_nodes_set) - - def clear(self) -> None: - """清空所有数据""" - self._nodes.clear() - self._node_to_idx.clear() - self._node_attrs.clear() - self._adjacency = None - self._edge_hash_map.clear() - self._adjacency_T = None - self._adjacency_dirty = True - self._total_nodes_added = 0 - self._total_edges_added = 0 - self._total_nodes_deleted = 0 - self._total_edges_deleted = 0 - logger.info("图存储已清空") - - def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: - """ - 保存到磁盘 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - # 保存邻接矩阵 - if self._adjacency is not None: - matrix_path = data_dir / "graph_adjacency.npz" - with atomic_write(matrix_path, "wb") as f: - save_npz(f, self._adjacency) - logger.debug(f"保存邻接矩阵: {matrix_path}") - - # 保存元数据 - metadata = { - "nodes": self._nodes, - "node_to_idx": self._node_to_idx, - "node_attrs": self._node_attrs, - "matrix_format": self.matrix_format, - "total_nodes_added": self._total_nodes_added, - "total_edges_added": self._total_edges_added, - "total_nodes_deleted": self._total_nodes_deleted, - "total_edges_deleted": self._total_edges_deleted, - "edge_hash_map": dict(self._edge_hash_map), # 持久化 V5 映射 (将 defaultdict 转换为普通 dict) - } - - metadata_path = data_dir / "graph_metadata.pkl" - with atomic_write(metadata_path, "wb") as f: - pickle.dump(metadata, f) - logger.debug(f"保存元数据: {metadata_path}") - - logger.info(f"图存储已保存到: {data_dir}") - - def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: - """ - 从磁盘加载 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - if not data_dir.exists(): - raise FileNotFoundError(f"数据目录不存在: {data_dir}") - - # 加载元数据 - metadata_path = data_dir / "graph_metadata.pkl" - if not metadata_path.exists(): - raise FileNotFoundError(f"元数据文件不存在: {metadata_path}") - - with open(metadata_path, "rb") as f: - metadata = pickle.load(f) - - # 恢复状态,并通过规范化处理旧数据中的重复项 - self._nodes = metadata["nodes"] - self._node_attrs = {} # 重新构建以确保键名 (Key) 规范化 - self._node_to_idx = {} # 重新构建以确保键名 (Key) 规范化 - - # 重新构建映射,处理旧数据中的碰撞 - for idx, node_name in enumerate(self._nodes): - canon = self._canonicalize(node_name) - if canon not in self._node_to_idx: - self._node_to_idx[canon] = idx - - # 处理属性 (优先保留已有的) - orig_attrs = metadata.get("node_attrs", {}) - if node_name in orig_attrs and canon not in self._node_attrs: - self._node_attrs[canon] = orig_attrs[node_name] - - self.matrix_format = metadata["matrix_format"] - self._total_nodes_added = metadata["total_nodes_added"] - self._total_edges_added = metadata["total_edges_added"] - self._total_nodes_deleted = metadata["total_nodes_deleted"] - self._total_edges_deleted = metadata["total_edges_deleted"] - - # 恢复 V5 边哈希映射 (Restore V5 edge hash map) - edge_map_data = metadata.get("edge_hash_map", {}) - # 重新初始化为 defaultdict(set) - self._edge_hash_map = defaultdict(set) - if edge_map_data: - for k, v in edge_map_data.items(): - self._edge_hash_map[k] = set(v) # 确保类型为 set - - # 加载邻接矩阵 - matrix_path = data_dir / "graph_adjacency.npz" - if matrix_path.exists(): - self._adjacency = load_npz(str(matrix_path)) - - # 确保格式正确 - if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): - self._adjacency = self._adjacency.tocsc() - elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): - self._adjacency = self._adjacency.tocsr() - - logger.debug(f"加载邻接矩阵: {matrix_path}, shape={self._adjacency.shape}") - - # 检查维度不匹配并修复 - if self._adjacency is not None: - adj_n = self._adjacency.shape[0] - current_n = len(self._nodes) - if current_n > adj_n: - logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") - self._expand_adjacency_matrix(current_n - adj_n) - - self._adjacency_dirty = True - logger.info( - f"图存储已加载: {len(self._nodes)} 个节点, " - f"{self._adjacency.nnz if self._adjacency is not None else 0} 条边" - ) - - def _expand_adjacency_matrix(self, added_nodes: int) -> None: - """ - 扩展邻接矩阵以容纳新节点 - - Args: - added_nodes: 新增节点数量 - """ - if self._adjacency is None: - n = len(self._nodes) - # 根据模式初始化 - - if self._modification_mode == GraphModificationMode.INCREMENTAL: - self._adjacency = lil_matrix((n, n), dtype=np.float32) - else: - self._adjacency = csr_matrix((n, n), dtype=np.float32) - return - - old_n = self._adjacency.shape[0] - new_n = old_n + added_nodes - - # 优化:根据模式选择不同的扩容策略 - if self._modification_mode == GraphModificationMode.INCREMENTAL: - # LIL 格式可以直接 resize,非常高效 - try: - if not isinstance(self._adjacency, lil_matrix): - self._adjacency = self._adjacency.tolil() - - self._adjacency.resize((new_n, new_n)) - # logger.debug(f"扩展 LIL 矩阵: {old_n} -> {new_n}") - except Exception as e: - logger.warning(f"LIL resize 失败,回退到通用方法: {e}") - self._expand_generic(new_n, old_n) - - else: - # CSR/CSC 格式使用 bmat 避免结构破坏警告 - try: - # bmat 需要明确的形状,不能全部依赖 None - added = new_n - old_n - # 创建零矩阵块 - # 注意: 这里统一创建 CSR 零矩阵,bmat 会处理合并 - z_tr = csr_matrix((old_n, added), dtype=np.float32) - z_bl = csr_matrix((added, old_n), dtype=np.float32) - z_br = csr_matrix((added, added), dtype=np.float32) - - self._adjacency = bmat( - [[self._adjacency, z_tr], [z_bl, z_br]], - format=self.matrix_format, - dtype=np.float32 - ) - # logger.debug(f"扩展矩阵 (bmat): {old_n} -> {new_n}") - except Exception as e: - logger.warning(f"bmat 扩展失败: {e}") - self._expand_generic(new_n, old_n) - - def _expand_generic(self, new_n: int, old_n: int): - """通用扩展方法(回退方案)""" - if self.matrix_format == "csr": - new_adjacency = csr_matrix((new_n, new_n), dtype=np.float32) - new_adjacency[:old_n, :old_n] = self._adjacency - else: - new_adjacency = csc_matrix((new_n, new_n), dtype=np.float32) - new_adjacency[:old_n, :old_n] = self._adjacency - self._adjacency = new_adjacency - self._adjacency_dirty = True - - # 如果都在增量模式,确保是LIL - if self._modification_mode == GraphModificationMode.INCREMENTAL: - try: - self._adjacency = self._adjacency.tolil() - except: - pass - - @property - def num_nodes(self) -> int: - """节点数量""" - return len(self._nodes) - - @property - def num_edges(self) -> int: - """边数量""" - if self._adjacency is None: - return 0 - return int(self._adjacency.nnz) - - @property - def density(self) -> float: - """ - 图密度(实际边数 / 可能的最大边数) - - 有向图: E / (V * (V - 1)) - 无向图: 2E / (V * (V - 1)) - - 这里按有向图计算 - """ - if self.num_nodes < 2: - return 0.0 - - max_edges = self.num_nodes * (self.num_nodes - 1) - return self.num_edges / max_edges if max_edges > 0 else 0.0 - - def __len__(self) -> int: - """节点数量""" - return self.num_nodes - - def has_data(self) -> bool: - """检查磁盘上是否存在现有数据""" - if self.data_dir is None: - return False - return (self.data_dir / "graph_metadata.pkl").exists() - - def __repr__(self) -> str: - return ( - f"GraphStore(nodes={self.num_nodes}, edges={self.num_edges}, " - f"density={self.density:.4f}, format={self.matrix_format})" - ) - - def rebuild_edge_hash_map(self, triples: List[Tuple[str, str, str, str]]) -> int: - """ - 从元数据重建 V5 边哈希映射 (Migration Tool) - - Args: - triples: List of (s, p, o, hash) - - Returns: - count of mapped hashes - """ - count = 0 - self._edge_hash_map = defaultdict(set) - - for s, p, o, h in triples: - if not h: continue - - s_canon = self._canonicalize(s) - o_canon = self._canonicalize(o) - - if s_canon in self._node_to_idx and o_canon in self._node_to_idx: - u = self._node_to_idx[s_canon] - v = self._node_to_idx[o_canon] - - # 如果是双向的,通常在元数据中存储为有向,而 GraphStore 也通常是有向的。 - # 映射键对应特定的边方向。 - self._edge_hash_map[(u, v)].add(h) - count += 1 - - self._adjacency_dirty = True - logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") - return count - diff --git a/plugins/A_memorix/core/storage/knowledge_types.py b/plugins/A_memorix/core/storage/knowledge_types.py deleted file mode 100644 index 4ab91218..00000000 --- a/plugins/A_memorix/core/storage/knowledge_types.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Knowledge type and import strategy helpers.""" - -from __future__ import annotations - -from enum import Enum -from typing import Any, Optional - - -class KnowledgeType(str, Enum): - """持久化到 paragraphs.knowledge_type 的合法类型。""" - - STRUCTURED = "structured" - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - MIXED = "mixed" - - -class ImportStrategy(str, Enum): - """文本导入阶段的策略选择。""" - - AUTO = "auto" - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - - -def allowed_knowledge_type_values() -> tuple[str, ...]: - return tuple(item.value for item in KnowledgeType) - - -def allowed_import_strategy_values() -> tuple[str, ...]: - return tuple(item.value for item in ImportStrategy) - - -def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]: - """从字符串解析合法的落库知识类型。""" - - if not isinstance(type_str, str): - return None - normalized = type_str.lower().strip() - try: - return KnowledgeType(normalized) - except ValueError: - return None - - -def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]: - """从字符串解析文本导入策略。""" - - if not isinstance(value, str): - return None - normalized = value.lower().strip() - try: - return ImportStrategy(normalized) - except ValueError: - return None - - -def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy: - """解析 import strategy;非法值直接报错。""" - - if value is None: - return default - if isinstance(value, ImportStrategy): - return value - - normalized = str(value or "").strip().lower() - if not normalized: - return default - - strategy = get_import_strategy_from_string(normalized) - if strategy is None: - allowed = "/".join(allowed_import_strategy_values()) - raise ValueError(f"strategy_override 必须为 {allowed}") - return strategy - - -def validate_stored_knowledge_type(value: Any) -> KnowledgeType: - """校验写库 knowledge_type,仅允许合法落库类型。""" - - if isinstance(value, KnowledgeType): - return value - - resolved = get_knowledge_type_from_string(value) - if resolved is None: - allowed = "/".join(allowed_knowledge_type_values()) - raise ValueError(f"knowledge_type 必须为 {allowed}") - return resolved - - -def resolve_stored_knowledge_type( - value: Any, - *, - content: str = "", - allow_legacy: bool = False, - unknown_fallback: Optional[KnowledgeType] = None, -) -> KnowledgeType: - """ - 将策略/字符串/旧值解析为合法落库类型。 - - `allow_legacy=True` 仅供迁移使用。 - """ - - if isinstance(value, KnowledgeType): - return value - - if isinstance(value, ImportStrategy): - if value == ImportStrategy.AUTO: - if not str(content or "").strip(): - raise ValueError("knowledge_type=auto 需要 content 才能推断") - from .type_detection import detect_knowledge_type - - return detect_knowledge_type(content) - return KnowledgeType(value.value) - - raw = str(value or "").strip() - if not raw: - if str(content or "").strip(): - from .type_detection import detect_knowledge_type - - return detect_knowledge_type(content) - raise ValueError("knowledge_type 不能为空") - - direct = get_knowledge_type_from_string(raw) - if direct is not None: - return direct - - strategy = get_import_strategy_from_string(raw) - if strategy is not None: - return resolve_stored_knowledge_type(strategy, content=content) - - if allow_legacy: - normalized = raw.lower() - if normalized == "imported": - return KnowledgeType.FACTUAL - if str(content or "").strip(): - from .type_detection import detect_knowledge_type - - detected = detect_knowledge_type(content) - if detected is not None: - return detected - if unknown_fallback is not None: - return unknown_fallback - - allowed = "/".join(allowed_knowledge_type_values()) - raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})") - - -def should_extract_relations(knowledge_type: KnowledgeType) -> bool: - """判断是否应该做关系抽取。""" - - return knowledge_type in [ - KnowledgeType.STRUCTURED, - KnowledgeType.FACTUAL, - KnowledgeType.MIXED, - ] - - -def get_default_chunk_size(knowledge_type: KnowledgeType) -> int: - """获取默认分块大小。""" - - chunk_sizes = { - KnowledgeType.STRUCTURED: 300, - KnowledgeType.NARRATIVE: 800, - KnowledgeType.FACTUAL: 500, - KnowledgeType.QUOTE: 400, - KnowledgeType.MIXED: 500, - } - return chunk_sizes.get(knowledge_type, 500) - - -def get_type_display_name(knowledge_type: KnowledgeType) -> str: - """获取知识类型中文名称。""" - - display_names = { - KnowledgeType.STRUCTURED: "结构化知识", - KnowledgeType.NARRATIVE: "叙事性文本", - KnowledgeType.FACTUAL: "事实陈述", - KnowledgeType.QUOTE: "引用文本", - KnowledgeType.MIXED: "混合类型", - } - return display_names.get(knowledge_type, "未知类型") diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py deleted file mode 100644 index 39f2701c..00000000 --- a/plugins/A_memorix/core/storage/metadata_store.py +++ /dev/null @@ -1,5748 +0,0 @@ -""" -元数据存储模块 - -基于SQLite的元数据管理,存储段落、实体、关系等信息。 -""" - -import sqlite3 -import pickle -import json -import uuid -import re -from datetime import datetime -from pathlib import Path -from typing import Optional, Union, List, Dict, Any, Tuple - -from src.common.logger import get_logger -from ..utils.hash import compute_hash, normalize_text -from ..utils.time_parser import normalize_time_meta -from .knowledge_types import ( - KnowledgeType, - allowed_knowledge_type_values, - resolve_stored_knowledge_type, - validate_stored_knowledge_type, -) - -logger = get_logger("A_Memorix.MetadataStore") - - -SCHEMA_VERSION = 8 - - -class MetadataStore: - """ - 元数据存储类 - - 功能: - - SQLite数据库管理 - - 段落/实体/关系元数据存储 - - 增删改查操作 - - 事务支持 - - 索引优化 - - 参数: - data_dir: 数据目录 - db_name: 数据库文件名(默认metadata.db) - """ - - def __init__( - self, - data_dir: Optional[Union[str, Path]] = None, - db_name: str = "metadata.db", - ): - """ - 初始化元数据存储 - - Args: - data_dir: 数据目录 - db_name: 数据库文件名 - """ - self.data_dir = Path(data_dir) if data_dir else None - self.db_name = db_name - self._conn: Optional[sqlite3.Connection] = None - self._is_initialized = False - self._db_path: Optional[Path] = None - - logger.info(f"MetadataStore 初始化: db={db_name}") - - def connect( - self, - data_dir: Optional[Union[str, Path]] = None, - *, - enforce_schema: bool = True, - ) -> None: - """ - 连接到数据库 - - Args: - data_dir: 数据目录(默认使用初始化时的目录) - """ - if data_dir is None: - data_dir = self.data_dir - - if data_dir is None: - raise ValueError("未指定数据目录") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - db_path = data_dir / self.db_name - db_existed = db_path.exists() - self._db_path = db_path - - # 连接数据库 - self._conn = sqlite3.connect( - str(db_path), - check_same_thread=False, - timeout=30.0, - ) - self._conn.row_factory = sqlite3.Row # 使用字典式访问 - - # 优化性能 - self._conn.execute("PRAGMA journal_mode=WAL") - self._conn.execute("PRAGMA synchronous=NORMAL") - self._conn.execute("PRAGMA cache_size=-64000") # 64MB缓存 - self._conn.execute("PRAGMA temp_store=MEMORY") - self._conn.execute("PRAGMA foreign_keys = ON") # 开启外键约束支持级联删除 - - logger.info(f"连接到数据库: {db_path}") - - # 初始化或校验 schema - if not self._is_initialized: - if not db_existed: - self._initialize_tables() - if enforce_schema: - self._assert_schema_compatible(db_existed=db_existed) - self._is_initialized = True - - # 初始化 FTS schema(幂等) - try: - self.ensure_fts_schema() - except Exception as e: - logger.warning(f"初始化 FTS schema 失败,将跳过 BM25 检索: {e}") - - def _assert_schema_compatible(self, db_existed: bool) -> None: - """vNext 运行时只做 schema 版本校验,不做隐式迁移。""" - cursor = self._conn.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" - ) - has_version_table = cursor.fetchone() is not None - if not has_version_table: - if db_existed: - raise RuntimeError( - "检测到旧版 metadata schema(缺少 schema_migrations)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - return - - cursor.execute("SELECT MAX(version) FROM schema_migrations") - row = cursor.fetchone() - version = int(row[0]) if row and row[0] is not None else 0 - if version != SCHEMA_VERSION: - raise RuntimeError( - f"metadata schema 版本不匹配: current={version}, expected={SCHEMA_VERSION}。" - " 请执行 scripts/release_vnext_migrate.py migrate。" - ) - - def close(self) -> None: - """关闭数据库连接""" - if self._conn: - self._conn.close() - self._conn = None - logger.info("数据库连接已关闭") - - def _initialize_tables(self) -> None: - """初始化数据库表结构""" - cursor = self._conn.cursor() - - # 段落表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraphs ( - hash TEXT PRIMARY KEY, - content TEXT NOT NULL, - vector_index INTEGER, - created_at REAL, - updated_at REAL, - metadata TEXT, - source TEXT, - word_count INTEGER, - event_time REAL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - knowledge_type TEXT DEFAULT 'mixed', - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_deleted INTEGER DEFAULT 0, - deleted_at REAL - ) - """) - - # 实体表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS entities ( - hash TEXT PRIMARY KEY, - name TEXT NOT NULL UNIQUE, - vector_index INTEGER, - appearance_count INTEGER DEFAULT 1, - created_at REAL, - metadata TEXT, - is_deleted INTEGER DEFAULT 0, - deleted_at REAL - ) - """) - - # 关系表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - UNIQUE(subject, predicate, object) - ) - """) - - # 回收站关系表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS deleted_relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - deleted_at REAL - ) - """) - - # 32位哈希别名映射(用于 vNext 唯一解析) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - - # Schema 版本 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at REAL NOT NULL - ) - """) - - # 三元组与段落的关联表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_relations ( - paragraph_hash TEXT NOT NULL, - relation_hash TEXT NOT NULL, - PRIMARY KEY (paragraph_hash, relation_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, - FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE - ) - """) - - # 实体与段落的关联表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_entities ( - paragraph_hash TEXT NOT NULL, - entity_hash TEXT NOT NULL, - mention_count INTEGER DEFAULT 1, - PRIMARY KEY (paragraph_hash, entity_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, - FOREIGN KEY (entity_hash) REFERENCES entities(hash) ON DELETE CASCADE - ) - """) - - # 创建索引 - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_vector - ON paragraphs(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_vector - ON entities(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_vector - ON relations(vector_index) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_subject - ON relations(subject) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_object - ON relations(object) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_name - ON entities(name) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_source - ON paragraphs(source) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraphs_deleted - ON paragraphs(is_deleted, deleted_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_entities_deleted - ON entities(is_deleted, deleted_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_inactive - ON relations(is_inactive, inactive_since) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_relations_protected - ON relations(is_pinned, protected_until) - """) - - # 人物画像开关表(按 stream_id + user_id 维度) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_switches ( - stream_id TEXT NOT NULL, - user_id TEXT NOT NULL, - enabled INTEGER NOT NULL DEFAULT 0, - updated_at REAL NOT NULL, - PRIMARY KEY (stream_id, user_id) - ) - """) - - # 人物画像快照表(版本化) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_snapshots ( - snapshot_id INTEGER PRIMARY KEY AUTOINCREMENT, - person_id TEXT NOT NULL, - profile_version INTEGER NOT NULL, - profile_text TEXT NOT NULL, - aliases_json TEXT, - relation_edges_json TEXT, - vector_evidence_json TEXT, - evidence_ids_json TEXT, - updated_at REAL NOT NULL, - expires_at REAL, - source_note TEXT, - UNIQUE(person_id, profile_version) - ) - """) - - # 已开启范围内的活跃人物集合 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_active_persons ( - stream_id TEXT NOT NULL, - user_id TEXT NOT NULL, - person_id TEXT NOT NULL, - last_seen_at REAL NOT NULL, - PRIMARY KEY (stream_id, user_id, person_id) - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS person_profile_overrides ( - person_id TEXT PRIMARY KEY, - override_text TEXT NOT NULL, - updated_at REAL NOT NULL, - updated_by TEXT, - source TEXT - ) - """) - - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_switches_enabled - ON person_profile_switches(enabled) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_snapshots_person - ON person_profile_snapshots(person_id, updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_active_seen - ON person_profile_active_persons(last_seen_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_person_profile_overrides_updated - ON person_profile_overrides(updated_at DESC) - """) - - # Episode 情景记忆表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episodes ( - episode_id TEXT PRIMARY KEY, - source TEXT, - title TEXT NOT NULL, - summary TEXT NOT NULL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - participants_json TEXT, - keywords_json TEXT, - evidence_ids_json TEXT, - paragraph_count INTEGER DEFAULT 0, - llm_confidence REAL DEFAULT 0.0, - segmentation_model TEXT, - segmentation_version TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - - # Episode -> Paragraph 映射 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_paragraphs ( - episode_id TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - position INTEGER DEFAULT 0, - PRIMARY KEY (episode_id, paragraph_hash), - FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - - # Episode 生成队列(异步) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( - paragraph_hash TEXT PRIMARY KEY, - source TEXT, - created_at REAL, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( - source TEXT PRIMARY KEY, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - reason TEXT, - requested_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end - ON episodes(source, event_time_end DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_updated_at - ON episodes(updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph - ON episode_paragraphs(paragraph_hash) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated - ON episode_pending_paragraphs(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created - ON episode_pending_paragraphs(source, created_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated - ON episode_rebuild_sources(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at - ON episode_rebuild_sources(updated_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS external_memory_refs ( - external_id TEXT PRIMARY KEY, - paragraph_hash TEXT NOT NULL, - source_type TEXT, - created_at REAL NOT NULL, - metadata_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph - ON external_memory_refs(paragraph_hash) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS memory_v5_operations ( - operation_id TEXT PRIMARY KEY, - action TEXT NOT NULL, - target TEXT, - reason TEXT, - updated_by TEXT, - created_at REAL NOT NULL, - resolved_hashes_json TEXT, - result_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created - ON memory_v5_operations(created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operations ( - operation_id TEXT PRIMARY KEY, - mode TEXT NOT NULL, - selector TEXT, - reason TEXT, - requested_by TEXT, - status TEXT NOT NULL, - created_at REAL NOT NULL, - restored_at REAL, - summary_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_created - ON delete_operations(created_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_mode - ON delete_operations(mode, created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operation_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - operation_id TEXT NOT NULL, - item_type TEXT NOT NULL, - item_hash TEXT, - item_key TEXT, - payload_json TEXT, - created_at REAL NOT NULL, - FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation - ON delete_operation_items(operation_id, id ASC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash - ON delete_operation_items(item_hash) - """) - # 新版 schema 包含完整字段,直接写入版本信息 - cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) - self._conn.commit() - logger.debug("数据库表结构初始化完成") - - def _migrate_schema(self) -> None: - """执行数据库schema迁移""" - cursor = self._conn.cursor() - - # vNext 关键表兜底:历史库可能缺失,需在迁移阶段主动补齐。 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, - applied_at REAL NOT NULL - ) - """) - - # Episode MVP 表结构补齐 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episodes ( - episode_id TEXT PRIMARY KEY, - source TEXT, - title TEXT NOT NULL, - summary TEXT NOT NULL, - event_time_start REAL, - event_time_end REAL, - time_granularity TEXT, - time_confidence REAL DEFAULT 1.0, - participants_json TEXT, - keywords_json TEXT, - evidence_ids_json TEXT, - paragraph_count INTEGER DEFAULT 0, - llm_confidence REAL DEFAULT 0.0, - segmentation_model TEXT, - segmentation_version TEXT, - created_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_paragraphs ( - episode_id TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - position INTEGER DEFAULT 0, - PRIMARY KEY (episode_id, paragraph_hash), - FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( - paragraph_hash TEXT PRIMARY KEY, - source TEXT, - created_at REAL, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( - source TEXT PRIMARY KEY, - status TEXT DEFAULT 'pending', - retry_count INTEGER DEFAULT 0, - last_error TEXT, - reason TEXT, - requested_at REAL NOT NULL, - updated_at REAL NOT NULL - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end - ON episodes(source, event_time_end DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episodes_updated_at - ON episodes(updated_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph - ON episode_paragraphs(paragraph_hash) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated - ON episode_pending_paragraphs(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created - ON episode_pending_paragraphs(source, created_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated - ON episode_rebuild_sources(status, updated_at) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at - ON episode_rebuild_sources(updated_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS external_memory_refs ( - external_id TEXT PRIMARY KEY, - paragraph_hash TEXT NOT NULL, - source_type TEXT, - created_at REAL NOT NULL, - metadata_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph - ON external_memory_refs(paragraph_hash) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS memory_v5_operations ( - operation_id TEXT PRIMARY KEY, - action TEXT NOT NULL, - target TEXT, - reason TEXT, - updated_by TEXT, - created_at REAL NOT NULL, - resolved_hashes_json TEXT, - result_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_memory_v5_operations_created - ON memory_v5_operations(created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operations ( - operation_id TEXT PRIMARY KEY, - mode TEXT NOT NULL, - selector TEXT, - reason TEXT, - requested_by TEXT, - status TEXT NOT NULL, - created_at REAL NOT NULL, - restored_at REAL, - summary_json TEXT - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_created - ON delete_operations(created_at DESC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operations_mode - ON delete_operations(mode, created_at DESC) - """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS delete_operation_items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - operation_id TEXT NOT NULL, - item_type TEXT NOT NULL, - item_hash TEXT, - item_key TEXT, - payload_json TEXT, - created_at REAL NOT NULL, - FOREIGN KEY (operation_id) REFERENCES delete_operations(operation_id) ON DELETE CASCADE - ) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_operation - ON delete_operation_items(operation_id, id ASC) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_delete_operation_items_hash - ON delete_operation_items(item_hash) - """) - - # 检查paragraphs表是否有knowledge_type列 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - - if "knowledge_type" not in columns: - logger.info("检测到旧版schema,正在迁移添加knowledge_type字段...") - try: - cursor.execute(""" - ALTER TABLE paragraphs - ADD COLUMN knowledge_type TEXT DEFAULT 'mixed' - """) - self._conn.commit() - logger.info("Schema迁移完成:已添加knowledge_type字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败(可能已存在): {e}") - - # 问题2: 时序字段迁移 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - temporal_columns = { - "event_time": "ALTER TABLE paragraphs ADD COLUMN event_time REAL", - "event_time_start": "ALTER TABLE paragraphs ADD COLUMN event_time_start REAL", - "event_time_end": "ALTER TABLE paragraphs ADD COLUMN event_time_end REAL", - "time_granularity": "ALTER TABLE paragraphs ADD COLUMN time_granularity TEXT", - "time_confidence": "ALTER TABLE paragraphs ADD COLUMN time_confidence REAL DEFAULT 1.0", - } - for col, sql in temporal_columns.items(): - if col not in columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败({col}): {e}") - - # 时序索引(仅在列存在时创建,兼容旧库迁移) - self._create_temporal_indexes_if_ready() - self._conn.commit() - - # 检查paragraphs表是否有is_permanent列 - cursor.execute("PRAGMA table_info(paragraphs)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_permanent" not in columns: - logger.info("正在迁移: 添加记忆动态字段...") - try: - # 段落表 - cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_permanent BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN last_accessed REAL") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN access_count INTEGER DEFAULT 0") - - # 关系表 - cursor.execute("ALTER TABLE relations ADD COLUMN is_permanent BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN last_accessed REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN access_count INTEGER DEFAULT 0") - - self._conn.commit() - logger.info("Schema迁移完成:已添加记忆动态字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败: {e}") - - # 检查relations表是否有is_inactive列 (V5 Memory System) - cursor.execute("PRAGMA table_info(relations)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_inactive" not in columns: - logger.info("正在迁移: 添加V5记忆动态字段 (inactive, protected)...") - try: - # 关系表 V5 新增字段 - cursor.execute("ALTER TABLE relations ADD COLUMN is_inactive BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN inactive_since REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN is_pinned BOOLEAN DEFAULT 0") - cursor.execute("ALTER TABLE relations ADD COLUMN protected_until REAL") - cursor.execute("ALTER TABLE relations ADD COLUMN last_reinforced REAL") - - # 为回收站创建 deleted_relations 表 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS deleted_relations ( - hash TEXT PRIMARY KEY, - subject TEXT NOT NULL, - predicate TEXT NOT NULL, - object TEXT NOT NULL, - vector_index INTEGER, - confidence REAL DEFAULT 1.0, - vector_state TEXT DEFAULT 'none', - vector_updated_at REAL, - vector_error TEXT, - vector_retry_count INTEGER DEFAULT 0, - created_at REAL, - source_paragraph TEXT, - metadata TEXT, - is_permanent BOOLEAN DEFAULT 0, - last_accessed REAL, - access_count INTEGER DEFAULT 0, - is_inactive BOOLEAN DEFAULT 0, - inactive_since REAL, - is_pinned BOOLEAN DEFAULT 0, - protected_until REAL, - last_reinforced REAL, - deleted_at REAL -- 用于记录删除时间的额外列 - ) - """) - - self._conn.commit() - logger.info("Schema迁移完成:已添加V5记忆动态字段及回收站表") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (V5): {e}") - - # 关系向量状态字段迁移 - cursor.execute("PRAGMA table_info(relations)") - relation_columns = {row[1] for row in cursor.fetchall()} - relation_vector_columns = { - "vector_state": "ALTER TABLE relations ADD COLUMN vector_state TEXT DEFAULT 'none'", - "vector_updated_at": "ALTER TABLE relations ADD COLUMN vector_updated_at REAL", - "vector_error": "ALTER TABLE relations ADD COLUMN vector_error TEXT", - "vector_retry_count": "ALTER TABLE relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", - } - for col, sql in relation_vector_columns.items(): - if col not in relation_columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (relations.{col}): {e}") - - # 回收站同步字段迁移(用于 restore 保留向量状态) - cursor.execute("PRAGMA table_info(deleted_relations)") - deleted_relation_columns = {row[1] for row in cursor.fetchall()} - deleted_relation_vector_columns = { - "vector_state": "ALTER TABLE deleted_relations ADD COLUMN vector_state TEXT DEFAULT 'none'", - "vector_updated_at": "ALTER TABLE deleted_relations ADD COLUMN vector_updated_at REAL", - "vector_error": "ALTER TABLE deleted_relations ADD COLUMN vector_error TEXT", - "vector_retry_count": "ALTER TABLE deleted_relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", - } - for col, sql in deleted_relation_vector_columns.items(): - if col not in deleted_relation_columns: - try: - cursor.execute(sql) - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (deleted_relations.{col}): {e}") - - # 检查 entities 表是否有 is_deleted 列 (Soft Delete System) - cursor.execute("PRAGMA table_info(entities)") - columns = [row[1] for row in cursor.fetchall()] - - if "is_deleted" not in columns: - logger.info("正在迁移: 添加软删除字段 (Soft Delete)...") - try: - # 实体表 - cursor.execute("ALTER TABLE entities ADD COLUMN is_deleted INTEGER DEFAULT 0") - cursor.execute("ALTER TABLE entities ADD COLUMN deleted_at REAL") - - # 段落表 - cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_deleted INTEGER DEFAULT 0") - cursor.execute("ALTER TABLE paragraphs ADD COLUMN deleted_at REAL") - - self._conn.commit() - logger.info("Schema迁移完成:已添加软删除字段") - except sqlite3.OperationalError as e: - logger.warning(f"Schema迁移失败 (Soft Delete): {e}") - - # 数据修复: 检查是否存在 source/vector_index 列错位的情况 - # 症状: vector_index (本应是int) 变成了文件名字符串, source (本应是文件名) 变成了类型字符串 - try: - cursor.execute(""" - SELECT count(*) FROM paragraphs - WHERE typeof(vector_index) = 'text' - AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') - """) - count = cursor.fetchone()[0] - if count > 0: - logger.warning(f"检测到 {count} 条数据存在列错位(文件名误存入vector_index),正在自动修复...") - cursor.execute(""" - UPDATE paragraphs - SET - knowledge_type = source, - source = vector_index, - vector_index = NULL - WHERE typeof(vector_index) = 'text' - AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') - """) - self._conn.commit() - logger.info(f"自动修复完成: 已校正 {cursor.rowcount} 条数据") - except Exception as e: - logger.error(f"数据自动修复失败: {e}") - - def _create_temporal_indexes_if_ready(self) -> None: - """ - 仅当时序列已存在时创建索引。 - - 旧库升级时,_initialize_tables 不能提前对不存在的列建索引; - 因此统一在迁移阶段按列存在性安全创建。 - """ - cursor = self._conn.cursor() - cursor.execute("PRAGMA table_info(paragraphs)") - columns = {row[1] for row in cursor.fetchall()} - - if "event_time" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_time ON paragraphs(event_time)" - ) - if "event_time_start" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_start ON paragraphs(event_time_start)" - ) - if "event_time_end" in columns: - cursor.execute( - "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_end ON paragraphs(event_time_end)" - ) - - def run_legacy_migration_for_vnext(self) -> Dict[str, Any]: - """ - 离线迁移入口: - - 复用旧迁移逻辑补齐历史库字段 - - 重建 relation 32位别名 - - 归一化历史 knowledge_type - - 写入 vNext schema 版本 - """ - self._migrate_schema() - alias_result = self.rebuild_relation_hash_aliases() - knowledge_type_result = self.normalize_paragraph_knowledge_types() - self.set_schema_version(SCHEMA_VERSION) - return { - "schema_version": SCHEMA_VERSION, - "alias_result": alias_result, - "knowledge_type_result": knowledge_type_result, - } - - def list_invalid_paragraph_knowledge_types(self) -> List[str]: - """列出当前库中不合法的段落 knowledge_type。""" - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT DISTINCT knowledge_type - FROM paragraphs - WHERE knowledge_type IS NULL - OR TRIM(COALESCE(knowledge_type, '')) = '' - OR LOWER(TRIM(knowledge_type)) NOT IN ({placeholders}) - ORDER BY knowledge_type - """.format(placeholders=", ".join("?" for _ in allowed_knowledge_type_values())), - tuple(allowed_knowledge_type_values()), - ) - invalid: List[str] = [] - for row in cursor.fetchall(): - raw = row[0] - invalid.append(str(raw) if raw is not None else "") - return invalid - - def normalize_paragraph_knowledge_types(self) -> Dict[str, Any]: - """将历史非法 knowledge_type 归一化为合法值。""" - - cursor = self._conn.cursor() - cursor.execute("SELECT hash, content, knowledge_type FROM paragraphs") - rows = cursor.fetchall() - - normalized_count = 0 - normalized_map: Dict[str, int] = {} - invalid_before: List[str] = [] - invalid_seen = set() - - for row in rows: - paragraph_hash = str(row["hash"]) - content = str(row["content"] or "") - raw_value = row["knowledge_type"] - try: - validate_stored_knowledge_type(raw_value) - continue - except ValueError: - raw_text = str(raw_value) if raw_value is not None else "" - if raw_text not in invalid_seen: - invalid_seen.add(raw_text) - invalid_before.append(raw_text) - - normalized_type = resolve_stored_knowledge_type( - raw_value, - content=content, - allow_legacy=True, - unknown_fallback=KnowledgeType.MIXED, - ) - cursor.execute( - "UPDATE paragraphs SET knowledge_type = ? WHERE hash = ?", - (normalized_type.value, paragraph_hash), - ) - normalized_count += 1 - normalized_map[normalized_type.value] = normalized_map.get(normalized_type.value, 0) + 1 - - self._conn.commit() - return { - "normalized": normalized_count, - "invalid_before": sorted(invalid_before), - "normalized_to": normalized_map, - } - - def _resolve_conn(self, conn: Optional[sqlite3.Connection] = None) -> sqlite3.Connection: - """解析可用连接。""" - resolved = conn or self._conn - if resolved is None: - raise RuntimeError("MetadataStore 未连接数据库") - return resolved - - def get_db_path(self) -> Path: - """获取 SQLite 数据库文件路径。""" - if self._db_path is not None: - return self._db_path - if self.data_dir is None: - raise RuntimeError("MetadataStore 未配置 data_dir") - return Path(self.data_dir) / self.db_name - - def ensure_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保 FTS5 schema 存在(幂等)。 - - 采用 external-content 方式,不在 FTS 表重复存储正文。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS paragraphs_fts - USING fts5( - content, - content='paragraphs', - content_rowid='rowid', - tokenize='unicode61' - ) - """) - - # insert trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_ai - AFTER INSERT ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(rowid, content) - VALUES (new.rowid, new.content); - END - """) - - # delete trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_ad - AFTER DELETE ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) - VALUES ('delete', old.rowid, old.content); - END - """) - - # update trigger - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS paragraphs_au - AFTER UPDATE OF content ON paragraphs - BEGIN - INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) - VALUES ('delete', old.rowid, old.content); - INSERT INTO paragraphs_fts(rowid, content) - VALUES (new.rowid, new.content); - END - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS5 schema 创建失败(可能不支持 FTS5): {e}") - c.rollback() - return False - - def ensure_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保 FTS 索引已回填。 - - 当历史数据存在但 FTS 表为空/不一致时执行 rebuild。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) AS n FROM paragraphs") - para_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(1) AS n FROM paragraphs_fts") - fts_count = int(cur.fetchone()[0]) - - if para_count > 0 and fts_count != para_count: - cur.execute("INSERT INTO paragraphs_fts(paragraphs_fts) VALUES ('rebuild')") - c.commit() - logger.info(f"FTS 回填完成: paragraphs={para_count}, fts={para_count}") - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS 回填失败: {e}") - c.rollback() - return False - - def ensure_relations_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """ - 确保关系 FTS5 schema 存在(幂等)。 - - 注意:relations 表没有 content 列,因此使用独立 FTS 表并通过触发器同步。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE VIRTUAL TABLE IF NOT EXISTS relations_fts - USING fts5( - relation_hash UNINDEXED, - content, - tokenize='unicode61' - ) - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_ai - AFTER INSERT ON relations - BEGIN - INSERT INTO relations_fts(relation_hash, content) - VALUES ( - new.hash, - COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') - ); - END - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_ad - AFTER DELETE ON relations - BEGIN - DELETE FROM relations_fts WHERE relation_hash = old.hash; - END - """) - - cur.execute(""" - CREATE TRIGGER IF NOT EXISTS relations_au - AFTER UPDATE OF subject, predicate, object ON relations - BEGIN - DELETE FROM relations_fts WHERE relation_hash = new.hash; - INSERT INTO relations_fts(relation_hash, content) - VALUES ( - new.hash, - COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') - ); - END - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS5 schema 创建失败(可能不支持 FTS5): {e}") - c.rollback() - return False - - def ensure_relations_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """确保关系 FTS 索引已回填。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) AS n FROM relations") - rel_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(1) AS n FROM relations_fts") - fts_count = int(cur.fetchone()[0]) - - if rel_count != fts_count: - cur.execute("DELETE FROM relations_fts") - cur.execute(""" - INSERT INTO relations_fts(relation_hash, content) - SELECT - r.hash, - COALESCE(r.subject, '') || ' ' || COALESCE(r.predicate, '') || ' ' || COALESCE(r.object, '') - FROM relations r - """) - c.commit() - logger.info(f"relations FTS 回填完成: relations={rel_count}, fts={rel_count}") - return True - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS 回填失败: {e}") - c.rollback() - return False - - def ensure_paragraph_ngram_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: - """确保段落 ngram 倒排表存在。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_ngrams ( - term TEXT NOT NULL, - paragraph_hash TEXT NOT NULL, - PRIMARY KEY (term, paragraph_hash), - FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE - ) - """) - cur.execute(""" - CREATE INDEX IF NOT EXISTS idx_paragraph_ngrams_hash - ON paragraph_ngrams(paragraph_hash) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS paragraph_ngram_meta ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"paragraph ngram schema 创建失败: {e}") - c.rollback() - return False - - @staticmethod - def _char_ngrams(text: str, n: int) -> List[str]: - compact = "".join(str(text or "").lower().split()) - if not compact: - return [] - if len(compact) < n: - return [compact] - return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] - - def ensure_paragraph_ngram_backfilled( - self, - n: int = 2, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 确保段落 ngram 倒排索引已回填。 - - 仅在 n 变化或文档数量变化时重建,避免每次加载都全量重建。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - n = max(1, int(n)) - try: - cur.execute("SELECT value FROM paragraph_ngram_meta WHERE key='ngram_n'") - row = cur.fetchone() - current_n = int(row[0]) if row and row[0] is not None else None - - cur.execute("SELECT COUNT(1) FROM paragraphs WHERE is_deleted IS NULL OR is_deleted = 0") - para_count = int(cur.fetchone()[0]) - cur.execute("SELECT COUNT(DISTINCT paragraph_hash) FROM paragraph_ngrams") - indexed_docs = int(cur.fetchone()[0]) - - need_rebuild = (current_n != n) or (para_count != indexed_docs) - if not need_rebuild: - return True - - cur.execute("DELETE FROM paragraph_ngrams") - cur.execute(""" - SELECT hash, content - FROM paragraphs - WHERE is_deleted IS NULL OR is_deleted = 0 - """) - rows = cur.fetchall() - - batch: List[Tuple[str, str]] = [] - batch_size = 2000 - for row in rows: - p_hash = str(row["hash"]) - terms = list(dict.fromkeys(self._char_ngrams(str(row["content"] or ""), n))) - for term in terms: - batch.append((term, p_hash)) - if len(batch) >= batch_size: - cur.executemany( - "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", - batch, - ) - batch.clear() - if batch: - cur.executemany( - "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", - batch, - ) - - cur.execute(""" - INSERT INTO paragraph_ngram_meta(key, value) VALUES('ngram_n', ?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value - """, (str(n),)) - cur.execute(""" - INSERT INTO paragraph_ngram_meta(key, value) VALUES('paragraph_count', ?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value - """, (str(para_count),)) - c.commit() - logger.info(f"paragraph ngram 回填完成: n={n}, paragraphs={para_count}") - return True - except Exception as e: - logger.warning(f"paragraph ngram 回填失败: {e}") - c.rollback() - return False - - def fts_upsert_paragraph( - self, - paragraph_hash: str, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 将段落写入(或覆盖)到 FTS 索引。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - "SELECT rowid, content FROM paragraphs WHERE hash = ?", - (paragraph_hash,), - ) - row = cur.fetchone() - if not row: - return False - rowid = int(row[0]) - content = str(row[1] or "") - cur.execute( - "INSERT OR REPLACE INTO paragraphs_fts(rowid, content) VALUES (?, ?)", - (rowid, content), - ) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS upsert 失败: {e}") - c.rollback() - return False - - def fts_delete_paragraph( - self, - paragraph_hash: str, - conn: Optional[sqlite3.Connection] = None, - ) -> bool: - """ - 从 FTS 索引删除段落。 - """ - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - "SELECT rowid, content FROM paragraphs WHERE hash = ?", - (paragraph_hash,), - ) - row = cur.fetchone() - if not row: - return False - rowid = int(row[0]) - content = str(row[1] or "") - cur.execute( - "INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) VALUES ('delete', ?, ?)", - (rowid, content), - ) - c.commit() - return True - except sqlite3.OperationalError as e: - logger.warning(f"FTS delete 失败: {e}") - c.rollback() - return False - - def fts_search_bm25( - self, - match_query: str, - limit: int = 20, - max_doc_len: int = 2000, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """ - 使用 FTS5 + bm25 执行全文检索。 - """ - if not match_query.strip(): - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - """ - SELECT p.hash, p.content, bm25(paragraphs_fts) AS bm25_score - FROM paragraphs_fts - JOIN paragraphs p ON p.rowid = paragraphs_fts.rowid - WHERE paragraphs_fts MATCH ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY bm25_score ASC - LIMIT ? - """, - (match_query, max(1, int(limit))), - ) - rows = cur.fetchall() - results: List[Dict[str, Any]] = [] - for row in rows: - content = str(row["content"] or "") - if max_doc_len > 0: - content = content[:max_doc_len] - results.append( - { - "hash": row["hash"], - "content": content, - "bm25_score": float(row["bm25_score"]), - } - ) - return results - except sqlite3.OperationalError as e: - logger.warning(f"FTS 查询失败: {e}") - return [] - - def fts_search_relations_bm25( - self, - match_query: str, - limit: int = 20, - max_doc_len: int = 512, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """使用 FTS5 + bm25 执行关系全文检索。""" - if not match_query.strip(): - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute( - """ - SELECT - r.hash, - r.subject, - r.predicate, - r.object, - bm25(relations_fts) AS bm25_score - FROM relations_fts - JOIN relations r ON r.hash = relations_fts.relation_hash - WHERE relations_fts MATCH ? - ORDER BY bm25_score ASC - LIMIT ? - """, - (match_query, max(1, int(limit))), - ) - rows = cur.fetchall() - out: List[Dict[str, Any]] = [] - for row in rows: - content = f"{row['subject']} {row['predicate']} {row['object']}" - if max_doc_len > 0: - content = content[:max_doc_len] - out.append( - { - "hash": row["hash"], - "subject": row["subject"], - "predicate": row["predicate"], - "object": row["object"], - "content": content, - "bm25_score": float(row["bm25_score"]), - } - ) - return out - except sqlite3.OperationalError as e: - logger.warning(f"relations FTS 查询失败: {e}") - return [] - - def ngram_search_paragraphs( - self, - tokens: List[str], - limit: int = 20, - max_doc_len: int = 2000, - conn: Optional[sqlite3.Connection] = None, - ) -> List[Dict[str, Any]]: - """按 ngram 倒排索引检索段落,避免 LIKE 全表扫描。""" - uniq = [t for t in dict.fromkeys([str(x).strip().lower() for x in tokens]) if t] - if not uniq: - return [] - - c = self._resolve_conn(conn) - cur = c.cursor() - placeholders = ",".join(["?"] * len(uniq)) - try: - cur.execute( - f""" - SELECT - p.hash, - p.content, - COUNT(*) AS hit_terms - FROM paragraph_ngrams ng - JOIN paragraphs p ON p.hash = ng.paragraph_hash - WHERE ng.term IN ({placeholders}) - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - GROUP BY p.hash, p.content - ORDER BY hit_terms DESC - LIMIT ? - """, - tuple(uniq + [max(1, int(limit))]), - ) - rows = cur.fetchall() - out: List[Dict[str, Any]] = [] - token_count = max(1, len(uniq)) - for row in rows: - hit_terms = int(row["hit_terms"]) - score = float(hit_terms / token_count) - content = str(row["content"] or "") - if max_doc_len > 0: - content = content[:max_doc_len] - out.append( - { - "hash": row["hash"], - "content": content, - "bm25_score": -score, - "fallback_score": score, - } - ) - return out - except sqlite3.OperationalError as e: - logger.warning(f"ngram 倒排查询失败: {e}") - return [] - - def fts_doc_count(self, conn: Optional[sqlite3.Connection] = None) -> int: - """获取 FTS 文档数量。""" - c = self._resolve_conn(conn) - cur = c.cursor() - try: - cur.execute("SELECT COUNT(1) FROM paragraphs_fts") - return int(cur.fetchone()[0]) - except sqlite3.OperationalError: - return 0 - - def shrink_memory(self, conn: Optional[sqlite3.Connection] = None) -> None: - """请求 SQLite 收缩当前连接缓存。""" - c = self._resolve_conn(conn) - try: - c.execute("PRAGMA shrink_memory") - except sqlite3.OperationalError: - pass - - @staticmethod - def _normalize_episode_source(source: Any) -> str: - return str(source or "").strip() - - def _dedupe_episode_sources(self, sources: List[Any]) -> List[str]: - normalized: List[str] = [] - seen = set() - for item in sources or []: - token = self._normalize_episode_source(item) - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - return normalized - - def _get_sources_for_paragraph_hashes( - self, - hashes: List[str], - *, - include_deleted: bool = True, - ) -> List[str]: - normalized_hashes = [ - str(item or "").strip() - for item in (hashes or []) - if str(item or "").strip() - ] - if not normalized_hashes: - return [] - - placeholders = ",".join(["?"] * len(normalized_hashes)) - conditions = ["hash IN ({})".format(placeholders), "TRIM(COALESCE(source, '')) != ''"] - if not include_deleted: - conditions.append("(is_deleted IS NULL OR is_deleted = 0)") - - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT DISTINCT TRIM(source) AS source - FROM paragraphs - WHERE {' AND '.join(conditions)} - """, - tuple(normalized_hashes), - ) - return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) - - def _enqueue_episode_source_rebuilds(self, sources: List[Any], reason: str = "") -> int: - normalized_sources = self._dedupe_episode_sources(sources) - if not normalized_sources: - return 0 - - now = datetime.now().timestamp() - reason_text = str(reason or "").strip()[:200] or None - cursor = self._conn.cursor() - cursor.executemany( - """ - INSERT INTO episode_rebuild_sources ( - source, status, retry_count, last_error, reason, requested_at, updated_at - ) VALUES (?, 'pending', 0, NULL, ?, ?, ?) - ON CONFLICT(source) DO UPDATE SET - status = 'pending', - last_error = NULL, - reason = excluded.reason, - requested_at = excluded.requested_at, - updated_at = excluded.updated_at - """, - [ - (source, reason_text, now, now) - for source in normalized_sources - ], - ) - self._conn.commit() - return len(normalized_sources) - - def add_paragraph( - self, - content: str, - vector_index: Optional[int] = None, - source: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - knowledge_type: str = "mixed", - time_meta: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加段落 - - Args: - content: 段落内容 - vector_index: 向量索引 - source: 来源 - metadata: 额外元数据 - knowledge_type: 知识类型 (narrative/factual/quote/structured/mixed) - time_meta: 时间元信息 (event_time/event_time_start/event_time_end/...) - - Returns: - 段落哈希值 - """ - content_normalized = normalize_text(content) - hash_value = compute_hash(content_normalized) - resolved_knowledge_type = validate_stored_knowledge_type(knowledge_type) - - now = datetime.now().timestamp() - word_count = len(content_normalized.split()) - normalized_time = normalize_time_meta(time_meta) - - cursor = self._conn.cursor() - try: - cursor.execute(""" - INSERT INTO paragraphs - ( - hash, content, vector_index, created_at, updated_at, metadata, source, word_count, - event_time, event_time_start, event_time_end, time_granularity, time_confidence, - knowledge_type - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - hash_value, - content, - vector_index, - now, - now, - pickle.dumps(metadata or {}), - source, - word_count, - normalized_time.get("event_time"), - normalized_time.get("event_time_start"), - normalized_time.get("event_time_end"), - normalized_time.get("time_granularity"), - normalized_time.get("time_confidence", 1.0), - resolved_knowledge_type.value, - )) - self._conn.commit() - try: - self.enqueue_episode_source_rebuild( - source=source, - reason="paragraph_added", - ) - except Exception as e: - logger.warning(f"Episode source 重建入队失败: hash={hash_value[:16]}..., err={e}") - logger.debug( - f"添加段落: hash={hash_value[:16]}..., words={word_count}, type={resolved_knowledge_type.value}" - ) - return hash_value - except sqlite3.IntegrityError: - logger.debug(f"段落已存在: {hash_value[:16]}...") - # 尝试复活 - self.revive_if_deleted(paragraph_hashes=[hash_value]) - return hash_value - - def _canonicalize_name(self, name: str) -> str: - """ - 规范化名称 (统一小写并去除首尾空格) - - Args: - name: 原始名称 - - Returns: - 规范化后的名称 - """ - if not name: - return "" - return name.strip().lower() - - def add_entity( - self, - name: str, - vector_index: Optional[int] = None, - source_paragraph: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加实体 - - Args: - name: 实体名称 - vector_index: 向量索引 - source_paragraph: 来源段落哈希 (如果提供,将建立关联) - metadata: 额外元数据 - - Returns: - 实体哈希值 - """ - # 1. 规范化名称 - name_normalized = self._canonicalize_name(name) - if not name_normalized: - raise ValueError("Entity name cannot be empty") - - hash_value = compute_hash(name_normalized) - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - - # 2. 插入实体 (INSERT OR IGNORE) - # 注意:这里我们保留原有的 name 字段存储,可以是 display name, - # 但 hash 必须由 canonical name 生成。 - # 如果实体已存在,我们其实不一定要更新 name (保留第一次的 display name 往往更好) - # 或者我们也可以选择不作为唯一键冲突,而是逻辑判断。 - # 考虑到 entities.hash 是主键,entities.name 是 UNIQUE。 - # 如果 name 大小写不同但 hash 相同 (冲突),或者 name 不同但 canonical name 相同? - # 由于 hash 是由 canonical name 算出来的,所以 hash 相同意味着 canonical name 相同。 - # 如果 db 中已存在的 name 是 "Apple",新来的 name 是 "apple",它们 canonical name 都是 "apple",hash 一样。 - # 此时 INSERT OR IGNORE 会忽略。 - - try: - cursor.execute(""" - INSERT INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, 1, ?, ?) - """, ( - hash_value, - name, - vector_index, - now, - pickle.dumps(metadata or {}), - )) - - logger.debug(f"添加实体: {name} ({hash_value[:8]})") - self._conn.commit() - - # 3. 建立来源关联 - if source_paragraph: - self.link_paragraph_entity(source_paragraph, hash_value) - - return hash_value - - except sqlite3.IntegrityError: - # 实体已存在 - # 1. 尝试复活 (自动复活) - self.revive_if_deleted(entity_hashes=[hash_value]) - - # 2. 更新计数 - cursor.execute(""" - UPDATE entities - SET appearance_count = appearance_count + 1 - WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - logger.debug(f"实体已存在(复活/计数+1): {name}") - - # 3. 建立来源关联 - if source_paragraph: - self.link_paragraph_entity(source_paragraph, hash_value) - - return hash_value - - def add_relation( - self, - subject: str, - predicate: str, - obj: str, - vector_index: Optional[int] = None, - confidence: float = 1.0, - source_paragraph: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - ) -> str: - """ - 添加关系 - - Args: - subject: 主语 - predicate: 谓语 - obj: 宾语 - vector_index: 向量索引 - confidence: 置信度 - source_paragraph: 来源段落哈希 - metadata: 额外元数据 - - Returns: - 关系哈希值 - """ - # 1. 规范化输入 - s_canon = self._canonicalize_name(subject) - p_canon = self._canonicalize_name(predicate) - o_canon = self._canonicalize_name(obj) - - if not all([s_canon, p_canon, o_canon]): - raise ValueError("Relation components cannot be empty") - - # 2. 计算组合哈希 - # 公式: md5(s|p|o) - relation_key = f"{s_canon}|{p_canon}|{o_canon}" - hash_value = compute_hash(relation_key) - - now = datetime.now().timestamp() - - # 记录原始 display name 到 metadata (如果需要的话,或者直接存到 DB 字段) - # 这里我们直接存入 subject, predicate, object 字段, - # 注意:如果 DB 里已存在该关系 (hash 相同),则不会更新这些字段,保留第一次的拼写。 - - cursor = self._conn.cursor() - try: - cursor.execute(""" - INSERT OR IGNORE INTO relations - (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - hash_value, - subject, # 原始拼写 - predicate, - obj, - vector_index, - confidence, - now, - source_paragraph, # 这里的 source_paragraph 仅作为 "首次发现地" 记录,也可留空 - pickle.dumps(metadata or {}), - )) - self._conn.commit() - - if cursor.rowcount > 0: - logger.debug(f"添加关系: {subject} -{predicate}-> {obj}") - else: - logger.debug(f"关系已存在: {subject} -{predicate}-> {obj}") - - # 3. 建立来源关联 (幂等) - # 无论关系是新创建的还是已存在的,只要提供了 source_paragraph,都要建立连接 - if source_paragraph: - self.link_paragraph_relation(source_paragraph, hash_value) - - return hash_value - - except sqlite3.IntegrityError as e: - logger.warning(f"添加关系异常: {e}") - return hash_value - - def link_paragraph_relation( - self, - paragraph_hash: str, - relation_hash: str, - ) -> bool: - """ - 关联段落和关系 (幂等) - """ - cursor = self._conn.cursor() - try: - # 使用 INSERT OR IGNORE 避免重复报错 - cursor.execute(""" - INSERT OR IGNORE INTO paragraph_relations - (paragraph_hash, relation_hash) - VALUES (?, ?) - """, (paragraph_hash, relation_hash)) - self._conn.commit() - self._enqueue_episode_source_rebuilds( - self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), - reason="paragraph_relation_linked", - ) - return True - except sqlite3.IntegrityError: - return False - - def link_paragraph_entity( - self, - paragraph_hash: str, - entity_hash: str, - mention_count: int = 1, - ) -> bool: - """ - 关联段落和实体 (幂等) - """ - cursor = self._conn.cursor() - try: - # 首先尝试插入 - cursor.execute(""" - INSERT OR IGNORE INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, (paragraph_hash, entity_hash, mention_count)) - - if cursor.rowcount == 0: - # 如果已存在 (IGNORE生效),则更新计数 - cursor.execute(""" - UPDATE paragraph_entities - SET mention_count = mention_count + ? - WHERE paragraph_hash = ? AND entity_hash = ? - """, (mention_count, paragraph_hash, entity_hash)) - - self._conn.commit() - self._enqueue_episode_source_rebuilds( - self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), - reason="paragraph_entity_linked", - ) - return True - except sqlite3.IntegrityError: - return False - - def get_paragraph(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取段落 - - Args: - hash_value: 段落哈希 - - Returns: - 段落信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM paragraphs WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "paragraph") - return None - - def update_paragraph_time_meta( - self, - paragraph_hash: str, - time_meta: Dict[str, Any], - ) -> bool: - """ - 更新段落时间元信息。 - """ - normalized = normalize_time_meta(time_meta) - if not normalized: - return False - source_to_rebuild = self._get_sources_for_paragraph_hashes( - [paragraph_hash], - include_deleted=True, - ) - - updates: List[str] = [] - params: List[Any] = [] - for key in [ - "event_time", - "event_time_start", - "event_time_end", - "time_granularity", - "time_confidence", - ]: - if key in normalized: - updates.append(f"{key} = ?") - params.append(normalized[key]) - - if not updates: - return False - - updates.append("updated_at = ?") - params.append(datetime.now().timestamp()) - params.append(paragraph_hash) - - cursor = self._conn.cursor() - cursor.execute( - f"UPDATE paragraphs SET {', '.join(updates)} WHERE hash = ?", - tuple(params), - ) - self._conn.commit() - changed = cursor.rowcount > 0 - if changed: - self._enqueue_episode_source_rebuilds( - source_to_rebuild, - reason="paragraph_time_updated", - ) - return changed - - def query_paragraphs_temporal( - self, - start_ts: Optional[float] = None, - end_ts: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - limit: int = 100, - allow_created_fallback: bool = True, - ) -> List[Dict[str, Any]]: - """ - 查询时序命中的段落(区间相交语义)。 - """ - if limit <= 0: - return [] - - effective_start = "COALESCE(p.event_time_start, p.event_time, p.event_time_end" - effective_end = "COALESCE(p.event_time_end, p.event_time, p.event_time_start" - if allow_created_fallback: - effective_start += ", p.created_at)" - effective_end += ", p.created_at)" - else: - effective_start += ")" - effective_end += ")" - - conditions = ["(p.is_deleted IS NULL OR p.is_deleted = 0)"] - params: List[Any] = [] - - if source: - conditions.append("p.source = ?") - params.append(source) - - if person: - conditions.append( - """ - EXISTS ( - SELECT 1 - FROM paragraph_entities pe - JOIN entities e ON e.hash = pe.entity_hash - WHERE pe.paragraph_hash = p.hash - AND LOWER(e.name) LIKE ? - ) - """ - ) - params.append(f"%{str(person).strip().lower()}%") - - if start_ts is not None and end_ts is not None: - conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") - params.extend([start_ts, end_ts]) - elif start_ts is not None: - conditions.append(f"({effective_end} >= ?)") - params.append(start_ts) - elif end_ts is not None: - conditions.append(f"({effective_start} <= ?)") - params.append(end_ts) - - where_sql = " AND ".join(conditions) - sql = f""" - SELECT p.* - FROM paragraphs p - WHERE {where_sql} - ORDER BY {effective_end} DESC, p.updated_at DESC - LIMIT ? - """ - params.append(limit) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_entity(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取实体 - - Args: - hash_value: 实体哈希 - - Returns: - 实体信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM entities WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "entity") - return None - - def get_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 获取关系 - - Args: - hash_value: 关系哈希 - - Returns: - 关系信息字典,不存在则返回None - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM relations WHERE hash = ? - """, (hash_value,)) - row = cursor.fetchone() - - if row: - return self._row_to_dict(row, "relation") - return None - - def get_paragraph_relations(self, paragraph_hash: str) -> List[Dict[str, Any]]: - """ - 获取段落的所有关系 - - Args: - paragraph_hash: 段落哈希 - - Returns: - 关系列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT r.* FROM relations r - JOIN paragraph_relations pr ON r.hash = pr.relation_hash - WHERE pr.paragraph_hash = ? - """, (paragraph_hash,)) - - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def get_paragraph_entities(self, paragraph_hash: str) -> List[Dict[str, Any]]: - """ - 获取段落的所有实体 - - Args: - paragraph_hash: 段落哈希 - - Returns: - 实体列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT e.*, pe.mention_count - FROM entities e - JOIN paragraph_entities pe ON e.hash = pe.entity_hash - WHERE pe.paragraph_hash = ? - """, (paragraph_hash,)) - - return [self._row_to_dict(row, "entity") for row in cursor.fetchall()] - - def get_paragraphs_by_entity(self, entity_name: str) -> List[Dict[str, Any]]: - """ - 获取包含指定实体的所有段落 (自动处理规范化) - - Args: - entity_name: 实体名称 (支持任意大小写) - - Returns: - 段落列表 - """ - # 1. 计算规范化 Hash - name_canon = self._canonicalize_name(entity_name) - if not name_canon: - return [] - - entity_hash = compute_hash(name_canon) - - cursor = self._conn.cursor() - # 2. 直接使用 Hash 查询中间表,完全避开 Name 匹配 - cursor.execute(""" - SELECT p.* - FROM paragraphs p - JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash - WHERE pe.entity_hash = ? - """, (entity_hash,)) - - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_relations( - self, - subject: Optional[str] = None, - predicate: Optional[str] = None, - object: Optional[str] = None, - ) -> List[Dict[str, Any]]: - """ - 查询关系(大小写不敏感) - - Args: - subject: 主语(可选) - predicate: 谓语(可选) - object: 宾语(可选) - - Returns: - 关系列表 - """ - # 构建查询条件 - conditions = [] - params = [] - - if subject: - conditions.append("LOWER(subject) = ?") - params.append(self._canonicalize_name(subject)) - if predicate: - conditions.append("LOWER(predicate) = ?") - params.append(self._canonicalize_name(predicate)) - if object: - conditions.append("LOWER(object) = ?") - params.append(self._canonicalize_name(object)) - - sql = "SELECT * FROM relations" - if conditions: - sql += " WHERE " + " AND ".join(conditions) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def get_all_triples(self) -> List[Tuple[str, str, str, str]]: - """ - 高效获取所有三元组 (subject, predicate, object, hash) - 直接返回元组,跳过字典转换和pickle反序列化,用于构建 V5 Map 缓存。 - """ - cursor = self._conn.cursor() - cursor.execute("SELECT subject, predicate, object, hash FROM relations") - return list(cursor.fetchall()) - - def get_paragraphs_by_relation(self, relation_hash: str) -> List[Dict[str, Any]]: - """ - 获取支持指定关系的所有段落 - - Args: - relation_hash: 关系哈希 - - Returns: - 段落列表 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT p.* - FROM paragraphs p - JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash - WHERE pr.relation_hash = ? - """, (relation_hash,)) - - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def get_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: - """ - 按来源获取段落 - - Args: - source: 来源标识符 - - Returns: - 段落列表 - """ - return self.query("SELECT * FROM paragraphs WHERE source = ?", (source,)) - - def get_all_sources(self) -> List[Dict[str, Any]]: - """ - 获取所有来源文件统计信息 - - Returns: - 来源列表 [{'source': 'name', 'count': int, 'last_updated': timestamp}] - """ - cursor = self._conn.cursor() - # 排除 source 为 NULL 或空的记录 - cursor.execute(""" - SELECT source, COUNT(*) as count, MAX(created_at) as last_updated - FROM paragraphs - WHERE source IS NOT NULL AND source != '' - GROUP BY source - ORDER BY last_updated DESC - """) - - results = [] - for row in cursor.fetchall(): - results.append({ - "source": row[0], - "count": row[1], - "last_updated": row[2] - }) - return results - - - def search_paragraphs_by_content(self, content_query: str) -> List[Dict[str, Any]]: - """按内容模糊搜索段落""" - cursor = self._conn.cursor() - cursor.execute(""" - SELECT * FROM paragraphs WHERE content LIKE ? - """, (f"%{content_query}%",)) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def delete_paragraph(self, hash_value: str) -> bool: - """ - 删除段落(级联删除相关关联) - - Args: - hash_value: 段落哈希 - - Returns: - 是否成功删除 - """ - cursor = self._conn.cursor() - cursor.execute(""" - DELETE FROM paragraphs WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - deleted = cursor.rowcount > 0 - if deleted: - logger.info(f"删除段落: {hash_value[:16]}...") - - return deleted - - def delete_entity(self, hash_or_name: str) -> bool: - """ - 删除实体(级联删除相关关联) - 支持通过哈希值或名称删除 - - 注意:会同时删除所有引用该实体(作为主语或宾语)的关系 - """ - cursor = self._conn.cursor() - - # 1. 解析实体信息 (获取 Name 和 Hash) - entity_name = None - entity_hash = None - - # 尝试作为 Hash 查询 - cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (hash_or_name,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - else: - # 尝试作为 Name 查询 (原始匹配) - cursor.execute("SELECT name, hash FROM entities WHERE name = ?", (hash_or_name,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - else: - # 最后的最后:尝试规范化名称 (Canonical) 查询,解决大小写或 WebUI 手动输入导致的不匹配 - name_canon = self._canonicalize_name(hash_or_name) - canon_hash = compute_hash(name_canon) - cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (canon_hash,)) - row = cursor.fetchone() - if row: - entity_name = row[0] - entity_hash = row[1] - - if not entity_name or not entity_hash: - logger.debug(f"删除实体请求跳过:未在元数据记录中找到 {hash_or_name}") - return False - - logger.info(f"开始删除实体: {entity_name} (Hash: {entity_hash[:8]}...)") - - try: - # 2. 查找相关关系 (Subject 或 Object 为该实体) - cursor.execute(""" - SELECT hash FROM relations - WHERE subject = ? OR object = ? - """, (entity_name, entity_name)) - - relation_hashes = [r[0] for r in cursor.fetchall()] - - if relation_hashes: - logger.info(f"发现 {len(relation_hashes)} 个相关关系,准备级联删除") - - # 3. 删除这些关系与段落的关联 - # SQLite 不支持直接 DELETE ... WHERE ... IN (...) 的列表参数,需要拼接占位符 - placeholders = ','.join(['?'] * len(relation_hashes)) - - cursor.execute(f""" - DELETE FROM paragraph_relations - WHERE relation_hash IN ({placeholders}) - """, relation_hashes) - - # 4. 删除关系本体 - cursor.execute(f""" - DELETE FROM relations - WHERE hash IN ({placeholders}) - """, relation_hashes) - - logger.info("相关关系已级联删除") - - # 5. 删除实体与段落的关联 - cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (entity_hash,)) - - # 6. 删除实体本体 - cursor.execute("DELETE FROM entities WHERE hash = ?", (entity_hash,)) - - self._conn.commit() - logger.info("实体删除完成") - return True - - except Exception as e: - logger.error(f"删除实体时发生错误: {e}") - self._conn.rollback() - return False - - def delete_relation(self, hash_value: str) -> bool: - """ - 删除关系(级联删除相关关联) - - Args: - hash_value: 关系哈希 - - Returns: - 是否成功删除 - """ - cursor = self._conn.cursor() - cursor.execute(""" - DELETE FROM relations WHERE hash = ? - """, (hash_value,)) - self._conn.commit() - - deleted = cursor.rowcount > 0 - if deleted: - logger.info(f"删除关系: {hash_value[:16]}...") - - return deleted - - def set_relation_vector_state( - self, - hash_value: str, - state: str, - error: Optional[str] = None, - bump_retry: bool = False, - ) -> bool: - """ - 更新关系向量状态。 - """ - state_norm = str(state or "").strip().lower() - if state_norm not in {"none", "pending", "ready", "failed"}: - raise ValueError(f"无效 vector_state: {state}") - - now = datetime.now().timestamp() - err_text = (str(error).strip() if error is not None else None) - if err_text: - err_text = err_text[:500] - clear_error = state_norm in {"none", "pending", "ready"} - - cursor = self._conn.cursor() - if bump_retry: - cursor.execute( - """ - UPDATE relations - SET vector_state = ?, - vector_updated_at = ?, - vector_error = ?, - vector_retry_count = COALESCE(vector_retry_count, 0) + 1 - WHERE hash = ? - """, - (state_norm, now, None if clear_error else err_text, hash_value), - ) - else: - cursor.execute( - """ - UPDATE relations - SET vector_state = ?, - vector_updated_at = ?, - vector_error = ? - WHERE hash = ? - """, - (state_norm, now, None if clear_error else err_text, hash_value), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_relations_by_vector_state( - self, - states: List[str], - limit: int = 200, - max_retry: Optional[int] = None, - ) -> List[Dict[str, Any]]: - """ - 根据向量状态列出关系,用于回填任务。 - """ - normalized_states = [ - str(s or "").strip().lower() - for s in (states or []) - if str(s or "").strip() - ] - normalized_states = [ - s for s in normalized_states - if s in {"none", "pending", "ready", "failed"} - ] - if not normalized_states: - return [] - - placeholders = ",".join(["?"] * len(normalized_states)) - params: List[Any] = list(normalized_states) - sql = f""" - SELECT hash, subject, predicate, object, confidence, source_paragraph, - vector_state, vector_updated_at, vector_error, vector_retry_count, created_at - FROM relations - WHERE vector_state IN ({placeholders}) - """ - if max_retry is not None: - sql += " AND COALESCE(vector_retry_count, 0) < ?" - params.append(int(max_retry)) - sql += " ORDER BY COALESCE(vector_updated_at, created_at, 0) ASC LIMIT ?" - params.append(max(1, int(limit))) - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params)) - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def count_relations_by_vector_state(self) -> Dict[str, int]: - """ - 统计关系向量状态分布。 - """ - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT COALESCE(vector_state, 'none') AS state, COUNT(*) AS cnt - FROM relations - GROUP BY COALESCE(vector_state, 'none') - """ - ) - result: Dict[str, int] = {"none": 0, "pending": 0, "ready": 0, "failed": 0} - total = 0 - for row in cursor.fetchall(): - state = str(row["state"] or "none").lower() - count = int(row["cnt"] or 0) - if state not in result: - result[state] = 0 - result[state] += count - total += count - result["total"] = total - return result - - def update_vector_index( - self, - item_type: str, - hash_value: str, - vector_index: int, - ) -> bool: - """ - 更新向量索引 - - Args: - item_type: 类型(paragraph/entity/relation) - hash_value: 哈希值 - vector_index: 向量索引 - - Returns: - 是否成功更新 - """ - valid_types = ["paragraph", "entity", "relation"] - if item_type not in valid_types: - raise ValueError(f"无效的类型: {item_type}") - - table_map = { - "paragraph": "paragraphs", - "entity": "entities", - "relation": "relations", - } - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET vector_index = ? - WHERE hash = ? - """, (vector_index, hash_value)) - self._conn.commit() - - return cursor.rowcount > 0 - - def set_permanence(self, hash_value: str, item_type: str, is_permanent: bool) -> bool: - """设置永久记忆标记""" - table_map = { - "paragraph": "paragraphs", - "relation": "relations", - } - if item_type not in table_map: - raise ValueError(f"类型 {item_type} 不支持设置永久性") - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET is_permanent = ? - WHERE hash = ? - """, (1 if is_permanent else 0, hash_value)) - self._conn.commit() - - if cursor.rowcount > 0: - logger.debug(f"设置永久记忆: {item_type}/{hash_value[:8]} -> {is_permanent}") - return True - return False - - def record_access(self, hash_value: str, item_type: str) -> bool: - """记录访问(更新时间和次数)""" - table_map = { - "paragraph": "paragraphs", - "relation": "relations", - } - if item_type not in table_map: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table_map[item_type]} - SET last_accessed = ?, access_count = access_count + 1 - WHERE hash = ? - """, (now, hash_value)) - self._conn.commit() - return cursor.rowcount > 0 - - def query( - self, - sql: str, - params: Optional[Tuple] = None, - ) -> List[Dict[str, Any]]: - """ - 执行自定义查询 - - Args: - sql: SQL语句 - params: 参数 - - Returns: - 查询结果列表 - """ - cursor = self._conn.cursor() - if params: - cursor.execute(sql, params) - else: - cursor.execute(sql) - - return [dict(row) for row in cursor.fetchall()] - - def get_external_memory_ref(self, external_id: str) -> Optional[Dict[str, Any]]: - """按 external_id 查询外部记忆映射。""" - token = str(external_id or "").strip() - if not token: - return None - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT external_id, paragraph_hash, source_type, created_at, metadata_json - FROM external_memory_refs - WHERE external_id = ? - LIMIT 1 - """, - (token,), - ) - row = cursor.fetchone() - if row is None: - return None - - payload = dict(row) - raw_metadata = payload.get("metadata_json") - if raw_metadata: - try: - payload["metadata"] = json.loads(raw_metadata) - except Exception: - payload["metadata"] = {} - else: - payload["metadata"] = {} - return payload - - def upsert_external_memory_ref( - self, - *, - external_id: str, - paragraph_hash: str, - source_type: str = "", - metadata: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - """注册 external_id 到段落哈希的幂等映射。""" - external_token = str(external_id or "").strip() - paragraph_token = str(paragraph_hash or "").strip() - if not external_token: - raise ValueError("external_id 不能为空") - if not paragraph_token: - raise ValueError("paragraph_hash 不能为空") - - now = datetime.now().timestamp() - metadata_json = json.dumps(metadata or {}, ensure_ascii=False) - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO external_memory_refs ( - external_id, paragraph_hash, source_type, created_at, metadata_json - ) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(external_id) DO UPDATE SET - paragraph_hash = excluded.paragraph_hash, - source_type = excluded.source_type, - metadata_json = excluded.metadata_json - """, - ( - external_token, - paragraph_token, - str(source_type or "").strip() or None, - now, - metadata_json, - ), - ) - self._conn.commit() - return self.get_external_memory_ref(external_token) or { - "external_id": external_token, - "paragraph_hash": paragraph_token, - "source_type": str(source_type or "").strip(), - "created_at": now, - "metadata": metadata or {}, - } - - @staticmethod - def _json_dumps(value: Any) -> str: - return json.dumps(value, ensure_ascii=False, sort_keys=True) - - @staticmethod - def _json_loads(value: Any, default: Any) -> Any: - if value in {None, ""}: - return default - try: - return json.loads(value) - except Exception: - return default - - def list_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: - hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] - if not hashes: - return [] - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT external_id, paragraph_hash, source_type, created_at, metadata_json - FROM external_memory_refs - WHERE paragraph_hash IN ({placeholders}) - ORDER BY created_at ASC, external_id ASC - """, - tuple(hashes), - ) - items: List[Dict[str, Any]] = [] - for row in cursor.fetchall(): - payload = dict(row) - payload["metadata"] = self._json_loads(payload.get("metadata_json"), {}) - items.append(payload) - return items - - def delete_external_memory_refs_by_paragraphs(self, paragraph_hashes: List[str]) -> List[Dict[str, Any]]: - items = self.list_external_memory_refs_by_paragraphs(paragraph_hashes) - hashes = [str(item or "").strip() for item in (paragraph_hashes or []) if str(item or "").strip()] - if not hashes: - return items - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute( - f"DELETE FROM external_memory_refs WHERE paragraph_hash IN ({placeholders})", - tuple(hashes), - ) - self._conn.commit() - return items - - def restore_external_memory_refs(self, refs: List[Dict[str, Any]]) -> int: - count = 0 - for item in refs or []: - external_id = str(item.get("external_id", "") or "").strip() - paragraph_hash = str(item.get("paragraph_hash", "") or "").strip() - if not external_id or not paragraph_hash: - continue - created_at = float(item.get("created_at") or datetime.now().timestamp()) - metadata_json = self._json_dumps(item.get("metadata") or {}) - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO external_memory_refs ( - external_id, paragraph_hash, source_type, created_at, metadata_json - ) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(external_id) DO UPDATE SET - paragraph_hash = excluded.paragraph_hash, - source_type = excluded.source_type, - created_at = excluded.created_at, - metadata_json = excluded.metadata_json - """, - ( - external_id, - paragraph_hash, - str(item.get("source_type", "") or "").strip() or None, - created_at, - metadata_json, - ), - ) - count += max(0, int(cursor.rowcount or 0)) - self._conn.commit() - return count - - def record_v5_operation( - self, - *, - action: str, - target: str, - resolved_hashes: List[str], - reason: str = "", - updated_by: str = "", - result: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - operation_id = f"v5_{uuid.uuid4().hex}" - created_at = datetime.now().timestamp() - payload = { - "operation_id": operation_id, - "action": str(action or "").strip(), - "target": str(target or "").strip(), - "reason": str(reason or "").strip(), - "updated_by": str(updated_by or "").strip(), - "created_at": created_at, - "resolved_hashes": [str(item or "").strip() for item in (resolved_hashes or []) if str(item or "").strip()], - "result": result or {}, - } - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO memory_v5_operations ( - operation_id, action, target, reason, updated_by, created_at, resolved_hashes_json, result_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - operation_id, - payload["action"], - payload["target"] or None, - payload["reason"] or None, - payload["updated_by"] or None, - created_at, - self._json_dumps(payload["resolved_hashes"]), - self._json_dumps(payload["result"]), - ), - ) - self._conn.commit() - return payload - - def create_delete_operation( - self, - *, - mode: str, - selector: Any, - items: List[Dict[str, Any]], - reason: str = "", - requested_by: str = "", - status: str = "executed", - summary: Optional[Dict[str, Any]] = None, - operation_id: Optional[str] = None, - ) -> Dict[str, Any]: - op_id = str(operation_id or f"del_{uuid.uuid4().hex}").strip() - created_at = datetime.now().timestamp() - normalized_items: List[Dict[str, Any]] = [] - for item in items or []: - if not isinstance(item, dict): - continue - item_type = str(item.get("item_type", "") or "").strip() - if not item_type: - continue - normalized_items.append( - { - "item_type": item_type, - "item_hash": str(item.get("item_hash", "") or "").strip() or None, - "item_key": str(item.get("item_key", "") or item.get("item_hash", "") or "").strip() or None, - "payload": item.get("payload") if isinstance(item.get("payload"), dict) else {}, - } - ) - - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO delete_operations ( - operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - ) VALUES (?, ?, ?, ?, ?, ?, ?, NULL, ?) - """, - ( - op_id, - str(mode or "").strip(), - self._json_dumps(selector if selector is not None else {}), - str(reason or "").strip() or None, - str(requested_by or "").strip() or None, - str(status or "executed").strip(), - created_at, - self._json_dumps(summary or {}), - ), - ) - if normalized_items: - cursor.executemany( - """ - INSERT INTO delete_operation_items ( - operation_id, item_type, item_hash, item_key, payload_json, created_at - ) VALUES (?, ?, ?, ?, ?, ?) - """, - [ - ( - op_id, - item["item_type"], - item["item_hash"], - item["item_key"], - self._json_dumps(item["payload"]), - created_at, - ) - for item in normalized_items - ], - ) - self._conn.commit() - return self.get_delete_operation(op_id) or { - "operation_id": op_id, - "mode": str(mode or "").strip(), - "selector": selector, - "reason": str(reason or "").strip(), - "requested_by": str(requested_by or "").strip(), - "status": str(status or "executed").strip(), - "created_at": created_at, - "summary": summary or {}, - "items": normalized_items, - } - - def mark_delete_operation_restored( - self, - operation_id: str, - *, - summary: Optional[Dict[str, Any]] = None, - ) -> bool: - token = str(operation_id or "").strip() - if not token: - return False - cursor = self._conn.cursor() - cursor.execute( - """ - UPDATE delete_operations - SET status = ?, restored_at = ?, summary_json = ? - WHERE operation_id = ? - """, - ( - "restored", - datetime.now().timestamp(), - self._json_dumps(summary or {}), - token, - ), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_delete_operations(self, *, limit: int = 50, mode: str = "") -> List[Dict[str, Any]]: - cursor = self._conn.cursor() - params: List[Any] = [] - where = "" - mode_token = str(mode or "").strip().lower() - if mode_token: - where = "WHERE LOWER(mode) = ?" - params.append(mode_token) - params.append(max(1, int(limit or 50))) - cursor.execute( - f""" - SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - FROM delete_operations - {where} - ORDER BY created_at DESC - LIMIT ? - """, - tuple(params), - ) - items: List[Dict[str, Any]] = [] - for row in cursor.fetchall(): - payload = dict(row) - payload["selector"] = self._json_loads(payload.get("selector"), {}) - payload["summary"] = self._json_loads(payload.get("summary_json"), {}) - items.append(payload) - return items - - def get_delete_operation(self, operation_id: str) -> Optional[Dict[str, Any]]: - token = str(operation_id or "").strip() - if not token: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT operation_id, mode, selector, reason, requested_by, status, created_at, restored_at, summary_json - FROM delete_operations - WHERE operation_id = ? - LIMIT 1 - """, - (token,), - ) - row = cursor.fetchone() - if row is None: - return None - - payload = dict(row) - payload["selector"] = self._json_loads(payload.get("selector"), {}) - payload["summary"] = self._json_loads(payload.get("summary_json"), {}) - - cursor.execute( - """ - SELECT item_type, item_hash, item_key, payload_json, created_at - FROM delete_operation_items - WHERE operation_id = ? - ORDER BY id ASC - """, - (token,), - ) - payload["items"] = [ - { - "item_type": str(item["item_type"] or ""), - "item_hash": str(item["item_hash"] or ""), - "item_key": str(item["item_key"] or ""), - "payload": self._json_loads(item["payload_json"], {}), - "created_at": item["created_at"], - } - for item in cursor.fetchall() - ] - return payload - - def purge_deleted_relations(self, *, cutoff_time: float, limit: int = 1000) -> List[str]: - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT hash - FROM deleted_relations - WHERE deleted_at IS NOT NULL AND deleted_at < ? - ORDER BY deleted_at ASC - LIMIT ? - """, - (float(cutoff_time), max(1, int(limit or 1000))), - ) - hashes = [str(row[0] or "").strip() for row in cursor.fetchall() if str(row[0] or "").strip()] - if not hashes: - return [] - placeholders = ",".join(["?"] * len(hashes)) - cursor.execute(f"DELETE FROM deleted_relations WHERE hash IN ({placeholders})", tuple(hashes)) - self._conn.commit() - return hashes - - def get_statistics(self) -> Dict[str, int]: - """ - 获取统计信息 - - Returns: - 统计信息字典 - """ - cursor = self._conn.cursor() - - stats = {} - - # 段落数量 - cursor.execute("SELECT COUNT(*) FROM paragraphs") - stats["paragraph_count"] = cursor.fetchone()[0] - - # 实体数量 - cursor.execute("SELECT COUNT(*) FROM entities") - stats["entity_count"] = cursor.fetchone()[0] - - # 关系数量 - cursor.execute("SELECT COUNT(*) FROM relations") - stats["relation_count"] = cursor.fetchone()[0] - - # 总词数 - cursor.execute("SELECT SUM(word_count) FROM paragraphs") - result = cursor.fetchone()[0] - stats["total_words"] = result if result else 0 - - return stats - - def count_paragraphs(self, include_deleted: bool = False, only_deleted: bool = False) -> int: - """ - 获取段落数量 - """ - cursor = self._conn.cursor() - if only_deleted: - cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 1") - return cursor.fetchone()[0] - if include_deleted: - cursor.execute("SELECT COUNT(*) FROM paragraphs") - return cursor.fetchone()[0] - cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 0") - return cursor.fetchone()[0] - - def count_relations(self, include_deleted: bool = False, only_deleted: bool = False) -> int: - """ - 获取关系数量 - """ - cursor = self._conn.cursor() - if only_deleted: - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - return cursor.fetchone()[0] - cursor.execute("SELECT COUNT(*) FROM relations") - active_count = cursor.fetchone()[0] - if not include_deleted: - return active_count - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - deleted_count = cursor.fetchone()[0] - return int(active_count) + int(deleted_count) - - def count_entities(self) -> int: - """ - 获取实体数量 - - Returns: - 实体数量 - """ - cursor = self._conn.cursor() - cursor.execute("SELECT COUNT(*) FROM entities") - return cursor.fetchone()[0] - - def get_knowledge_type_distribution(self) -> Dict[str, int]: - """获取段落知识类型分布。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT knowledge_type, COUNT(*) as count - FROM paragraphs - WHERE is_deleted = 0 - GROUP BY knowledge_type - """ - ) - result: Dict[str, int] = {} - for row in cursor.fetchall(): - type_name = row[0] if row[0] else "未分类" - result[str(type_name)] = int(row[1] or 0) - return result - - def get_memory_status_summary(self, now_ts: Optional[float] = None) -> Dict[str, int]: - """聚合 memory status 统计。""" - now_ts = float(now_ts) if now_ts is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 0") - active_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 1") - inactive_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM deleted_relations") - deleted_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE is_pinned = 1") - pinned_count = int(cursor.fetchone()[0] or 0) - cursor.execute("SELECT COUNT(*) FROM relations WHERE protected_until > ?", (now_ts,)) - ttl_count = int(cursor.fetchone()[0] or 0) - return { - "active_count": active_count, - "inactive_count": inactive_count, - "deleted_count": deleted_count, - "pinned_count": pinned_count, - "temp_protected_count": ttl_count, - } - - def get_relations_subject_object_map(self, hashes: List[str]) -> Dict[str, Tuple[str, str]]: - """批量获取关系 hash 对应的 (subject, object)。""" - if not hashes: - return {} - cursor = self._conn.cursor() - placeholders = ",".join(["?"] * len(hashes)) - cursor.execute( - f"SELECT hash, subject, object FROM relations WHERE hash IN ({placeholders})", - hashes, - ) - return {str(row[0]): (str(row[1]), str(row[2])) for row in cursor.fetchall()} - - def get_connection(self) -> sqlite3.Connection: - """公开连接访问(用于离线脚本),替代外部访问私有字段。""" - return self._resolve_conn() - - def get_relation_db_snapshot(self) -> Tuple[int, float, str]: - """返回关系快照:(relation_count, max_created_at, max_hash)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT - COUNT(*) AS relation_count, - COALESCE(MAX(created_at), 0) AS max_created_at, - COALESCE(MAX(hash), '') AS max_hash - FROM relations - """ - ) - row = cursor.fetchone() - if not row: - return (0, 0.0, "") - return ( - int(row[0] or 0), - float(row[1] or 0.0), - str(row[2] or ""), - ) - - def is_entity_still_referenced(self, entity_hash: str, entity_name: str = "") -> bool: - """ - 判断实体是否仍被引用: - 1) 被 paragraph_entities 引用 - 2) 在 relations.subject/object 中出现 - """ - token_hash = str(entity_hash or "").strip() - if token_hash: - cursor = self._conn.cursor() - cursor.execute( - "SELECT 1 FROM paragraph_entities WHERE entity_hash = ? LIMIT 1", - (token_hash,), - ) - if cursor.fetchone() is not None: - return True - - canon_name = self._canonicalize_name(entity_name) - if canon_name: - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT 1 - FROM relations - WHERE LOWER(TRIM(subject)) = ? OR LOWER(TRIM(object)) = ? - LIMIT 1 - """, - (canon_name, canon_name), - ) - if cursor.fetchone() is not None: - return True - return False - - def search_relations_by_subject_or_object( - self, - query: str, - *, - limit: int = 5, - include_deleted: bool = False, - ) -> List[Dict[str, Any]]: - """按 subject/object 模糊查询关系。""" - q = str(query or "").strip() - if not q: - return [] - max_limit = int(max(1, limit)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT * - FROM relations - WHERE subject LIKE ? OR object LIKE ? - LIMIT ? - """, - (f"%{q}%", f"%{q}%", max_limit), - ) - rows = [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - if rows or not include_deleted: - return rows - - cursor.execute( - """ - SELECT * - FROM deleted_relations - WHERE subject LIKE ? OR object LIKE ? - LIMIT ? - """, - (f"%{q}%", f"%{q}%", max_limit), - ) - return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] - - def list_hashes(self, table: str) -> List[str]: - """安全枚举指定表的 hash 列。""" - allowed = {"paragraphs", "entities", "relations", "deleted_relations"} - token = str(table or "").strip().lower() - if token not in allowed: - raise ValueError(f"unsupported table for list_hashes: {table}") - cursor = self._conn.cursor() - cursor.execute(f"SELECT hash FROM {token}") - return [str(row[0]) for row in cursor.fetchall()] - - def get_orphan_deleted_relation_hashes(self, limit: int = 200) -> List[str]: - """获取 deleted_relations 中已不在 relations 的孤儿 hash。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT d.hash - FROM deleted_relations d - LEFT JOIN relations r ON r.hash = d.hash - WHERE r.hash IS NULL - LIMIT ? - """, - (int(max(1, limit)),), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def resolve_relation_hash_alias( - self, - value: str, - *, - include_deleted: bool = False, - ) -> List[str]: - """ - 解析关系哈希输入: - - 64位:直接校验存在性 - - 32位:通过 relation_hash_aliases 唯一映射 - """ - token = str(value or "").strip().lower() - if not token: - return [] - if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token): - cursor = self._conn.cursor() - cursor.execute("SELECT 1 FROM relations WHERE hash = ? LIMIT 1", (token,)) - if cursor.fetchone(): - return [token] - if include_deleted: - cursor.execute("SELECT 1 FROM deleted_relations WHERE hash = ? LIMIT 1", (token,)) - if cursor.fetchone(): - return [token] - return [] - - if len(token) != 32 or not all(ch in "0123456789abcdef" for ch in token): - return [] - - cursor = self._conn.cursor() - cursor.execute("SELECT hash FROM relation_hash_aliases WHERE alias32 = ?", (token,)) - row = cursor.fetchone() - if not row: - return [] - resolved = str(row[0]) - return [resolved] - - def rebuild_relation_hash_aliases(self) -> Dict[str, Any]: - """重建 32 位 relation hash 别名映射。""" - cursor = self._conn.cursor() - # 历史库兜底:缺表时先创建,避免迁移过程直接中断。 - cursor.execute(""" - CREATE TABLE IF NOT EXISTS relation_hash_aliases ( - alias32 TEXT PRIMARY KEY, - hash TEXT NOT NULL - ) - """) - cursor.execute("DELETE FROM relation_hash_aliases") - - cursor.execute("SELECT hash FROM relations") - hashes = [str(r[0]) for r in cursor.fetchall()] - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='deleted_relations'" - ) - has_deleted_relations = cursor.fetchone() is not None - if has_deleted_relations: - cursor.execute("SELECT hash FROM deleted_relations") - hashes.extend(str(r[0]) for r in cursor.fetchall()) - - alias_map: Dict[str, str] = {} - conflicts: Dict[str, set[str]] = {} - for h in hashes: - if len(h) != 64: - continue - alias = h[:32] - old = alias_map.get(alias) - if old is None: - alias_map[alias] = h - elif old != h: - conflicts.setdefault(alias, set()).update({old, h}) - - for alias, full_hash in alias_map.items(): - if alias in conflicts: - continue - cursor.execute( - "INSERT INTO relation_hash_aliases(alias32, hash) VALUES (?, ?)", - (alias, full_hash), - ) - self._conn.commit() - return { - "inserted": len(alias_map) - len(conflicts), - "conflict_count": len(conflicts), - "conflicts": sorted(conflicts.keys()), - } - - def search_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: - """按 relation 内容模糊查询 hash。""" - q = str(query or "").strip() - if not q: - return [] - cursor = self._conn.cursor() - cursor.execute( - "SELECT hash FROM relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", - (f"%{q}%", f"%{q}%", int(max(1, limit))), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def search_deleted_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: - """按 deleted_relations 内容模糊查询 hash。""" - q = str(query or "").strip() - if not q: - return [] - cursor = self._conn.cursor() - cursor.execute( - "SELECT hash FROM deleted_relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", - (f"%{q}%", f"%{q}%", int(max(1, limit))), - ) - return [str(row[0]) for row in cursor.fetchall()] - - def restore_entity_by_hash(self, entity_hash: str) -> bool: - """恢复软删除实体。""" - cursor = self._conn.cursor() - cursor.execute( - "UPDATE entities SET is_deleted=0, deleted_at=NULL WHERE hash=?", - (str(entity_hash),), - ) - changed = cursor.rowcount > 0 - if changed: - self._conn.commit() - return changed - - def restore_paragraph_by_hash(self, paragraph_hash: str) -> bool: - """恢复软删除段落。""" - cursor = self._conn.cursor() - cursor.execute( - "UPDATE paragraphs SET is_deleted=0, deleted_at=NULL WHERE hash=?", - (str(paragraph_hash),), - ) - changed = cursor.rowcount > 0 - if changed: - self._conn.commit() - return changed - - def backfill_temporal_metadata_from_created_at( - self, - *, - limit: int = 100000, - dry_run: bool = False, - no_created_fallback: bool = False, - ) -> Dict[str, int]: - """回填段落 event_time 字段(created_at 兜底)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT hash, created_at, source - FROM paragraphs - WHERE (event_time IS NULL AND event_time_start IS NULL AND event_time_end IS NULL) - ORDER BY created_at DESC - LIMIT ? - """, - (int(max(1, limit)),), - ) - rows = cursor.fetchall() - candidates = len(rows) - if dry_run: - return {"candidates": candidates, "updated": 0} - if no_created_fallback: - return {"candidates": candidates, "updated": 0} - - updated = 0 - touched_sources: List[str] = [] - for row in rows: - created_at = row["created_at"] - if created_at is None: - continue - cursor.execute( - """ - UPDATE paragraphs - SET event_time = ?, time_granularity = ?, time_confidence = ?, updated_at = ? - WHERE hash = ? - """, - (float(created_at), "day", 0.2, float(created_at), row["hash"]), - ) - if cursor.rowcount > 0: - updated += 1 - touched_sources.append(row["source"]) - self._conn.commit() - if updated > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_time_backfill", - ) - return {"candidates": candidates, "updated": updated} - - def get_schema_version(self) -> int: - cursor = self._conn.cursor() - cursor.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" - ) - if cursor.fetchone() is None: - return 0 - cursor.execute("SELECT MAX(version) FROM schema_migrations") - row = cursor.fetchone() - return int(row[0]) if row and row[0] is not None else 0 - - def set_schema_version(self, version: int = SCHEMA_VERSION) -> None: - cursor = self._conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at REAL NOT NULL)" - ) - cursor.execute( - "INSERT OR REPLACE INTO schema_migrations(version, applied_at) VALUES (?, ?)", - (int(version), datetime.now().timestamp()), - ) - self._conn.commit() - - def delete_paragraph_atomic(self, paragraph_hash: str) -> Dict[str, Any]: - """ - 两阶段删除段落:DB 事务内计算 + 提交后执行清理 - - Args: - paragraph_hash: 段落哈希 - - Returns: - cleanup_plan: 包含需要后续从 Vector/GraphStore 中移除的 ID 列表 - """ - cleanup_plan = { - "paragraph_hash": paragraph_hash, - "vector_id_to_remove": None, - "edges_to_remove": [], # (src, tgt) 元组列表 (fallback) - "relation_prune_ops": [], # (subject, object, relation_hash) 精准裁剪 - "episode_sources_to_rebuild": [], - } - - cursor = self._conn.cursor() - try: - # === Phase 1: DB Transaction (可回滚) === - # 使用 IMMEDIATE 模式,一旦开启事务立即锁定 DB (防止其他写操作插队导致幻读) - cursor.execute("BEGIN IMMEDIATE") - - # 1. [快照] 获取候选关系 - cursor.execute("SELECT relation_hash FROM paragraph_relations WHERE paragraph_hash = ?", (paragraph_hash,)) - candidate_relations = [row[0] for row in cursor.fetchall()] - - # 2. [快照] 确认该段落存在并记录 ID 用于向量删除 - cursor.execute("SELECT hash, source FROM paragraphs WHERE hash = ?", (paragraph_hash,)) - paragraph_row = cursor.fetchone() - if paragraph_row: - cleanup_plan["vector_id_to_remove"] = paragraph_hash - cleanup_plan["episode_sources_to_rebuild"] = self._dedupe_episode_sources( - [paragraph_row["source"]] - ) - - # 3. [主删除] 删除段落 (触发 CASCADE 删 paragraph_relations) - cursor.execute("DELETE FROM paragraphs WHERE hash = ?", (paragraph_hash,)) - - # 4. [计算孤儿] - orphaned_hashes = [] - for rel_hash in candidate_relations: - count = cursor.execute( - "SELECT count(*) FROM paragraph_relations WHERE relation_hash = ?", - (rel_hash,) - ).fetchone()[0] - - if count == 0: - # 是孤儿:记录边信息以便后续删 Graph - cursor.execute("SELECT subject, object FROM relations WHERE hash = ?", (rel_hash,)) - rel_info = cursor.fetchone() - if rel_info: - s_val, o_val = rel_info[0], rel_info[1] - cleanup_plan["relation_prune_ops"].append((s_val, o_val, rel_hash)) - - # 仅当 (subject, object) 不再有任何关系时,才计划删整条边(兼容旧实现)。 - sibling_count = cursor.execute( - """ - SELECT count(*) FROM relations - WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) - AND LOWER(TRIM(object)) = LOWER(TRIM(?)) - AND hash != ? - """, - (s_val, o_val, rel_hash) - ).fetchone()[0] - if sibling_count == 0: - cleanup_plan["edges_to_remove"].append((s_val, o_val)) - - orphaned_hashes.append(rel_hash) - - # 5. [DB清理] 删除孤儿关系记录 - if orphaned_hashes: - placeholders = ','.join(['?'] * len(orphaned_hashes)) - cursor.execute(f"DELETE FROM relations WHERE hash IN ({placeholders})", orphaned_hashes) - - self._conn.commit() - if cleanup_plan["episode_sources_to_rebuild"]: - self._enqueue_episode_source_rebuilds( - cleanup_plan["episode_sources_to_rebuild"], - reason="paragraph_deleted", - ) - if cleanup_plan["vector_id_to_remove"]: - logger.debug(f"原子删除段落成功: {paragraph_hash}, 计划清理 {len(orphaned_hashes)} 个孤儿关系") - return cleanup_plan - - except Exception as e: - self._conn.rollback() - logger.error(f"DB Transaction failed: {e}") - raise e - - - def clear_all(self) -> None: - """清空所有表数据""" - cursor = self._conn.cursor() - tables = [ - "paragraphs", "entities", "relations", - "paragraph_relations", "paragraph_entities", - "episodes", "episode_paragraphs", - "episode_rebuild_sources", "episode_pending_paragraphs", - ] - for table in tables: - cursor.execute(f"DELETE FROM {table}") - self._conn.commit() - logger.info("元数据存储所有表已清空") - - - - def update_relation_timestamp(self, hash_value: str, access_count_delta: int = 1) -> None: - """更新关系的访问时间和计数""" - now = datetime.now().timestamp() - - # 同时更新 last_accessed (旧) 和 last_reinforced (V5) - - cursor = self._conn.cursor() - cursor.execute(""" - UPDATE relations - SET last_accessed = ?, - access_count = access_count + ? - WHERE hash = ? - """, (now, access_count_delta, hash_value)) - self._conn.commit() - - # ========================================================================= - # V5 Memory System Methods - # ========================================================================= - - def get_relation_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: - """ - 批量获取关系状态 (V5) - - Args: - hashes: 关系哈希列表 - - Returns: - Dict[hash, status_dict] - status_dict 包含: is_inactive, weight(confidence), is_pinned, protected_until, last_reinforced, inactive_since - """ - if not hashes: - return {} - - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT hash, is_inactive, confidence, is_pinned, protected_until, last_reinforced, inactive_since - FROM relations - WHERE hash IN ({placeholders}) - """, hashes) - - result = {} - for row in cursor.fetchall(): - result[row["hash"]] = { - "is_inactive": bool(row["is_inactive"]), - "weight": row["confidence"], - "is_pinned": bool(row["is_pinned"]), - "protected_until": row["protected_until"], - "last_reinforced": row["last_reinforced"], - "inactive_since": row["inactive_since"] - } - return result - - def mark_relations_active(self, hashes: List[str], boost_weight: Optional[float] = None) -> None: - """ - 批量标记关系为活跃 (Active/Revive) - - Args: - hashes: 关系哈希列表 - boost_weight: 如果提供,将设置 confidence = max(confidence, boost_weight) - """ - if not hashes: - return - - placeholders = ",".join(["?"] * len(hashes)) - cursor = self._conn.cursor() - - if boost_weight is not None: - cursor.execute(f""" - UPDATE relations - SET is_inactive = 0, - inactive_since = NULL, - confidence = MAX(confidence, ?) - WHERE hash IN ({placeholders}) - """, (boost_weight, *hashes)) - else: - cursor.execute(f""" - UPDATE relations - SET is_inactive = 0, - inactive_since = NULL - WHERE hash IN ({placeholders}) - """, hashes) - - self._conn.commit() - - def update_relations_protection( - self, - hashes: List[str], - protected_until: Optional[float] = None, - is_pinned: Optional[bool] = None, - last_reinforced: Optional[float] = None - ) -> None: - """ - 批量更新关系保护状态 - """ - if not hashes: - return - - updates = [] - params = [] - - if protected_until is not None: - updates.append("protected_until = ?") - params.append(protected_until) - if is_pinned is not None: - updates.append("is_pinned = ?") - params.append(1 if is_pinned else 0) - if last_reinforced is not None: - updates.append("last_reinforced = ?") - params.append(last_reinforced) - - if not updates: - return - - sql_set = ", ".join(updates) - placeholders = ",".join(["?"] * len(hashes)) - - params.extend(hashes) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE relations - SET {sql_set} - WHERE hash IN ({placeholders}) - """, params) - self._conn.commit() - - def get_prune_candidates(self, cutoff_time: float, limit: int = 1000) -> List[str]: - """ - 获取待修剪候选 (已过冷冻保留期) - - Args: - cutoff_time: 截止时间 (now - 冷冻时长) - limit: 限制数量 - """ - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash FROM relations - WHERE is_inactive = 1 - AND inactive_since < ? - LIMIT ? - """, (cutoff_time, limit)) - return [row[0] for row in cursor.fetchall()] - - def backup_and_delete_relations(self, hashes: List[str]) -> int: - """ - 备份并删除关系 (Prune) - - Returns: - 删除的数量 - """ - if not hashes: - return 0 - - placeholders = ",".join(["?"] * len(hashes)) - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - try: - # 1. 备份 - cursor.execute(f""" - INSERT OR REPLACE INTO deleted_relations - (hash, subject, predicate, object, vector_index, confidence, created_at, - vector_state, vector_updated_at, vector_error, vector_retry_count, - source_paragraph, metadata, is_permanent, last_accessed, access_count, - is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, deleted_at) - SELECT - hash, subject, predicate, object, vector_index, confidence, created_at, - vector_state, vector_updated_at, vector_error, vector_retry_count, - source_paragraph, metadata, is_permanent, last_accessed, access_count, - is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, ? - FROM relations - WHERE hash IN ({placeholders}) - """, (now, *hashes)) - - # 2. 删除 (级联删除会自动处理 paragraph_relations 关联) - cursor.execute(f""" - DELETE FROM relations - WHERE hash IN ({placeholders}) - """, hashes) - - deleted_count = cursor.rowcount - self._conn.commit() - return deleted_count - - except Exception as e: - logger.error(f"备份删除失败: {e}") - self._conn.rollback() - return 0 - - def restore_relation_metadata(self, hash_value: str) -> Optional[Dict[str, Any]]: - """ - 从回收站恢复关系元数据 - - Returns: - 恢复后的关系数据 (字典),失败返回 None - """ - cursor = self._conn.cursor() - try: - # 1. 查询备份数据 - cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) - row = cursor.fetchone() - if not row: - return None - - data = dict(row) - # 移除 deleted_at 字段 - if "deleted_at" in data: - del data["deleted_at"] - - # 2. 插入回 relations 表 - # 动态构建 SQL 以适应字段变化 - columns = list(data.keys()) - placeholders = ",".join(["?"] * len(columns)) - cols_str = ",".join(columns) - values = list(data.values()) - - cursor.execute(f""" - INSERT OR REPLACE INTO relations ({cols_str}) - VALUES ({placeholders}) - """, values) - - # 3. 从备份表删除 - cursor.execute("DELETE FROM deleted_relations WHERE hash = ?", (hash_value,)) - - self._conn.commit() - return self._row_to_dict(row, "relation") # 使用助手函数将原始行转换为字典 - - except Exception as e: - logger.error(f"恢复关系失败: {hash_value} - {e}") - self._conn.rollback() - return None - - def restore_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """兼容旧调用名:恢复关系。""" - return self.restore_relation_metadata(hash_value) - - def get_protected_relations_hashes(self) -> List[str]: - """获取所有受保护关系的哈希 (Pinned 或 Protected Until > Now)""" - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash FROM relations - WHERE is_pinned = 1 OR protected_until > ? - """, (now,)) - - return [row[0] for row in cursor.fetchall()] - - - - def get_deleted_relations(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取回收站中的关系记录""" - cursor = self._conn.cursor() - cursor.execute("SELECT * FROM deleted_relations ORDER BY deleted_at DESC LIMIT ?", (limit,)) - data = [] - for row in cursor.fetchall(): - d = dict(row) - # 是否需要解码元数据?是的,与普通行相同 - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - data.append(d) - return data - - def get_deleted_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: - """获取单条回收站记录""" - cursor = self._conn.cursor() - cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) - row = cursor.fetchone() - if not row: return None - - d = dict(row) - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - return d - - def reinforce_relations(self, hashes: List[str]) -> None: - """强化关系 (更新 last_reinforced, is_inactive=0)""" - if not hashes: return - now = datetime.now().timestamp() - - cursor = self._conn.cursor() - # Batch update? chunking - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - sql = f""" - UPDATE relations - SET last_reinforced = ?, is_inactive = 0, inactive_since = NULL - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [now] + chunk) - - self._conn.commit() - - def mark_relations_inactive(self, hashes: List[str], inactive_since: Optional[float] = None) -> None: - """标记关系为非活跃 (Freeze)。兼容显式 inactive_since 或默认当前时间。""" - if not hashes: - return - mark_time = inactive_since if inactive_since is not None else datetime.now().timestamp() - - cursor = self._conn.cursor() - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - sql = f""" - UPDATE relations - SET is_inactive = 1, inactive_since = ? - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [mark_time] + chunk) - - self._conn.commit() - - def protect_relations( - self, - hashes: List[str], - is_pinned: bool = False, - ttl_seconds: float = 0 - ) -> None: - """ - 设置保护状态 - """ - if not hashes: return - now = datetime.now().timestamp() - protected_until = (now + ttl_seconds) if ttl_seconds > 0 else 0 - - cursor = self._conn.cursor() - chunk_size = 500 - for i in range(0, len(hashes), chunk_size): - chunk = hashes[i:i+chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - - # 由于 is_pinned 和 protected_until 是分开的,如果请求固定(pin),我们会同时更新这两项, - # 但通常用户要么切换固定状态,要么设置 TTL。 - # 如果 is_pinned=True,TTL 通常就不重要了。 - # 但目前的逻辑是正交处理它们的。 - - # 如果用户取消固定 (is_pinned=False),我们是否应该尊重已设置的 TTL? - # 当前的 API 会同时设置这两项。 - - sql = f""" - UPDATE relations - SET is_pinned = ?, protected_until = ? - WHERE hash IN ({placeholders}) - """ - cursor.execute(sql, [is_pinned, protected_until] + chunk) - - self._conn.commit() - - def vacuum(self) -> None: - """优化数据库""" - cursor = self._conn.cursor() - cursor.execute("VACUUM") - self._conn.commit() - logger.info("数据库优化完成") - - def _row_to_dict(self, row: sqlite3.Row, row_type: str) -> Dict[str, Any]: - """ - 将数据库行转换为字典 - - Args: - row: 数据库行 - row_type: 行类型 - - Returns: - 字典 - """ - d = dict(row) - - # 解码pickle字段 - if "metadata" in d and d["metadata"]: - try: - d["metadata"] = pickle.loads(d["metadata"]) - except Exception: - d["metadata"] = {} - - return d - - @property - def is_connected(self) -> bool: - """是否已连接""" - return self._conn is not None - - def __enter__(self): - """上下文管理器入口""" - self.connect() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """上下文管理器出口""" - self.close() - - # ========================================================================= - # V5 Soft Delete & Garbage Collection - # ========================================================================= - - def get_entity_gc_candidates(self, isolated_hashes: List[str], retention_seconds: float) -> List[str]: - """ - 获取实体 GC 候选列表 (Soft Delete Candidates) - 条件: - 1. 在 isolated_hashes 列表中 (由 GraphStore 提供;通常是实体名称) - 2. is_deleted = 0 (未被标记) - 3. created_at < now - retention (过了新手保护期) - 4. 不被任何 active paragraph 引用 (paragraph_entities check) - - Args: - isolated_hashes: 孤儿实体名称列表(兼容传入 hash) - retention_seconds: 保留时间 (秒) - """ - if not isolated_hashes: - return [] - - # GraphStore.get_isolated_nodes 返回节点名,这里做 canonicalize -> entity hash 映射。 - # 同时兼容历史调用直接传 hash。 - normalized_hashes: List[str] = [] - for item in isolated_hashes: - if not item: - continue - v = str(item).strip() - if len(v) == 64 and all(c in "0123456789abcdefABCDEF" for c in v): - normalized_hashes.append(v.lower()) - else: - canon = self._canonicalize_name(v) - if canon: - normalized_hashes.append(compute_hash(canon)) - - normalized_hashes = list(dict.fromkeys(normalized_hashes)) - if not normalized_hashes: - return [] - - now = datetime.now().timestamp() - cutoff = now - retention_seconds - - candidates = [] - batch_size = 900 - - # 分批处理 IN 查询 - for i in range(0, len(normalized_hashes), batch_size): - batch = normalized_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - # 使用 NOT EXISTS 子查询检查引用 - # 注意: paragraph_entities 中引用的 paragraph 如果被软删了,是否算引用? - # 这里的语义: 只要有 rows 存在于 paragraph_entities 且该 row 对应的 paragraph 没被彻底物理删除,就算引用。 - # 更严格: ... OR (EXISTS ... AND entity_hash=... AND is_deleted=0) - # 但 paragraph_entities 表没有 is_deleted 字段(它是关联表). 我们检查关联是否存在。 - # 如果 paragraph 本身 soft deleted, 它的引用应该失效吗? - # 策略: 只有当 paragraph 也是 active 时,引用才有效。 - # JOIN paragraphs p ON pe.paragraph_hash = p.hash WHERE p.is_deleted = 0 - - query = f""" - SELECT e.hash FROM entities e - WHERE e.hash IN ({placeholders}) - AND e.is_deleted = 0 - AND (e.created_at IS NULL OR e.created_at < ?) - AND NOT EXISTS ( - SELECT 1 FROM paragraph_entities pe - JOIN paragraphs p ON pe.paragraph_hash = p.hash - WHERE pe.entity_hash = e.hash - AND p.is_deleted = 0 - ) - """ - - cursor = self._conn.cursor() - cursor.execute(query, [*batch, cutoff]) - candidates.extend([row[0] for row in cursor.fetchall()]) - - return candidates - - def get_paragraph_gc_candidates(self, retention_seconds: float) -> List[str]: - """ - 获取段落 GC 候选列表 - 条件: - 1. is_deleted = 0 - 2. created_at < cutoff - 3. 没有 Relations (paragraph_relations empty) - 4. 没有 Entities 引用 (paragraph_entities empty) - OR 引用的 Entities 全是软删状态? (太复杂,简单点: 无引用) - - Refined Strategy: - 段落孤儿判定 = - (Left Join paragraph_relations -> NULL) AND - (Left Join paragraph_entities -> NULL) - """ - now = datetime.now().timestamp() - cutoff = now - retention_seconds - - query = """ - SELECT p.hash FROM paragraphs p - LEFT JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash - LEFT JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash - WHERE p.is_deleted = 0 - AND (p.created_at IS NULL OR p.created_at < ?) - AND pr.relation_hash IS NULL - AND pe.entity_hash IS NULL - """ - - cursor = self._conn.cursor() - cursor.execute(query, (cutoff,)) - return [row[0] for row in cursor.fetchall()] - - def mark_as_deleted(self, hashes: List[str], type_: str) -> int: - """ - 标记为软删除 (Mark Phase) - - Args: - hashes: Hash 列表 - type_: 'entity' | 'paragraph' - """ - if not hashes: - return 0 - - table = "entities" if type_ == "entity" else "paragraphs" - now = datetime.now().timestamp() - touched_sources: List[str] = [] - if type_ == "paragraph": - touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - # 幂等更新: 只更那些 is_deleted=0 的 - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE {table} - SET is_deleted = 1, deleted_at = ? - WHERE is_deleted = 0 AND hash IN ({placeholders}) - """, [now] + batch) - count += cursor.rowcount - - self._conn.commit() - if type_ == "paragraph" and count > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_soft_deleted", - ) - if count > 0: - logger.info(f"软删除标记 ({table}): {count} 项") - return count - - def sweep_deleted_items(self, type_: str, grace_period_seconds: float) -> List[Tuple[str, str]]: - """ - 扫描可物理清理的项目 (Sweep Phase - Selection) - - Args: - type_: 'entity' | 'paragraph' - grace_period_seconds: 宽限期 - - Returns: - List[(hash, name)]: 待删除项列表 (paragraph name为空) - """ - table = "entities" if type_ == "entity" else "paragraphs" - now = datetime.now().timestamp() - cutoff = now - grace_period_seconds - - cols = "hash, name" if type_ == "entity" else "hash, '' as name" - - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT {cols} FROM {table} - WHERE is_deleted = 1 - AND deleted_at < ? - """, (cutoff,)) - - return [(row[0], row[1]) for row in cursor.fetchall()] - - def physically_delete_entities(self, hashes: List[str]) -> int: - """物理删除实体 (批量)""" - if not hashes: return 0 - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f"DELETE FROM entities WHERE hash IN ({placeholders})", batch) - count += cursor.rowcount - - self._conn.commit() - return count - - def physically_delete_paragraphs(self, hashes: List[str]) -> int: - """物理删除段落 (批量)""" - if not hashes: return 0 - touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) - - count = 0 - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f"DELETE FROM paragraphs WHERE hash IN ({placeholders})", batch) - count += cursor.rowcount - - self._conn.commit() - if count > 0: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_physically_deleted", - ) - return count - - def revive_if_deleted(self, entity_hashes: List[str] = None, paragraph_hashes: List[str] = None) -> int: - """ - 复活已软删的项目 (Auto Revival) - 当数据被再次访问、引用或导入时调用。 - """ - count = 0 - - if entity_hashes: - batch_size = 900 - for i in range(0, len(entity_hashes), batch_size): - batch = entity_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE entities - SET is_deleted = 0, deleted_at = NULL - WHERE is_deleted = 1 AND hash IN ({placeholders}) - """, batch) - count += cursor.rowcount - - if paragraph_hashes: - touched_sources = self._get_sources_for_paragraph_hashes(paragraph_hashes, include_deleted=True) - batch_size = 900 - for i in range(0, len(paragraph_hashes), batch_size): - batch = paragraph_hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - UPDATE paragraphs - SET is_deleted = 0, deleted_at = NULL - WHERE is_deleted = 1 AND hash IN ({placeholders}) - """, batch) - count += cursor.rowcount - else: - touched_sources = [] - - if count > 0: - self._conn.commit() - if touched_sources: - self._enqueue_episode_source_rebuilds( - touched_sources, - reason="paragraph_revived", - ) - logger.info(f"自动复活: {count} 项 (Soft Delete Revived)") - - return count - - def revive_entities_by_names(self, names: List[str]) -> int: - """ - 根据名称复活实体 (Convenience wrapper) - """ - if not names: return 0 - - # 使用内部方法计算哈希 - hashes = [compute_hash(self._canonicalize_name(n)) for n in names] - return self.revive_if_deleted(entity_hashes=hashes) - - def get_entity_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: - """批量获取实体状态 (WebUI用)""" - if not hashes: return {} - - result = {} - batch_size = 900 - for i in range(0, len(hashes), batch_size): - batch = hashes[i:i+batch_size] - placeholders = ",".join(["?"] * len(batch)) - - cursor = self._conn.cursor() - cursor.execute(f""" - SELECT hash, is_deleted, deleted_at - FROM entities - WHERE hash IN ({placeholders}) - """, batch) - - for row in cursor.fetchall(): - result[row[0]] = { - "is_deleted": bool(row[1]), - "deleted_at": row[2] - } - return result - - # ========================================================================= - # Person Profile (问题3) - Switches / Active Set / Snapshots - # ========================================================================= - - def set_person_profile_switch( - self, - stream_id: str, - user_id: str, - enabled: bool, - updated_at: Optional[float] = None, - ) -> None: - """设置人物画像自动注入开关(按 stream_id + user_id)。""" - if not stream_id or not user_id: - raise ValueError("stream_id 和 user_id 不能为空") - - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_switches (stream_id, user_id, enabled, updated_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(stream_id, user_id) DO UPDATE SET - enabled = excluded.enabled, - updated_at = excluded.updated_at - """, - (str(stream_id), str(user_id), 1 if enabled else 0, ts), - ) - self._conn.commit() - - def get_person_profile_switch(self, stream_id: str, user_id: str, default: bool = False) -> bool: - """读取人物画像自动注入开关。""" - if not stream_id or not user_id: - return bool(default) - - cursor = self._conn.cursor() - cursor.execute( - "SELECT enabled FROM person_profile_switches WHERE stream_id = ? AND user_id = ?", - (str(stream_id), str(user_id)), - ) - row = cursor.fetchone() - if not row: - return bool(default) - return bool(row[0]) - - def get_enabled_person_profile_switches(self, limit: int = 1000) -> List[Dict[str, Any]]: - """获取已开启人物画像注入开关的会话范围。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT stream_id, user_id, enabled, updated_at - FROM person_profile_switches - WHERE enabled = 1 - ORDER BY updated_at DESC - LIMIT ? - """, - (int(max(1, limit)),), - ) - return [ - { - "stream_id": row[0], - "user_id": row[1], - "enabled": bool(row[2]), - "updated_at": row[3], - } - for row in cursor.fetchall() - ] - - def mark_person_profile_active( - self, - stream_id: str, - user_id: str, - person_id: str, - seen_at: Optional[float] = None, - ) -> None: - """记录活跃人物(用于定时按需刷新)。""" - if not stream_id or not user_id or not person_id: - return - ts = float(seen_at) if seen_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_active_persons (stream_id, user_id, person_id, last_seen_at) - VALUES (?, ?, ?, ?) - ON CONFLICT(stream_id, user_id, person_id) DO UPDATE SET - last_seen_at = excluded.last_seen_at - """, - (str(stream_id), str(user_id), str(person_id), ts), - ) - self._conn.commit() - - def get_active_person_ids_for_enabled_switches( - self, - active_after: Optional[float] = None, - limit: int = 200, - ) -> List[str]: - """获取“已开启开关范围内”的活跃人物集合。""" - cursor = self._conn.cursor() - sql = """ - SELECT a.person_id, MAX(a.last_seen_at) AS last_seen - FROM person_profile_active_persons a - JOIN person_profile_switches s - ON a.stream_id = s.stream_id AND a.user_id = s.user_id - WHERE s.enabled = 1 - """ - params: List[Any] = [] - if active_after is not None: - sql += " AND a.last_seen_at >= ?" - params.append(float(active_after)) - sql += """ - GROUP BY a.person_id - ORDER BY last_seen DESC - LIMIT ? - """ - params.append(int(max(1, limit))) - cursor.execute(sql, tuple(params)) - return [str(row[0]) for row in cursor.fetchall() if row and row[0]] - - def get_latest_person_profile_snapshot(self, person_id: str) -> Optional[Dict[str, Any]]: - """获取人物最新画像快照。""" - if not person_id: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT - snapshot_id, person_id, profile_version, profile_text, - aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, - updated_at, expires_at, source_note - FROM person_profile_snapshots - WHERE person_id = ? - ORDER BY profile_version DESC - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - if not row: - return None - - def _load_list(raw: Any) -> List[Any]: - if not raw: - return [] - try: - data = json.loads(raw) - return data if isinstance(data, list) else [] - except Exception: - return [] - - return { - "snapshot_id": row[0], - "person_id": row[1], - "profile_version": int(row[2]), - "profile_text": row[3] or "", - "aliases": _load_list(row[4]), - "relation_edges": _load_list(row[5]), - "vector_evidence": _load_list(row[6]), - "evidence_ids": _load_list(row[7]), - "updated_at": row[8], - "expires_at": row[9], - "source_note": row[10] or "", - } - - def upsert_person_profile_snapshot( - self, - person_id: str, - profile_text: str, - aliases: Optional[List[str]] = None, - relation_edges: Optional[List[Dict[str, Any]]] = None, - vector_evidence: Optional[List[Dict[str, Any]]] = None, - evidence_ids: Optional[List[str]] = None, - expires_at: Optional[float] = None, - source_note: str = "", - updated_at: Optional[float] = None, - ) -> Dict[str, Any]: - """写入人物画像快照(按 person_id 自动递增版本)。""" - if not person_id: - raise ValueError("person_id 不能为空") - - aliases = aliases or [] - relation_edges = relation_edges or [] - vector_evidence = vector_evidence or [] - evidence_ids = evidence_ids or [] - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT profile_version - FROM person_profile_snapshots - WHERE person_id = ? - ORDER BY profile_version DESC - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - next_version = int(row[0]) + 1 if row else 1 - - cursor.execute( - """ - INSERT INTO person_profile_snapshots ( - person_id, profile_version, profile_text, - aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, - updated_at, expires_at, source_note - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - str(person_id), - next_version, - str(profile_text or ""), - json.dumps(aliases, ensure_ascii=False), - json.dumps(relation_edges, ensure_ascii=False), - json.dumps(vector_evidence, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - ts, - float(expires_at) if expires_at is not None else None, - str(source_note or ""), - ), - ) - self._conn.commit() - latest = self.get_latest_person_profile_snapshot(person_id) - return latest or { - "person_id": person_id, - "profile_version": next_version, - "profile_text": str(profile_text or ""), - "aliases": aliases, - "relation_edges": relation_edges, - "vector_evidence": vector_evidence, - "evidence_ids": evidence_ids, - "updated_at": ts, - "expires_at": expires_at, - "source_note": source_note, - } - - def get_person_profile_override(self, person_id: str) -> Optional[Dict[str, Any]]: - """获取人物画像手工覆盖内容。""" - if not person_id: - return None - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT person_id, override_text, updated_at, updated_by, source - FROM person_profile_overrides - WHERE person_id = ? - LIMIT 1 - """, - (str(person_id),), - ) - row = cursor.fetchone() - if not row: - return None - return { - "person_id": str(row[0]), - "override_text": str(row[1] or ""), - "updated_at": row[2], - "updated_by": str(row[3] or ""), - "source": str(row[4] or ""), - } - - def set_person_profile_override( - self, - person_id: str, - override_text: str, - updated_by: str = "", - source: str = "webui", - updated_at: Optional[float] = None, - ) -> Dict[str, Any]: - """写入人物画像手工覆盖;空文本等价于清除覆盖。""" - if not person_id: - raise ValueError("person_id 不能为空") - - text = str(override_text or "").strip() - if not text: - self.delete_person_profile_override(person_id) - return { - "person_id": str(person_id), - "override_text": "", - "updated_at": None, - "updated_by": str(updated_by or ""), - "source": str(source or ""), - } - - ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO person_profile_overrides ( - person_id, override_text, updated_at, updated_by, source - ) VALUES (?, ?, ?, ?, ?) - ON CONFLICT(person_id) DO UPDATE SET - override_text = excluded.override_text, - updated_at = excluded.updated_at, - updated_by = excluded.updated_by, - source = excluded.source - """, - ( - str(person_id), - text, - ts, - str(updated_by or ""), - str(source or ""), - ), - ) - self._conn.commit() - return self.get_person_profile_override(person_id) or { - "person_id": str(person_id), - "override_text": text, - "updated_at": ts, - "updated_by": str(updated_by or ""), - "source": str(source or ""), - } - - def delete_person_profile_override(self, person_id: str) -> bool: - """删除人物画像手工覆盖。""" - if not person_id: - return False - cursor = self._conn.cursor() - cursor.execute( - "DELETE FROM person_profile_overrides WHERE person_id = ?", - (str(person_id),), - ) - self._conn.commit() - return cursor.rowcount > 0 - - # ========================================================================= - # Episode MVP - # ========================================================================= - - def enqueue_episode_source_rebuild(self, source: str, reason: str = "") -> bool: - """将 source 入队到 episode 重建队列。""" - return bool(self._enqueue_episode_source_rebuilds([source], reason=reason)) - - def fetch_episode_source_rebuild_batch( - self, - limit: int = 20, - max_retry: int = 3, - ) -> List[Dict[str, Any]]: - """获取待处理的 source 重建任务。""" - safe_limit = max(1, int(limit)) - safe_retry = max(0, int(max_retry)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT source, status, retry_count, last_error, reason, requested_at, updated_at - FROM episode_rebuild_sources - WHERE status = 'pending' - OR (status = 'failed' AND retry_count < ?) - ORDER BY requested_at ASC, updated_at ASC - LIMIT ? - """, - (safe_retry, safe_limit), - ) - return [dict(row) for row in cursor.fetchall()] - - def mark_episode_source_running( - self, - source: str, - *, - requested_at: Optional[float] = None, - ) -> bool: - """将 source 标记为 running。""" - token = self._normalize_episode_source(source) - if not token: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - params: List[Any] = [now, token] - sql = """ - UPDATE episode_rebuild_sources - SET status = 'running', - updated_at = ? - WHERE source = ? - AND status IN ('pending', 'failed') - """ - if requested_at is not None: - sql += " AND requested_at = ?" - params.append(float(requested_at)) - cursor.execute(sql, tuple(params)) - self._conn.commit() - return cursor.rowcount > 0 - - def mark_episode_source_done( - self, - source: str, - *, - requested_at: Optional[float] = None, - ) -> bool: - """将 source 标记为 done;若运行期间发生新写入,则保持 pending。""" - token = self._normalize_episode_source(source) - if not token: - return False - - now = datetime.now().timestamp() - cursor = self._conn.cursor() - if requested_at is None: - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = 'done', - last_error = NULL, - updated_at = ? - WHERE source = ? - """, - (now, token), - ) - else: - req_ts = float(requested_at) - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = CASE - WHEN requested_at > ? THEN 'pending' - ELSE 'done' - END, - last_error = NULL, - updated_at = ? - WHERE source = ? - """, - (req_ts, now, token), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def mark_episode_source_failed( - self, - source: str, - error: str = "", - *, - requested_at: Optional[float] = None, - ) -> bool: - """标记 source 失败;若运行期间发生新写入,则重新回到 pending。""" - token = self._normalize_episode_source(source) - if not token: - return False - - err_text = str(error or "").strip()[:500] - now = datetime.now().timestamp() - cursor = self._conn.cursor() - if requested_at is None: - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = 'failed', - retry_count = COALESCE(retry_count, 0) + 1, - last_error = ?, - updated_at = ? - WHERE source = ? - """, - (err_text, now, token), - ) - else: - req_ts = float(requested_at) - cursor.execute( - """ - UPDATE episode_rebuild_sources - SET status = CASE - WHEN requested_at > ? THEN 'pending' - ELSE 'failed' - END, - retry_count = CASE - WHEN requested_at > ? THEN COALESCE(retry_count, 0) - ELSE COALESCE(retry_count, 0) + 1 - END, - last_error = CASE - WHEN requested_at > ? THEN NULL - ELSE ? - END, - updated_at = ? - WHERE source = ? - """, - (req_ts, req_ts, req_ts, err_text, now, token), - ) - self._conn.commit() - return cursor.rowcount > 0 - - def list_episode_source_rebuilds( - self, - *, - statuses: Optional[List[str]] = None, - limit: int = 100, - ) -> List[Dict[str, Any]]: - """列出 source 重建状态。""" - safe_limit = max(1, int(limit)) - params: List[Any] = [] - conditions: List[str] = [] - normalized_statuses = [ - str(item or "").strip().lower() - for item in (statuses or []) - if str(item or "").strip().lower() in {"pending", "running", "done", "failed"} - ] - if normalized_statuses: - placeholders = ",".join(["?"] * len(normalized_statuses)) - conditions.append(f"status IN ({placeholders})") - params.extend(normalized_statuses) - - where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" - params.append(safe_limit) - cursor = self._conn.cursor() - cursor.execute( - f""" - SELECT source, status, retry_count, last_error, reason, requested_at, updated_at - FROM episode_rebuild_sources - {where_sql} - ORDER BY updated_at DESC, source ASC - LIMIT ? - """, - tuple(params), - ) - return [dict(row) for row in cursor.fetchall()] - - def get_episode_source_rebuild_summary(self, failed_limit: int = 20) -> Dict[str, Any]: - """汇总 source 重建队列状态。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT status, COUNT(*) AS cnt - FROM episode_rebuild_sources - GROUP BY status - """ - ) - counts = {"pending": 0, "running": 0, "done": 0, "failed": 0, "total": 0} - for row in cursor.fetchall(): - status = str(row["status"] or "").strip().lower() - cnt = int(row["cnt"] or 0) - counts[status] = counts.get(status, 0) + cnt - counts["total"] += cnt - - running = self.list_episode_source_rebuilds(statuses=["running"], limit=20) - failed = self.list_episode_source_rebuilds( - statuses=["failed"], - limit=max(1, int(failed_limit)), - ) - return { - "counts": counts, - "running": running, - "failed": failed, - } - - def get_live_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: - """获取指定 source 下所有 live paragraphs。""" - token = self._normalize_episode_source(source) - if not token: - return [] - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT * - FROM paragraphs - WHERE TRIM(COALESCE(source, '')) = ? - AND (is_deleted IS NULL OR is_deleted = 0) - ORDER BY created_at ASC, hash ASC - """, - (token,), - ) - return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] - - def list_episode_sources_for_rebuild(self) -> List[str]: - """列出全量重建涉及的 source(live paragraphs + stale episodes)。""" - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT DISTINCT source - FROM ( - SELECT TRIM(source) AS source - FROM paragraphs - WHERE TRIM(COALESCE(source, '')) != '' - AND (is_deleted IS NULL OR is_deleted = 0) - UNION - SELECT TRIM(source) AS source - FROM episodes - WHERE TRIM(COALESCE(source, '')) != '' - ) - WHERE TRIM(COALESCE(source, '')) != '' - ORDER BY source ASC - """ - ) - return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) - - def is_episode_source_query_blocked(self, source: str) -> bool: - """判断 source 是否处于重建中或失败状态。""" - token = self._normalize_episode_source(source) - if not token: - return False - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT 1 - FROM episode_rebuild_sources - WHERE source = ? - AND status IN ('pending', 'running', 'failed') - LIMIT 1 - """, - (token,), - ) - return cursor.fetchone() is not None - - def replace_episodes_for_source( - self, - source: str, - episodes_payloads: List[Dict[str, Any]], - ) -> Dict[str, Any]: - """按 source 全量替换 episode 结果。""" - token = self._normalize_episode_source(source) - if not token: - return {"source": "", "episode_count": 0} - - payloads = [dict(item) for item in (episodes_payloads or []) if isinstance(item, dict)] - now = datetime.now().timestamp() - cursor = self._conn.cursor() - - try: - cursor.execute("BEGIN IMMEDIATE") - cursor.execute( - """ - SELECT episode_id, created_at - FROM episodes - WHERE TRIM(COALESCE(source, '')) = ? - """, - (token,), - ) - existing_created_at = { - str(row["episode_id"]): self._as_optional_float(row["created_at"]) - for row in cursor.fetchall() - } - - cursor.execute( - "DELETE FROM episodes WHERE TRIM(COALESCE(source, '')) = ?", - (token,), - ) - - inserted_count = 0 - for raw_payload in payloads: - title = str(raw_payload.get("title", "") or "").strip() - summary = str(raw_payload.get("summary", "") or "").strip() - evidence_ids = [ - str(item).strip() - for item in (raw_payload.get("evidence_ids") or []) - if str(item).strip() - ] - evidence_ids = list(dict.fromkeys(evidence_ids)) - if not title or not summary or not evidence_ids: - continue - - episode_id = str(raw_payload.get("episode_id", "") or "").strip() - if not episode_id: - seed = json.dumps( - { - "source": token, - "title": title, - "summary": summary, - "event_time_start": raw_payload.get("event_time_start"), - "event_time_end": raw_payload.get("event_time_end"), - "evidence_ids": evidence_ids, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - participants = [ - str(item).strip() - for item in (raw_payload.get("participants") or []) - if str(item).strip() - ][:16] - keywords = [ - str(item).strip() - for item in (raw_payload.get("keywords") or []) - if str(item).strip() - ][:20] - paragraph_count = raw_payload.get("paragraph_count", len(evidence_ids)) - try: - paragraph_count = max(0, int(paragraph_count)) - except Exception: - paragraph_count = len(evidence_ids) - if paragraph_count <= 0: - paragraph_count = len(evidence_ids) - if paragraph_count <= 0: - continue - - time_confidence = raw_payload.get("time_confidence", 1.0) - llm_confidence = raw_payload.get("llm_confidence", 0.0) - try: - time_confidence = float(time_confidence) - except Exception: - time_confidence = 1.0 - try: - llm_confidence = float(llm_confidence) - except Exception: - llm_confidence = 0.0 - - created_at = existing_created_at.get(episode_id) - created_ts = created_at if created_at is not None else now - updated_ts = self._as_optional_float(raw_payload.get("updated_at")) or now - - cursor.execute( - """ - INSERT INTO episodes ( - episode_id, source, title, summary, - event_time_start, event_time_end, time_granularity, time_confidence, - participants_json, keywords_json, evidence_ids_json, - paragraph_count, llm_confidence, segmentation_model, segmentation_version, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - ( - episode_id, - token, - title[:120], - summary[:2000], - self._as_optional_float(raw_payload.get("event_time_start")), - self._as_optional_float(raw_payload.get("event_time_end")), - str(raw_payload.get("time_granularity", "") or "").strip() or None, - time_confidence, - json.dumps(participants, ensure_ascii=False), - json.dumps(keywords, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - paragraph_count, - llm_confidence, - str(raw_payload.get("segmentation_model", "") or "").strip() or None, - str(raw_payload.get("segmentation_version", "") or "").strip() or None, - created_ts, - updated_ts, - ), - ) - cursor.executemany( - """ - INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) - VALUES (?, ?, ?) - """, - [(episode_id, hash_value, idx) for idx, hash_value in enumerate(evidence_ids)], - ) - inserted_count += 1 - - self._conn.commit() - return {"source": token, "episode_count": inserted_count} - except Exception: - self._conn.rollback() - raise - - def enqueue_episode_pending( - self, - paragraph_hash: str, - source: Optional[str] = None, - created_at: Optional[float] = None, - ) -> None: - """将段落入队到 episode 异步生成队列。""" - token = str(paragraph_hash or "").strip() - if not token: - return - now = datetime.now().timestamp() - created_ts = float(created_at) if created_at is not None else now - src = str(source or "").strip() or None - - cursor = self._conn.cursor() - cursor.execute( - """ - INSERT INTO episode_pending_paragraphs ( - paragraph_hash, source, created_at, status, retry_count, last_error, updated_at - ) VALUES (?, ?, ?, 'pending', 0, NULL, ?) - ON CONFLICT(paragraph_hash) DO UPDATE SET - source = excluded.source, - created_at = COALESCE(episode_pending_paragraphs.created_at, excluded.created_at), - status = CASE - WHEN episode_pending_paragraphs.status = 'done' THEN 'done' - ELSE 'pending' - END, - last_error = CASE - WHEN episode_pending_paragraphs.status = 'done' THEN episode_pending_paragraphs.last_error - ELSE NULL - END, - updated_at = excluded.updated_at - """, - (token, src, created_ts, now), - ) - self._conn.commit() - - def fetch_episode_pending_batch(self, limit: int = 20, max_retry: int = 3) -> List[Dict[str, Any]]: - """获取待处理 episode 队列批次。""" - safe_limit = max(1, int(limit)) - safe_retry = max(0, int(max_retry)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT paragraph_hash, source, created_at, status, retry_count, last_error, updated_at - FROM episode_pending_paragraphs - WHERE status = 'pending' - OR (status = 'failed' AND retry_count < ?) - ORDER BY updated_at ASC - LIMIT ? - """, - (safe_retry, safe_limit), - ) - return [dict(row) for row in cursor.fetchall()] - - def mark_episode_pending_running(self, hashes: List[str]) -> None: - """批量标记队列项为 running。""" - if not hashes: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - chunk_size = 500 - uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) - for i in range(0, len(uniq), chunk_size): - chunk = uniq[i:i + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE episode_pending_paragraphs - SET status = 'running', updated_at = ? - WHERE paragraph_hash IN ({placeholders}) - AND status IN ('pending', 'failed') - """, - [now] + chunk, - ) - self._conn.commit() - - def mark_episode_pending_done(self, hashes: List[str]) -> None: - """批量标记队列项为 done。""" - if not hashes: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - chunk_size = 500 - uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) - for i in range(0, len(uniq), chunk_size): - chunk = uniq[i:i + chunk_size] - placeholders = ",".join(["?"] * len(chunk)) - cursor.execute( - f""" - UPDATE episode_pending_paragraphs - SET status = 'done', - last_error = NULL, - updated_at = ? - WHERE paragraph_hash IN ({placeholders}) - """, - [now] + chunk, - ) - self._conn.commit() - - def mark_episode_pending_failed(self, hash_value: str, error: str = "") -> None: - """标记单条队列项失败并累加重试次数。""" - token = str(hash_value or "").strip() - if not token: - return - now = datetime.now().timestamp() - cursor = self._conn.cursor() - cursor.execute( - """ - UPDATE episode_pending_paragraphs - SET status = 'failed', - retry_count = COALESCE(retry_count, 0) + 1, - last_error = ?, - updated_at = ? - WHERE paragraph_hash = ? - """, - (str(error or ""), now, token), - ) - self._conn.commit() - - def get_episode_pending_status_counts(self, source: str) -> Dict[str, int]: - """统计某个 source 当前 pending 队列中的状态分布。""" - token = self._normalize_episode_source(source) - if not token: - return {"pending": 0, "running": 0, "failed": 0, "done": 0} - - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT status, COUNT(*) AS count - FROM episode_pending_paragraphs - WHERE TRIM(COALESCE(source, '')) = ? - GROUP BY status - """, - (token,), - ) - counts = {"pending": 0, "running": 0, "failed": 0, "done": 0} - for row in cursor.fetchall(): - status = str(row["status"] or "").strip().lower() - if status in counts: - counts[status] = int(row["count"] or 0) - return counts - - def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: - data = dict(row) - - def _load_list(raw: Any) -> List[Any]: - if not raw: - return [] - try: - val = json.loads(raw) - return val if isinstance(val, list) else [] - except Exception: - return [] - - data["participants"] = _load_list(data.pop("participants_json", None)) - data["keywords"] = _load_list(data.pop("keywords_json", None)) - data["evidence_ids"] = _load_list(data.pop("evidence_ids_json", None)) - return data - - @staticmethod - def _as_optional_float(value: Any) -> Optional[float]: - if value is None: - return None - try: - return float(value) - except Exception: - return None - - def upsert_episode(self, payload: Dict[str, Any]) -> Dict[str, Any]: - """写入或更新 episode。""" - if not isinstance(payload, dict): - raise ValueError("payload 必须是字典") - - title = str(payload.get("title", "") or "").strip() - summary = str(payload.get("summary", "") or "").strip() - if not title: - raise ValueError("episode.title 不能为空") - if not summary: - raise ValueError("episode.summary 不能为空") - - source = str(payload.get("source", "") or "").strip() or None - participants_raw = payload.get("participants", []) or [] - keywords_raw = payload.get("keywords", []) or [] - evidence_ids_raw = payload.get("evidence_ids", []) or [] - participants = [str(x).strip() for x in participants_raw if str(x).strip()] - keywords = [str(x).strip() for x in keywords_raw if str(x).strip()] - evidence_ids = [str(x).strip() for x in evidence_ids_raw if str(x).strip()] - - now = datetime.now().timestamp() - created_at = self._as_optional_float(payload.get("created_at")) - updated_at = self._as_optional_float(payload.get("updated_at")) - created_ts = created_at if created_at is not None else now - updated_ts = updated_at if updated_at is not None else now - - episode_id = str(payload.get("episode_id", "") or "").strip() - if not episode_id: - seed = json.dumps( - { - "source": source, - "title": title, - "summary": summary, - "event_time_start": payload.get("event_time_start"), - "event_time_end": payload.get("event_time_end"), - "evidence_ids": evidence_ids, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - paragraph_count = payload.get("paragraph_count") - if paragraph_count is None: - paragraph_count = len(evidence_ids) - try: - paragraph_count = int(paragraph_count) - except Exception: - paragraph_count = len(evidence_ids) - - time_conf = payload.get("time_confidence", 1.0) - llm_conf = payload.get("llm_confidence", 0.0) - try: - time_conf = float(time_conf) - except Exception: - time_conf = 1.0 - try: - llm_conf = float(llm_conf) - except Exception: - llm_conf = 0.0 - - cursor = self._conn.cursor() - cursor.execute( - "SELECT created_at FROM episodes WHERE episode_id = ? LIMIT 1", - (episode_id,), - ) - existed = cursor.fetchone() - if existed and existed[0] is not None: - created_ts = float(existed[0]) - - cursor.execute( - """ - INSERT INTO episodes ( - episode_id, source, title, summary, - event_time_start, event_time_end, time_granularity, time_confidence, - participants_json, keywords_json, evidence_ids_json, - paragraph_count, llm_confidence, segmentation_model, segmentation_version, - created_at, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(episode_id) DO UPDATE SET - source = excluded.source, - title = excluded.title, - summary = excluded.summary, - event_time_start = excluded.event_time_start, - event_time_end = excluded.event_time_end, - time_granularity = excluded.time_granularity, - time_confidence = excluded.time_confidence, - participants_json = excluded.participants_json, - keywords_json = excluded.keywords_json, - evidence_ids_json = excluded.evidence_ids_json, - paragraph_count = excluded.paragraph_count, - llm_confidence = excluded.llm_confidence, - segmentation_model = excluded.segmentation_model, - segmentation_version = excluded.segmentation_version, - updated_at = excluded.updated_at - """, - ( - episode_id, - source, - title, - summary, - self._as_optional_float(payload.get("event_time_start")), - self._as_optional_float(payload.get("event_time_end")), - str(payload.get("time_granularity", "") or "").strip() or None, - time_conf, - json.dumps(participants, ensure_ascii=False), - json.dumps(keywords, ensure_ascii=False), - json.dumps(evidence_ids, ensure_ascii=False), - max(0, paragraph_count), - llm_conf, - str(payload.get("segmentation_model", "") or "").strip() or None, - str(payload.get("segmentation_version", "") or "").strip() or None, - created_ts, - updated_ts, - ), - ) - self._conn.commit() - return self.get_episode_by_id(episode_id) or {"episode_id": episode_id} - - def bind_episode_paragraphs(self, episode_id: str, paragraph_hashes_ordered: List[str]) -> int: - """重建 episode 与段落映射。""" - token = str(episode_id or "").strip() - if not token: - raise ValueError("episode_id 不能为空") - - normalized: List[str] = [] - seen = set() - for item in paragraph_hashes_ordered or []: - h = str(item or "").strip() - if not h or h in seen: - continue - seen.add(h) - normalized.append(h) - - cursor = self._conn.cursor() - cursor.execute("DELETE FROM episode_paragraphs WHERE episode_id = ?", (token,)) - - if normalized: - cursor.executemany( - """ - INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) - VALUES (?, ?, ?) - """, - [(token, h, idx) for idx, h in enumerate(normalized)], - ) - - now = datetime.now().timestamp() - cursor.execute( - """ - UPDATE episodes - SET paragraph_count = ?, updated_at = ? - WHERE episode_id = ? - """, - (len(normalized), now, token), - ) - self._conn.commit() - return len(normalized) - - def _build_episode_query_components( - self, - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> Tuple[str, str, str, List[str], List[Any]]: - source_expr = "TRIM(COALESCE(e.source, ''))" - effective_start = "COALESCE(e.event_time_start, e.event_time_end, e.updated_at)" - effective_end = "COALESCE(e.event_time_end, e.event_time_start, e.updated_at)" - conditions: List[str] = [] - params: List[Any] = [] - - conditions.append(f"{source_expr} != ''") - conditions.append("COALESCE(e.paragraph_count, 0) > 0") - conditions.append( - """ - NOT EXISTS ( - SELECT 1 - FROM episode_rebuild_sources ers - WHERE ers.source = TRIM(COALESCE(e.source, '')) - AND ers.status IN ('pending', 'running') - ) - """ - ) - - if source: - token = self._normalize_episode_source(source) - if not token: - return source_expr, effective_start, effective_end, ["1 = 0"], [] - conditions.append(f"{source_expr} = ?") - params.append(token) - - p = str(person or "").strip().lower() - if p: - like_person = f"%{p}%" - conditions.append( - """ - ( - LOWER(COALESCE(e.participants_json, '')) LIKE ? - OR EXISTS ( - SELECT 1 - FROM episode_paragraphs ep_person - JOIN paragraph_entities pe ON pe.paragraph_hash = ep_person.paragraph_hash - JOIN entities en ON en.hash = pe.entity_hash - WHERE ep_person.episode_id = e.episode_id - AND LOWER(en.name) LIKE ? - ) - ) - """ - ) - params.extend([like_person, like_person]) - - if time_from is not None and time_to is not None: - conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") - params.extend([float(time_from), float(time_to)]) - elif time_from is not None: - conditions.append(f"({effective_end} >= ?)") - params.append(float(time_from)) - elif time_to is not None: - conditions.append(f"({effective_start} <= ?)") - params.append(float(time_to)) - - return source_expr, effective_start, effective_end, conditions, params - - @staticmethod - def _tokenize_episode_query(query: str) -> Tuple[str, List[str]]: - """将 episode 查询归一化为短语和 token。""" - normalized = normalize_text(str(query or "")).strip().lower() - if not normalized: - return "", [] - - token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") - tokens: List[str] = [] - seen = set() - for token in token_pattern.findall(normalized): - if token in seen: - continue - seen.add(token) - tokens.append(token) - - if not tokens and len(normalized) >= 2: - tokens = [normalized] - return normalized, tokens - - def get_episode_rows_by_paragraph_hashes( - self, - paragraph_hashes: List[str], - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> List[Dict[str, Any]]: - normalized: List[str] = [] - seen = set() - for item in paragraph_hashes or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - if not normalized: - return [] - - _, _, _, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - placeholders = ",".join(["?"] * len(normalized)) - conditions.append(f"ep.paragraph_hash IN ({placeholders})") - conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") - where_sql = "WHERE " + " AND ".join(conditions) - - sql = f""" - SELECT e.*, ep.paragraph_hash AS matched_paragraph_hash - FROM episodes e - JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id - JOIN paragraphs p ON p.hash = ep.paragraph_hash - {where_sql} - ORDER BY e.updated_at DESC - """ - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params + normalized)) - - grouped: Dict[str, Dict[str, Any]] = {} - for row in cursor.fetchall(): - episode_id = str(row["episode_id"] or "").strip() - if not episode_id: - continue - payload = grouped.get(episode_id) - if payload is None: - payload = self._episode_row_to_dict(row) - payload["matched_paragraph_hashes"] = [] - grouped[episode_id] = payload - matched_hash = str(row["matched_paragraph_hash"] or "").strip() - if matched_hash and matched_hash not in payload["matched_paragraph_hashes"]: - payload["matched_paragraph_hashes"].append(matched_hash) - - out = list(grouped.values()) - for item in out: - item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) - return out - - def get_episode_rows_by_relation_hashes( - self, - relation_hashes: List[str], - *, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - ) -> List[Dict[str, Any]]: - normalized: List[str] = [] - seen = set() - for item in relation_hashes or []: - token = str(item or "").strip() - if not token or token in seen: - continue - seen.add(token) - normalized.append(token) - if not normalized: - return [] - - _, _, _, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - placeholders = ",".join(["?"] * len(normalized)) - conditions.append(f"pr.relation_hash IN ({placeholders})") - conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") - where_sql = "WHERE " + " AND ".join(conditions) - - sql = f""" - SELECT - e.*, - p.hash AS matched_paragraph_hash, - pr.relation_hash AS matched_relation_hash - FROM episodes e - JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id - JOIN paragraphs p ON p.hash = ep.paragraph_hash - JOIN paragraph_relations pr ON pr.paragraph_hash = p.hash - {where_sql} - ORDER BY e.updated_at DESC - """ - cursor = self._conn.cursor() - cursor.execute(sql, tuple(params + normalized)) - - grouped: Dict[str, Dict[str, Any]] = {} - for row in cursor.fetchall(): - episode_id = str(row["episode_id"] or "").strip() - if not episode_id: - continue - payload = grouped.get(episode_id) - if payload is None: - payload = self._episode_row_to_dict(row) - payload["matched_paragraph_hashes"] = [] - payload["matched_relation_hashes"] = [] - grouped[episode_id] = payload - matched_paragraph = str(row["matched_paragraph_hash"] or "").strip() - matched_relation = str(row["matched_relation_hash"] or "").strip() - if matched_paragraph and matched_paragraph not in payload["matched_paragraph_hashes"]: - payload["matched_paragraph_hashes"].append(matched_paragraph) - if matched_relation and matched_relation not in payload["matched_relation_hashes"]: - payload["matched_relation_hashes"].append(matched_relation) - - out = list(grouped.values()) - for item in out: - item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) - item["matched_relation_count"] = len(item.get("matched_relation_hashes", [])) - return out - - def query_episodes( - self, - query: str = "", - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - limit: int = 20, - ) -> List[Dict[str, Any]]: - """查询 episode 列表。""" - safe_limit = max(1, int(limit)) - _, effective_start, effective_end, conditions, params = self._build_episode_query_components( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - - q, tokens = self._tokenize_episode_query(query) - select_score_sql = "0.0 AS lexical_score" - order_sql = f"{effective_end} DESC, e.updated_at DESC" - select_params: List[Any] = [] - query_params: List[Any] = [] - if q: - field_exprs = { - "title": "LOWER(COALESCE(e.title, ''))", - "summary": "LOWER(COALESCE(e.summary, ''))", - "keywords": "LOWER(COALESCE(e.keywords_json, ''))", - "participants": "LOWER(COALESCE(e.participants_json, ''))", - } - - score_parts: List[str] = [] - phrase_like = f"%{q}%" - score_parts.extend( - [ - f"CASE WHEN {field_exprs['title']} LIKE ? THEN 6.0 ELSE 0.0 END", - f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 4.5 ELSE 0.0 END", - f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 3.0 ELSE 0.0 END", - f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 2.0 ELSE 0.0 END", - ] - ) - select_params.extend([phrase_like, phrase_like, phrase_like, phrase_like]) - - token_predicates: List[str] = [] - for token in tokens: - like = f"%{token}%" - token_any = ( - f"({field_exprs['title']} LIKE ? OR " - f"{field_exprs['summary']} LIKE ? OR " - f"{field_exprs['keywords']} LIKE ? OR " - f"{field_exprs['participants']} LIKE ?)" - ) - token_predicates.append(token_any) - query_params.extend([like, like, like, like]) - - score_parts.append( - "(" - f"CASE WHEN {field_exprs['title']} LIKE ? THEN 3.0 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['keywords']} LIKE ? THEN 2.5 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['summary']} LIKE ? THEN 2.0 ELSE 0.0 END + " - f"CASE WHEN {field_exprs['participants']} LIKE ? THEN 1.5 ELSE 0.0 END + " - f"CASE WHEN {token_any.replace('?', '?')} THEN 2.0 ELSE 0.0 END" - ")" - ) - select_params.extend([like, like, like, like, like, like, like, like]) - - if token_predicates: - conditions.append("(" + " OR ".join(token_predicates) + ")") - - select_score_sql = f"({' + '.join(score_parts)}) AS lexical_score" - order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" - - where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" - sql = f""" - SELECT e.*, {select_score_sql} - FROM episodes e - {where_sql} - ORDER BY {order_sql} - LIMIT ? - """ - final_params = list(select_params) + list(params) + list(query_params) + [safe_limit] - - cursor = self._conn.cursor() - cursor.execute(sql, tuple(final_params)) - return [self._episode_row_to_dict(row) for row in cursor.fetchall()] - - def get_episode_by_id(self, episode_id: str) -> Optional[Dict[str, Any]]: - """获取单条 episode。""" - token = str(episode_id or "").strip() - if not token: - return None - cursor = self._conn.cursor() - cursor.execute( - "SELECT * FROM episodes WHERE episode_id = ? LIMIT 1", - (token,), - ) - row = cursor.fetchone() - if not row: - return None - return self._episode_row_to_dict(row) - - def get_episode_paragraphs(self, episode_id: str, limit: int = 100) -> List[Dict[str, Any]]: - """获取 episode 关联段落(按 position 排序)。""" - token = str(episode_id or "").strip() - if not token: - return [] - safe_limit = max(1, int(limit)) - cursor = self._conn.cursor() - cursor.execute( - """ - SELECT p.*, ep.position - FROM episode_paragraphs ep - JOIN paragraphs p ON p.hash = ep.paragraph_hash - WHERE ep.episode_id = ? - AND (p.is_deleted IS NULL OR p.is_deleted = 0) - ORDER BY ep.position ASC - LIMIT ? - """, - (token, safe_limit), - ) - items = [] - for row in cursor.fetchall(): - payload = self._row_to_dict(row, "paragraph") - payload["position"] = row["position"] - items.append(payload) - return items - - def has_table(self, table_name: str) -> bool: - """检查数据库是否存在指定表。""" - if not self._conn: - return False - cursor = self._conn.cursor() - cursor.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? LIMIT 1", - (table_name,), - ) - return cursor.fetchone() is not None - - def get_deleted_entities(self, limit: int = 50) -> List[Dict[str, Any]]: - """获取已软删除的实体 (回收站用)""" - if not self.has_table("entities"): return [] - - cursor = self._conn.cursor() - cursor.execute(""" - SELECT hash, name, deleted_at - FROM entities - WHERE is_deleted = 1 - ORDER BY deleted_at DESC - LIMIT ? - """, (limit,)) - - items = [] - for row in cursor.fetchall(): - items.append({ - "hash": row[0], - "name": row[1], - "type": "entity", # 标记为实体 - "deleted_at": row[2] - }) - return items - - def __repr__(self) -> str: - stats = self.get_statistics() if self.is_connected else {} - return ( - f"MetadataStore(paragraphs={stats.get('paragraph_count', 0)}, " - f"entities={stats.get('entity_count', 0)}, " - f"relations={stats.get('relation_count', 0)})" - ) - - def has_data(self) -> bool: - """检查磁盘上是否存在现有数据""" - if self.data_dir is None: - return False - return (self.data_dir / self.db_name).exists() diff --git a/plugins/A_memorix/core/storage/type_detection.py b/plugins/A_memorix/core/storage/type_detection.py deleted file mode 100644 index c20d2cb4..00000000 --- a/plugins/A_memorix/core/storage/type_detection.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Heuristic detection for import strategies and stored knowledge types.""" - -from __future__ import annotations - -import re -from typing import Optional - -from .knowledge_types import ( - ImportStrategy, - KnowledgeType, - parse_import_strategy, - resolve_stored_knowledge_type, -) - - -_NARRATIVE_MARKERS = [ - r"然后", - r"接着", - r"于是", - r"后来", - r"最后", - r"突然", - r"一天", - r"曾经", - r"有一次", - r"从前", - r"说道", - r"问道", - r"想着", - r"觉得", -] -_FACTUAL_MARKERS = [ - r"是", - r"有", - r"在", - r"为", - r"属于", - r"位于", - r"包含", - r"拥有", - r"成立于", - r"出生于", -] - - -def _non_empty_lines(content: str) -> list[str]: - return [line for line in str(content or "").splitlines() if line.strip()] - - -def looks_like_structured_text(content: str) -> bool: - text = str(content or "").strip() - if "|" not in text or text.count("|") < 2: - return False - parts = text.split("|") - return len(parts) == 3 and all(part.strip() for part in parts) - - -def looks_like_quote_text(content: str) -> bool: - lines = _non_empty_lines(content) - if len(lines) < 5: - return False - avg_len = sum(len(line) for line in lines) / len(lines) - return avg_len < 20 - - -def looks_like_narrative_text(content: str) -> bool: - text = str(content or "").strip() - if not text: - return False - - narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text)) - has_dialogue = bool(re.search(r'["「『].*?["」』]', text)) - has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###")) - return has_chapter or has_dialogue or narrative_score >= 2 - - -def looks_like_factual_text(content: str) -> bool: - text = str(content or "").strip() - if not text: - return False - if looks_like_structured_text(text) or looks_like_quote_text(text): - return False - - factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text)) - if factual_score <= 0: - return False - - if len(text) <= 240: - return True - return factual_score >= 2 and not looks_like_narrative_text(text) - - -def select_import_strategy( - content: str, - *, - override: Optional[str | ImportStrategy] = None, - chat_log: bool = False, -) -> ImportStrategy: - """文本导入策略选择:override > quote > factual > narrative。""" - - if chat_log: - return ImportStrategy.NARRATIVE - - strategy = parse_import_strategy(override, default=ImportStrategy.AUTO) - if strategy != ImportStrategy.AUTO: - return strategy - - if looks_like_quote_text(content): - return ImportStrategy.QUOTE - if looks_like_factual_text(content): - return ImportStrategy.FACTUAL - return ImportStrategy.NARRATIVE - - -def detect_knowledge_type(content: str) -> KnowledgeType: - """自动检测落库 knowledge_type;无法可靠判断时回退 mixed。""" - - text = str(content or "").strip() - if not text: - return KnowledgeType.MIXED - if looks_like_structured_text(text): - return KnowledgeType.STRUCTURED - if looks_like_quote_text(text): - return KnowledgeType.QUOTE - if looks_like_factual_text(text): - return KnowledgeType.FACTUAL - if looks_like_narrative_text(text): - return KnowledgeType.NARRATIVE - return KnowledgeType.MIXED - - -def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType: - """优先使用显式 type_hint,否则自动检测。""" - - if type_hint: - return resolve_stored_knowledge_type(type_hint, content=content) - return detect_knowledge_type(content) diff --git a/plugins/A_memorix/core/storage/vector_store.py b/plugins/A_memorix/core/storage/vector_store.py deleted file mode 100644 index 97a9144c..00000000 --- a/plugins/A_memorix/core/storage/vector_store.py +++ /dev/null @@ -1,776 +0,0 @@ -""" -向量存储模块 - -基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。 -""" - -import os -import pickle -import hashlib -import shutil -import time -from pathlib import Path -from typing import Optional, Union, Tuple, List, Dict, Set, Any -import random -import threading # Added threading import - -import numpy as np - -try: - import faiss - HAS_FAISS = True -except ImportError: - HAS_FAISS = False - -from src.common.logger import get_logger -from ..utils.quantization import QuantizationType -from ..utils.io import atomic_write, atomic_save_path - -logger = get_logger("A_Memorix.VectorStore") - - -class VectorStore: - """ - 向量存储类 (SQ8 + Append-Only Disk) - - 特性: - - 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit)) - - 存储: float16 on-disk binary (vectors.bin) - - 内存: 仅索引常驻 RAM (<512MB for 100k vectors) - - ID: SHA1-based stable int64 IDs - - 一致性: 强制 L2 Normalization (IP == Cosine) - """ - - # 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化) - DEFAULT_MIN_TRAIN = 40 - # 强制训练样本量 - TRAIN_SIZE = 10000 - # 储水池采样上限 (流式处理前 50k 数据) - RESERVOIR_CAPACITY = 10000 - RESERVOIR_SAMPLE_SCOPE = 50000 - - def __init__( - self, - dimension: int, - quantization_type: QuantizationType = QuantizationType.INT8, - index_type: str = "sq8", - data_dir: Optional[Union[str, Path]] = None, - use_mmap: bool = True, - buffer_size: int = 1024, - ): - if not HAS_FAISS: - raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu") - - self.dimension = dimension - self.data_dir = Path(data_dir) if data_dir else None - if self.data_dir: - self.data_dir.mkdir(parents=True, exist_ok=True) - if quantization_type != QuantizationType.INT8: - raise ValueError( - "vNext 仅支持 quantization_type=int8(SQ8)。" - " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" - ) - normalized_index_type = str(index_type or "sq8").strip().lower() - if normalized_index_type not in {"sq8", "int8"}: - raise ValueError( - "vNext 仅支持 index_type=sq8。" - " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" - ) - self.quantization_type = QuantizationType.INT8 - self.index_type = "sq8" - self.buffer_size = buffer_size - - self._index: Optional[faiss.IndexIDMap2] = None - self._init_index() - - self._is_trained = False - self._vector_norm = "l2" - - # Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力 - # 必须使用 IndexIDMap2 以保证 ID 与主索引一致 - self._fallback_index: Optional[faiss.IndexIDMap2] = None - self._init_fallback_index() - - self._known_hashes: Set[str] = set() - self._deleted_ids: Set[int] = set() - - self._reservoir_buffer: List[np.ndarray] = [] - self._seen_count_for_reservoir = 0 - - self._write_buffer_vecs: List[np.ndarray] = [] - self._write_buffer_ids: List[int] = [] - - self._total_added = 0 - self._total_deleted = 0 - self._bin_count = 0 - - # Thread safety lock - self._lock = threading.RLock() - - logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage") - - def _init_index(self): - """初始化空的 Faiss 索引""" - quantizer = faiss.IndexScalarQuantizer( - self.dimension, - faiss.ScalarQuantizer.QT_8bit, - faiss.METRIC_INNER_PRODUCT - ) - self._index = faiss.IndexIDMap2(quantizer) - self._is_trained = False - - def _init_fallback_index(self): - """初始化 Flat 回退索引""" - flat_index = faiss.IndexFlatIP(self.dimension) - self._fallback_index = faiss.IndexIDMap2(flat_index) - logger.debug("Fallback index (Flat) initialized.") - - @staticmethod - def _generate_id(key: str) -> int: - """生成稳定的 int64 ID (SHA1 截断)""" - h = hashlib.sha1(key.encode("utf-8")).digest() - val = int.from_bytes(h[:8], byteorder="big", signed=False) - return val & 0x7FFFFFFFFFFFFFFF - - @property - def _bin_path(self) -> Path: - return self.data_dir / "vectors.bin" - - @property - def _ids_bin_path(self) -> Path: - return self.data_dir / "vectors_ids.bin" - - @property - def _int_to_str_map(self) -> Dict[int, str]: - """Lazy build volatile map from known hashes""" - # Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently - # But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache - if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes): - with self._lock: # Protect cache rebuild - self._cached_map = {self._generate_id(k): k for k in self._known_hashes} - return self._cached_map - - def add(self, vectors: np.ndarray, ids: List[str]) -> int: - with self._lock: - if vectors.shape[1] != self.dimension: - raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}") - - vectors = np.ascontiguousarray(vectors, dtype=np.float32) - faiss.normalize_L2(vectors) - - processed_vecs = [] - processed_int_ids = [] - - for i, str_id in enumerate(ids): - if str_id in self._known_hashes: - continue - - int_id = self._generate_id(str_id) - self._known_hashes.add(str_id) - - processed_vecs.append(vectors[i]) - processed_int_ids.append(int_id) - - if not processed_vecs: - return 0 - - batch_vecs = np.array(processed_vecs, dtype=np.float32) - batch_ids = np.array(processed_int_ids, dtype=np.int64) - - self._write_buffer_vecs.append(batch_vecs) - self._write_buffer_ids.extend(processed_int_ids) - - if len(self._write_buffer_ids) >= self.buffer_size: - self._flush_write_buffer_unlocked() - - if not self._is_trained: - # 双写到回退索引 - self._fallback_index.add_with_ids(batch_vecs, batch_ids) - - self._update_reservoir(batch_vecs) - # 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断 - if len(self._reservoir_buffer) >= 10000: - logger.info(f"训练样本达到上限,开始训练...") - self._train_and_replay_unlocked() - - self._total_added += len(batch_ids) - return len(batch_ids) - - def _flush_write_buffer(self): - with self._lock: - self._flush_write_buffer_unlocked() - - def _flush_write_buffer_unlocked(self): - if not self._write_buffer_vecs: - return - - batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0) - batch_ids = np.array(self._write_buffer_ids, dtype=np.int64) - - vecs_fp16 = batch_vecs.astype(np.float16) - - with open(self._bin_path, "ab") as f: - f.write(vecs_fp16.tobytes()) - - ids_bytes = batch_ids.astype('>i8').tobytes() - with open(self._ids_bin_path, "ab") as f: - f.write(ids_bytes) - - self._bin_count += len(batch_ids) - - if self._is_trained and self._index.is_trained: - self._index.add_with_ids(batch_vecs, batch_ids) - else: - # 即使在 flush 时,如果未训练,也要同步到 fallback - self._fallback_index.add_with_ids(batch_vecs, batch_ids) - - self._write_buffer_vecs.clear() - self._write_buffer_ids.clear() - - def _update_reservoir(self, vectors: np.ndarray): - for vec in vectors: - self._seen_count_for_reservoir += 1 - if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY: - self._reservoir_buffer.append(vec) - else: - if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE: - r = random.randint(0, self._seen_count_for_reservoir - 1) - if r < self.RESERVOIR_CAPACITY: - self._reservoir_buffer[r] = vec - - def _train_and_replay(self): - with self._lock: - self._train_and_replay_unlocked() - - def _train_and_replay_unlocked(self): - if not self._reservoir_buffer: - logger.warning("No training data available.") - return - - train_data = np.array(self._reservoir_buffer, dtype=np.float32) - logger.info(f"Training Index with {len(train_data)} samples...") - - try: - self._index.train(train_data) - except Exception as e: - logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.") - return - - self._is_trained = True - self._reservoir_buffer = [] - - logger.info("Replaying data from disk to populate index...") - try: - replay_count = self._replay_vectors_to_index() - # 只有当 replay 成功且数据量一致时,才释放回退索引 - if self._index.ntotal >= self._bin_count: - logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.") - self._fallback_index.reset() - else: - logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.") - except Exception as e: - logger.error(f"Replay failed: {e}. Keeping fallback index as backup.") - - def _replay_vectors_to_index(self) -> int: - """从 vectors.bin 读取并添加到 index""" - if not self._bin_path.exists() or not self._ids_bin_path.exists(): - return 0 - - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - - if not vec_data: - break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_fp32 = batch_fp16.astype(np.float32) - faiss.normalize_L2(batch_fp32) - - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - if not all(valid_mask): - batch_fp32 = batch_fp32[valid_mask] - batch_ids = batch_ids[valid_mask] - - if len(batch_ids) > 0: - self._index.add_with_ids(batch_fp32, batch_ids) - - def search( - self, - query: np.ndarray, - k: int = 10, - filter_deleted: bool = True, - ) -> Tuple[List[str], List[float]]: - query_local = np.array(query, dtype=np.float32, order="C", copy=True) - if query_local.ndim == 1: - got_dim = int(query_local.shape[0]) - query_local = query_local.reshape(1, -1) - elif query_local.ndim == 2: - if query_local.shape[0] != 1: - raise ValueError( - f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" - ) - got_dim = int(query_local.shape[1]) - else: - raise ValueError( - f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" - ) - - if got_dim != self.dimension: - raise ValueError( - f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}" - ) - if not np.all(np.isfinite(query_local)): - raise ValueError("query embedding contains non-finite values") - - faiss.normalize_L2(query_local) - - # 查询路径仅负责检索,不在此触发训练/回放。 - # 训练/回放前置到 warmup_index(),并由插件启动阶段触发。 - # Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。 - with self._lock: - self._flush_write_buffer_unlocked() - search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index - if search_index.ntotal == 0: - logger.warning("Indices are empty. No data to search.") - return [], [] - # 执行检索 - dists, ids = search_index.search(query_local, k * 2) - - # Faiss search 返回的是 (1, K) 的数组,取第一行 - dists = dists[0] - ids = ids[0] - - results = [] - for id_val, score in zip(ids, dists): - if id_val == -1: continue - if filter_deleted and id_val in self._deleted_ids: - continue - - str_id = self._int_to_str_map.get(id_val) - if str_id: - results.append((str_id, float(score))) - - # Sort and trim just in case filtering reduced count - results.sort(key=lambda x: x[1], reverse=True) - results = results[:k] - - if not results: - return [], [] - - return [r[0] for r in results], [r[1] for r in results] - - def warmup_index(self, force_train: bool = True) -> Dict[str, Any]: - """ - 预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。 - - Args: - force_train: 是否在满足阈值时强制训练 SQ8 索引 - - Returns: - 预热状态摘要 - """ - started = time.perf_counter() - logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}") - - try: - with self._lock: - self._flush_write_buffer() - - if self._bin_path.exists(): - self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2) - else: - self._bin_count = 0 - - needs_fallback_bootstrap = ( - self._bin_count > 0 - and self._fallback_index.ntotal == 0 - and (not self._is_trained or self._index.ntotal == 0) - ) - if needs_fallback_bootstrap: - self._bootstrap_fallback_from_disk() - - min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN))) - needs_train = ( - bool(force_train) - and self._bin_count >= min_train - and not self._is_trained - ) - if needs_train: - self._force_train_small_data() - - duration_ms = (time.perf_counter() - started) * 1000.0 - summary = { - "ok": True, - "trained": bool(self._is_trained), - "index_ntotal": int(self._index.ntotal), - "fallback_ntotal": int(self._fallback_index.ntotal), - "bin_count": int(self._bin_count), - "duration_ms": duration_ms, - "error": None, - } - except Exception as e: - duration_ms = (time.perf_counter() - started) * 1000.0 - summary = { - "ok": False, - "trained": bool(self._is_trained), - "index_ntotal": int(self._index.ntotal) if self._index is not None else 0, - "fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0, - "bin_count": int(getattr(self, "_bin_count", 0)), - "duration_ms": duration_ms, - "error": str(e), - } - logger.error( - "metric.vector_index_prewarm_fail=1 " - f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} " - f"error={e}" - ) - return summary - - logger.info( - "metric.vector_index_prewarm_success=1 " - f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} " - f"trained={summary['trained']} " - f"index_ntotal={summary['index_ntotal']} " - f"fallback_ntotal={summary['fallback_ntotal']} " - f"bin_count={summary['bin_count']}" - ) - return summary - - def _bootstrap_fallback_from_disk(self): - with self._lock: - self._bootstrap_fallback_from_disk_unlocked() - - def _bootstrap_fallback_from_disk_unlocked(self): - """重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引""" - if not self._bin_path.exists() or not self._ids_bin_path.exists(): - return - - logger.info("Replaying all disk vectors to fallback index...") - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - if not vec_data: break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_fp32 = batch_fp16.astype(np.float32) - faiss.normalize_L2(batch_fp32) - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - if any(valid_mask): - self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask]) - - logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.") - - def _force_train_small_data(self): - with self._lock: - self._force_train_small_data_unlocked() - - def _force_train_small_data_unlocked(self): - logger.info("Forcing training on small dataset...") - self._reservoir_buffer = [] - - chunk_size = 10000 - vec_item_size = self.dimension * 2 - - with open(self._bin_path, "rb") as f: - while len(self._reservoir_buffer) < self.TRAIN_SIZE: - data = f.read(chunk_size * vec_item_size) - if not data: break - fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension) - fp32 = fp16.astype(np.float32) - faiss.normalize_L2(fp32) - - for vec in fp32: - self._reservoir_buffer.append(vec) - if len(self._reservoir_buffer) >= self.TRAIN_SIZE: - break - - self._train_and_replay_unlocked() - - def delete(self, ids: List[str]) -> int: - with self._lock: - count = 0 - for str_id in ids: - if str_id not in self._known_hashes: - continue - int_id = self._generate_id(str_id) - if int_id not in self._deleted_ids: - self._deleted_ids.add(int_id) - if self._index.is_trained: - self._index.remove_ids(np.array([int_id], dtype=np.int64)) - # 同步从 fallback 移除 - if self._fallback_index.ntotal > 0: - self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64)) - count += 1 - self._total_deleted += count - - # Check GC - self._check_rebuild_needed() - return count - - def _check_rebuild_needed(self): - """GC Excution Check""" - if self._bin_count == 0: return - ratio = len(self._deleted_ids) / self._bin_count - if ratio > 0.3 and len(self._deleted_ids) > 1000: - logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})") - self.rebuild_index() - - def rebuild_index(self): - """GC: 重建索引,压缩 bin 文件""" - with self._lock: - self._rebuild_index_locked() - - def _rebuild_index_locked(self): - """实际 GC 重建逻辑。""" - logger.info("Starting Compaction (GC)...") - - tmp_bin = self.data_dir / "vectors.bin.tmp" - tmp_ids = self.data_dir / "vectors_ids.bin.tmp" - - vec_item_size = self.dimension * 2 - id_item_size = 8 - chunk_size = 10000 - - new_count = 0 - - # 1. Compact Files - with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \ - open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id: - while True: - vec_data = f_vec.read(chunk_size * vec_item_size) - id_data = f_id.read(chunk_size * id_item_size) - if not vec_data: break - - batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) - batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) - - keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids] - - if any(keep_mask): - keep_vecs = batch_fp16[keep_mask] - keep_ids = batch_ids[keep_mask] - - w_vec.write(keep_vecs.tobytes()) - w_id.write(keep_ids.astype('>i8').tobytes()) - new_count += len(keep_ids) - - # 2. Reset State & Atomic Swap - self._bin_count = new_count - - # Close current index - self._index.reset() - if self._fallback_index: self._fallback_index.reset() # Also clear fallback - self._is_trained = False - - # Swap files - shutil.move(str(tmp_bin), str(self._bin_path)) - shutil.move(str(tmp_ids), str(self._ids_bin_path)) - - # Reset Tombstones (Critical) - self._deleted_ids.clear() - - # 3. Reload/Rebuild Index (Fresh Train) - # We need to re-train because data distribution might have changed significantly after deletion - self._init_index() - self._init_fallback_index() # Re-init fallback too - self._force_train_small_data() # This will train and replay from the NEW compact file - - logger.info("Compaction Complete.") - - def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: - with self._lock: - if not data_dir: - data_dir = self.data_dir - if not data_dir: - raise ValueError("No data_dir") - - data_dir = Path(data_dir) - data_dir.mkdir(parents=True, exist_ok=True) - - self._flush_write_buffer_unlocked() - - if self._is_trained: - index_path = data_dir / "vectors.index" - with atomic_save_path(index_path) as tmp: - faiss.write_index(self._index, tmp) - - meta = { - "dimension": self.dimension, - "quantization_type": self.quantization_type.value, - "is_trained": self._is_trained, - "vector_norm": self._vector_norm, - "deleted_ids": list(self._deleted_ids), - "known_hashes": list(self._known_hashes), - } - - with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f: - pickle.dump(meta, f) - - logger.info("VectorStore saved.") - - def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]: - """ - 离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。 - """ - with self._lock: - target_dir = Path(data_dir) if data_dir else self.data_dir - if target_dir is None: - raise ValueError("No data_dir") - target_dir = Path(target_dir) - npy_path = target_dir / "vectors.npy" - idx_path = target_dir / "vectors.index" - bin_path = target_dir / "vectors.bin" - ids_bin_path = target_dir / "vectors_ids.bin" - meta_path = target_dir / "vectors_metadata.pkl" - - if not npy_path.exists(): - return {"migrated": False, "reason": "npy_missing"} - if not meta_path.exists(): - raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl") - if bin_path.exists() and ids_bin_path.exists(): - return {"migrated": False, "reason": "bin_exists"} - - # Reset in-memory state to avoid appending to stale runtime buffers. - self._known_hashes.clear() - self._deleted_ids.clear() - self._write_buffer_vecs.clear() - self._write_buffer_ids.clear() - self._init_index() - self._init_fallback_index() - self._is_trained = False - self._bin_count = 0 - - self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir) - self.save(target_dir) - return {"migrated": True, "reason": "ok"} - - def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: - with self._lock: - if not data_dir: data_dir = self.data_dir - data_dir = Path(data_dir) - - npy_path = data_dir / "vectors.npy" - idx_path = data_dir / "vectors.index" - bin_path = data_dir / "vectors.bin" - - if npy_path.exists() and not bin_path.exists(): - raise RuntimeError( - "检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - meta_path = data_dir / "vectors_metadata.pkl" - if not meta_path.exists(): - logger.warning("No metadata found, initialized empty.") - return - - with open(meta_path, "rb") as f: - meta = pickle.load(f) - - if meta.get("vector_norm") != "l2": - logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...") - self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", [])) - self._deleted_ids = set(meta.get("deleted_ids", [])) - self._init_index() - self._force_train_small_data() - return - - self._is_trained = meta.get("is_trained", False) - self._vector_norm = meta.get("vector_norm", "l2") - self._deleted_ids = set(meta.get("deleted_ids", [])) - self._known_hashes = set(meta.get("known_hashes", [])) - - if self._is_trained: - if idx_path.exists(): - try: - self._index = faiss.read_index(str(idx_path)) - if not isinstance(self._index, faiss.IndexIDMap2): - logger.warning("Loaded index type mismatch. Rebuilding...") - self._init_index() - self._force_train_small_data() - except Exception as e: - logger.error(f"Failed to load index: {e}. Rebuilding...") - self._init_index() - self._force_train_small_data() - else: - logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...") - self._init_index() - self._force_train_small_data() - - if bin_path.exists(): - self._bin_count = bin_path.stat().st_size // (self.dimension * 2) - - def _migrate_from_npy(self, npy_path, idx_path, data_dir): - with self._lock: - self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir) - - def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir): - try: - arr = np.load(npy_path, mmap_mode="r") - except Exception: - arr = np.load(npy_path) - - meta_path = data_dir / "vectors_metadata.pkl" - old_ids = [] - if meta_path.exists(): - with open(meta_path, "rb") as f: - m = pickle.load(f) - old_ids = m.get("ids", []) - - if len(arr) != len(old_ids): - logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}") - return - - logger.info(f"Migrating {len(arr)} vectors...") - - chunk = 1000 - for i in range(0, len(arr), chunk): - sub_arr = arr[i : i+chunk] - sub_ids = old_ids[i : i+chunk] - self.add(sub_arr, sub_ids) - - if not self._is_trained: - self._force_train_small_data() - - shutil.move(str(npy_path), str(npy_path) + ".bak") - if idx_path.exists(): - shutil.move(str(idx_path), str(idx_path) + ".bak") - - logger.info("Migration complete.") - - def clear(self) -> None: - with self._lock: - self._ids_bin_path.unlink(missing_ok=True) - self._bin_path.unlink(missing_ok=True) - self._init_index() - self._known_hashes.clear() - self._deleted_ids.clear() - self._bin_count = 0 - logger.info("VectorStore cleared.") - - def has_data(self) -> bool: - return (self.data_dir / "vectors_metadata.pkl").exists() - - @property - def num_vectors(self) -> int: - return len(self._known_hashes) - len(self._deleted_ids) - - def __contains__(self, hash_value: str) -> bool: - """Check if a hash exists in the store""" - return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids - diff --git a/plugins/A_memorix/core/strategies/__init__.py b/plugins/A_memorix/core/strategies/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/plugins/A_memorix/core/strategies/base.py b/plugins/A_memorix/core/strategies/base.py deleted file mode 100644 index ff250cdf..00000000 --- a/plugins/A_memorix/core/strategies/base.py +++ /dev/null @@ -1,89 +0,0 @@ -from abc import ABC, abstractmethod -from typing import List, Dict, Any, Optional, Union -from dataclasses import dataclass, field -from enum import Enum -import hashlib - -class KnowledgeType(str, Enum): - NARRATIVE = "narrative" - FACTUAL = "factual" - QUOTE = "quote" - MIXED = "mixed" - -@dataclass -class SourceInfo: - file: str - offset_start: int - offset_end: int - checksum: str = "" - -@dataclass -class ChunkContext: - chunk_id: str - index: int - context: Dict[str, Any] = field(default_factory=dict) - text: str = "" - -@dataclass -class ChunkFlags: - verbatim: bool = False - requires_llm: bool = True - -@dataclass -class ProcessedChunk: - type: KnowledgeType - source: SourceInfo - chunk: ChunkContext - data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities - flags: ChunkFlags = field(default_factory=ChunkFlags) - - def to_dict(self) -> Dict: - return { - "type": self.type.value, - "source": { - "file": self.source.file, - "offset_start": self.source.offset_start, - "offset_end": self.source.offset_end, - "checksum": self.source.checksum - }, - "chunk": { - "text": self.chunk.text, - "chunk_id": self.chunk.chunk_id, - "index": self.chunk.index, - "context": self.chunk.context - }, - "data": self.data, - "flags": { - "verbatim": self.flags.verbatim, - "requires_llm": self.flags.requires_llm - } - } - -class BaseStrategy(ABC): - def __init__(self, filename: str): - self.filename = filename - - @abstractmethod - def split(self, text: str) -> List[ProcessedChunk]: - """按策略将文本切分为块。""" - pass - - @abstractmethod - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - """从文本块中抽取结构化信息。""" - pass - - def calculate_checksum(self, text: str) -> str: - return hashlib.sha256(text.encode("utf-8")).hexdigest() - - def build_language_guard(self, text: str) -> str: - """ - 构建统一的输出语言约束。 - 不区分语言类型,仅要求抽取值保持原文语言,不做翻译。 - """ - _ = text # 预留参数,便于后续按需扩展 - return ( - "Focus on the original source language. Keep extracted events, entities, predicates " - "and objects in the same language as the source text, preserve names/terms as-is, " - "and do not translate." - ) diff --git a/plugins/A_memorix/core/strategies/factual.py b/plugins/A_memorix/core/strategies/factual.py deleted file mode 100644 index 4b7d6e56..00000000 --- a/plugins/A_memorix/core/strategies/factual.py +++ /dev/null @@ -1,98 +0,0 @@ -import re -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext - -class FactualStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - # 结构感知切分 - lines = text.split('\n') - chunks = [] - current_chunk_lines = [] - current_len = 0 - target_size = 600 - - for i, line in enumerate(lines): - # 判断是否应当切分 - # 若当前行为列表项/定义/表格行,则尽量不切分 - is_structure = self._is_structural_line(line) - - current_len += len(line) + 1 - current_chunk_lines.append(line) - - # 达到目标长度且不在紧凑结构块内时切分(过长时强制切分) - if current_len >= target_size and not is_structure: - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - current_chunk_lines = [] - current_len = 0 - elif current_len >= target_size * 2: # 超长时强制切分 - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - current_chunk_lines = [] - current_len = 0 - - if current_chunk_lines: - chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) - - return chunks - - def _is_structural_line(self, line: str) -> bool: - line = line.strip() - if not line: return False - # 列表项 - if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line): - return True - # 定义项(术语: 定义) - if re.match(r'^[^::]+[::].+', line): - return True - # 表格行(按 markdown 语法假设) - if line.startswith('|') and line.endswith('|'): - return True - return False - - def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk: - text = "\n".join(lines) - return ProcessedChunk( - type=KnowledgeType.FACTUAL, - source=SourceInfo( - file=self.filename, - offset_start=0, # 简化处理:真实偏移跟踪需要额外状态 - offset_end=0, - checksum=self.calculate_checksum(text) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{index}", - index=index, - text=text - ) - ) - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - if not llm_func: - raise ValueError("LLM function required for Factual extraction") - - language_guard = self.build_language_guard(chunk.chunk.text) - prompt = f"""You are a factual knowledge extraction engine. -Extract factual triples and entities from the text. -Preserve lists and definitions accurately. - -Language constraints: -- {language_guard} -- Preserve original names and domain terms exactly when possible. -- JSON keys must stay exactly as: triples, entities, subject, predicate, object. - -Text: -{chunk.chunk.text} - -Return ONLY valid JSON: -{{ - "triples": [ - {{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}} - ], - "entities": ["Entity1", "Entity2"] -}} -""" - result = await llm_func(prompt) - - # 结果保持原样存入 data,后续统一归一化流程会处理 - # vector_store 侧期望关系字段为 subject/predicate/object 映射形式 - chunk.data = result - return chunk diff --git a/plugins/A_memorix/core/strategies/narrative.py b/plugins/A_memorix/core/strategies/narrative.py deleted file mode 100644 index 731414f7..00000000 --- a/plugins/A_memorix/core/strategies/narrative.py +++ /dev/null @@ -1,126 +0,0 @@ -import re -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext - -class NarrativeStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - scenes = self._split_into_scenes(text) - chunks = [] - - for scene_idx, (scene_text, scene_title) in enumerate(scenes): - scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx) - chunks.extend(scene_chunks) - - return chunks - - def _split_into_scenes(self, text: str) -> List[tuple[str, str]]: - """按标题或分隔符把文本切分为场景。""" - # 简单启发式:按 markdown 标题或特定分隔符切分 - # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 - # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 - scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)' - - # 保留分隔符,以便识别场景起点 - parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE) - - scenes = [] - current_scene_title = "Start" - current_scene_content = [] - - if parts and parts[0].strip() == "": - parts = parts[1:] - - for part in parts: - if re.match(scene_pattern_str, part, re.MULTILINE): - # 先保存上一段场景 - if current_scene_content: - scenes.append(("".join(current_scene_content), current_scene_title)) - current_scene_content = [] - current_scene_title = part.strip() - else: - current_scene_content.append(part) - - if current_scene_content: - scenes.append(("".join(current_scene_content), current_scene_title)) - - # 若未识别到场景,则把全文视作单一场景 - if not scenes: - scenes = [(text, "Whole Text")] - - return scenes - - def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]: - chunks = [] - if len(text) <= window_size: - chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0)) - return chunks - - stride = window_size - overlap - start = 0 - local_idx = 0 - while start < len(text): - end = min(start + window_size, len(text)) - chunk_text = text[start:end] - - # 尽量对齐到最近换行,避免生硬截断句子 - # 仅在未到文本尾部时进行回退 - if end < len(text): - last_newline = chunk_text.rfind('\n') - if last_newline > window_size // 2: # 仅在回退距离可接受时启用 - end = start + last_newline + 1 - chunk_text = text[start:end] - - chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start)) - - start += len(chunk_text) - overlap if end < len(text) else len(chunk_text) - local_idx += 1 - - return chunks - - def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk: - return ProcessedChunk( - type=KnowledgeType.NARRATIVE, - source=SourceInfo( - file=self.filename, - offset_start=offset, - offset_end=offset + len(text), - checksum=self.calculate_checksum(text) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{scene_idx}_{local_idx}", - index=local_idx, - text=text, - context={"scene_id": scene_id} - ) - ) - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - if not llm_func: - raise ValueError("LLM function required for Narrative extraction") - - language_guard = self.build_language_guard(chunk.chunk.text) - prompt = f"""You are a narrative knowledge extraction engine. -Extract key events and character relations from the scene text. - -Language constraints: -- {language_guard} -- Preserve original names and terms exactly when possible. -- JSON keys must stay exactly as: events, relations, subject, predicate, object. - -Scene: -{chunk.chunk.context.get('scene_id')} - -Text: -{chunk.chunk.text} - -Return ONLY valid JSON: -{{ - "events": ["event description 1", "event description 2"], - "relations": [ - {{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}} - ] -}} -""" - result = await llm_func(prompt) - chunk.data = result - return chunk diff --git a/plugins/A_memorix/core/strategies/quote.py b/plugins/A_memorix/core/strategies/quote.py deleted file mode 100644 index 10733d64..00000000 --- a/plugins/A_memorix/core/strategies/quote.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import List, Dict, Any -from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags - -class QuoteStrategy(BaseStrategy): - def split(self, text: str) -> List[ProcessedChunk]: - # Split by double newlines (stanzas) - stanzas = text.split("\n\n") - chunks = [] - offset = 0 - - for idx, stanza in enumerate(stanzas): - if not stanza.strip(): - offset += len(stanza) + 2 - continue - - chunk = ProcessedChunk( - type=KnowledgeType.QUOTE, - source=SourceInfo( - file=self.filename, - offset_start=offset, - offset_end=offset + len(stanza), - checksum=self.calculate_checksum(stanza) - ), - chunk=ChunkContext( - chunk_id=f"{self.filename}_{idx}", - index=idx, - text=stanza - ), - flags=ChunkFlags( - verbatim=True, - requires_llm=False # Default to no LLM, but can be overridden - ) - ) - chunks.append(chunk) - offset += len(stanza) + 2 # +2 for \n\n - - return chunks - - async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: - # For quotes, the text itself is the entity/knowledge - # We might use LLM to extract headers/metadata if requested, but core logic is pass-through - - # Treat the whole chunk text as a verbatim entity - chunk.data = { - "verbatim_entities": [chunk.chunk.text] - } - - if llm_func and chunk.flags.requires_llm: - # Optional: Extract metadata - pass - - return chunk diff --git a/plugins/A_memorix/core/utils/__init__.py b/plugins/A_memorix/core/utils/__init__.py deleted file mode 100644 index e0d763cf..00000000 --- a/plugins/A_memorix/core/utils/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -"""工具模块 - 哈希、监控等辅助功能""" - -from .hash import compute_hash, normalize_text -from .monitor import MemoryMonitor -from .quantization import quantize_vector, dequantize_vector -from .time_parser import ( - parse_query_datetime_to_timestamp, - parse_query_time_range, - parse_ingest_datetime_to_timestamp, - normalize_time_meta, - format_timestamp, -) -from .relation_write_service import RelationWriteService, RelationWriteResult -from .relation_query import RelationQuerySpec, parse_relation_query_spec -from .plugin_id_policy import PluginIdPolicy - -__all__ = [ - "compute_hash", - "normalize_text", - "MemoryMonitor", - "quantize_vector", - "dequantize_vector", - "parse_query_datetime_to_timestamp", - "parse_query_time_range", - "parse_ingest_datetime_to_timestamp", - "normalize_time_meta", - "format_timestamp", - "RelationWriteService", - "RelationWriteResult", - "RelationQuerySpec", - "parse_relation_query_spec", - "PluginIdPolicy", -] diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py deleted file mode 100644 index dcf64c34..00000000 --- a/plugins/A_memorix/core/utils/aggregate_query_service.py +++ /dev/null @@ -1,360 +0,0 @@ -""" -聚合查询服务: -- 并发执行 search/time/episode 分支 -- 统一分支结果结构 -- 可选混合排序(Weighted RRF) -""" - -from __future__ import annotations - -import asyncio -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.AggregateQueryService") - -BranchRunner = Callable[[], Awaitable[Dict[str, Any]]] - - -class AggregateQueryService: - """聚合查询执行服务(search/time/episode)。""" - - def __init__(self, plugin_config: Optional[Any] = None): - self.plugin_config = plugin_config or {} - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin_config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _as_float(value: Any, default: float = 0.0) -> float: - try: - return float(value) - except Exception: - return float(default) - - @staticmethod - def _as_int(value: Any, default: int = 0) -> int: - try: - return int(value) - except Exception: - return int(default) - - def _rrf_k(self) -> float: - raw = self._cfg("retrieval.aggregate.rrf_k", 60.0) - value = self._as_float(raw, 60.0) - return max(1.0, value) - - def _weights(self) -> Dict[str, float]: - defaults = {"search": 1.0, "time": 1.0, "episode": 1.0} - raw = self._cfg("retrieval.aggregate.weights", {}) - if not isinstance(raw, dict): - return defaults - - out = dict(defaults) - for key in ("search", "time", "episode"): - if key in raw: - out[key] = max(0.0, self._as_float(raw.get(key), defaults[key])) - return out - - @staticmethod - def _normalize_branch_payload( - name: str, - payload: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: - data = payload if isinstance(payload, dict) else {} - results_raw = data.get("results", []) - results = results_raw if isinstance(results_raw, list) else [] - count = data.get("count") - if count is None: - count = len(results) - return { - "name": name, - "success": bool(data.get("success", False)), - "skipped": bool(data.get("skipped", False)), - "skip_reason": str(data.get("skip_reason", "") or "").strip(), - "error": str(data.get("error", "") or "").strip(), - "results": results, - "count": max(0, int(count)), - "elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)), - "content": str(data.get("content", "") or ""), - "query_type": str(data.get("query_type", "") or name), - } - - @staticmethod - def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str: - item_type = str(item.get("type", "") or "").strip().lower() - if item_type == "episode": - episode_id = str(item.get("episode_id", "") or "").strip() - if episode_id: - return f"episode:{episode_id}" - - item_hash = str(item.get("hash", "") or "").strip() - if item_hash: - return f"{item_type}:{item_hash}" - - return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}" - - def _build_mixed_results( - self, - *, - branches: Dict[str, Dict[str, Any]], - top_k: int, - ) -> List[Dict[str, Any]]: - rrf_k = self._rrf_k() - weights = self._weights() - bucket: Dict[str, Dict[str, Any]] = {} - - for branch_name, branch in branches.items(): - if not branch.get("success", False): - continue - results = branch.get("results", []) - if not isinstance(results, list): - continue - - weight = max(0.0, float(weights.get(branch_name, 1.0))) - for idx, item in enumerate(results, start=1): - if not isinstance(item, dict): - continue - key = self._mix_key(item, branch_name, idx) - score = weight / (rrf_k + float(idx)) - if key not in bucket: - merged = dict(item) - merged["fusion_score"] = 0.0 - merged["_source_branches"] = set() - bucket[key] = merged - - target = bucket[key] - target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score - target["_source_branches"].add(branch_name) - - mixed = list(bucket.values()) - mixed.sort( - key=lambda x: ( - -float(x.get("fusion_score", 0.0)), - str(x.get("type", "") or ""), - str(x.get("hash", "") or x.get("episode_id", "") or ""), - ) - ) - - out: List[Dict[str, Any]] = [] - for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1): - merged = dict(item) - branches_set = merged.pop("_source_branches", set()) - merged["source_branches"] = sorted(list(branches_set)) - merged["rank"] = rank - out.append(merged) - return out - - @staticmethod - def _status(branch: Dict[str, Any]) -> str: - if branch.get("skipped", False): - return "skipped" - if branch.get("success", False): - return "success" - return "failed" - - def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: - summary: Dict[str, Dict[str, Any]] = {} - for name, branch in branches.items(): - status = self._status(branch) - summary[name] = { - "status": status, - "count": int(branch.get("count", 0) or 0), - } - if status == "skipped": - summary[name]["reason"] = str(branch.get("skip_reason", "") or "") - if status == "failed": - summary[name]["error"] = str(branch.get("error", "") or "") - return summary - - def _build_content( - self, - *, - query: str, - branches: Dict[str, Dict[str, Any]], - errors: List[Dict[str, str]], - mixed_results: Optional[List[Dict[str, Any]]], - ) -> str: - lines: List[str] = [ - f"🔀 聚合查询结果(query='{query or 'N/A'}')", - "", - "分支状态:", - ] - for name in ("search", "time", "episode"): - branch = branches.get(name, {}) - status = self._status(branch) - count = int(branch.get("count", 0) or 0) - line = f"- {name}: {status}, count={count}" - reason = str(branch.get("skip_reason", "") or "").strip() - err = str(branch.get("error", "") or "").strip() - if status == "skipped" and reason: - line += f" ({reason})" - if status == "failed" and err: - line += f" ({err})" - lines.append(line) - - if errors: - lines.append("") - lines.append("错误:") - for item in errors[:6]: - lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}") - - if mixed_results is not None: - lines.append("") - lines.append(f"🧩 混合结果({len(mixed_results)} 条):") - for idx, item in enumerate(mixed_results[:5], start=1): - src = ",".join(item.get("source_branches", []) or []) - if str(item.get("type", "") or "") == "episode": - title = str(item.get("title", "") or "Untitled") - lines.append(f"{idx}. 🧠 {title} [{src}]") - else: - text = str(item.get("content", "") or "") - if len(text) > 80: - text = text[:80] + "..." - lines.append(f"{idx}. {text} [{src}]") - - return "\n".join(lines) - - async def execute( - self, - *, - query: str, - top_k: int, - mix: bool, - mix_top_k: Optional[int], - time_from: Optional[str], - time_to: Optional[str], - search_runner: Optional[BranchRunner], - time_runner: Optional[BranchRunner], - episode_runner: Optional[BranchRunner], - ) -> Dict[str, Any]: - clean_query = str(query or "").strip() - safe_top_k = max(1, int(top_k)) - safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k)) - - branches: Dict[str, Dict[str, Any]] = {} - errors: List[Dict[str, str]] = [] - scheduled: List[Tuple[str, asyncio.Task]] = [] - - if clean_query: - if search_runner is not None: - scheduled.append(("search", asyncio.create_task(search_runner()))) - else: - branches["search"] = self._normalize_branch_payload( - "search", - {"success": False, "error": "search runner unavailable", "results": []}, - ) - else: - branches["search"] = self._normalize_branch_payload( - "search", - { - "success": False, - "skipped": True, - "skip_reason": "missing_query", - "results": [], - "count": 0, - }, - ) - - if time_from or time_to: - if time_runner is not None: - scheduled.append(("time", asyncio.create_task(time_runner()))) - else: - branches["time"] = self._normalize_branch_payload( - "time", - {"success": False, "error": "time runner unavailable", "results": []}, - ) - else: - branches["time"] = self._normalize_branch_payload( - "time", - { - "success": False, - "skipped": True, - "skip_reason": "missing_time_window", - "results": [], - "count": 0, - }, - ) - - if episode_runner is not None: - scheduled.append(("episode", asyncio.create_task(episode_runner()))) - else: - branches["episode"] = self._normalize_branch_payload( - "episode", - {"success": False, "error": "episode runner unavailable", "results": []}, - ) - - if scheduled: - done = await asyncio.gather( - *[task for _, task in scheduled], - return_exceptions=True, - ) - for (branch_name, _), payload in zip(scheduled, done): - if isinstance(payload, Exception): - logger.error(f"aggregate branch failed: branch={branch_name} error={payload}") - normalized = self._normalize_branch_payload( - branch_name, - { - "success": False, - "error": str(payload), - "results": [], - }, - ) - else: - normalized = self._normalize_branch_payload(branch_name, payload) - branches[branch_name] = normalized - - for name in ("search", "time", "episode"): - branch = branches.get(name) - if not branch: - continue - if branch.get("skipped", False): - continue - if not branch.get("success", False): - errors.append( - { - "branch": name, - "error": str(branch.get("error", "") or "unknown error"), - } - ) - - success = any( - bool(branches.get(name, {}).get("success", False)) - for name in ("search", "time", "episode") - ) - mixed_results: Optional[List[Dict[str, Any]]] = None - if mix: - mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k) - - payload: Dict[str, Any] = { - "success": success, - "query_type": "aggregate", - "query": clean_query, - "top_k": safe_top_k, - "mix": bool(mix), - "mix_top_k": safe_mix_top_k, - "branches": branches, - "errors": errors, - "summary": self._build_summary(branches), - } - if mixed_results is not None: - payload["mixed_results"] = mixed_results - - payload["content"] = self._build_content( - query=clean_query, - branches=branches, - errors=errors, - mixed_results=mixed_results, - ) - return payload diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py deleted file mode 100644 index 44b22854..00000000 --- a/plugins/A_memorix/core/utils/episode_retrieval_service.py +++ /dev/null @@ -1,182 +0,0 @@ -"""Episode hybrid retrieval service.""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from src.common.logger import get_logger - -from ..retrieval import DualPathRetriever, TemporalQueryOptions - -logger = get_logger("A_Memorix.EpisodeRetrievalService") - - -class EpisodeRetrievalService: - """Hybrid episode retrieval backed by lexical rows and evidence projection.""" - - _RRF_K = 60.0 - _BRANCH_WEIGHTS = { - "lexical": 1.0, - "paragraph_evidence": 1.0, - "relation_evidence": 0.85, - } - - def __init__( - self, - *, - metadata_store: Any, - retriever: Optional[DualPathRetriever] = None, - ) -> None: - self.metadata_store = metadata_store - self.retriever = retriever - - async def query( - self, - *, - query: str = "", - top_k: int = 5, - time_from: Optional[float] = None, - time_to: Optional[float] = None, - person: Optional[str] = None, - source: Optional[str] = None, - include_paragraphs: bool = False, - ) -> List[Dict[str, Any]]: - clean_query = str(query or "").strip() - safe_top_k = max(1, int(top_k)) - candidate_k = max(30, safe_top_k * 6) - - branches: Dict[str, List[Dict[str, Any]]] = { - "lexical": self.metadata_store.query_episodes( - query=clean_query, - time_from=time_from, - time_to=time_to, - person=person, - source=source, - limit=(candidate_k if clean_query else safe_top_k), - ) - } - - if clean_query and self.retriever is not None: - try: - temporal = TemporalQueryOptions( - time_from=time_from, - time_to=time_to, - person=person, - source=source, - ) - results = await self.retriever.retrieve( - query=clean_query, - top_k=candidate_k, - temporal=temporal, - ) - except Exception as exc: - logger.warning(f"episode evidence retrieval failed, fallback to lexical only: {exc}") - else: - paragraph_rank_map: Dict[str, int] = {} - relation_rank_map: Dict[str, int] = {} - for rank, item in enumerate(results, start=1): - hash_value = str(getattr(item, "hash_value", "") or "").strip() - result_type = str(getattr(item, "result_type", "") or "").strip().lower() - if not hash_value: - continue - if result_type == "paragraph" and hash_value not in paragraph_rank_map: - paragraph_rank_map[hash_value] = rank - elif result_type == "relation" and hash_value not in relation_rank_map: - relation_rank_map[hash_value] = rank - - if paragraph_rank_map: - paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes( - list(paragraph_rank_map.keys()), - source=source, - ) - if paragraph_rows: - branches["paragraph_evidence"] = self._rank_projected_rows( - paragraph_rows, - rank_map=paragraph_rank_map, - support_key="matched_paragraph_hashes", - ) - - if relation_rank_map: - relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes( - list(relation_rank_map.keys()), - source=source, - ) - if relation_rows: - branches["relation_evidence"] = self._rank_projected_rows( - relation_rows, - rank_map=relation_rank_map, - support_key="matched_relation_hashes", - ) - - fused = self._fuse_branches(branches, top_k=safe_top_k) - if include_paragraphs: - for item in fused: - item["paragraphs"] = self.metadata_store.get_episode_paragraphs( - episode_id=str(item.get("episode_id") or ""), - limit=50, - ) - return fused - - @staticmethod - def _rank_projected_rows( - rows: List[Dict[str, Any]], - *, - rank_map: Dict[str, int], - support_key: str, - ) -> List[Dict[str, Any]]: - sentinel = 10**9 - ranked = [dict(item) for item in rows] - - def _first_support_rank(item: Dict[str, Any]) -> int: - support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])] - ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map] - return min(ranks) if ranks else sentinel - - ranked.sort( - key=lambda item: ( - _first_support_rank(item), - -int(item.get("matched_paragraph_count") or 0), - -float(item.get("updated_at") or 0.0), - str(item.get("episode_id") or ""), - ) - ) - return ranked - - def _fuse_branches( - self, - branches: Dict[str, List[Dict[str, Any]]], - *, - top_k: int, - ) -> List[Dict[str, Any]]: - bucket: Dict[str, Dict[str, Any]] = {} - for branch_name, rows in branches.items(): - weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0) - if weight <= 0.0: - continue - for rank, row in enumerate(rows, start=1): - episode_id = str(row.get("episode_id", "") or "").strip() - if not episode_id: - continue - if episode_id not in bucket: - payload = dict(row) - payload.pop("matched_paragraph_hashes", None) - payload.pop("matched_relation_hashes", None) - payload.pop("matched_paragraph_count", None) - payload.pop("matched_relation_count", None) - payload["_fusion_score"] = 0.0 - bucket[episode_id] = payload - bucket[episode_id]["_fusion_score"] = float( - bucket[episode_id].get("_fusion_score", 0.0) - ) + weight / (self._RRF_K + float(rank)) - - out = list(bucket.values()) - out.sort( - key=lambda item: ( - -float(item.get("_fusion_score", 0.0)), - -float(item.get("updated_at") or 0.0), - str(item.get("episode_id") or ""), - ) - ) - for item in out: - item.pop("_fusion_score", None) - return out[: max(1, int(top_k))] diff --git a/plugins/A_memorix/core/utils/episode_segmentation_service.py b/plugins/A_memorix/core/utils/episode_segmentation_service.py deleted file mode 100644 index f42b1456..00000000 --- a/plugins/A_memorix/core/utils/episode_segmentation_service.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -Episode 语义切分服务(LLM 主路径)。 - -职责: -1. 组装语义切分提示词 -2. 调用 LLM 生成结构化 episode JSON -3. 严格校验输出结构,返回标准化结果 -""" - -from __future__ import annotations - -import json -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger -from src.config.model_configs import TaskConfig -from src.config.config import model_config as host_model_config -from src.services import llm_service as llm_api - -logger = get_logger("A_Memorix.EpisodeSegmentationService") - - -class EpisodeSegmentationService: - """基于 LLM 的 episode 语义切分服务。""" - - SEGMENTATION_VERSION = "episode_mvp_v1" - - def __init__(self, plugin_config: Optional[dict] = None): - self.plugin_config = plugin_config or {} - - def _cfg(self, key: str, default: Any = None) -> Any: - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _is_task_config(obj: Any) -> bool: - return hasattr(obj, "model_list") and bool(getattr(obj, "model_list", [])) - - def _build_single_model_task(self, model_name: str, template: TaskConfig) -> TaskConfig: - return TaskConfig( - model_list=[model_name], - max_tokens=template.max_tokens, - temperature=template.temperature, - slow_threshold=template.slow_threshold, - selection_strategy=template.selection_strategy, - ) - - def _pick_template_task(self, available_tasks: Dict[str, Any]) -> Optional[TaskConfig]: - preferred = ("utils", "replyer", "planner", "tool_use") - for task_name in preferred: - cfg = available_tasks.get(task_name) - if self._is_task_config(cfg): - return cfg - for task_name, cfg in available_tasks.items(): - if task_name != "embedding" and self._is_task_config(cfg): - return cfg - for cfg in available_tasks.values(): - if self._is_task_config(cfg): - return cfg - return None - - def _resolve_model_config(self) -> Tuple[Optional[Any], str]: - available_tasks = llm_api.get_available_models() or {} - if not available_tasks: - return None, "unavailable" - - selector = str(self._cfg("episode.segmentation_model", "auto") or "auto").strip() - model_dict = getattr(host_model_config, "models_dict", {}) or {} - - if selector and selector.lower() != "auto": - direct_task = available_tasks.get(selector) - if self._is_task_config(direct_task): - return direct_task, selector - - if selector in model_dict: - template = self._pick_template_task(available_tasks) - if template is not None: - return self._build_single_model_task(selector, template), selector - - logger.warning(f"episode.segmentation_model='{selector}' 不可用,回退 auto") - - for task_name in ("utils", "replyer", "planner", "tool_use"): - cfg = available_tasks.get(task_name) - if self._is_task_config(cfg): - return cfg, task_name - - fallback = self._pick_template_task(available_tasks) - if fallback is not None: - return fallback, "auto" - return None, "unavailable" - - @staticmethod - def _clamp_score(value: Any, default: float = 0.0) -> float: - try: - num = float(value) - except Exception: - num = default - if num < 0.0: - return 0.0 - if num > 1.0: - return 1.0 - return num - - @staticmethod - def _safe_json_loads(text: str) -> Dict[str, Any]: - raw = str(text or "").strip() - if not raw: - raise ValueError("empty_response") - - if "```" in raw: - raw = raw.replace("```json", "```").replace("```JSON", "```") - parts = raw.split("```") - for part in parts: - part = part.strip() - if part.startswith("{") and part.endswith("}"): - raw = part - break - - try: - data = json.loads(raw) - if isinstance(data, dict): - return data - except Exception: - pass - - start = raw.find("{") - end = raw.rfind("}") - if start >= 0 and end > start: - candidate = raw[start : end + 1] - data = json.loads(candidate) - if isinstance(data, dict): - return data - - raise ValueError("invalid_json_response") - - def _build_prompt( - self, - *, - source: str, - window_start: Optional[float], - window_end: Optional[float], - paragraphs: List[Dict[str, Any]], - ) -> str: - rows: List[str] = [] - for idx, item in enumerate(paragraphs, 1): - p_hash = str(item.get("hash", "") or "").strip() - content = str(item.get("content", "") or "").strip().replace("\r\n", "\n") - content = content[:800] - event_start = item.get("event_time_start") - event_end = item.get("event_time_end") - event_time = item.get("event_time") - rows.append( - ( - f"[{idx}] hash={p_hash}\n" - f"event_time={event_time}\n" - f"event_time_start={event_start}\n" - f"event_time_end={event_end}\n" - f"content={content}" - ) - ) - - source_text = str(source or "").strip() or "unknown" - return ( - "You are an episode segmentation engine.\n" - "Group the given paragraphs into one or more coherent episodes.\n" - "Return JSON ONLY. No markdown, no explanation.\n" - "\n" - "Hard JSON schema:\n" - "{\n" - ' "episodes": [\n' - " {\n" - ' "title": "string",\n' - ' "summary": "string",\n' - ' "paragraph_hashes": ["hash1", "hash2"],\n' - ' "participants": ["person1", "person2"],\n' - ' "keywords": ["kw1", "kw2"],\n' - ' "time_confidence": 0.0,\n' - ' "llm_confidence": 0.0\n' - " }\n" - " ]\n" - "}\n" - "\n" - "Rules:\n" - "1) paragraph_hashes must come from input only.\n" - "2) title and summary must be non-empty.\n" - "3) keep participants/keywords concise and deduplicated.\n" - "4) if uncertain, still provide best effort confidence values.\n" - "\n" - f"source={source_text}\n" - f"window_start={window_start}\n" - f"window_end={window_end}\n" - "paragraphs:\n" - + "\n\n".join(rows) - ) - - def _normalize_episodes( - self, - *, - payload: Dict[str, Any], - input_hashes: List[str], - ) -> List[Dict[str, Any]]: - raw_episodes = payload.get("episodes") - if not isinstance(raw_episodes, list): - raise ValueError("episodes_missing_or_not_list") - - valid_hashes = set(input_hashes) - normalized: List[Dict[str, Any]] = [] - for item in raw_episodes: - if not isinstance(item, dict): - continue - - title = str(item.get("title", "") or "").strip() - summary = str(item.get("summary", "") or "").strip() - if not title or not summary: - continue - - raw_hashes = item.get("paragraph_hashes") - if not isinstance(raw_hashes, list): - continue - - dedup_hashes: List[str] = [] - seen_hashes = set() - for h in raw_hashes: - token = str(h or "").strip() - if not token or token in seen_hashes or token not in valid_hashes: - continue - seen_hashes.add(token) - dedup_hashes.append(token) - - if not dedup_hashes: - continue - - participants = [] - for p in item.get("participants", []) or []: - token = str(p or "").strip() - if token: - participants.append(token) - - keywords = [] - for kw in item.get("keywords", []) or []: - token = str(kw or "").strip() - if token: - keywords.append(token) - - normalized.append( - { - "title": title, - "summary": summary, - "paragraph_hashes": dedup_hashes, - "participants": participants[:16], - "keywords": keywords[:20], - "time_confidence": self._clamp_score(item.get("time_confidence"), default=1.0), - "llm_confidence": self._clamp_score(item.get("llm_confidence"), default=0.5), - } - ) - - if not normalized: - raise ValueError("episodes_all_invalid") - return normalized - - async def segment( - self, - *, - source: str, - window_start: Optional[float], - window_end: Optional[float], - paragraphs: List[Dict[str, Any]], - ) -> Dict[str, Any]: - if not paragraphs: - raise ValueError("paragraphs_empty") - - model_config, model_label = self._resolve_model_config() - if model_config is None: - raise RuntimeError("episode segmentation model unavailable") - - prompt = self._build_prompt( - source=source, - window_start=window_start, - window_end=window_end, - paragraphs=paragraphs, - ) - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="A_Memorix.EpisodeSegmentation", - ) - if not success or not response: - raise RuntimeError("llm_generate_failed") - - payload = self._safe_json_loads(str(response)) - input_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs] - episodes = self._normalize_episodes(payload=payload, input_hashes=input_hashes) - - return { - "episodes": episodes, - "segmentation_model": model_label, - "segmentation_version": self.SEGMENTATION_VERSION, - } - diff --git a/plugins/A_memorix/core/utils/episode_service.py b/plugins/A_memorix/core/utils/episode_service.py deleted file mode 100644 index ca94dd96..00000000 --- a/plugins/A_memorix/core/utils/episode_service.py +++ /dev/null @@ -1,558 +0,0 @@ -""" -Episode 聚合与落库服务。 - -流程: -1. 从 pending 队列读取段落并组批 -2. 按 source + 时间窗口切组 -3. 调用 LLM 语义切分 -4. 写入 episodes + episode_paragraphs -5. LLM 失败时使用确定性 fallback -""" - -from __future__ import annotations - -import json -import re -from collections import Counter -from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from .episode_segmentation_service import EpisodeSegmentationService -from .hash import compute_hash - -logger = get_logger("A_Memorix.EpisodeService") - - -class EpisodeService: - """Episode MVP 后台处理服务。""" - - def __init__( - self, - *, - metadata_store: Any, - plugin_config: Optional[Any] = None, - segmentation_service: Optional[EpisodeSegmentationService] = None, - ): - self.metadata_store = metadata_store - self.plugin_config = plugin_config or {} - self.segmentation_service = segmentation_service or EpisodeSegmentationService( - plugin_config=self._config_dict(), - ) - - def _config_dict(self) -> Dict[str, Any]: - if isinstance(self.plugin_config, dict): - return self.plugin_config - return {} - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin_config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - @staticmethod - def _to_optional_float(value: Any) -> Optional[float]: - if value is None: - return None - try: - return float(value) - except Exception: - return None - - @staticmethod - def _clamp_score(value: Any, default: float = 1.0) -> float: - try: - num = float(value) - except Exception: - num = default - if num < 0.0: - return 0.0 - if num > 1.0: - return 1.0 - return num - - @staticmethod - def _paragraph_anchor(paragraph: Dict[str, Any]) -> float: - for key in ("event_time_end", "event_time_start", "event_time", "created_at"): - value = paragraph.get(key) - try: - if value is not None: - return float(value) - except Exception: - continue - return 0.0 - - @staticmethod - def _paragraph_sort_key(paragraph: Dict[str, Any]) -> Tuple[float, str]: - return ( - EpisodeService._paragraph_anchor(paragraph), - str(paragraph.get("hash", "") or ""), - ) - - def load_pending_paragraphs( - self, - pending_rows: List[Dict[str, Any]], - ) -> Tuple[List[Dict[str, Any]], List[str]]: - """ - 将 pending 行展开为段落上下文。 - - Returns: - (loaded_paragraphs, missing_hashes) - """ - loaded: List[Dict[str, Any]] = [] - missing: List[str] = [] - for row in pending_rows or []: - p_hash = str(row.get("paragraph_hash", "") or "").strip() - if not p_hash: - continue - - paragraph = self.metadata_store.get_paragraph(p_hash) - if not paragraph: - missing.append(p_hash) - continue - - loaded.append( - { - "hash": p_hash, - "source": str(row.get("source") or paragraph.get("source") or "").strip(), - "content": str(paragraph.get("content", "") or ""), - "created_at": self._to_optional_float(paragraph.get("created_at")) - or self._to_optional_float(row.get("created_at")) - or 0.0, - "event_time": self._to_optional_float(paragraph.get("event_time")), - "event_time_start": self._to_optional_float(paragraph.get("event_time_start")), - "event_time_end": self._to_optional_float(paragraph.get("event_time_end")), - "time_granularity": str(paragraph.get("time_granularity", "") or "").strip() or None, - "time_confidence": self._clamp_score(paragraph.get("time_confidence"), default=1.0), - } - ) - return loaded, missing - - def group_paragraphs(self, paragraphs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - 按 source + 时间邻近窗口组批,并受段落数/字符数上限约束。 - """ - if not paragraphs: - return [] - - max_paragraphs = max(1, int(self._cfg("episode.max_paragraphs_per_call", 20))) - max_chars = max(200, int(self._cfg("episode.max_chars_per_call", 6000))) - window_seconds = max( - 60.0, - float(self._cfg("episode.source_time_window_hours", 24)) * 3600.0, - ) - - by_source: Dict[str, List[Dict[str, Any]]] = {} - for paragraph in paragraphs: - source = str(paragraph.get("source", "") or "").strip() - by_source.setdefault(source, []).append(paragraph) - - groups: List[Dict[str, Any]] = [] - for source, items in by_source.items(): - ordered = sorted(items, key=self._paragraph_sort_key) - - current: List[Dict[str, Any]] = [] - current_chars = 0 - last_anchor: Optional[float] = None - - def flush() -> None: - nonlocal current, current_chars, last_anchor - if not current: - return - sorted_current = sorted(current, key=self._paragraph_sort_key) - groups.append( - { - "source": source, - "paragraphs": sorted_current, - } - ) - current = [] - current_chars = 0 - last_anchor = None - - for paragraph in ordered: - anchor = self._paragraph_anchor(paragraph) - content_len = len(str(paragraph.get("content", "") or "")) - - need_flush = False - if current: - if len(current) >= max_paragraphs: - need_flush = True - elif current_chars + content_len > max_chars: - need_flush = True - elif last_anchor is not None and abs(anchor - last_anchor) > window_seconds: - need_flush = True - - if need_flush: - flush() - - current.append(paragraph) - current_chars += content_len - last_anchor = anchor - - flush() - - groups.sort( - key=lambda g: self._paragraph_anchor(g["paragraphs"][0]) if g.get("paragraphs") else 0.0 - ) - return groups - - def _compute_time_meta(self, paragraphs: List[Dict[str, Any]]) -> Tuple[Optional[float], Optional[float], Optional[str], float]: - starts: List[float] = [] - ends: List[float] = [] - granularity_priority = { - "minute": 4, - "hour": 3, - "day": 2, - "month": 1, - "year": 0, - } - granularity = None - granularity_rank = -1 - conf_values: List[float] = [] - - for p in paragraphs: - s = self._to_optional_float(p.get("event_time_start")) - e = self._to_optional_float(p.get("event_time_end")) - t = self._to_optional_float(p.get("event_time")) - c = self._to_optional_float(p.get("created_at")) - - start_candidate = s if s is not None else (t if t is not None else (e if e is not None else c)) - end_candidate = e if e is not None else (t if t is not None else (s if s is not None else c)) - - if start_candidate is not None: - starts.append(start_candidate) - if end_candidate is not None: - ends.append(end_candidate) - - g = str(p.get("time_granularity", "") or "").strip().lower() - if g in granularity_priority and granularity_priority[g] > granularity_rank: - granularity_rank = granularity_priority[g] - granularity = g - - conf_values.append(self._clamp_score(p.get("time_confidence"), default=1.0)) - - time_start = min(starts) if starts else None - time_end = max(ends) if ends else None - time_conf = sum(conf_values) / len(conf_values) if conf_values else 1.0 - return time_start, time_end, granularity, self._clamp_score(time_conf, default=1.0) - - def _collect_participants(self, paragraph_hashes: List[str], limit: int = 16) -> List[str]: - seen = set() - participants: List[str] = [] - for p_hash in paragraph_hashes: - try: - entities = self.metadata_store.get_paragraph_entities(p_hash) - except Exception: - entities = [] - for item in entities: - name = str(item.get("name", "") or "").strip() - if not name: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - participants.append(name) - if len(participants) >= limit: - return participants - return participants - - @staticmethod - def _derive_keywords(paragraphs: List[Dict[str, Any]], limit: int = 12) -> List[str]: - token_counter: Counter[str] = Counter() - token_pattern = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}") - stop_words = { - "the", - "and", - "that", - "this", - "with", - "from", - "for", - "have", - "will", - "your", - "you", - "我们", - "你们", - "他们", - "以及", - "一个", - "这个", - "那个", - "然后", - "因为", - "所以", - } - for p in paragraphs: - text = str(p.get("content", "") or "").lower() - for token in token_pattern.findall(text): - if token in stop_words: - continue - token_counter[token] += 1 - - return [token for token, _ in token_counter.most_common(limit)] - - def _build_fallback_episode(self, group: Dict[str, Any]) -> Dict[str, Any]: - paragraphs = group.get("paragraphs", []) or [] - source = str(group.get("source", "") or "").strip() - hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] - snippets = [] - for p in paragraphs[:3]: - text = str(p.get("content", "") or "").strip().replace("\n", " ") - if text: - snippets.append(text[:140]) - summary = ";".join(snippets)[:500] if snippets else "自动回退生成的情景记忆。" - - time_start, time_end, granularity, time_conf = self._compute_time_meta(paragraphs) - participants = self._collect_participants(hashes, limit=12) - keywords = self._derive_keywords(paragraphs, limit=10) - - if time_start is not None: - day_text = datetime.fromtimestamp(time_start).strftime("%Y-%m-%d") - title = f"{source or 'unknown'} {day_text} 情景片段" - else: - title = f"{source or 'unknown'} 情景片段" - - return { - "title": title[:80], - "summary": summary, - "paragraph_hashes": hashes, - "participants": participants, - "keywords": keywords, - "time_confidence": time_conf, - "llm_confidence": 0.0, - "event_time_start": time_start, - "event_time_end": time_end, - "time_granularity": granularity, - "segmentation_model": "fallback_rule", - "segmentation_version": EpisodeSegmentationService.SEGMENTATION_VERSION, - } - - @staticmethod - def _normalize_episode_hashes(episode_hashes: List[str], group_hashes_ordered: List[str]) -> List[str]: - in_group = set(group_hashes_ordered) - dedup: List[str] = [] - seen = set() - for h in episode_hashes or []: - token = str(h or "").strip() - if not token or token not in in_group or token in seen: - continue - seen.add(token) - dedup.append(token) - return dedup - - async def _build_episode_payloads_for_group(self, group: Dict[str, Any]) -> Dict[str, Any]: - paragraphs = group.get("paragraphs", []) or [] - if not paragraphs: - return { - "payloads": [], - "done_hashes": [], - "episode_count": 0, - "fallback_count": 0, - } - - source = str(group.get("source", "") or "").strip() - group_hashes = [str(p.get("hash", "") or "").strip() for p in paragraphs if str(p.get("hash", "") or "").strip()] - group_start, group_end, _, _ = self._compute_time_meta(paragraphs) - - fallback_used = False - segmentation_model = "fallback_rule" - segmentation_version = EpisodeSegmentationService.SEGMENTATION_VERSION - - try: - llm_result = await self.segmentation_service.segment( - source=source, - window_start=group_start, - window_end=group_end, - paragraphs=paragraphs, - ) - episodes = list(llm_result.get("episodes") or []) - segmentation_model = str(llm_result.get("segmentation_model", "") or "").strip() or "auto" - segmentation_version = str(llm_result.get("segmentation_version", "") or "").strip() or EpisodeSegmentationService.SEGMENTATION_VERSION - if not episodes: - raise ValueError("llm_empty_episodes") - except Exception as e: - logger.warning( - "Episode segmentation fallback: " - f"source={source} " - f"size={len(group_hashes)} " - f"err={e}" - ) - episodes = [self._build_fallback_episode(group)] - fallback_used = True - - stored_payloads: List[Dict[str, Any]] = [] - for episode in episodes: - ordered_hashes = self._normalize_episode_hashes( - episode_hashes=episode.get("paragraph_hashes", []), - group_hashes_ordered=group_hashes, - ) - if not ordered_hashes: - continue - - sub_paragraphs = [p for p in paragraphs if str(p.get("hash", "") or "") in set(ordered_hashes)] - event_start, event_end, granularity, time_conf_default = self._compute_time_meta(sub_paragraphs) - - participants = [str(x).strip() for x in (episode.get("participants", []) or []) if str(x).strip()] - keywords = [str(x).strip() for x in (episode.get("keywords", []) or []) if str(x).strip()] - if not participants: - participants = self._collect_participants(ordered_hashes, limit=16) - if not keywords: - keywords = self._derive_keywords(sub_paragraphs, limit=12) - - title = str(episode.get("title", "") or "").strip()[:120] - summary = str(episode.get("summary", "") or "").strip()[:2000] - if not title or not summary: - continue - - seed = json.dumps( - { - "source": source, - "hashes": ordered_hashes, - "version": segmentation_version, - }, - ensure_ascii=False, - sort_keys=True, - ) - episode_id = compute_hash(seed) - - payload = { - "episode_id": episode_id, - "source": source or None, - "title": title, - "summary": summary, - "event_time_start": episode.get("event_time_start", event_start), - "event_time_end": episode.get("event_time_end", event_end), - "time_granularity": episode.get("time_granularity", granularity), - "time_confidence": self._clamp_score( - episode.get("time_confidence"), - default=time_conf_default, - ), - "participants": participants[:16], - "keywords": keywords[:20], - "evidence_ids": ordered_hashes, - "paragraph_count": len(ordered_hashes), - "llm_confidence": self._clamp_score( - episode.get("llm_confidence"), - default=0.0 if fallback_used else 0.6, - ), - "segmentation_model": ( - str(episode.get("segmentation_model", "") or "").strip() - or ("fallback_rule" if fallback_used else segmentation_model) - ), - "segmentation_version": ( - str(episode.get("segmentation_version", "") or "").strip() - or segmentation_version - ), - } - stored_payloads.append(payload) - - return { - "payloads": stored_payloads, - "done_hashes": group_hashes, - "episode_count": len(stored_payloads), - "fallback_count": 1 if fallback_used else 0, - } - - async def process_group(self, group: Dict[str, Any]) -> Dict[str, Any]: - result = await self._build_episode_payloads_for_group(group) - stored_count = 0 - for payload in result.get("payloads") or []: - stored = self.metadata_store.upsert_episode(payload) - final_id = str(stored.get("episode_id") or payload.get("episode_id") or "") - if final_id: - self.metadata_store.bind_episode_paragraphs( - final_id, - list(payload.get("evidence_ids") or []), - ) - stored_count += 1 - - result["episode_count"] = stored_count - return { - "done_hashes": list(result.get("done_hashes") or []), - "episode_count": stored_count, - "fallback_count": int(result.get("fallback_count") or 0), - } - - async def process_pending_rows(self, pending_rows: List[Dict[str, Any]]) -> Dict[str, Any]: - loaded, missing_hashes = self.load_pending_paragraphs(pending_rows) - groups = self.group_paragraphs(loaded) - - done_hashes: List[str] = list(missing_hashes) - failed_hashes: Dict[str, str] = {} - episode_count = 0 - fallback_count = 0 - - for group in groups: - group_hashes = [str(p.get("hash", "") or "").strip() for p in (group.get("paragraphs") or [])] - try: - result = await self.process_group(group) - done_hashes.extend(result.get("done_hashes") or []) - episode_count += int(result.get("episode_count") or 0) - fallback_count += int(result.get("fallback_count") or 0) - except Exception as e: - err = str(e)[:500] - for h in group_hashes: - if h: - failed_hashes[h] = err - - dedup_done = list(dict.fromkeys([h for h in done_hashes if h])) - return { - "done_hashes": dedup_done, - "failed_hashes": failed_hashes, - "episode_count": episode_count, - "fallback_count": fallback_count, - "missing_count": len(missing_hashes), - "group_count": len(groups), - } - - async def rebuild_source(self, source: str) -> Dict[str, Any]: - token = str(source or "").strip() - if not token: - return { - "source": "", - "episode_count": 0, - "fallback_count": 0, - "group_count": 0, - "paragraph_count": 0, - } - - paragraphs = self.metadata_store.get_live_paragraphs_by_source(token) - if not paragraphs: - replace_result = self.metadata_store.replace_episodes_for_source(token, []) - return { - "source": token, - "episode_count": int(replace_result.get("episode_count") or 0), - "fallback_count": 0, - "group_count": 0, - "paragraph_count": 0, - } - - groups = self.group_paragraphs(paragraphs) - payloads: List[Dict[str, Any]] = [] - fallback_count = 0 - - for group in groups: - result = await self._build_episode_payloads_for_group(group) - payloads.extend(list(result.get("payloads") or [])) - fallback_count += int(result.get("fallback_count") or 0) - - replace_result = self.metadata_store.replace_episodes_for_source(token, payloads) - return { - "source": token, - "episode_count": int(replace_result.get("episode_count") or 0), - "fallback_count": fallback_count, - "group_count": len(groups), - "paragraph_count": len(paragraphs), - } diff --git a/plugins/A_memorix/core/utils/hash.py b/plugins/A_memorix/core/utils/hash.py deleted file mode 100644 index b6363257..00000000 --- a/plugins/A_memorix/core/utils/hash.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -哈希工具模块 - -提供文本哈希计算功能,用于唯一标识和去重。 -""" - -import hashlib -import re -from typing import Union - - -def compute_hash(text: str, hash_type: str = "sha256") -> str: - """ - 计算文本的哈希值 - - Args: - text: 输入文本 - hash_type: 哈希算法类型(sha256, md5等) - - Returns: - 哈希值字符串 - """ - if hash_type == "sha256": - return hashlib.sha256(text.encode("utf-8")).hexdigest() - elif hash_type == "md5": - return hashlib.md5(text.encode("utf-8")).hexdigest() - else: - raise ValueError(f"不支持的哈希算法: {hash_type}") - - -def normalize_text(text: str) -> str: - """ - 规范化文本用于哈希计算 - - 执行以下操作: - - 去除首尾空白 - - 统一换行符为\\n - - 压缩多个连续空格 - - 去除不可见字符(保留换行和制表符) - - Args: - text: 输入文本 - - Returns: - 规范化后的文本 - """ - # 去除首尾空白 - text = text.strip() - - # 统一换行符 - text = text.replace("\r\n", "\n").replace("\r", "\n") - - # 压缩多个连续空格为一个(但保留换行和制表符) - text = re.sub(r"[^\S\n]+", " ", text) - - return text - - -def compute_paragraph_hash(paragraph: str) -> str: - """ - 计算段落的哈希值 - - Args: - paragraph: 段落文本 - - Returns: - 段落哈希值(用于paragraph-前缀) - """ - normalized = normalize_text(paragraph) - return compute_hash(normalized) - - -def compute_entity_hash(entity: str) -> str: - """ - 计算实体的哈希值 - - Args: - entity: 实体名称 - - Returns: - 实体哈希值(用于entity-前缀) - """ - normalized = entity.strip().lower() - return compute_hash(normalized) - - -def compute_relation_hash(relation: tuple) -> str: - """ - 计算关系的哈希值 - - Args: - relation: 关系元组 (subject, predicate, object) - - Returns: - 关系哈希值(用于relation-前缀) - """ - # 将关系元组转为字符串 - relation_str = str(tuple(relation)) - return compute_hash(relation_str) - - -def format_hash_key(hash_type: str, hash_value: str) -> str: - """ - 格式化哈希键 - - Args: - hash_type: 类型前缀(paragraph, entity, relation) - hash_value: 哈希值 - - Returns: - 格式化的键(如 paragraph-abc123...) - """ - return f"{hash_type}-{hash_value}" - - -def parse_hash_key(key: str) -> tuple[str, str]: - """ - 解析哈希键 - - Args: - key: 格式化的键(如 paragraph-abc123...) - - Returns: - (类型, 哈希值) 元组 - """ - parts = key.split("-", 1) - if len(parts) != 2: - raise ValueError(f"无效的哈希键格式: {key}") - return parts[0], parts[1] diff --git a/plugins/A_memorix/core/utils/import_payloads.py b/plugins/A_memorix/core/utils/import_payloads.py deleted file mode 100644 index 6986a4c1..00000000 --- a/plugins/A_memorix/core/utils/import_payloads.py +++ /dev/null @@ -1,110 +0,0 @@ -"""Shared import payload normalization helpers.""" - -from __future__ import annotations - -from typing import Any, Dict, List, Optional - -from ..storage import KnowledgeType, resolve_stored_knowledge_type -from .time_parser import normalize_time_meta - - -def _normalize_entities(raw_entities: Any) -> List[str]: - if not isinstance(raw_entities, list): - return [] - out: List[str] = [] - seen = set() - for item in raw_entities: - name = str(item or "").strip() - if not name: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - out.append(name) - return out - - -def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]: - if not isinstance(raw_relations, list): - return [] - out: List[Dict[str, str]] = [] - for item in raw_relations: - if not isinstance(item, dict): - continue - subject = str(item.get("subject", "")).strip() - predicate = str(item.get("predicate", "")).strip() - obj = str(item.get("object", "")).strip() - if not (subject and predicate and obj): - continue - out.append( - { - "subject": subject, - "predicate": predicate, - "object": obj, - } - ) - return out - - -def normalize_paragraph_import_item( - item: Any, - *, - default_source: str, -) -> Dict[str, Any]: - """Normalize one paragraph import item from text/json payloads.""" - - if isinstance(item, str): - content = str(item) - knowledge_type = resolve_stored_knowledge_type(None, content=content) - return { - "content": content, - "knowledge_type": knowledge_type.value, - "source": str(default_source or "").strip(), - "time_meta": None, - "entities": [], - "relations": [], - } - - if not isinstance(item, dict) or "content" not in item: - raise ValueError("段落项必须为字符串或包含 content 的对象") - - content = str(item.get("content", "") or "") - if not content.strip(): - raise ValueError("段落 content 不能为空") - - raw_time_meta = { - "event_time": item.get("event_time"), - "event_time_start": item.get("event_time_start"), - "event_time_end": item.get("event_time_end"), - "time_range": item.get("time_range"), - "time_granularity": item.get("time_granularity"), - "time_confidence": item.get("time_confidence"), - } - time_meta_field = item.get("time_meta") - if isinstance(time_meta_field, dict): - raw_time_meta.update(time_meta_field) - - knowledge_type_raw = item.get("knowledge_type") - if knowledge_type_raw is None: - knowledge_type_raw = item.get("type") - knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content) - source = str(item.get("source") or default_source or "").strip() - if not source: - source = str(default_source or "").strip() - - normalized_time_meta = normalize_time_meta(raw_time_meta) - return { - "content": content, - "knowledge_type": knowledge_type.value, - "source": source, - "time_meta": normalized_time_meta if normalized_time_meta else None, - "entities": _normalize_entities(item.get("entities")), - "relations": _normalize_relations(item.get("relations")), - } - - -def normalize_summary_knowledge_type(value: Any) -> KnowledgeType: - """Normalize config-driven summary knowledge type.""" - - return resolve_stored_knowledge_type(value, content="") diff --git a/plugins/A_memorix/core/utils/io.py b/plugins/A_memorix/core/utils/io.py deleted file mode 100644 index ed14df43..00000000 --- a/plugins/A_memorix/core/utils/io.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -IO Utilities - -提供原子文件写入等IO辅助功能。 -""" - -import os -import shutil -import contextlib -from pathlib import Path -from typing import Union - -@contextlib.contextmanager -def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs): - """ - 原子文件写入上下文管理器 - - 原理: - 1. 写入 .tmp 临时文件 - 2. 写入成功后,使用 os.replace 原子替换目标文件 - 3. 如果失败,自动删除临时文件 - - Args: - file_path: 目标文件路径 - mode: 打开模式 ('w', 'wb' 等) - encoding: 编码 - **kwargs: 传给 open() 的其他参数 - """ - path = Path(file_path) - # 确保父目录存在 - path.parent.mkdir(parents=True, exist_ok=True) - - # 临时文件路径 - tmp_path = path.with_suffix(path.suffix + ".tmp") - - try: - with open(tmp_path, mode, encoding=encoding, **kwargs) as f: - yield f - - # 确保写入磁盘 - f.flush() - os.fsync(f.fileno()) - - # 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子) - # 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败 - os.replace(tmp_path, path) - - except Exception as e: - # 清理临时文件 - if tmp_path.exists(): - try: - os.remove(tmp_path) - except: - pass - raise e - -@contextlib.contextmanager -def atomic_save_path(file_path: Union[str, Path]): - """ - 提供临时路径用于原子保存 (针对只接受路径的API,如Faiss) - - Args: - file_path: 最终目标路径 - - Yields: - tmp_path: 临时文件路径 (str) - """ - path = Path(file_path) - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(path.suffix + ".tmp") - - try: - yield str(tmp_path) - - if Path(tmp_path).exists(): - os.replace(tmp_path, path) - - except Exception as e: - if Path(tmp_path).exists(): - try: - os.remove(tmp_path) - except: - pass - raise e diff --git a/plugins/A_memorix/core/utils/matcher.py b/plugins/A_memorix/core/utils/matcher.py deleted file mode 100644 index bddff5ee..00000000 --- a/plugins/A_memorix/core/utils/matcher.py +++ /dev/null @@ -1,89 +0,0 @@ -""" -高效文本匹配工具模块 - -实现 Aho-Corasick 算法用于多模式匹配。 -""" - -from typing import List, Dict, Tuple, Set, Any -from collections import deque - - -class AhoCorasick: - """ - Aho-Corasick 自动机实现高效多模式匹配 - """ - - def __init__(self): - # next_states[state][char] = next_state - self.next_states: List[Dict[str, int]] = [{}] - # fail[state] = fail_state - self.fail: List[int] = [0] - # output[state] = set of patterns ending at this state - self.output: List[Set[str]] = [set()] - self.patterns: Set[str] = set() - - def add_pattern(self, pattern: str): - """添加模式""" - if not pattern: - return - self.patterns.add(pattern) - state = 0 - for char in pattern: - if char not in self.next_states[state]: - new_state = len(self.next_states) - self.next_states[state][char] = new_state - self.next_states.append({}) - self.fail.append(0) - self.output.append(set()) - state = self.next_states[state][char] - self.output[state].add(pattern) - - def build(self): - """构建失败指针""" - queue = deque() - # 处理第一层 - for char, state in self.next_states[0].items(): - queue.append(state) - self.fail[state] = 0 - - while queue: - r = queue.popleft() - for char, s in self.next_states[r].items(): - queue.append(s) - # 找到失败路径 - state = self.fail[r] - while char not in self.next_states[state] and state != 0: - state = self.fail[state] - self.fail[s] = self.next_states[state].get(char, 0) - # 合并输出 - self.output[s].update(self.output[self.fail[s]]) - - def search(self, text: str) -> List[Tuple[int, str]]: - """ - 在文本中搜索所有模式 - - Returns: - [(结束索引, 匹配到的模式), ...] - """ - state = 0 - results = [] - for i, char in enumerate(text): - while char not in self.next_states[state] and state != 0: - state = self.fail[state] - state = self.next_states[state].get(char, 0) - for pattern in self.output[state]: - results.append((i, pattern)) - return results - - def find_all(self, text: str) -> Dict[str, int]: - """ - 查找并统计所有模式出现次数 - - Returns: - {模式: 出现次数} - """ - results = self.search(text) - stats = {} - for _, pattern in results: - stats[pattern] = stats.get(pattern, 0) + 1 - return stats diff --git a/plugins/A_memorix/core/utils/monitor.py b/plugins/A_memorix/core/utils/monitor.py deleted file mode 100644 index 39c794ab..00000000 --- a/plugins/A_memorix/core/utils/monitor.py +++ /dev/null @@ -1,189 +0,0 @@ -""" -内存监控模块 - -提供内存使用监控和预警功能。 -""" - -import gc -import threading -import time -from typing import Callable, Optional - -try: - import psutil - HAS_PSUTIL = True -except ImportError: - HAS_PSUTIL = False - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.MemoryMonitor") - - -class MemoryMonitor: - """ - 内存监控器 - - 功能: - - 实时监控内存使用 - - 超过阈值时触发警告 - - 支持自动垃圾回收 - """ - - def __init__( - self, - max_memory_mb: int, - warning_threshold: float = 0.9, - check_interval: float = 10.0, - enable_auto_gc: bool = True, - ): - """ - 初始化内存监控器 - - Args: - max_memory_mb: 最大内存限制(MB) - warning_threshold: 警告阈值(0-1之间,默认0.9表示90%) - check_interval: 检查间隔(秒) - enable_auto_gc: 是否启用自动垃圾回收 - """ - self.max_memory_mb = max_memory_mb - self.warning_threshold = warning_threshold - self.check_interval = check_interval - self.enable_auto_gc = enable_auto_gc - - self._running = False - self._thread: Optional[threading.Thread] = None - self._callbacks: list[Callable[[float, float], None]] = [] - - def start(self): - """启动监控""" - if self._running: - logger.warning("内存监控已在运行") - return - - if not HAS_PSUTIL: - logger.warning("psutil 未安装,内存监控功能不可用") - return - - self._running = True - self._thread = threading.Thread(target=self._monitor_loop, daemon=True) - self._thread.start() - logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)") - - def stop(self): - """停止监控""" - if not self._running: - return - - self._running = False - if self._thread: - self._thread.join(timeout=5.0) - logger.info("内存监控已停止") - - def register_callback(self, callback: Callable[[float, float], None]): - """ - 注册内存超限回调函数 - - Args: - callback: 回调函数,接收 (当前使用MB, 限制MB) 参数 - """ - self._callbacks.append(callback) - - def get_current_memory_mb(self) -> float: - """ - 获取当前进程内存使用量(MB) - - Returns: - 内存使用量(MB) - """ - if not HAS_PSUTIL: - # 降级方案:使用内置函数 - import sys - return sys.getsizeof(gc.get_objects()) / 1024 / 1024 - - process = psutil.Process() - return process.memory_info().rss / 1024 / 1024 - - def get_memory_usage_ratio(self) -> float: - """ - 获取内存使用率 - - Returns: - 使用率(0-1之间) - """ - current = self.get_current_memory_mb() - return current / self.max_memory_mb if self.max_memory_mb > 0 else 0 - - def _monitor_loop(self): - """监控循环""" - while self._running: - try: - current_mb = self.get_current_memory_mb() - ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0 - - # 检查是否超过阈值 - if ratio >= self.warning_threshold: - logger.warning( - f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB " - f"({ratio*100:.1f}%)" - ) - - # 触发回调 - for callback in self._callbacks: - try: - callback(current_mb, self.max_memory_mb) - except Exception as e: - logger.error(f"内存回调执行失败: {e}") - - # 自动垃圾回收 - if self.enable_auto_gc: - before = self.get_current_memory_mb() - gc.collect() - after = self.get_current_memory_mb() - freed = before - after - if freed > 1: # 释放超过1MB才记录 - logger.info(f"垃圾回收释放: {freed:.1f}MB") - - # 定期垃圾回收(即使未超限) - elif self.enable_auto_gc and int(time.time()) % 60 == 0: - gc.collect() - - except Exception as e: - logger.error(f"内存监控出错: {e}") - - # 等待下次检查 - time.sleep(self.check_interval) - - def __enter__(self): - """上下文管理器入口""" - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """上下文管理器出口""" - self.stop() - - -def get_memory_info() -> dict: - """ - 获取系统内存信息 - - Returns: - 内存信息字典 - """ - if not HAS_PSUTIL: - return {"error": "psutil 未安装"} - - try: - mem = psutil.virtual_memory() - process = psutil.Process() - - return { - "system_total_gb": mem.total / 1024 / 1024 / 1024, - "system_available_gb": mem.available / 1024 / 1024 / 1024, - "system_usage_percent": mem.percent, - "process_mb": process.memory_info().rss / 1024 / 1024, - "process_percent": (process.memory_info().rss / mem.total) * 100, - } - except Exception as e: - return {"error": str(e)} diff --git a/plugins/A_memorix/core/utils/path_fallback_service.py b/plugins/A_memorix/core/utils/path_fallback_service.py deleted file mode 100644 index 7a802743..00000000 --- a/plugins/A_memorix/core/utils/path_fallback_service.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Shared path-fallback helpers for search post-processing.""" - -from __future__ import annotations - -import hashlib -from typing import Any, Dict, List, Optional, Sequence, Tuple - -from ..retrieval.dual_path import RetrievalResult - - -def extract_entities(query: str, graph_store: Any) -> List[str]: - """Extract up to two graph nodes from a query using n-gram matching.""" - if not graph_store: - return [] - - text = str(query or "").strip() - if not text: - return [] - - # Keep the heuristic aligned with previous legacy behavior. - tokens = ( - text.replace("?", " ") - .replace("!", " ") - .replace(".", " ") - .split() - ) - if not tokens: - return [] - - found_entities = set() - skip_indices = set() - max_n = min(4, len(tokens)) - - for size in range(max_n, 0, -1): - for i in range(len(tokens) - size + 1): - if any(idx in skip_indices for idx in range(i, i + size)): - continue - span = " ".join(tokens[i : i + size]) - matched_node = graph_store.find_node(span, ignore_case=True) - if not matched_node: - continue - found_entities.add(matched_node) - for idx in range(i, i + size): - skip_indices.add(idx) - - return list(found_entities) - - -def find_paths_between_entities( - start_node: str, - end_node: str, - graph_store: Any, - metadata_store: Any, - *, - max_depth: int = 3, - max_paths: int = 5, -) -> List[Dict[str, Any]]: - """Find and enrich indirect paths between two nodes.""" - if not graph_store or not metadata_store: - return [] - - try: - paths = graph_store.find_paths( - start_node, - end_node, - max_depth=max_depth, - max_paths=max_paths, - ) - except Exception: - return [] - - if not paths: - return [] - - edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {} - formatted_paths: List[Dict[str, Any]] = [] - - for path_nodes in paths: - if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2: - continue - - path_desc: List[str] = [] - for i in range(len(path_nodes) - 1): - u = str(path_nodes[i]) - v = str(path_nodes[i + 1]) - - cache_key = tuple(sorted((u, v))) - if cache_key in edge_cache: - pred, direction = edge_cache[cache_key] - else: - pred = "related" - direction = "->" - rels = metadata_store.get_relations(subject=u, object=v) - if not rels: - rels = metadata_store.get_relations(subject=v, object=u) - direction = "<-" - if rels: - best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) - pred = str(best_rel.get("predicate", "related") or "related") - edge_cache[cache_key] = (pred, direction) - - step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-" - path_desc.append(step_str) - - full_path_str = str(path_nodes[0]) - for i, step in enumerate(path_desc): - full_path_str += f" {step} {path_nodes[i + 1]}" - - formatted_paths.append( - { - "nodes": list(path_nodes), - "description": full_path_str, - } - ) - - return formatted_paths - - -def find_paths_from_query( - query: str, - graph_store: Any, - metadata_store: Any, - *, - max_depth: int = 3, - max_paths: int = 5, -) -> List[Dict[str, Any]]: - """Extract entities from query and resolve indirect paths.""" - entities = extract_entities(query, graph_store) - if len(entities) != 2: - return [] - return find_paths_between_entities( - entities[0], - entities[1], - graph_store, - metadata_store, - max_depth=max_depth, - max_paths=max_paths, - ) - - -def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]: - """Convert path results into retrieval results for the unified pipeline.""" - converted: List[RetrievalResult] = [] - for item in paths: - description = str(item.get("description", "")).strip() - if not description: - continue - hash_seed = description.encode("utf-8") - path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}" - converted.append( - RetrievalResult( - hash_value=path_hash, - content=f"[Indirect Relation] {description}", - score=0.95, - result_type="relation", - source="graph_path", - metadata={ - "source": "graph_path", - "is_indirect": True, - "nodes": list(item.get("nodes", [])), - }, - ) - ) - return converted - diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py deleted file mode 100644 index 6460c013..00000000 --- a/plugins/A_memorix/core/utils/person_profile_service.py +++ /dev/null @@ -1,554 +0,0 @@ -""" -人物画像服务 - -主链路: -person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储 -""" - -import json -import time -from typing import Any, Dict, List, Optional, Tuple - -from sqlalchemy import or_ -from sqlmodel import select - -from src.common.logger import get_logger -from src.common.database.database import get_db_session -from src.common.database.database_model import PersonInfo - -from ..embedding import EmbeddingAPIAdapter -from ..retrieval import ( - DualPathRetriever, - RetrievalStrategy, - DualPathRetrieverConfig, - SparseBM25Config, - FusionConfig, - GraphRelationRecallConfig, -) -from ..storage import MetadataStore, GraphStore, VectorStore - -logger = get_logger("A_Memorix.PersonProfileService") - - -class PersonProfileService: - """人物画像聚合/刷新服务。""" - - def __init__( - self, - metadata_store: MetadataStore, - graph_store: Optional[GraphStore] = None, - vector_store: Optional[VectorStore] = None, - embedding_manager: Optional[EmbeddingAPIAdapter] = None, - sparse_index: Any = None, - plugin_config: Optional[dict] = None, - retriever: Optional[DualPathRetriever] = None, - ): - self.metadata_store = metadata_store - self.graph_store = graph_store - self.vector_store = vector_store - self.embedding_manager = embedding_manager - self.sparse_index = sparse_index - self.plugin_config = plugin_config or {} - self.retriever = retriever or self._build_retriever() - - def _cfg(self, key: str, default: Any = None) -> Any: - """读取嵌套配置。""" - if not isinstance(self.plugin_config, dict): - return default - current: Any = self.plugin_config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - def _build_retriever(self) -> Optional[DualPathRetriever]: - """按需构建检索器(无依赖时返回 None)。""" - if not all( - [ - self.vector_store is not None, - self.graph_store is not None, - self.metadata_store is not None, - self.embedding_manager is not None, - ] - ): - return None - try: - sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} - fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {} - graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {} - if not isinstance(sparse_cfg_raw, dict): - sparse_cfg_raw = {} - if not isinstance(fusion_cfg_raw, dict): - fusion_cfg_raw = {} - if not isinstance(graph_recall_cfg_raw, dict): - graph_recall_cfg_raw = {} - - sparse_cfg = SparseBM25Config(**sparse_cfg_raw) - fusion_cfg = FusionConfig(**fusion_cfg_raw) - graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) - config = DualPathRetrieverConfig( - top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)), - top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)), - top_k_final=int(self._cfg("retrieval.top_k_final", 10)), - alpha=float(self._cfg("retrieval.alpha", 0.5)), - enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), - ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), - ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), - enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), - retrieval_strategy=RetrievalStrategy.DUAL_PATH, - debug=bool(self._cfg("advanced.debug", False)), - sparse=sparse_cfg, - fusion=fusion_cfg, - graph_recall=graph_recall_cfg, - ) - return DualPathRetriever( - vector_store=self.vector_store, - graph_store=self.graph_store, - metadata_store=self.metadata_store, - embedding_manager=self.embedding_manager, - sparse_index=self.sparse_index, - config=config, - ) - except Exception as e: - logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}") - return None - - @staticmethod - def resolve_person_id(identifier: str) -> str: - """按 person_id 或姓名/别名解析 person_id。""" - if not identifier: - return "" - key = str(identifier).strip() - if not key: - return "" - - try: - with get_db_session(auto_commit=False) as session: - record = session.exec( - select(PersonInfo.person_id).where(PersonInfo.person_id == key).limit(1) - ).first() - if record: - return str(record) - - record = session.exec( - select(PersonInfo.person_id) - .where( - or_( - PersonInfo.person_name == key, - PersonInfo.user_nickname == key, - ) - ) - .limit(1) - ).first() - if record: - return str(record) - - record = session.exec( - select(PersonInfo.person_id) - .where(PersonInfo.group_cardname.contains(key)) - .limit(1) - ).first() - if record: - return str(record) - except Exception as e: - logger.warning(f"按别名解析 person_id 失败: identifier={key}, err={e}") - - if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): - return key.lower() - - return "" - - def _parse_group_nicks(self, raw_value: Any) -> List[str]: - if not raw_value: - return [] - if isinstance(raw_value, list): - items = raw_value - else: - try: - items = json.loads(raw_value) - except Exception: - return [] - names: List[str] = [] - for item in items: - if isinstance(item, dict): - value = str(item.get("group_cardname") or item.get("group_nick_name") or "").strip() - if value: - names.append(value) - elif isinstance(item, str): - value = item.strip() - if value: - names.append(value) - return names - - def _parse_memory_traits(self, raw_value: Any) -> List[str]: - if not raw_value: - return [] - try: - values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value - except Exception: - return [] - if not isinstance(values, list): - return [] - traits: List[str] = [] - for item in values: - text = str(item).strip() - if not text: - continue - if ":" in text: - parts = text.split(":") - if len(parts) >= 3: - content = ":".join(parts[1:-1]).strip() - if content: - traits.append(content) - continue - traits.append(text) - return traits[:10] - - def _recover_aliases_from_memory(self, person_id: str) -> Tuple[List[str], str]: - """当人物主档案缺失时,从已有记忆证据里回捞可用别名。""" - if not person_id: - return [], "" - - aliases: List[str] = [] - primary_name = "" - seen = set() - - try: - paragraphs = self.metadata_store.get_paragraphs_by_entity(person_id) - except Exception as e: - logger.warning(f"从记忆证据回捞人物别名失败: person_id={person_id}, err={e}") - return [], "" - - for paragraph in paragraphs[:20]: - paragraph_hash = str(paragraph.get("hash", "") or "").strip() - if not paragraph_hash: - continue - try: - paragraph_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) - except Exception: - paragraph_entities = [] - for entity in paragraph_entities: - name = str(entity.get("name", "") or "").strip() - if not name or name == person_id: - continue - key = name.lower() - if key in seen: - continue - seen.add(key) - aliases.append(name) - if not primary_name: - primary_name = name - return aliases, primary_name - - def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: - """获取人物别名集合、主展示名、记忆特征。""" - aliases: List[str] = [] - primary_name = "" - memory_traits: List[str] = [] - if not person_id: - return aliases, primary_name, memory_traits - recovered_aliases, recovered_primary_name = self._recover_aliases_from_memory(person_id) - try: - with get_db_session(auto_commit=False) as session: - record = session.exec( - select(PersonInfo).where(PersonInfo.person_id == person_id).limit(1) - ).first() - if not record: - return recovered_aliases, recovered_primary_name or person_id, memory_traits - person_name = str(getattr(record, "person_name", "") or "").strip() - nickname = str(getattr(record, "user_nickname", "") or "").strip() - group_nicks = self._parse_group_nicks(getattr(record, "group_cardname", None)) - memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) - - primary_name = ( - person_name - or nickname - or recovered_primary_name - or str(getattr(record, "user_id", "") or "").strip() - or person_id - ) - - candidates = [person_name, nickname] + group_nicks + recovered_aliases - seen = set() - for item in candidates: - norm = str(item or "").strip() - if not norm or norm in seen: - continue - seen.add(norm) - aliases.append(norm) - except Exception as e: - logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}") - return aliases, primary_name, memory_traits - - def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: - relation_by_hash: Dict[str, Dict[str, Any]] = {} - for alias in aliases: - for rel in self.metadata_store.get_relations(subject=alias): - h = str(rel.get("hash", "")) - if h: - relation_by_hash[h] = rel - for rel in self.metadata_store.get_relations(object=alias): - h = str(rel.get("hash", "")) - if h: - relation_by_hash[h] = rel - - relations = list(relation_by_hash.values()) - relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True) - relations = relations[: max(1, int(limit))] - - edges: List[Dict[str, Any]] = [] - for rel in relations: - edges.append( - { - "hash": str(rel.get("hash", "")), - "subject": str(rel.get("subject", "")), - "predicate": str(rel.get("predicate", "")), - "object": str(rel.get("object", "")), - "confidence": float(rel.get("confidence", 1.0) or 1.0), - } - ) - return edges - - async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]: - alias_queries = [a for a in aliases if a] - if not alias_queries: - return [] - - if self.retriever is None: - # 回退:无检索器时只做简单内容匹配 - fallback: List[Dict[str, Any]] = [] - seen_hash = set() - for alias in alias_queries: - for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]: - h = str(para.get("hash", "")) - if not h or h in seen_hash: - continue - seen_hash.add(h) - fallback.append( - { - "hash": h, - "type": "paragraph", - "score": 0.0, - "content": str(para.get("content", ""))[:180], - "metadata": {}, - } - ) - return fallback[:top_k] - - per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) - seen_hash = set() - evidence: List[Dict[str, Any]] = [] - for alias in alias_queries: - try: - results = await self.retriever.retrieve(alias, top_k=per_alias_top_k) - except Exception as e: - logger.warning(f"向量证据召回失败: alias={alias}, err={e}") - continue - for item in results: - h = str(getattr(item, "hash_value", "") or "") - if not h or h in seen_hash: - continue - seen_hash.add(h) - evidence.append( - { - "hash": h, - "type": str(getattr(item, "result_type", "")), - "score": float(getattr(item, "score", 0.0) or 0.0), - "content": str(getattr(item, "content", "") or "")[:220], - "metadata": dict(getattr(item, "metadata", {}) or {}), - } - ) - evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) - return evidence[:top_k] - - def _build_profile_text( - self, - person_id: str, - primary_name: str, - aliases: List[str], - relation_edges: List[Dict[str, Any]], - vector_evidence: List[Dict[str, Any]], - memory_traits: List[str], - ) -> str: - """基于证据构建画像文本(供 LLM 上下文注入)。""" - lines: List[str] = [] - lines.append(f"人物ID: {person_id}") - if primary_name: - lines.append(f"主称呼: {primary_name}") - if aliases: - lines.append(f"别名: {', '.join(aliases[:8])}") - if memory_traits: - lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}") - - if relation_edges: - lines.append("关系证据:") - for rel in relation_edges[:6]: - s = rel.get("subject", "") - p = rel.get("predicate", "") - o = rel.get("object", "") - conf = float(rel.get("confidence", 0.0)) - lines.append(f"- {s} {p} {o} (conf={conf:.2f})") - - if vector_evidence: - lines.append("向量证据摘要:") - for item in vector_evidence[:4]: - content = str(item.get("content", "")).strip() - if content: - lines.append(f"- {content}") - - if len(lines) <= 2: - lines.append("暂无足够证据形成稳定画像。") - - return "\n".join(lines) - - @staticmethod - def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool: - if not snapshot: - return True - now = time.time() - expires_at = snapshot.get("expires_at") - if expires_at is not None: - try: - return now >= float(expires_at) - except Exception: - return True - updated_at = float(snapshot.get("updated_at") or 0.0) - return (now - updated_at) >= ttl_seconds - - def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]: - """将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。""" - payload = dict(profile_payload or {}) - auto_text = str(payload.get("profile_text", "") or "") - payload["auto_profile_text"] = auto_text - payload["has_manual_override"] = False - payload["manual_override_text"] = "" - payload["override_updated_at"] = None - payload["override_updated_by"] = "" - payload["profile_source"] = "auto_snapshot" - - if not person_id or self.metadata_store is None: - return payload - - try: - override = self.metadata_store.get_person_profile_override(person_id) - except Exception as e: - logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}") - return payload - - if not override: - return payload - - manual_text = str(override.get("override_text", "") or "").strip() - if not manual_text: - return payload - - payload["has_manual_override"] = True - payload["manual_override_text"] = manual_text - payload["override_updated_at"] = override.get("updated_at") - payload["override_updated_by"] = str(override.get("updated_by", "") or "") - payload["profile_text"] = manual_text - payload["profile_source"] = "manual_override" - return payload - - async def query_person_profile( - self, - person_id: str = "", - person_keyword: str = "", - top_k: int = 12, - ttl_seconds: float = 6 * 3600, - force_refresh: bool = False, - source_note: str = "", - ) -> Dict[str, Any]: - """查询或刷新人物画像。""" - pid = str(person_id or "").strip() - if not pid and person_keyword: - pid = self.resolve_person_id(person_keyword) - - if not pid: - return { - "success": False, - "error": "person_id 无效,且未能通过别名解析", - } - - latest = self.metadata_store.get_latest_person_profile_snapshot(pid) - if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds): - aliases, primary_name, _ = self.get_person_aliases(pid) - payload = { - "success": True, - "person_id": pid, - "person_name": primary_name, - "from_cache": True, - **(latest or {}), - } - if aliases and not payload.get("aliases"): - payload["aliases"] = aliases - return { - **self._apply_manual_override(pid, payload), - } - - aliases, primary_name, memory_traits = self.get_person_aliases(pid) - if not aliases and person_keyword: - aliases = [person_keyword.strip()] - primary_name = person_keyword.strip() - relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2)) - vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k)) - - evidence_ids = [ - str(item.get("hash", "")) - for item in (relation_edges + vector_evidence) - if str(item.get("hash", "")).strip() - ] - dedup_ids: List[str] = [] - seen = set() - for item in evidence_ids: - if item in seen: - continue - seen.add(item) - dedup_ids.append(item) - - profile_text = self._build_profile_text( - person_id=pid, - primary_name=primary_name, - aliases=aliases, - relation_edges=relation_edges, - vector_evidence=vector_evidence, - memory_traits=memory_traits, - ) - - expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None - snapshot = self.metadata_store.upsert_person_profile_snapshot( - person_id=pid, - profile_text=profile_text, - aliases=aliases, - relation_edges=relation_edges, - vector_evidence=vector_evidence, - evidence_ids=dedup_ids, - expires_at=expires_at, - source_note=source_note, - ) - payload = { - "success": True, - "person_id": pid, - "person_name": primary_name, - "from_cache": False, - **snapshot, - } - return { - **self._apply_manual_override(pid, payload), - } - - @staticmethod - def format_persona_profile_block(profile: Dict[str, Any]) -> str: - """格式化给 replyer 的注入块。""" - if not profile or not profile.get("success"): - return "" - text = str(profile.get("profile_text", "") or "").strip() - if not text: - return "" - return ( - "【人物画像-内部参考】\n" - f"{text}\n" - "仅供内部推理,不要向用户逐字复述。" - ) diff --git a/plugins/A_memorix/core/utils/plugin_id_policy.py b/plugins/A_memorix/core/utils/plugin_id_policy.py deleted file mode 100644 index 8e730e12..00000000 --- a/plugins/A_memorix/core/utils/plugin_id_policy.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Plugin ID matching policy for A_Memorix.""" - -from __future__ import annotations - -from typing import Any - - -class PluginIdPolicy: - """Centralized plugin id normalization/matching policy.""" - - CANONICAL_ID = "a_memorix" - - @classmethod - def normalize(cls, plugin_id: Any) -> str: - if not isinstance(plugin_id, str): - return "" - return plugin_id.strip().lower() - - @classmethod - def is_target_plugin_id(cls, plugin_id: Any) -> bool: - normalized = cls.normalize(plugin_id) - if not normalized: - return False - if normalized == cls.CANONICAL_ID: - return True - return normalized.split(".")[-1] == cls.CANONICAL_ID - diff --git a/plugins/A_memorix/core/utils/quantization.py b/plugins/A_memorix/core/utils/quantization.py deleted file mode 100644 index 4e84f977..00000000 --- a/plugins/A_memorix/core/utils/quantization.py +++ /dev/null @@ -1,344 +0,0 @@ -""" -向量量化工具模块 - -提供向量量化与反量化功能,用于压缩存储空间。 -""" - -import numpy as np -from enum import Enum -from typing import Tuple, Union - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.Quantization") - - -class QuantizationType(Enum): - """量化类型枚举""" - FLOAT32 = "float32" # 无量化 - INT8 = "int8" # 标量量化(8位整数) - PQ = "pq" # 乘积量化(Product Quantization) - - -def quantize_vector( - vector: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, -) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - 量化向量 - - Args: - vector: 输入向量(float32) - quant_type: 量化类型 - - Returns: - 量化后的向量: - - INT8: int8向量 - - PQ: (编码向量, 聚类中心) 元组 - """ - if quant_type == QuantizationType.FLOAT32: - return vector.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - return _scalar_quantize_int8(vector) - - elif quant_type == QuantizationType.PQ: - return _product_quantize(vector) - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def dequantize_vector( - quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], - quant_type: QuantizationType = QuantizationType.INT8, - original_shape: Tuple[int, ...] = None, -) -> np.ndarray: - """ - 反量化向量 - - Args: - quantized_vector: 量化后的向量 - quant_type: 量化类型 - original_shape: 原始向量形状(用于PQ) - - Returns: - 反量化后的向量(float32) - """ - if quant_type == QuantizationType.FLOAT32: - return quantized_vector.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - return _scalar_dequantize_int8(quantized_vector) - - elif quant_type == QuantizationType.PQ: - if not isinstance(quantized_vector, tuple): - raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)") - return _product_dequantize(quantized_vector[0], quantized_vector[1]) - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray: - """ - 标量量化:float32 -> int8 - - 将向量归一化到 [0, 255] 范围,然后映射到 int8 - - Args: - vector: 输入向量 - - Returns: - 量化后的 int8 向量 - """ - # 计算最小最大值 - min_val = np.min(vector) - max_val = np.max(vector) - - # 避免除零 - if max_val == min_val: - return np.zeros_like(vector, dtype=np.int8) - - # 归一化到 [0, 255] - normalized = (vector - min_val) / (max_val - min_val) * 255 - - # 映射到 [-128, 127] 并转换为 int8 - # np.round might return float, minus 128 then cast - quantized = np.round(normalized - 128.0).astype(np.int8) - - # 存储归一化参数(用于反量化) - # 在实际存储中,这些参数需要单独保存 - # 这里为了简单,我们使用一个全局字典来模拟 - if not hasattr(_scalar_quantize_int8, "_params"): - _scalar_quantize_int8._params = {} - - vector_id = id(vector) - _scalar_quantize_int8._params[vector_id] = (min_val, max_val) - - return quantized - - -def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray: - """ - 标量反量化:int8 -> float32 - - Args: - quantized: 量化后的 int8 向量 - - Returns: - 反量化后的 float32 向量 - """ - # 计算归一化参数(如果提供了) - # 在实际应用中,min_val 和 max_val 应该被保存 - if not hasattr(_scalar_dequantize_int8, "_params"): - # 默认假设范围是 [-1, 1] - return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0 - - # 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取) - # return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min - return (quantized.astype(np.float32) + 128.0) / 255.0 - - -def quantize_matrix( - matrix: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, -) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: - """ - 量化矩阵(批量量化向量) - - Args: - matrix: 输入矩阵(N x D,每行是一个向量) - quant_type: 量化类型 - - Returns: - 量化后的矩阵 - """ - if quant_type == QuantizationType.FLOAT32: - return matrix.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - # 对整个矩阵进行全局归一化 - min_val = np.min(matrix) - max_val = np.max(matrix) - - if max_val == min_val: - return np.zeros_like(matrix, dtype=np.int8) - - # 归一化到 [0, 255] - normalized = (matrix - min_val) / (max_val - min_val) * 255 - quantized = np.round(normalized).astype(np.int8) - - return quantized - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def dequantize_matrix( - quantized_matrix: np.ndarray, - quant_type: QuantizationType = QuantizationType.INT8, - min_val: float = None, - max_val: float = None, -) -> np.ndarray: - """ - 反量化矩阵 - - Args: - quantized_matrix: 量化后的矩阵 - quant_type: 量化类型 - min_val: 归一化最小值(int8反量化需要) - max_val: 归一化最大值(int8反量化需要) - - Returns: - 反量化后的矩阵 - """ - if quant_type == QuantizationType.FLOAT32: - return quantized_matrix.astype(np.float32) - - elif quant_type == QuantizationType.INT8: - # 使用提供的归一化参数反量化 - if min_val is None or max_val is None: - # 默认假设范围是 [0, 255] -> [-1, 1] - return quantized_matrix.astype(np.float32) / 127.0 - else: - # 恢复到原始范围 - normalized = quantized_matrix.astype(np.float32) / 255.0 - return normalized * (max_val - min_val) + min_val - - else: - raise ValueError(f"不支持的量化类型: {quant_type}") - - -def estimate_memory_reduction( - num_vectors: int, - dimension: int, - from_type: QuantizationType, - to_type: QuantizationType, -) -> Tuple[float, float]: - """ - 估算内存节省量 - - Args: - num_vectors: 向量数量 - dimension: 向量维度 - from_type: 原始量化类型 - to_type: 目标量化类型 - - Returns: - (原始大小MB, 量化后大小MB, 节省比例) - """ - # 计算每个向量占用的字节数 - bytes_per_element = { - QuantizationType.FLOAT32: 4, - QuantizationType.INT8: 1, - QuantizationType.PQ: 0.25, # 假设压缩到1/4 - } - - original_bytes = num_vectors * dimension * bytes_per_element[from_type] - quantized_bytes = num_vectors * dimension * bytes_per_element[to_type] - - original_mb = original_bytes / 1024 / 1024 - quantized_mb = quantized_bytes / 1024 / 1024 - reduction_ratio = (original_bytes - quantized_bytes) / original_bytes - - return original_mb, quantized_mb, reduction_ratio - - -def estimate_compression_stats( - num_vectors: int, - dimension: int, - quant_type: QuantizationType, -) -> dict: - """ - 估算压缩统计信息 - - Args: - num_vectors: 向量数量 - dimension: 向量维度 - quant_type: 量化类型 - - Returns: - 统计信息字典 - """ - original_mb, quantized_mb, ratio = estimate_memory_reduction( - num_vectors, dimension, QuantizationType.FLOAT32, quant_type - ) - - return { - "num_vectors": num_vectors, - "dimension": dimension, - "quantization_type": quant_type.value, - "original_size_mb": round(original_mb, 2), - "quantized_size_mb": round(quantized_mb, 2), - "saved_mb": round(original_mb - quantized_mb, 2), - "compression_ratio": round(ratio * 100, 2), - } - - -def _product_quantize( - vector: np.ndarray, m: int = 8, k: int = 256 -) -> Tuple[np.ndarray, np.ndarray]: - """ - 乘积量化 (PQ) 简化实现 - - Args: - vector: 输入向量 (D,) - m: 子空间数量 - k: 每个子空间的聚类中心数 - - Returns: - (编码后的向量, 聚类中心) - """ - d = vector.shape[0] - if d % m != 0: - raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除") - - ds = d // m # 子空间维度 - codes = np.zeros(m, dtype=np.uint8) - centroids = np.zeros((m, k, ds), dtype=np.float32) - - # 这里采用一种简化的 PQ:不进行 K-means 训练, - # 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练) - # 为了演示,我们直接将子空间切分为 k 份进行量化 - for i in range(m): - sub_vec = vector[i * ds : (i + 1) * ds] - # 简化:假定每个子空间的取值范围并划分 - # 实际 PQ 应使用 k-means 产生的 centroids - # 这里为演示创建一个随机 codebook 并找到最接近的核心 - sub_min, sub_max = np.min(sub_vec), np.max(sub_vec) - if sub_max == sub_min: - linspace = np.zeros(k) - else: - linspace = np.linspace(sub_min, sub_max, k) - - for j in range(k): - centroids[i, j, :] = linspace[j] - - # 编码:这里简化为取子空间均值找最接近的 centroid - sub_mean = np.mean(sub_vec) - code = np.argmin(np.abs(linspace - sub_mean)) - codes[i] = code - - return codes, centroids - - -def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray: - """ - PQ 反量化 - - Args: - codes: 编码向量 (M,) - centroids: 聚类中心 (M, K, DS) - - Returns: - 恢复后的向量 (D,) - """ - m, k, ds = centroids.shape - vector = np.zeros(m * ds, dtype=np.float32) - - for i in range(m): - code = codes[i] - vector[i * ds : (i + 1) * ds] = centroids[i, code, :] - - return vector diff --git a/plugins/A_memorix/core/utils/relation_query.py b/plugins/A_memorix/core/utils/relation_query.py deleted file mode 100644 index ffde9cac..00000000 --- a/plugins/A_memorix/core/utils/relation_query.py +++ /dev/null @@ -1,121 +0,0 @@ -"""关系查询规格解析工具。""" - -from __future__ import annotations - -import re -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class RelationQuerySpec: - raw: str - is_structured: bool - subject: Optional[str] - predicate: Optional[str] - object: Optional[str] - error: Optional[str] = None - - -_NATURAL_LANGUAGE_PATTERN = re.compile( - r"(^\s*(what|who|which|how|why|when|where)\b|" - r"\?|?|" - r"\b(relation|related|between)\b|" - r"(什么关系|有哪些关系|之间|关联))", - re.IGNORECASE, -) - - -def _looks_like_natural_language(raw: str) -> bool: - text = str(raw or "").strip() - if not text: - return False - return _NATURAL_LANGUAGE_PATTERN.search(text) is not None - - -def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec: - raw = str(relation_spec or "").strip() - if not raw: - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - error="empty", - ) - - if "|" in raw: - parts = [p.strip() for p in raw.split("|")] - if len(parts) < 2: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=None, - predicate=None, - object=None, - error="invalid_pipe_format", - ) - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0] or None, - predicate=parts[1] or None, - object=parts[2] if len(parts) > 2 and parts[2] else None, - ) - - if "->" in raw: - parts = [p.strip() for p in raw.split("->") if p.strip()] - if len(parts) >= 3: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=parts[1], - object=parts[2], - ) - if len(parts) == 2: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=None, - object=parts[1], - ) - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=None, - predicate=None, - object=None, - error="invalid_arrow_format", - ) - - if _looks_like_natural_language(raw): - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - ) - - # 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。 - # 两词形式过于模糊,不再视为结构化关系查询。 - parts = raw.split() - if len(parts) == 3: - return RelationQuerySpec( - raw=raw, - is_structured=True, - subject=parts[0], - predicate=parts[1], - object=parts[2], - ) - - return RelationQuerySpec( - raw=raw, - is_structured=False, - subject=None, - predicate=None, - object=None, - ) diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py deleted file mode 100644 index 6fa2e621..00000000 --- a/plugins/A_memorix/core/utils/relation_write_service.py +++ /dev/null @@ -1,166 +0,0 @@ -""" -统一关系写入与关系向量化服务。 - -规则: -1. 元数据是主数据源,向量是从索引。 -2. 关系先写 metadata,再写向量。 -3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。 -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from src.common.logger import get_logger - - -logger = get_logger("A_Memorix.RelationWriteService") - - -@dataclass -class RelationWriteResult: - hash_value: str - vector_written: bool - vector_already_exists: bool - vector_state: str - - -class RelationWriteService: - """关系写入收口服务。""" - - ERROR_MAX_LEN = 500 - - def __init__( - self, - metadata_store: Any, - graph_store: Any, - vector_store: Any, - embedding_manager: Any, - ): - self.metadata_store = metadata_store - self.graph_store = graph_store - self.vector_store = vector_store - self.embedding_manager = embedding_manager - - @staticmethod - def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str: - s = str(subject or "").strip() - p = str(predicate or "").strip() - o = str(obj or "").strip() - # 双表达:兼容关键词检索与自然语言问句 - return f"{s} {p} {o}\n{s}和{o}的关系是{p}" - - async def ensure_relation_vector( - self, - hash_value: str, - subject: str, - predicate: str, - obj: str, - *, - max_error_len: int = ERROR_MAX_LEN, - ) -> RelationWriteResult: - """ - 为已有关系确保向量存在并更新状态。 - """ - if hash_value in self.vector_store: - self.metadata_store.set_relation_vector_state(hash_value, "ready") - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=True, - vector_state="ready", - ) - - self.metadata_store.set_relation_vector_state(hash_value, "pending") - try: - vector_text = self.build_relation_vector_text(subject, predicate, obj) - embedding = await self.embedding_manager.encode(vector_text) - self.vector_store.add( - vectors=embedding.reshape(1, -1), - ids=[hash_value], - ) - self.metadata_store.set_relation_vector_state(hash_value, "ready") - logger.info( - "metric.relation_vector_write_success=1 " - "metric.relation_vector_write_success_count=1 " - f"hash={hash_value[:16]}" - ) - return RelationWriteResult( - hash_value=hash_value, - vector_written=True, - vector_already_exists=False, - vector_state="ready", - ) - except ValueError: - # 向量已存在冲突,按成功处理 - self.metadata_store.set_relation_vector_state(hash_value, "ready") - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=True, - vector_state="ready", - ) - except Exception as e: - err = str(e)[:max_error_len] - self.metadata_store.set_relation_vector_state( - hash_value, - "failed", - error=err, - bump_retry=True, - ) - logger.warning( - "metric.relation_vector_write_fail=1 " - "metric.relation_vector_write_fail_count=1 " - f"hash={hash_value[:16]} " - f"err={err}" - ) - return RelationWriteResult( - hash_value=hash_value, - vector_written=False, - vector_already_exists=False, - vector_state="failed", - ) - - async def upsert_relation_with_vector( - self, - subject: str, - predicate: str, - obj: str, - confidence: float = 1.0, - source_paragraph: str = "", - metadata: Optional[Dict[str, Any]] = None, - *, - write_vector: bool = True, - ) -> RelationWriteResult: - """ - 统一关系写入: - 1) 写 metadata relation - 2) 写 graph edge relation_hash - 3) 按需写 relation vector - """ - rel_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - confidence=confidence, - source_paragraph=source_paragraph, - metadata=metadata or {}, - ) - self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - - if not write_vector: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - return RelationWriteResult( - hash_value=rel_hash, - vector_written=False, - vector_already_exists=False, - vector_state="none", - ) - - return await self.ensure_relation_vector( - hash_value=rel_hash, - subject=subject, - predicate=predicate, - obj=obj, - ) diff --git a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py b/plugins/A_memorix/core/utils/retrieval_tuning_manager.py deleted file mode 100644 index e0e8ecd6..00000000 --- a/plugins/A_memorix/core/utils/retrieval_tuning_manager.py +++ /dev/null @@ -1,1857 +0,0 @@ -""" -Retrieval tuning manager for WebUI. -""" - -from __future__ import annotations - -import asyncio -import copy -import json -import random -import re -import time -import uuid -from collections import Counter, deque -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from ..runtime.search_runtime_initializer import build_search_runtime -from .search_execution_service import SearchExecutionRequest, SearchExecutionService - -try: - from src.services import llm_service as llm_api -except Exception: # pragma: no cover - llm_api = None - -logger = get_logger("A_Memorix.RetrievalTuningManager") - - -OBJECTIVES = {"precision_priority", "balanced", "recall_priority"} -INTENSITIES = {"quick": 8, "standard": 20, "deep": 32} -CATEGORIES = {"query_nl", "query_kw", "spo_relation", "spo_search"} -_RUNTIME_CONFIG_INSTANCE_KEYS = { - "vector_store", - "graph_store", - "metadata_store", - "embedding_manager", - "sparse_index", - "relation_write_service", - "plugin_instance", -} - - -def _now() -> float: - return time.time() - - -def _clamp_int(value: Any, default: int, min_value: int, max_value: int) -> int: - try: - parsed = int(value) - except Exception: - parsed = int(default) - return max(min_value, min(max_value, parsed)) - - -def _clamp_float(value: Any, default: float, min_value: float, max_value: float) -> float: - try: - parsed = float(value) - except Exception: - parsed = float(default) - return max(min_value, min(max_value, parsed)) - - -def _coerce_bool(value: Any, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - text = str(value).strip().lower() - if text in {"1", "true", "yes", "y", "on"}: - return True - if text in {"0", "false", "no", "n", "off"}: - return False - return default - - -def _nested_get(data: Dict[str, Any], key: str, default: Any = None) -> Any: - cur: Any = data - for part in key.split("."): - if isinstance(cur, dict) and part in cur: - cur = cur[part] - else: - return default - return cur - - -def _nested_set(data: Dict[str, Any], key: str, value: Any) -> None: - parts = key.split(".") - cur = data - for part in parts[:-1]: - if part not in cur or not isinstance(cur[part], dict): - cur[part] = {} - cur = cur[part] - cur[parts[-1]] = value - - -def _deep_merge(base: Dict[str, Any], patch: Dict[str, Any]) -> Dict[str, Any]: - out = copy.deepcopy(base) - for key, value in (patch or {}).items(): - if isinstance(value, dict) and isinstance(out.get(key), dict): - out[key] = _deep_merge(out[key], value) - else: - out[key] = copy.deepcopy(value) - return out - - -def _safe_json_loads(text: str) -> Optional[Any]: - raw = str(text or "").strip() - if not raw: - return None - if "```" in raw: - raw = raw.replace("```json", "```") - for seg in raw.split("```"): - seg = seg.strip() - if seg.startswith("{") or seg.startswith("["): - raw = seg - break - try: - return json.loads(raw) - except Exception: - pass - s = raw.find("{") - e = raw.rfind("}") - if s >= 0 and e > s: - try: - return json.loads(raw[s : e + 1]) - except Exception: - return None - return None - - -@dataclass -class RetrievalQueryCase: - case_id: str - category: str - query: str - expected_hashes: List[str] = field(default_factory=list) - expected_spo: Dict[str, str] = field(default_factory=dict) - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_dict(self) -> Dict[str, Any]: - return { - "case_id": self.case_id, - "category": self.category, - "query": self.query, - "expected_hashes": list(self.expected_hashes), - "expected_spo": dict(self.expected_spo), - "metadata": dict(self.metadata), - } - - -@dataclass -class RetrievalTuningRoundRecord: - round_index: int - candidate_profile: Dict[str, Any] - metrics: Dict[str, Any] - score: float - latency_ms: float - failure_summary: Dict[str, Any] = field(default_factory=dict) - created_at: float = field(default_factory=_now) - - def to_dict(self) -> Dict[str, Any]: - return { - "round_index": self.round_index, - "candidate_profile": copy.deepcopy(self.candidate_profile), - "metrics": copy.deepcopy(self.metrics), - "score": float(self.score), - "latency_ms": float(self.latency_ms), - "failure_summary": copy.deepcopy(self.failure_summary), - "created_at": float(self.created_at), - } - - -@dataclass -class RetrievalTuningTaskRecord: - task_id: str - status: str - progress: float - objective: str - intensity: str - rounds_total: int - rounds_done: int = 0 - best_profile: Dict[str, Any] = field(default_factory=dict) - best_metrics: Dict[str, Any] = field(default_factory=dict) - best_score: float = -1.0 - baseline_profile: Dict[str, Any] = field(default_factory=dict) - baseline_metrics: Dict[str, Any] = field(default_factory=dict) - error: str = "" - params: Dict[str, Any] = field(default_factory=dict) - query_set_stats: Dict[str, Any] = field(default_factory=dict) - artifact_paths: Dict[str, str] = field(default_factory=dict) - rounds: List[RetrievalTuningRoundRecord] = field(default_factory=list) - cancel_requested: bool = False - created_at: float = field(default_factory=_now) - started_at: Optional[float] = None - finished_at: Optional[float] = None - updated_at: float = field(default_factory=_now) - apply_log: List[Dict[str, Any]] = field(default_factory=list) - - def to_summary(self) -> Dict[str, Any]: - return { - "task_id": self.task_id, - "status": self.status, - "progress": self.progress, - "objective": self.objective, - "intensity": self.intensity, - "rounds_total": self.rounds_total, - "rounds_done": self.rounds_done, - "best_score": self.best_score, - "error": self.error, - "query_set_stats": dict(self.query_set_stats), - "artifact_paths": dict(self.artifact_paths), - "created_at": self.created_at, - "started_at": self.started_at, - "finished_at": self.finished_at, - "updated_at": self.updated_at, - } - - def to_detail(self, include_rounds: bool = False) -> Dict[str, Any]: - payload = self.to_summary() - payload.update( - { - "params": copy.deepcopy(self.params), - "best_profile": copy.deepcopy(self.best_profile), - "best_metrics": copy.deepcopy(self.best_metrics), - "baseline_profile": copy.deepcopy(self.baseline_profile), - "baseline_metrics": copy.deepcopy(self.baseline_metrics), - "apply_log": copy.deepcopy(self.apply_log), - } - ) - if include_rounds: - payload["rounds"] = [x.to_dict() for x in self.rounds] - return payload - - -class RetrievalTuningManager: - def __init__( - self, - plugin: Any, - *, - import_write_blocked_provider: Optional[Callable[[], bool]] = None, - ): - self.plugin = plugin - self._import_write_blocked_provider = import_write_blocked_provider - - self._lock = asyncio.Lock() - self._tasks: Dict[str, RetrievalTuningTaskRecord] = {} - self._task_order: deque[str] = deque() - self._queue: deque[str] = deque() - self._active_task_id: Optional[str] = None - self._worker_task: Optional[asyncio.Task] = None - self._stopping = False - - self._rollback_snapshot: Optional[Dict[str, Any]] = None - - self._artifacts_root = Path(__file__).resolve().parents[2] / "artifacts" / "retrieval_tuning" - self._artifacts_root.mkdir(parents=True, exist_ok=True) - - def _cfg(self, key: str, default: Any = None) -> Any: - getter = getattr(self.plugin, "get_config", None) - if callable(getter): - return getter(key, default) - return default - - def _is_enabled(self) -> bool: - return bool(self._cfg("web.tuning.enabled", True)) - - def _queue_limit(self) -> int: - return _clamp_int(self._cfg("web.tuning.max_queue_size", 8), 8, 1, 100) - - def _poll_interval_s(self) -> float: - ms = _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000) - return max(0.2, ms / 1000.0) - - def _llm_retry_cfg(self) -> Dict[str, Any]: - return { - "max_attempts": _clamp_int(self._cfg("web.tuning.llm_retry.max_attempts", 3), 3, 1, 10), - "min_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.min_wait_seconds", 2), 2.0, 0.1, 60.0), - "max_wait_seconds": _clamp_float(self._cfg("web.tuning.llm_retry.max_wait_seconds", 20), 20.0, 0.2, 120.0), - "backoff_multiplier": _clamp_float(self._cfg("web.tuning.llm_retry.backoff_multiplier", 2), 2.0, 1.0, 10.0), - } - - def _eval_query_timeout_s(self) -> float: - return _clamp_float( - self._cfg("web.tuning.eval_query_timeout_seconds", 10.0), - 10.0, - 0.01, - 120.0, - ) - - def get_runtime_settings(self) -> Dict[str, Any]: - intensity = str(self._cfg("web.tuning.default_intensity", "standard") or "standard") - if intensity not in INTENSITIES: - intensity = "standard" - objective = str(self._cfg("web.tuning.default_objective", "precision_priority") or "precision_priority") - if objective not in OBJECTIVES: - objective = "precision_priority" - return { - "enabled": self._is_enabled(), - "poll_interval_ms": _clamp_int(self._cfg("web.tuning.poll_interval_ms", 1200), 1200, 200, 60000), - "max_queue_size": self._queue_limit(), - "default_objective": objective, - "default_intensity": intensity, - "default_rounds": INTENSITIES[intensity], - "default_top_k_eval": _clamp_int(self._cfg("web.tuning.default_top_k_eval", 20), 20, 5, 100), - "default_sample_size": _clamp_int(self._cfg("web.tuning.default_sample_size", 24), 24, 4, 200), - "eval_query_timeout_seconds": self._eval_query_timeout_s(), - "llm_retry": self._llm_retry_cfg(), - } - - def _ensure_ready(self) -> None: - required = ("metadata_store", "vector_store", "graph_store", "embedding_manager") - missing = [x for x in required if getattr(self.plugin, x, None) is None] - if missing: - raise ValueError(f"调优依赖未初始化: {', '.join(missing)}") - checker = getattr(self.plugin, "is_runtime_ready", None) - if callable(checker) and not checker(): - raise ValueError("插件运行时未就绪") - provider = self._import_write_blocked_provider - if provider is not None and bool(provider()): - raise ValueError("导入任务运行中,当前禁止启动检索调优") - - def get_profile_snapshot(self) -> Dict[str, Any]: - cfg = getattr(self.plugin, "config", {}) or {} - profile = { - "retrieval": { - "top_k_paragraphs": _nested_get(cfg, "retrieval.top_k_paragraphs", 20), - "top_k_relations": _nested_get(cfg, "retrieval.top_k_relations", 10), - "top_k_final": _nested_get(cfg, "retrieval.top_k_final", 10), - "alpha": _nested_get(cfg, "retrieval.alpha", 0.5), - "enable_ppr": _nested_get(cfg, "retrieval.enable_ppr", True), - "search": {"smart_fallback": {"enabled": _nested_get(cfg, "retrieval.search.smart_fallback.enabled", True)}}, - "sparse": { - "enabled": _nested_get(cfg, "retrieval.sparse.enabled", True), - "mode": _nested_get(cfg, "retrieval.sparse.mode", "auto"), - "candidate_k": _nested_get(cfg, "retrieval.sparse.candidate_k", 80), - "relation_candidate_k": _nested_get(cfg, "retrieval.sparse.relation_candidate_k", 60), - }, - "fusion": { - "method": _nested_get(cfg, "retrieval.fusion.method", "weighted_rrf"), - "rrf_k": _nested_get(cfg, "retrieval.fusion.rrf_k", 60), - "vector_weight": _nested_get(cfg, "retrieval.fusion.vector_weight", 0.7), - "bm25_weight": _nested_get(cfg, "retrieval.fusion.bm25_weight", 0.3), - }, - }, - "threshold": { - "percentile": _nested_get(cfg, "threshold.percentile", 75.0), - "min_results": _nested_get(cfg, "threshold.min_results", 3), - }, - } - return self._normalize_profile(profile, fallback=profile) - - def _normalize_profile(self, profile: Optional[Dict[str, Any]], *, fallback: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - raw = copy.deepcopy(profile or {}) - base = copy.deepcopy(fallback or self.get_profile_snapshot()) - - def pick(path: str, default: Any) -> Any: - if _nested_get(raw, path, None) is not None: - return _nested_get(raw, path, default) - if path in raw: - return raw.get(path, default) - return _nested_get(base, path, default) - - fusion_method = str(pick("retrieval.fusion.method", "weighted_rrf") or "weighted_rrf").strip().lower() - if fusion_method not in {"weighted_rrf", "alpha_legacy"}: - fusion_method = "weighted_rrf" - - sparse_mode = str(pick("retrieval.sparse.mode", "auto") or "auto").strip().lower() - if sparse_mode not in {"auto", "hybrid", "fallback_only"}: - sparse_mode = "auto" - - vec_w = _clamp_float(pick("retrieval.fusion.vector_weight", 0.7), 0.7, 0.0, 1.0) - bm_w = _clamp_float(pick("retrieval.fusion.bm25_weight", 0.3), 0.3, 0.0, 1.0) - s = vec_w + bm_w - if s <= 1e-9: - vec_w, bm_w = 0.7, 0.3 - else: - vec_w, bm_w = vec_w / s, bm_w / s - - return { - "retrieval": { - "top_k_paragraphs": _clamp_int(pick("retrieval.top_k_paragraphs", 20), 20, 10, 1200), - "top_k_relations": _clamp_int(pick("retrieval.top_k_relations", 10), 10, 4, 512), - "top_k_final": _clamp_int(pick("retrieval.top_k_final", 10), 10, 4, 512), - "alpha": _clamp_float(pick("retrieval.alpha", 0.5), 0.5, 0.0, 1.0), - "enable_ppr": _coerce_bool(pick("retrieval.enable_ppr", True), True), - "search": {"smart_fallback": {"enabled": _coerce_bool(pick("retrieval.search.smart_fallback.enabled", True), True)}}, - "sparse": { - "enabled": _coerce_bool(pick("retrieval.sparse.enabled", True), True), - "mode": sparse_mode, - "candidate_k": _clamp_int(pick("retrieval.sparse.candidate_k", 80), 80, 20, 2000), - "relation_candidate_k": _clamp_int(pick("retrieval.sparse.relation_candidate_k", 60), 60, 20, 2000), - }, - "fusion": { - "method": fusion_method, - "rrf_k": _clamp_int(pick("retrieval.fusion.rrf_k", 60), 60, 1, 500), - "vector_weight": float(vec_w), - "bm25_weight": float(bm_w), - }, - }, - "threshold": { - "percentile": _clamp_float(pick("threshold.percentile", 75.0), 75.0, 1.0, 99.0), - "min_results": _clamp_int(pick("threshold.min_results", 3), 3, 1, 100), - }, - } - - def _apply_profile_to_runtime(self, normalized: Dict[str, Any]) -> None: - if not isinstance(getattr(self.plugin, "config", None), dict): - raise RuntimeError("插件 config 不可写") - for key, value in normalized.items(): - _nested_set(self.plugin.config, key, value) - plugin_cfg = getattr(self.plugin, "_plugin_config", None) - if isinstance(plugin_cfg, dict): - for key, value in normalized.items(): - _nested_set(plugin_cfg, key, value) - - async def apply_profile(self, profile: Dict[str, Any], *, reason: str = "manual") -> Dict[str, Any]: - normalized = self._normalize_profile(profile) - current = self.get_profile_snapshot() - self._rollback_snapshot = current - self._apply_profile_to_runtime(normalized) - return { - "applied": normalized, - "rollback_snapshot": current, - "reason": reason, - "applied_at": _now(), - } - - async def rollback_profile(self) -> Dict[str, Any]: - if not self._rollback_snapshot: - raise ValueError("暂无可回滚的参数快照") - target = self._normalize_profile(self._rollback_snapshot, fallback=self._rollback_snapshot) - self._apply_profile_to_runtime(target) - return {"rolled_back_to": target, "rolled_back_at": _now()} - - def export_toml_snippet(self, profile: Optional[Dict[str, Any]] = None) -> str: - p = self._normalize_profile(profile or self.get_profile_snapshot()) - r = p["retrieval"] - t = p["threshold"] - lines = [ - "[retrieval]", - f"top_k_paragraphs = {int(r['top_k_paragraphs'])}", - f"top_k_relations = {int(r['top_k_relations'])}", - f"top_k_final = {int(r['top_k_final'])}", - f"alpha = {float(r['alpha']):.4f}", - f"enable_ppr = {str(bool(r['enable_ppr'])).lower()}", - "", - "[retrieval.search.smart_fallback]", - f"enabled = {str(bool(r['search']['smart_fallback']['enabled'])).lower()}", - "", - "[retrieval.sparse]", - f"enabled = {str(bool(r['sparse']['enabled'])).lower()}", - f"mode = \"{r['sparse']['mode']}\"", - f"candidate_k = {int(r['sparse']['candidate_k'])}", - f"relation_candidate_k = {int(r['sparse']['relation_candidate_k'])}", - "", - "[retrieval.fusion]", - f"method = \"{r['fusion']['method']}\"", - f"rrf_k = {int(r['fusion']['rrf_k'])}", - f"vector_weight = {float(r['fusion']['vector_weight']):.4f}", - f"bm25_weight = {float(r['fusion']['bm25_weight']):.4f}", - "", - "[threshold]", - f"percentile = {float(t['percentile']):.4f}", - f"min_results = {int(t['min_results'])}", - ] - return "\n".join(lines).strip() + "\n" - - def _pending_task_count(self) -> int: - return sum(1 for t in self._tasks.values() if t.status in {"queued", "running", "cancel_requested"}) - - def _sample_triples_for_query_set( - self, - *, - triples: List[Tuple[Any, Any, Any, Any]], - sample_size: int, - seed: int, - ) -> Tuple[List[Tuple[str, str, str, str]], Dict[str, Any]]: - normalized: List[Tuple[str, str, str, str]] = [] - for row in triples: - try: - subject, predicate, obj, rel_hash = row - except Exception: - continue - relation_hash = str(rel_hash or "").strip() - if not relation_hash: - continue - normalized.append((str(subject or ""), str(predicate or ""), str(obj or ""), relation_hash)) - - if not normalized: - return [], {"error": "no_relations"} - - target = min(max(4, int(sample_size)), len(normalized)) - predicate_counter = Counter([str(x[1] or "").strip() or "__empty__" for x in normalized]) - entity_counter = Counter() - for subj, _, obj, _ in normalized: - entity_counter.update([str(subj or "").strip().lower() or "__empty__"]) - entity_counter.update([str(obj or "").strip().lower() or "__empty__"]) - - if target >= len(normalized): - return list(normalized), { - "strategy": "all", - "sample_size": int(target), - "total_triples": int(len(normalized)), - "predicate_total": int(len(predicate_counter)), - "predicate_sampled": int(len(predicate_counter)), - } - - rng = random.Random(f"{seed}:triple_sample") - by_predicate: Dict[str, List[int]] = {} - for idx, (_, predicate, _, _) in enumerate(normalized): - key = str(predicate or "").strip() or "__empty__" - by_predicate.setdefault(key, []).append(idx) - for pool in by_predicate.values(): - rng.shuffle(pool) - - predicate_order = sorted(by_predicate.keys()) - rng.shuffle(predicate_order) - - selected: List[int] = [] - selected_set = set() - - # First pass: predicate round-robin to avoid head predicate dominating query set. - while len(selected) < target: - progressed = False - for key in predicate_order: - pool = by_predicate.get(key, []) - if not pool: - continue - idx = int(pool.pop()) - if idx in selected_set: - continue - selected.append(idx) - selected_set.add(idx) - progressed = True - if len(selected) >= target: - break - if not progressed: - break - - if len(selected) < target: - remain = [idx for idx in range(len(normalized)) if idx not in selected_set] - rng.shuffle(remain) - - # Second pass: prefer lower-frequency entities and predicates for better diversity. - def _remain_score(idx: int) -> Tuple[int, int]: - subj, predicate, obj, _ = normalized[idx] - subject_freq = int(entity_counter.get(str(subj or "").strip().lower() or "__empty__", 0)) - object_freq = int(entity_counter.get(str(obj or "").strip().lower() or "__empty__", 0)) - pred_freq = int(predicate_counter.get(str(predicate or "").strip() or "__empty__", 0)) - return (subject_freq + object_freq, pred_freq) - - remain = sorted(remain, key=_remain_score) - need = target - len(selected) - for idx in remain[:need]: - selected.append(idx) - selected_set.add(idx) - - selected = selected[:target] - sampled = [normalized[idx] for idx in selected] - sampled_predicates = {str(x[1] or "").strip() or "__empty__" for x in sampled} - - return sampled, { - "strategy": "predicate_round_robin_entity_diversity", - "sample_size": int(target), - "total_triples": int(len(normalized)), - "predicate_total": int(len(predicate_counter)), - "predicate_sampled": int(len(sampled_predicates)), - } - - def _select_round_eval_cases( - self, - *, - cases: List[RetrievalQueryCase], - intensity: str, - round_index: int, - seed: int, - ) -> List[RetrievalQueryCase]: - if not cases: - return [] - mode = str(intensity or "standard").strip().lower() - if mode not in INTENSITIES: - mode = "standard" - if mode == "deep": - return list(cases) - - if mode == "quick": - ratio = 0.45 - min_total = 16 - else: - ratio = 0.70 - min_total = 24 - - total = len(cases) - target = max(min_total, int(total * ratio)) - if target >= total: - return list(cases) - - rng = random.Random(f"{seed}:{round_index}:subset") - by_cat: Dict[str, List[RetrievalQueryCase]] = {} - for item in cases: - by_cat.setdefault(str(item.category), []).append(item) - - selected: List[RetrievalQueryCase] = [] - selected_ids = set() - cat_names = sorted([x for x in by_cat.keys() if x in CATEGORIES]) - if not cat_names: - cat_names = sorted(by_cat.keys()) - per_cat = max(1, target // max(1, len(cat_names))) - - for cat in cat_names: - pool = by_cat.get(cat, []) - if not pool: - continue - picked = list(pool) if len(pool) <= per_cat else rng.sample(pool, per_cat) - for item in picked: - if item.case_id in selected_ids: - continue - selected.append(item) - selected_ids.add(item.case_id) - - if len(selected) < target: - remain = [x for x in cases if x.case_id not in selected_ids] - if len(remain) > (target - len(selected)): - remain = rng.sample(remain, target - len(selected)) - for item in remain: - selected.append(item) - selected_ids.add(item.case_id) - - return selected[:target] - - async def _ensure_worker(self) -> None: - async with self._lock: - if self._worker_task and not self._worker_task.done(): - return - self._stopping = False - self._worker_task = asyncio.create_task(self._worker_loop()) - - async def shutdown(self) -> None: - self._stopping = True - worker = self._worker_task - if worker is None or worker.done(): - return - worker.cancel() - try: - await worker - except asyncio.CancelledError: - pass - except Exception as e: - logger.warning(f"Retrieval tuning worker shutdown failed: {e}") - - async def create_task(self, payload: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("检索调优中心已禁用") - self._ensure_ready() - - data = payload or {} - objective = str(data.get("objective") or self._cfg("web.tuning.default_objective", "precision_priority")) - if objective not in OBJECTIVES: - raise ValueError(f"objective 非法: {objective}") - - intensity = str(data.get("intensity") or self._cfg("web.tuning.default_intensity", "standard")) - if intensity not in INTENSITIES: - raise ValueError(f"intensity 非法: {intensity}") - - rounds_total = _clamp_int(data.get("rounds", INTENSITIES[intensity]), INTENSITIES[intensity], 1, 200) - sample_size = _clamp_int(data.get("sample_size", self._cfg("web.tuning.default_sample_size", 24)), 24, 4, 500) - top_k_eval = _clamp_int(data.get("top_k_eval", self._cfg("web.tuning.default_top_k_eval", 20)), 20, 5, 100) - eval_query_timeout_seconds = _clamp_float( - data.get("eval_query_timeout_seconds", self._eval_query_timeout_s()), - self._eval_query_timeout_s(), - 0.01, - 120.0, - ) - llm_enabled = _coerce_bool(data.get("llm_enabled", True), True) - seed = data.get("seed") - try: - seed = int(seed) - except Exception: - seed = int(time.time()) % 1000003 - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("调优任务队列已满,请稍后重试") - task = RetrievalTuningTaskRecord( - task_id=uuid.uuid4().hex, - status="queued", - progress=0.0, - objective=objective, - intensity=intensity, - rounds_total=rounds_total, - params={ - "sample_size": sample_size, - "top_k_eval": top_k_eval, - "eval_query_timeout_seconds": float(eval_query_timeout_seconds), - "llm_enabled": llm_enabled, - "seed": seed, - }, - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - task.updated_at = _now() - - await self._ensure_worker() - return task.to_summary() - - async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: - limit = _clamp_int(limit, 50, 1, 500) - async with self._lock: - items: List[Dict[str, Any]] = [] - for task_id in list(self._task_order)[:limit]: - task = self._tasks.get(task_id) - if task: - items.append(task.to_summary()) - return items - - async def get_task(self, task_id: str, include_rounds: bool = False) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - return task.to_detail(include_rounds=include_rounds) - - async def get_rounds(self, task_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: - offset = max(0, int(offset)) - limit = _clamp_int(limit, 50, 1, 500) - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - total = len(task.rounds) - sliced = task.rounds[offset : offset + limit] - return { - "total": total, - "offset": offset, - "limit": limit, - "items": [item.to_dict() for item in sliced], - } - - async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - if task.status in {"completed", "failed", "cancelled"}: - return task.to_summary() - if task.status == "queued": - task.status = "cancelled" - task.cancel_requested = True - task.finished_at = _now() - task.updated_at = task.finished_at - self._queue = deque([x for x in self._queue if x != task_id]) - return task.to_summary() - task.status = "cancel_requested" - task.cancel_requested = True - task.updated_at = _now() - return task.to_summary() - - async def apply_best(self, task_id: str) -> Dict[str, Any]: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - raise ValueError("任务不存在") - if task.status != "completed": - raise ValueError("任务未完成,无法应用最优参数") - if not task.best_profile: - raise ValueError("任务没有可应用的最优参数") - best = copy.deepcopy(task.best_profile) - applied = await self.apply_profile(best, reason=f"task:{task_id}:apply_best") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.apply_log.append({"applied_at": _now(), "reason": "apply_best", "profile": best}) - task.updated_at = _now() - return applied - - async def get_report(self, task_id: str, fmt: str = "md") -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return None - artifacts = dict(task.artifact_paths) - fmt = str(fmt or "md").strip().lower() - if fmt not in {"md", "json"}: - fmt = "md" - path_key = "report_md" if fmt == "md" else "report_json" - path = artifacts.get(path_key) - if not path: - return {"format": fmt, "content": "", "path": ""} - p = Path(path) - if not p.exists(): - return {"format": fmt, "content": "", "path": str(p)} - try: - content = p.read_text(encoding="utf-8") - except Exception: - content = "" - return {"format": fmt, "content": content, "path": str(p)} - - async def _worker_loop(self) -> None: - while not self._stopping: - task_id: Optional[str] = None - async with self._lock: - while self._queue: - candidate = self._queue.popleft() - task = self._tasks.get(candidate) - if task is None: - continue - if task.status != "queued": - continue - task_id = candidate - self._active_task_id = candidate - break - - if not task_id: - await asyncio.sleep(self._poll_interval_s()) - continue - - try: - await self._run_task(task_id) - except asyncio.CancelledError: - raise - except Exception as e: - logger.error(f"Retrieval tuning task crashed: task_id={task_id}, err={e}") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.status = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = task.finished_at - finally: - async with self._lock: - if self._active_task_id == task_id: - self._active_task_id = None - - async def _run_task(self, task_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.status = "running" - task.started_at = _now() - task.updated_at = task.started_at - - artifacts_dir = self._artifacts_root / task_id - artifacts_dir.mkdir(parents=True, exist_ok=True) - query_set_path = artifacts_dir / "query_set.json" - rounds_path = artifacts_dir / "round_metrics.jsonl" - best_profile_path = artifacts_dir / "best_profile.json" - report_json_path = artifacts_dir / "report.json" - report_md_path = artifacts_dir / "report.md" - - try: - params = dict(task.params) - cases, stats = await self._build_query_set( - sample_size=int(params["sample_size"]), - seed=int(params["seed"]), - llm_enabled=bool(params.get("llm_enabled", True)), - ) - if not cases: - raise ValueError("当前知识库样本不足,无法构建调优测试集") - - query_set_path.write_text( - json.dumps( - { - "task_id": task_id, - "created_at": _now(), - "stats": stats, - "items": [c.to_dict() for c in cases], - }, - ensure_ascii=False, - indent=2, - ), - encoding="utf-8", - ) - - baseline_profile = self.get_profile_snapshot() - top_k_eval = int(params["top_k_eval"]) - baseline_eval = await self._evaluate_profile( - profile=baseline_profile, - cases=cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - baseline_round = RetrievalTuningRoundRecord( - round_index=0, - candidate_profile=baseline_profile, - metrics=baseline_eval["metrics"], - score=float(baseline_eval["score"]), - latency_ms=float(baseline_eval["avg_elapsed_ms"]), - failure_summary=baseline_eval["failure_summary"], - ) - rounds_path.write_text(json.dumps(baseline_round.to_dict(), ensure_ascii=False) + "\n", encoding="utf-8") - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.query_set_stats = stats - task.baseline_profile = copy.deepcopy(baseline_profile) - task.baseline_metrics = copy.deepcopy(baseline_eval["metrics"]) - task.rounds.append(baseline_round) - task.best_profile = copy.deepcopy(baseline_profile) - task.best_metrics = copy.deepcopy(baseline_eval["metrics"]) - task.best_score = float(baseline_eval["score"]) - task.progress = 0.0 - task.updated_at = _now() - - best_profile = copy.deepcopy(baseline_profile) - best_metrics = copy.deepcopy(baseline_eval["metrics"]) - best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) - best_score = float(baseline_eval["score"]) - llm_suggestions: List[Dict[str, Any]] = [] - task_cancelled = False - - for round_idx in range(1, int(task.rounds_total) + 1): - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - if task.cancel_requested or task.status == "cancel_requested": - task.status = "cancelled" - task.finished_at = _now() - task.updated_at = task.finished_at - task_cancelled = True - break - - if round_idx == 1 or (round_idx % 5 == 0 and not llm_suggestions): - llm_suggestions = await self._suggest_profiles_with_llm( - base_profile=best_profile, - failure_summary=best_failure_summary, - objective=task.objective, - max_count=3, - enabled=bool(params.get("llm_enabled", True)), - ) - - candidate_profile = self._generate_candidate_profile( - task_id=task_id, - round_index=round_idx, - objective=task.objective, - baseline_profile=baseline_profile, - best_profile=best_profile, - llm_suggestions=llm_suggestions, - ) - eval_cases = self._select_round_eval_cases( - cases=cases, - intensity=task.intensity, - round_index=round_idx, - seed=int(params.get("seed", 0)), - ) - eval_result = await self._evaluate_profile( - profile=candidate_profile, - cases=eval_cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - round_record = RetrievalTuningRoundRecord( - round_index=round_idx, - candidate_profile=candidate_profile, - metrics=eval_result["metrics"], - score=float(eval_result["score"]), - latency_ms=float(eval_result["avg_elapsed_ms"]), - failure_summary=eval_result["failure_summary"], - ) - with rounds_path.open("a", encoding="utf-8") as fp: - fp.write(json.dumps(round_record.to_dict(), ensure_ascii=False) + "\n") - - if float(eval_result["score"]) > float(best_score): - best_score = float(eval_result["score"]) - best_profile = copy.deepcopy(candidate_profile) - best_metrics = copy.deepcopy(eval_result["metrics"]) - best_failure_summary = copy.deepcopy(eval_result["failure_summary"]) - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - task.rounds_done = round_idx - task.rounds.append(round_record) - task.best_profile = copy.deepcopy(best_profile) - task.best_metrics = copy.deepcopy(best_metrics) - task.best_score = float(best_score) - task.progress = min(1.0, float(round_idx) / float(task.rounds_total)) - task.updated_at = _now() - - if best_profile and (not task_cancelled): - # 候选轮可能基于子样本评估,收官时用全量样本复核,确保最终指标可解释。 - best_full = await self._evaluate_profile( - profile=best_profile, - cases=cases, - objective=task.objective, - top_k_eval=top_k_eval, - query_timeout_s=float(params.get("eval_query_timeout_seconds") or self._eval_query_timeout_s()), - ) - best_profile = copy.deepcopy(best_profile) - best_metrics = copy.deepcopy(best_full["metrics"]) - best_failure_summary = copy.deepcopy(best_full["failure_summary"]) - best_score = float(best_full["score"]) - if best_score < float(baseline_eval["score"]): - best_profile = copy.deepcopy(baseline_profile) - best_metrics = copy.deepcopy(baseline_eval["metrics"]) - best_failure_summary = copy.deepcopy(baseline_eval["failure_summary"]) - best_score = float(baseline_eval["score"]) - - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.best_profile = copy.deepcopy(best_profile) - task.best_metrics = copy.deepcopy(best_metrics) - task.best_score = float(best_score) - task.updated_at = _now() - - async with self._lock: - task = self._tasks.get(task_id) - if task is None: - return - if task.status not in {"cancelled", "failed"}: - task.status = "completed" - task.progress = 1.0 - task.finished_at = _now() - task.updated_at = task.finished_at - final_task = copy.deepcopy(task) - - if final_task.status == "completed": - best_profile_path.write_text(json.dumps(final_task.best_profile, ensure_ascii=False, indent=2), encoding="utf-8") - report_payload = self._build_report_payload(final_task) - report_json_path.write_text(json.dumps(report_payload, ensure_ascii=False, indent=2), encoding="utf-8") - report_md_path.write_text(self._build_report_markdown(final_task, report_payload), encoding="utf-8") - - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.artifact_paths = { - "query_set": str(query_set_path), - "round_metrics_jsonl": str(rounds_path), - "best_profile": str(best_profile_path), - "report_json": str(report_json_path), - "report_md": str(report_md_path), - } - task.updated_at = _now() - except Exception as e: - logger.error(f"Retrieval tuning task failed: task_id={task_id}, err={e}") - async with self._lock: - task = self._tasks.get(task_id) - if task is not None: - task.status = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = task.finished_at - - async def _build_query_set(self, *, sample_size: int, seed: int, llm_enabled: bool) -> Tuple[List[RetrievalQueryCase], Dict[str, Any]]: - store = getattr(self.plugin, "metadata_store", None) - if store is None: - return [], {"error": "metadata_store_unavailable"} - - triples = list(store.get_all_triples() or []) - if not triples: - return [], {"error": "no_relations"} - - sampled, sample_info = self._sample_triples_for_query_set( - triples=triples, - sample_size=sample_size, - seed=seed, - ) - if not sampled: - return [], {"error": "no_relations"} - - anchors: List[Dict[str, Any]] = [] - for idx, row in enumerate(sampled): - subject, predicate, obj, relation_hash = row - paragraphs = store.get_paragraphs_by_relation(relation_hash) - para_hash = "" - para_content = "" - if paragraphs: - para_hash = str(paragraphs[0].get("hash") or "").strip() - para_content = str(paragraphs[0].get("content") or "") - anchors.append( - { - "anchor_id": f"a{idx+1:04d}", - "subject": str(subject or ""), - "predicate": str(predicate or ""), - "object": str(obj or ""), - "relation_hash": relation_hash, - "paragraph_hash": para_hash, - "paragraph_excerpt": para_content[:300], - } - ) - - if not anchors: - return [], {"error": "no_anchors"} - - predicate_groups: Dict[str, List[Dict[str, Any]]] = {} - for anchor in anchors: - predicate_groups.setdefault(str(anchor.get("predicate") or ""), []).append(anchor) - - nl_queries = await self._generate_nl_queries_with_llm(anchors, enabled=llm_enabled) - cases: List[RetrievalQueryCase] = [] - - seq = 0 - for anchor in anchors: - seq += 1 - subject = anchor["subject"] - predicate = anchor["predicate"] - obj = anchor["object"] - rel_hash = anchor["relation_hash"] - para_hash = anchor["paragraph_hash"] - expected = [rel_hash] - if para_hash: - expected.append(para_hash) - aid = anchor["anchor_id"] - - common_meta = { - "anchor_id": aid, - "relation_hash": rel_hash, - "paragraph_hash": para_hash, - "subject": subject, - "predicate": predicate, - "object": obj, - } - cases.append( - RetrievalQueryCase( - case_id=f"spo_relation_{seq:04d}", - category="spo_relation", - query=f"{subject}|{predicate}|{obj}", - expected_hashes=[rel_hash], - expected_spo={"subject": subject, "predicate": predicate, "object": obj}, - metadata=dict(common_meta), - ) - ) - cases.append( - RetrievalQueryCase( - case_id=f"spo_search_{seq:04d}", - category="spo_search", - query=self._build_spo_search_query( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ), - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - cases.append( - RetrievalQueryCase( - case_id=f"query_kw_{seq:04d}", - category="query_kw", - query=self._build_keyword_query( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ), - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - nl_query = nl_queries.get(aid) or self._build_nl_template( - anchor=anchor, - seq=seq, - predicate_groups=predicate_groups, - ) - cases.append( - RetrievalQueryCase( - case_id=f"query_nl_{seq:04d}", - category="query_nl", - query=nl_query, - expected_hashes=list(expected), - metadata=dict(common_meta), - ) - ) - - counts = Counter([c.category for c in cases]) - stats = { - "anchors": len(anchors), - "case_total": len(cases), - "category_counts": {k: int(v) for k, v in counts.items()}, - "seed": int(seed), - "sample_size": int(sample_info.get("sample_size", len(anchors))), - "sampling": dict(sample_info), - "llm_nl_enabled": bool(llm_enabled), - "llm_nl_generated": int(len(nl_queries)), - } - return cases, stats - - def _pick_contrast_anchor( - self, - *, - anchor: Dict[str, Any], - predicate_groups: Dict[str, List[Dict[str, Any]]], - seq: int, - ) -> Optional[Dict[str, Any]]: - predicate = str(anchor.get("predicate") or "") - pool = predicate_groups.get(predicate, []) - if not pool: - return None - candidates = [x for x in pool if x is not anchor and str(x.get("object") or "") != str(anchor.get("object") or "")] - if not candidates: - return None - return candidates[seq % len(candidates)] - - def _build_spo_search_query( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - - variants = [ - f"{subject} {predicate} {obj}", - f"{subject} {obj} relation {predicate}", - f"{predicate} {subject} {obj} evidence", - f"{subject} {predicate} {obj} not {contrast_obj}".strip(), - ] - return variants[seq % len(variants)].strip() - - def _build_keyword_query( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - excerpt = str(anchor.get("paragraph_excerpt") or "") - tokens = re.findall(r"[A-Za-z0-9_\u4e00-\u9fff]{2,}", excerpt) - extras: List[str] = [] - seen = set() - for token in tokens: - key = token.lower() - if key in seen: - continue - if key in {subject.lower(), predicate.lower(), obj.lower()}: - continue - seen.add(key) - extras.append(token) - if len(extras) >= 2: - break - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - - variants = [ - [subject, obj] + extras[:2], - [predicate, obj] + extras[:2], - [subject, predicate] + extras[:2], - [subject, obj, predicate, contrast_obj] + extras[:1], - ] - parts = variants[seq % len(variants)] - return " ".join([x for x in parts if x]).strip() - - def _build_nl_template( - self, - *, - anchor: Dict[str, Any], - seq: int, - predicate_groups: Dict[str, List[Dict[str, Any]]], - ) -> str: - subject = str(anchor.get("subject") or "") - predicate = str(anchor.get("predicate") or "") - obj = str(anchor.get("object") or "") - contrast = self._pick_contrast_anchor(anchor=anchor, predicate_groups=predicate_groups, seq=seq) - contrast_obj = str(contrast.get("object") or "").strip() if contrast else "" - templates = [ - f"请问 {subject} 与 {obj} 的关系是什么,是否是“{predicate}”?", - f"在当前知识库中,哪条信息说明 {subject} 对应的是 {obj},关系词接近“{predicate}”?", - f"我想确认:{subject} 和 {obj} 之间是不是“{predicate}”这层关系,而不是 {contrast_obj}?", - f"帮我查一下关于 {subject} 与 {obj} 的证据,重点看 {predicate} 相关描述。", - ] - return templates[seq % len(templates)] - - async def _select_llm_model(self) -> Optional[Any]: - if llm_api is None: - return None - try: - models = llm_api.get_available_models() or {} - except Exception: - return None - if not models: - return None - - cfg_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() - if cfg_model.lower() != "auto" and cfg_model in models: - return models[cfg_model] - for task_name in ["utils", "planner", "tool_use", "replyer", "embedding"]: - if task_name in models: - return models[task_name] - return models[next(iter(models))] - - async def _llm_call_text(self, prompt: str, *, request_type: str) -> str: - if llm_api is None: - raise RuntimeError("llm_api unavailable") - model_cfg = await self._select_llm_model() - if model_cfg is None: - raise RuntimeError("no_llm_model") - - retry = self._llm_retry_cfg() - max_attempts = int(retry["max_attempts"]) - min_wait = float(retry["min_wait_seconds"]) - max_wait = float(retry["max_wait_seconds"]) - backoff = float(retry["backoff_multiplier"]) - - last_error: Optional[Exception] = None - for idx in range(max_attempts): - try: - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_cfg, - request_type=request_type, - ) - if not success: - raise RuntimeError("llm_generation_failed") - text = str(response or "").strip() - if text: - return text - raise RuntimeError("empty_llm_response") - except Exception as e: - last_error = e - if idx >= max_attempts - 1: - break - delay = min(max_wait, min_wait * (backoff ** idx)) - await asyncio.sleep(max(0.05, delay)) - raise RuntimeError(f"LLM call failed: {last_error}") - - async def _generate_nl_queries_with_llm(self, anchors: List[Dict[str, Any]], *, enabled: bool) -> Dict[str, str]: - if not enabled or llm_api is None or not anchors: - return {} - payload = [ - { - "anchor_id": x["anchor_id"], - "subject": x["subject"], - "predicate": x["predicate"], - "object": x["object"], - "paragraph_excerpt": x["paragraph_excerpt"][:180], - } - for x in anchors[:60] - ] - prompt = ( - "你是检索评估问题生成器。" - "请基于给定 SPO 与简短上下文,为每条样本生成 1 条自然语言检索问题,返回 JSON:" - "{\"items\":[{\"anchor_id\":\"...\",\"query\":\"...\"}]}。\n" - f"样本:\n{json.dumps(payload, ensure_ascii=False)}" - ) - try: - raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.NLCaseGen") - obj = _safe_json_loads(raw) - if not isinstance(obj, dict): - return {} - items = obj.get("items") - if not isinstance(items, list): - return {} - out: Dict[str, str] = {} - for row in items: - if not isinstance(row, dict): - continue - anchor_id = str(row.get("anchor_id") or "").strip() - query = str(row.get("query") or "").strip() - if anchor_id and query: - out[anchor_id] = query - return out - except Exception: - return {} - - async def _suggest_profiles_with_llm( - self, - *, - base_profile: Dict[str, Any], - failure_summary: Dict[str, Any], - objective: str, - max_count: int, - enabled: bool, - ) -> List[Dict[str, Any]]: - if not enabled or llm_api is None or max_count <= 0: - return [] - prompt = ( - "你是检索调参专家。" - "请基于基础参数与失败摘要,给出最多 " - f"{int(max_count)} 组候选参数,返回 JSON: {{\"profiles\": [ ... ]}}。\n" - "字段仅可包含:retrieval.top_k_paragraphs, retrieval.top_k_relations, retrieval.top_k_final, " - "retrieval.alpha, retrieval.enable_ppr, retrieval.search.smart_fallback.enabled, " - "retrieval.sparse.enabled, retrieval.sparse.mode, retrieval.sparse.candidate_k, retrieval.sparse.relation_candidate_k, " - "retrieval.fusion.method, retrieval.fusion.rrf_k, retrieval.fusion.vector_weight, retrieval.fusion.bm25_weight, " - "threshold.percentile, threshold.min_results。\n" - f"objective={objective}\n" - f"base={json.dumps(base_profile, ensure_ascii=False)}\n" - f"failure_summary={json.dumps(failure_summary, ensure_ascii=False)}" - ) - try: - raw = await self._llm_call_text(prompt, request_type="A_Memorix.RetrievalTuning.ProfileSuggest") - obj = _safe_json_loads(raw) - if not isinstance(obj, dict): - return [] - profiles = obj.get("profiles") - if not isinstance(profiles, list): - return [] - out = [] - for item in profiles[:max_count]: - if isinstance(item, dict): - out.append(self._normalize_profile(item, fallback=base_profile)) - return out - except Exception: - return [] - - def _generate_candidate_profile( - self, - *, - task_id: str, - round_index: int, - objective: str, - baseline_profile: Dict[str, Any], - best_profile: Dict[str, Any], - llm_suggestions: List[Dict[str, Any]], - ) -> Dict[str, Any]: - if llm_suggestions: - return self._normalize_profile(llm_suggestions.pop(0), fallback=best_profile) - - rng = random.Random(f"{task_id}:{round_index}") - base = baseline_profile if round_index % 4 == 1 else best_profile - candidate = copy.deepcopy(base) - - if objective == "precision_priority": - para_choices = [40, 80, 120, 180, 240, 320] - rel_choices = [4, 8, 12, 16, 24] - final_choices = [4, 8, 12, 16, 20, 32, 48, 64] - alpha_choices = [0.0, 0.35, 0.50, 0.62, 0.72, 0.82, 0.90] - pct_choices = [55, 60, 65, 72, 80] - min_results_choices = [1, 2] - elif objective == "recall_priority": - para_choices = [120, 220, 300, 420, 560, 720] - rel_choices = [8, 12, 16, 24, 32] - final_choices = [8, 16, 32, 48, 64, 96, 128] - alpha_choices = [0.20, 0.35, 0.45, 0.55, 0.65, 0.75] - pct_choices = [40, 48, 55, 62] - min_results_choices = [1, 2, 3] - else: - para_choices = [80, 160, 240, 320, 420, 520] - rel_choices = [6, 10, 14, 18, 24, 30] - final_choices = [6, 12, 20, 32, 48, 64, 80] - alpha_choices = [0.25, 0.45, 0.55, 0.65, 0.75, 0.85] - pct_choices = [48, 55, 62, 70] - min_results_choices = [1, 2, 3] - - _nested_set(candidate, "retrieval.top_k_paragraphs", rng.choice(para_choices)) - _nested_set(candidate, "retrieval.top_k_relations", rng.choice(rel_choices)) - _nested_set(candidate, "retrieval.top_k_final", rng.choice(final_choices)) - _nested_set(candidate, "retrieval.alpha", rng.choice(alpha_choices)) - # PPR 在 TestClient/异步评估场景下存在偶发长时阻塞风险,调优评估链路固定关闭。 - _nested_set(candidate, "retrieval.enable_ppr", False) - _nested_set(candidate, "retrieval.search.smart_fallback.enabled", bool(rng.choice([True, True, False]))) - _nested_set(candidate, "retrieval.sparse.enabled", bool(rng.choice([True, True, False]))) - _nested_set(candidate, "retrieval.sparse.mode", rng.choice(["auto", "hybrid", "fallback_only"])) - _nested_set(candidate, "retrieval.sparse.candidate_k", rng.choice([60, 80, 120, 160, 220, 320])) - _nested_set(candidate, "retrieval.sparse.relation_candidate_k", rng.choice([40, 60, 90, 120, 180, 260])) - _nested_set(candidate, "retrieval.fusion.method", rng.choice(["weighted_rrf", "weighted_rrf", "alpha_legacy"])) - _nested_set(candidate, "retrieval.fusion.rrf_k", rng.choice([30, 45, 60, 75, 90])) - vec_w = float(rng.choice([0.55, 0.65, 0.72, 0.80, 0.88])) - _nested_set(candidate, "retrieval.fusion.vector_weight", vec_w) - _nested_set(candidate, "retrieval.fusion.bm25_weight", 1.0 - vec_w) - _nested_set(candidate, "threshold.percentile", rng.choice(pct_choices)) - _nested_set(candidate, "threshold.min_results", rng.choice(min_results_choices)) - - return self._normalize_profile(candidate, fallback=base) - - def _build_runtime_config(self, normalized_profile: Dict[str, Any]) -> Dict[str, Any]: - raw_base = getattr(self.plugin, "config", {}) or {} - if isinstance(raw_base, dict): - base = { - key: value - for key, value in raw_base.items() - if key not in _RUNTIME_CONFIG_INSTANCE_KEYS - } - else: - base = {} - merged = _deep_merge(base, normalized_profile) - # 调优评估场景优先稳定性,避免并发访问共享 SQLite/Faiss 导致长时阻塞。 - _nested_set(merged, "retrieval.enable_parallel", False) - # 调优评估阶段关闭 PPR,规避 PageRank 线程计算偶发阻塞导致整轮卡死。 - _nested_set(merged, "retrieval.enable_ppr", False) - merged["vector_store"] = getattr(self.plugin, "vector_store", None) - merged["graph_store"] = getattr(self.plugin, "graph_store", None) - merged["metadata_store"] = getattr(self.plugin, "metadata_store", None) - merged["embedding_manager"] = getattr(self.plugin, "embedding_manager", None) - merged["sparse_index"] = getattr(self.plugin, "sparse_index", None) - merged["plugin_instance"] = self.plugin - return merged - - async def _evaluate_profile( - self, - *, - profile: Dict[str, Any], - cases: List[RetrievalQueryCase], - objective: str, - top_k_eval: int, - query_timeout_s: float, - ) -> Dict[str, Any]: - normalized = self._normalize_profile(profile) - eval_top_k = _clamp_int(top_k_eval, 20, 1, 1000) - # 评估时让 top_k_final 参与有效召回深度,避免该参数对评分无影响。 - request_top_k = min( - int(eval_top_k), - _clamp_int(_nested_get(normalized, "retrieval.top_k_final", eval_top_k), eval_top_k, 1, 512), - ) - eval_timeout_s = _clamp_float( - query_timeout_s, - self._eval_query_timeout_s(), - 0.01, - 120.0, - ) - runtime_cfg = self._build_runtime_config(normalized) - runtime = build_search_runtime( - plugin_config=runtime_cfg, - logger_obj=logger, - owner_tag="retrieval_tuning", - log_prefix="[RetrievalTuning]", - ) - if not runtime.ready: - metrics = { - "total_text_cases": 0, - "precision_at_1": 0.0, - "precision_at_3": 0.0, - "mrr": 0.0, - "recall_at_k": 0.0, - "spo_relation_hit_rate": 0.0, - "empty_rate": 1.0, - "avg_elapsed_ms": 0.0, - "category": {}, - "error": runtime.error or "runtime_not_ready", - } - return {"metrics": metrics, "score": -1.0, "avg_elapsed_ms": 0.0, "failure_summary": {"reason": metrics["error"]}} - - text_total = 0 - hit1 = 0 - hit3 = 0 - hitk = 0 - mrr_sum = 0.0 - empty_count = 0 - timeout_count = 0 - elapsed_total = 0.0 - text_failed: List[str] = [] - - spo_total = 0 - spo_hit = 0 - spo_failed: List[str] = [] - - category_stats: Dict[str, Dict[str, Any]] = {} - failed_predicates = Counter() - - for case in cases: - cat = str(case.category) - if cat not in CATEGORIES: - continue - if cat not in category_stats: - category_stats[cat] = { - "total": 0, - "hit": 0, - "hit_at_1": 0, - "hit_at_3": 0, - "empty": 0, - } - category_stats[cat]["total"] += 1 - - if cat == "spo_relation": - spo_total += 1 - spo = case.expected_spo or {} - rows = runtime.metadata_store.get_relations( - subject=str(spo.get("subject") or ""), - predicate=str(spo.get("predicate") or ""), - object=str(spo.get("object") or ""), - ) - expected_hash = str(case.expected_hashes[0]) if case.expected_hashes else "" - ok = False - for row in rows: - if not isinstance(row, dict): - continue - if expected_hash and str(row.get("hash") or "") == expected_hash: - ok = True - break - if not expected_hash: - ok = True - break - if ok: - spo_hit += 1 - category_stats[cat]["hit"] += 1 - category_stats[cat]["hit_at_1"] += 1 - category_stats[cat]["hit_at_3"] += 1 - else: - spo_failed.append(case.case_id) - failed_predicates.update([str(spo.get("predicate") or "").strip() or "__empty__"]) - continue - - text_total += 1 - req = SearchExecutionRequest( - caller="retrieval_tuning", - query_type="search", - query=str(case.query or "").strip(), - top_k=int(request_top_k), - use_threshold=True, - # 调优评估固定关闭 PPR,避免该链路阻塞拖挂整轮任务。 - enable_ppr=False, - ) - try: - execution = await asyncio.wait_for( - SearchExecutionService.execute( - retriever=runtime.retriever, - threshold_filter=runtime.threshold_filter, - plugin_config=runtime_cfg, - request=req, - enforce_chat_filter=False, - reinforce_access=False, - ), - timeout=float(eval_timeout_s), - ) - except asyncio.TimeoutError: - timeout_count += 1 - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - if execution is None: - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - elapsed_total += float(getattr(execution, "elapsed_ms", 0.0) or 0.0) - - if not bool(getattr(execution, "success", False)): - empty_count += 1 - category_stats[cat]["empty"] += 1 - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - continue - - hashes = [str(getattr(x, "hash_value", "") or "") for x in (getattr(execution, "results", None) or [])] - if not hashes: - empty_count += 1 - category_stats[cat]["empty"] += 1 - - expected_set = set(case.expected_hashes or []) - rank = 0 - for idx, hv in enumerate(hashes, start=1): - if hv and hv in expected_set: - rank = idx - break - - if rank > 0: - category_stats[cat]["hit"] += 1 - hitk += 1 - if rank <= 1: - hit1 += 1 - category_stats[cat]["hit_at_1"] += 1 - if rank <= 3: - hit3 += 1 - category_stats[cat]["hit_at_3"] += 1 - mrr_sum += 1.0 / float(rank) - else: - text_failed.append(case.case_id) - failed_predicates.update([str(case.metadata.get("predicate") or "__unknown__")]) - - p1 = (hit1 / text_total) if text_total else 0.0 - p3 = (hit3 / text_total) if text_total else 0.0 - recall = (hitk / text_total) if text_total else 0.0 - mrr = (mrr_sum / text_total) if text_total else 0.0 - spo_rate = (spo_hit / spo_total) if spo_total else 0.0 - empty_rate = (empty_count / text_total) if text_total else 1.0 - avg_elapsed = (elapsed_total / text_total) if text_total else 0.0 - - metrics = { - "total_text_cases": int(text_total), - "precision_at_1": float(round(p1, 6)), - "precision_at_3": float(round(p3, 6)), - "mrr": float(round(mrr, 6)), - "recall_at_k": float(round(recall, 6)), - "spo_relation_hit_rate": float(round(spo_rate, 6)), - "empty_rate": float(round(empty_rate, 6)), - "timeout_count": int(timeout_count), - "avg_elapsed_ms": float(round(avg_elapsed, 3)), - "category": category_stats, - } - metrics["category_floor_penalty"] = float(round(self._category_floor_penalty(metrics, objective=objective), 6)) - - score = self._score_metrics(metrics, objective=objective) - failure_summary = { - "text_failed_count": len(text_failed), - "spo_failed_count": len(spo_failed), - "failed_case_ids": text_failed[:50] + spo_failed[:50], - "failed_by_category": {k: int(v["total"] - v["hit"]) for k, v in category_stats.items()}, - "top_failed_predicates": [ - {"predicate": key, "count": int(cnt)} - for key, cnt in failed_predicates.most_common(5) - if key - ], - "query_timeout_seconds": float(eval_timeout_s), - "timeout_count": int(timeout_count), - "effective_top_k": int(request_top_k), - "ppr_forced_disabled": True, - } - return { - "metrics": metrics, - "score": float(round(score, 6)), - "avg_elapsed_ms": float(avg_elapsed), - "failure_summary": failure_summary, - } - - def _score_metrics(self, metrics: Dict[str, Any], *, objective: str) -> float: - p1 = float(metrics.get("precision_at_1", 0.0) or 0.0) - p3 = float(metrics.get("precision_at_3", 0.0) or 0.0) - mrr = float(metrics.get("mrr", 0.0) or 0.0) - recall = float(metrics.get("recall_at_k", 0.0) or 0.0) - spo = float(metrics.get("spo_relation_hit_rate", 0.0) or 0.0) - empty_rate = float(metrics.get("empty_rate", 1.0) or 1.0) - category_penalty = metrics.get("category_floor_penalty", None) - if category_penalty is None: - category_penalty = self._category_floor_penalty(metrics, objective=objective) - category_penalty = float(max(0.0, category_penalty)) - - if objective == "recall_priority": - raw = 0.15 * p1 + 0.15 * p3 + 0.15 * mrr + 0.40 * recall + 0.15 * spo - penalty = 0.05 * empty_rate - elif objective == "balanced": - raw = 0.25 * p1 + 0.20 * p3 + 0.15 * mrr + 0.25 * recall + 0.15 * spo - penalty = 0.10 * empty_rate - else: - raw = 0.40 * p1 + 0.20 * p3 + 0.15 * mrr + 0.15 * recall + 0.10 * spo - penalty = 0.15 * empty_rate - return float(raw - penalty - category_penalty) - - def _category_floor_penalty(self, metrics: Dict[str, Any], *, objective: str) -> float: - category = metrics.get("category") - if not isinstance(category, dict) or not category: - return 0.0 - - if objective == "recall_priority": - floors = {"query_nl": 0.60, "query_kw": 0.48, "spo_search": 0.52, "spo_relation": 0.88} - scale = 0.12 - elif objective == "balanced": - floors = {"query_nl": 0.65, "query_kw": 0.52, "spo_search": 0.55, "spo_relation": 0.90} - scale = 0.18 - else: - floors = {"query_nl": 0.70, "query_kw": 0.55, "spo_search": 0.58, "spo_relation": 0.92} - scale = 0.25 - - weights = {"query_nl": 1.0, "query_kw": 1.1, "spo_search": 1.0, "spo_relation": 1.2} - weighted_shortfall = 0.0 - weight_total = 0.0 - - for cat, floor in floors.items(): - row = category.get(cat) - if not isinstance(row, dict): - continue - total = int(row.get("total", 0) or 0) - if total <= 0: - continue - hit = float(row.get("hit", 0.0) or 0.0) - hit_rate = max(0.0, min(1.0, hit / float(max(1, total)))) - shortfall = max(0.0, float(floor) - hit_rate) - w = float(weights.get(cat, 1.0)) - weighted_shortfall += w * shortfall - weight_total += w - - if weight_total <= 1e-9: - return 0.0 - return float(scale * (weighted_shortfall / weight_total)) - - def _build_report_payload(self, task: RetrievalTuningTaskRecord) -> Dict[str, Any]: - baseline = task.baseline_metrics or {} - best = task.best_metrics or {} - - def delta(name: str) -> float: - return float(best.get(name, 0.0) or 0.0) - float(baseline.get(name, 0.0) or 0.0) - - return { - "task_id": task.task_id, - "objective": task.objective, - "intensity": task.intensity, - "status": task.status, - "created_at": task.created_at, - "started_at": task.started_at, - "finished_at": task.finished_at, - "rounds_total": task.rounds_total, - "rounds_done": task.rounds_done, - "best_score": task.best_score, - "baseline_score": self._score_metrics(baseline, objective=task.objective), - "query_set_stats": task.query_set_stats, - "baseline_metrics": baseline, - "best_metrics": best, - "deltas": { - "precision_at_1": delta("precision_at_1"), - "precision_at_3": delta("precision_at_3"), - "mrr": delta("mrr"), - "recall_at_k": delta("recall_at_k"), - "spo_relation_hit_rate": delta("spo_relation_hit_rate"), - "empty_rate": delta("empty_rate"), - "timeout_count": delta("timeout_count"), - "avg_elapsed_ms": delta("avg_elapsed_ms"), - }, - "best_profile": task.best_profile, - "baseline_profile": task.baseline_profile, - "apply_log": task.apply_log, - } - - def _build_report_markdown(self, task: RetrievalTuningTaskRecord, payload: Dict[str, Any]) -> str: - baseline = payload.get("baseline_metrics", {}) or {} - best = payload.get("best_metrics", {}) or {} - d = payload.get("deltas", {}) or {} - lines = [ - f"# 检索调优报告({task.task_id})", - "", - "## 1. 任务信息", - f"- 状态: {task.status}", - f"- 目标函数: {task.objective}", - f"- 强度: {task.intensity}", - f"- 轮次: baseline + {task.rounds_total}", - f"- 创建时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.created_at))}", - f"- 开始时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.started_at)) if task.started_at else '-'}", - f"- 完成时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(task.finished_at)) if task.finished_at else '-'}", - "", - "## 2. 基线 vs 最优", - f"- baseline score: {payload.get('baseline_score', 0.0):.6f}", - f"- best score: {task.best_score:.6f}", - f"- P@1: {baseline.get('precision_at_1', 0.0):.4f} -> {best.get('precision_at_1', 0.0):.4f} (Δ {d.get('precision_at_1', 0.0):+.4f})", - f"- P@3: {baseline.get('precision_at_3', 0.0):.4f} -> {best.get('precision_at_3', 0.0):.4f} (Δ {d.get('precision_at_3', 0.0):+.4f})", - f"- MRR: {baseline.get('mrr', 0.0):.4f} -> {best.get('mrr', 0.0):.4f} (Δ {d.get('mrr', 0.0):+.4f})", - f"- Recall@K: {baseline.get('recall_at_k', 0.0):.4f} -> {best.get('recall_at_k', 0.0):.4f} (Δ {d.get('recall_at_k', 0.0):+.4f})", - f"- SPO relation hit: {baseline.get('spo_relation_hit_rate', 0.0):.4f} -> {best.get('spo_relation_hit_rate', 0.0):.4f} (Δ {d.get('spo_relation_hit_rate', 0.0):+.4f})", - f"- 空结果率: {baseline.get('empty_rate', 0.0):.4f} -> {best.get('empty_rate', 0.0):.4f} (Δ {d.get('empty_rate', 0.0):+.4f})", - f"- 超时数: {int(baseline.get('timeout_count', 0) or 0)} -> {int(best.get('timeout_count', 0) or 0)} (Δ {int(d.get('timeout_count', 0) or 0):+d})", - f"- 平均耗时(ms): {baseline.get('avg_elapsed_ms', 0.0):.2f} -> {best.get('avg_elapsed_ms', 0.0):.2f} (Δ {d.get('avg_elapsed_ms', 0.0):+.2f})", - "", - "## 3. 最优参数", - "```json", - json.dumps(task.best_profile, ensure_ascii=False, indent=2), - "```", - "", - "## 4. 测试集规模", - f"- {json.dumps(task.query_set_stats, ensure_ascii=False)}", - "", - "## 5. 说明", - "- 本报告仅对当前已存储图谱与向量状态有效。", - "- 参数应用策略:运行时生效,不自动写入 config.toml。", - ] - return "\n".join(lines).strip() + "\n" diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py deleted file mode 100644 index 131ab32a..00000000 --- a/plugins/A_memorix/core/utils/runtime_self_check.py +++ /dev/null @@ -1,218 +0,0 @@ -"""Runtime self-check helpers for A_Memorix.""" - -from __future__ import annotations - -import time -from typing import Any, Dict, Optional - -import numpy as np - -from src.common.logger import get_logger - -logger = get_logger("A_Memorix.RuntimeSelfCheck") - -_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check" - - -def _safe_int(value: Any, default: int = 0) -> int: - try: - return int(value) - except Exception: - return int(default) - - -def _get_config_value(config: Any, key: str, default: Any = None) -> Any: - getter = getattr(config, "get_config", None) - if callable(getter): - return getter(key, default) - - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _build_report( - *, - ok: bool, - code: str, - message: str, - configured_dimension: int, - vector_store_dimension: int, - detected_dimension: int, - encoded_dimension: int, - elapsed_ms: float, - sample_text: str, -) -> Dict[str, Any]: - return { - "ok": bool(ok), - "code": str(code or "").strip(), - "message": str(message or "").strip(), - "configured_dimension": int(configured_dimension), - "vector_store_dimension": int(vector_store_dimension), - "detected_dimension": int(detected_dimension), - "encoded_dimension": int(encoded_dimension), - "elapsed_ms": float(elapsed_ms), - "sample_text": str(sample_text or ""), - "checked_at": time.time(), - } - - -def _normalize_encoded_vector(encoded: Any) -> np.ndarray: - if encoded is None: - raise ValueError("embedding encode returned None") - - if isinstance(encoded, np.ndarray): - array = encoded - else: - array = np.asarray(encoded, dtype=np.float32) - - if array.ndim == 2: - if array.shape[0] != 1: - raise ValueError(f"embedding encode returned batched output: shape={tuple(array.shape)}") - array = array[0] - - if array.ndim != 1: - raise ValueError(f"embedding encode returned invalid ndim={array.ndim}") - if array.size <= 0: - raise ValueError("embedding encode returned empty vector") - if not np.all(np.isfinite(array)): - raise ValueError("embedding encode returned non-finite values") - return array.astype(np.float32, copy=False) - - -async def run_embedding_runtime_self_check( - *, - config: Any, - vector_store: Optional[Any], - embedding_manager: Optional[Any], - sample_text: str = _DEFAULT_SAMPLE_TEXT, -) -> Dict[str, Any]: - """Probe the real embedding path and compare dimensions with runtime storage.""" - configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0) - vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0) - - if vector_store is None or embedding_manager is None: - return _build_report( - ok=False, - code="runtime_components_missing", - message="vector_store 或 embedding_manager 未初始化", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=0, - encoded_dimension=0, - elapsed_ms=0.0, - sample_text=sample_text, - ) - - start = time.perf_counter() - detected_dimension = 0 - encoded_dimension = 0 - try: - detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) - encoded = await embedding_manager.encode(sample_text) - encoded_array = _normalize_encoded_vector(encoded) - encoded_dimension = int(encoded_array.shape[0]) - except Exception as exc: - elapsed_ms = (time.perf_counter() - start) * 1000.0 - logger.warning(f"embedding runtime self-check failed: {exc}") - return _build_report( - ok=False, - code="embedding_probe_failed", - message=f"embedding probe failed: {exc}", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - elapsed_ms = (time.perf_counter() - start) * 1000.0 - expected_dimension = vector_store_dimension or configured_dimension or detected_dimension - if expected_dimension <= 0: - return _build_report( - ok=False, - code="invalid_expected_dimension", - message="无法确定期望 embedding 维度", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - if encoded_dimension != expected_dimension: - msg = ( - "embedding 真实输出维度与当前向量存储不一致: " - f"expected={expected_dimension}, encoded={encoded_dimension}" - ) - logger.error(msg) - return _build_report( - ok=False, - code="embedding_dimension_mismatch", - message=msg, - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - return _build_report( - ok=True, - code="ok", - message="embedding runtime self-check passed", - configured_dimension=configured_dimension, - vector_store_dimension=vector_store_dimension, - detected_dimension=detected_dimension, - encoded_dimension=encoded_dimension, - elapsed_ms=elapsed_ms, - sample_text=sample_text, - ) - - -async def ensure_runtime_self_check( - plugin_or_config: Any, - *, - force: bool = False, - sample_text: str = _DEFAULT_SAMPLE_TEXT, -) -> Dict[str, Any]: - """Run or reuse cached runtime self-check report.""" - if plugin_or_config is None: - return _build_report( - ok=False, - code="missing_plugin_or_config", - message="plugin/config unavailable", - configured_dimension=0, - vector_store_dimension=0, - detected_dimension=0, - encoded_dimension=0, - elapsed_ms=0.0, - sample_text=sample_text, - ) - - cache = getattr(plugin_or_config, "_runtime_self_check_report", None) - if isinstance(cache, dict) and cache and not force: - return cache - - report = await run_embedding_runtime_self_check( - config=getattr(plugin_or_config, "config", plugin_or_config), - vector_store=getattr(plugin_or_config, "vector_store", None) - if not isinstance(plugin_or_config, dict) - else plugin_or_config.get("vector_store"), - embedding_manager=getattr(plugin_or_config, "embedding_manager", None) - if not isinstance(plugin_or_config, dict) - else plugin_or_config.get("embedding_manager"), - sample_text=sample_text, - ) - try: - setattr(plugin_or_config, "_runtime_self_check_report", report) - except Exception: - pass - return report diff --git a/plugins/A_memorix/core/utils/search_execution_service.py b/plugins/A_memorix/core/utils/search_execution_service.py deleted file mode 100644 index 7df243af..00000000 --- a/plugins/A_memorix/core/utils/search_execution_service.py +++ /dev/null @@ -1,439 +0,0 @@ -""" -统一检索执行服务。 - -用于收敛 Action/Tool 在 search/time 上的核心执行流程,避免重复实现。 -""" - -from __future__ import annotations - -import hashlib -import json -import time -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -from src.common.logger import get_logger - -from ..retrieval import TemporalQueryOptions -from .search_postprocess import ( - apply_safe_content_dedup, - maybe_apply_smart_path_fallback, -) -from .time_parser import parse_query_time_range - -logger = get_logger("A_Memorix.SearchExecutionService") - - -def _get_config_value(config: Optional[dict], key: str, default: Any = None) -> Any: - if not isinstance(config, dict): - return default - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -def _sanitize_text(value: Any) -> str: - if value is None: - return "" - return str(value).strip() - - -@dataclass -class SearchExecutionRequest: - caller: str - stream_id: Optional[str] = None - group_id: Optional[str] = None - user_id: Optional[str] = None - query_type: str = "search" # search|time|hybrid - query: str = "" - top_k: Optional[int] = None - time_from: Optional[str] = None - time_to: Optional[str] = None - person: Optional[str] = None - source: Optional[str] = None - use_threshold: bool = True - enable_ppr: bool = True - - -@dataclass -class SearchExecutionResult: - success: bool - error: str = "" - query_type: str = "search" - query: str = "" - top_k: int = 10 - time_from: Optional[str] = None - time_to: Optional[str] = None - person: Optional[str] = None - source: Optional[str] = None - temporal: Optional[TemporalQueryOptions] = None - results: List[Any] = field(default_factory=list) - elapsed_ms: float = 0.0 - chat_filtered: bool = False - dedup_hit: bool = False - - @property - def count(self) -> int: - return len(self.results) - - -class SearchExecutionService: - """统一检索执行服务。""" - - @staticmethod - def _resolve_plugin_instance(plugin_config: Optional[dict]) -> Optional[Any]: - if isinstance(plugin_config, dict): - plugin_instance = plugin_config.get("plugin_instance") - if plugin_instance is not None: - return plugin_instance - - try: - from ...plugin import AMemorixPlugin - - return getattr(AMemorixPlugin, "get_global_instance", lambda: None)() - except Exception: - return None - - @staticmethod - def _normalize_query_type(raw_query_type: str) -> str: - return _sanitize_text(raw_query_type).lower() or "search" - - @staticmethod - def _resolve_runtime_component( - plugin_config: Optional[dict], - plugin_instance: Optional[Any], - key: str, - ) -> Optional[Any]: - if isinstance(plugin_config, dict): - value = plugin_config.get(key) - if value is not None: - return value - if plugin_instance is not None: - value = getattr(plugin_instance, key, None) - if value is not None: - return value - return None - - @staticmethod - def _resolve_top_k( - plugin_config: Optional[dict], - query_type: str, - top_k_raw: Optional[Any], - ) -> Tuple[bool, int, str]: - temporal_default_top_k = int( - _get_config_value(plugin_config, "retrieval.temporal.default_top_k", 10) - ) - default_top_k = temporal_default_top_k if query_type in {"time", "hybrid"} else 10 - if top_k_raw is None: - return True, max(1, min(50, default_top_k)), "" - try: - top_k = int(top_k_raw) - except (TypeError, ValueError): - return False, 0, "top_k 参数必须为整数" - return True, max(1, min(50, top_k)), "" - - @staticmethod - def _build_temporal( - plugin_config: Optional[dict], - query_type: str, - time_from_raw: Optional[str], - time_to_raw: Optional[str], - person: Optional[str], - source: Optional[str], - ) -> Tuple[bool, Optional[TemporalQueryOptions], str]: - if query_type not in {"time", "hybrid"}: - return True, None, "" - - temporal_enabled = bool(_get_config_value(plugin_config, "retrieval.temporal.enabled", True)) - if not temporal_enabled: - return False, None, "时序检索已禁用(retrieval.temporal.enabled=false)" - - if not time_from_raw and not time_to_raw: - return False, None, "time/hybrid 模式至少需要 time_from 或 time_to" - - try: - ts_from, ts_to = parse_query_time_range( - str(time_from_raw) if time_from_raw is not None else None, - str(time_to_raw) if time_to_raw is not None else None, - ) - except ValueError as e: - return False, None, f"时间参数错误: {e}" - - temporal = TemporalQueryOptions( - time_from=ts_from, - time_to=ts_to, - person=_sanitize_text(person) or None, - source=_sanitize_text(source) or None, - allow_created_fallback=bool( - _get_config_value(plugin_config, "retrieval.temporal.allow_created_fallback", True) - ), - candidate_multiplier=int( - _get_config_value(plugin_config, "retrieval.temporal.candidate_multiplier", 8) - ), - max_scan=int(_get_config_value(plugin_config, "retrieval.temporal.max_scan", 1000)), - ) - return True, temporal, "" - - @staticmethod - def _build_request_key( - request: SearchExecutionRequest, - query_type: str, - top_k: int, - temporal: Optional[TemporalQueryOptions], - ) -> str: - payload = { - "stream_id": _sanitize_text(request.stream_id), - "query_type": query_type, - "query": _sanitize_text(request.query), - "time_from": _sanitize_text(request.time_from), - "time_to": _sanitize_text(request.time_to), - "time_from_ts": temporal.time_from if temporal else None, - "time_to_ts": temporal.time_to if temporal else None, - "person": _sanitize_text(request.person), - "source": _sanitize_text(request.source), - "top_k": int(top_k), - "use_threshold": bool(request.use_threshold), - "enable_ppr": bool(request.enable_ppr), - } - payload_json = json.dumps(payload, ensure_ascii=False, sort_keys=True) - return hashlib.sha1(payload_json.encode("utf-8")).hexdigest() - - @staticmethod - async def execute( - *, - retriever: Any, - threshold_filter: Optional[Any], - plugin_config: Optional[dict], - request: SearchExecutionRequest, - enforce_chat_filter: bool = True, - reinforce_access: bool = True, - ) -> SearchExecutionResult: - if retriever is None: - return SearchExecutionResult(success=False, error="知识检索器未初始化") - - query_type = SearchExecutionService._normalize_query_type(request.query_type) - query = _sanitize_text(request.query) - if query_type not in {"search", "time", "hybrid"}: - return SearchExecutionResult( - success=False, - error=f"query_type 无效: {query_type}(仅支持 search/time/hybrid)", - ) - - if query_type in {"search", "hybrid"} and not query: - return SearchExecutionResult( - success=False, - error="search/hybrid 模式必须提供 query", - ) - - top_k_ok, top_k, top_k_error = SearchExecutionService._resolve_top_k( - plugin_config, query_type, request.top_k - ) - if not top_k_ok: - return SearchExecutionResult(success=False, error=top_k_error) - - temporal_ok, temporal, temporal_error = SearchExecutionService._build_temporal( - plugin_config=plugin_config, - query_type=query_type, - time_from_raw=request.time_from, - time_to_raw=request.time_to, - person=request.person, - source=request.source, - ) - if not temporal_ok: - return SearchExecutionResult(success=False, error=temporal_error) - - plugin_instance = SearchExecutionService._resolve_plugin_instance(plugin_config) - if ( - enforce_chat_filter - and plugin_instance is not None - and hasattr(plugin_instance, "is_chat_enabled") - ): - if not plugin_instance.is_chat_enabled( - stream_id=request.stream_id, - group_id=request.group_id, - user_id=request.user_id, - ): - logger.info( - "检索请求被聊天过滤拦截: " - f"caller={request.caller}, " - f"stream_id={request.stream_id}" - ) - return SearchExecutionResult( - success=True, - query_type=query_type, - query=query, - top_k=top_k, - time_from=request.time_from, - time_to=request.time_to, - person=request.person, - source=request.source, - temporal=temporal, - results=[], - elapsed_ms=0.0, - chat_filtered=True, - dedup_hit=False, - ) - - request_key = SearchExecutionService._build_request_key( - request=request, - query_type=query_type, - top_k=top_k, - temporal=temporal, - ) - - async def _executor() -> Dict[str, Any]: - original_ppr = bool(getattr(retriever.config, "enable_ppr", True)) - setattr(retriever.config, "enable_ppr", bool(request.enable_ppr)) - started_at = time.time() - try: - retrieved = await retriever.retrieve( - query=query, - top_k=top_k, - temporal=temporal, - ) - - should_apply_threshold = bool(request.use_threshold) and threshold_filter is not None - if ( - query_type == "time" - and not query - and bool( - _get_config_value( - plugin_config, - "retrieval.time.skip_threshold_when_query_empty", - True, - ) - ) - ): - should_apply_threshold = False - - if should_apply_threshold: - retrieved = threshold_filter.filter(retrieved) - - if ( - reinforce_access - and plugin_instance is not None - and hasattr(plugin_instance, "reinforce_access") - ): - relation_hashes = [ - item.hash_value - for item in retrieved - if getattr(item, "result_type", "") == "relation" - ] - if relation_hashes: - await plugin_instance.reinforce_access(relation_hashes) - - if query_type == "search": - graph_store = SearchExecutionService._resolve_runtime_component( - plugin_config, plugin_instance, "graph_store" - ) - metadata_store = SearchExecutionService._resolve_runtime_component( - plugin_config, plugin_instance, "metadata_store" - ) - fallback_enabled = bool( - _get_config_value( - plugin_config, - "retrieval.search.smart_fallback.enabled", - True, - ) - ) - fallback_threshold = float( - _get_config_value( - plugin_config, - "retrieval.search.smart_fallback.threshold", - 0.6, - ) - ) - retrieved, fallback_triggered, fallback_added = maybe_apply_smart_path_fallback( - query=query, - results=list(retrieved), - graph_store=graph_store, - metadata_store=metadata_store, - enabled=fallback_enabled, - threshold=fallback_threshold, - ) - if fallback_triggered: - logger.info( - "metric.smart_fallback_triggered_count=1 " - f"caller={request.caller} " - f"added={fallback_added}" - ) - - dedup_enabled = bool( - _get_config_value( - plugin_config, - "retrieval.search.safe_content_dedup.enabled", - True, - ) - ) - if dedup_enabled: - retrieved, removed_count = apply_safe_content_dedup(list(retrieved)) - if removed_count > 0: - logger.info( - f"metric.safe_dedup_removed_count={removed_count} " - f"caller={request.caller}" - ) - - elapsed_ms = (time.time() - started_at) * 1000.0 - return {"results": retrieved, "elapsed_ms": elapsed_ms} - finally: - setattr(retriever.config, "enable_ppr", original_ppr) - - dedup_hit = False - try: - # 调优评估需要逐轮真实执行,且应避免额外 dedup 锁竞争。 - bypass_request_dedup = str(request.caller or "").strip().lower() == "retrieval_tuning" - if ( - not bypass_request_dedup - and - plugin_instance is not None - and hasattr(plugin_instance, "execute_request_with_dedup") - ): - dedup_hit, payload = await plugin_instance.execute_request_with_dedup( - request_key, - _executor, - ) - else: - payload = await _executor() - except Exception as e: - return SearchExecutionResult(success=False, error=f"知识检索失败: {e}") - - if dedup_hit: - logger.info(f"metric.search_execution_dedup_hit_count=1 caller={request.caller}") - - return SearchExecutionResult( - success=True, - query_type=query_type, - query=query, - top_k=top_k, - time_from=request.time_from, - time_to=request.time_to, - person=request.person, - source=request.source, - temporal=temporal, - results=payload.get("results", []), - elapsed_ms=float(payload.get("elapsed_ms", 0.0)), - chat_filtered=False, - dedup_hit=bool(dedup_hit), - ) - - @staticmethod - def to_serializable_results(results: List[Any]) -> List[Dict[str, Any]]: - serialized: List[Dict[str, Any]] = [] - for item in results: - metadata = dict(getattr(item, "metadata", {}) or {}) - if "time_meta" not in metadata: - metadata["time_meta"] = {} - serialized.append( - { - "hash": getattr(item, "hash_value", ""), - "type": getattr(item, "result_type", ""), - "score": float(getattr(item, "score", 0.0)), - "content": getattr(item, "content", ""), - "metadata": metadata, - } - ) - return serialized diff --git a/plugins/A_memorix/core/utils/search_postprocess.py b/plugins/A_memorix/core/utils/search_postprocess.py deleted file mode 100644 index 52688e08..00000000 --- a/plugins/A_memorix/core/utils/search_postprocess.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Post-processing helpers for unified search execution.""" - -from __future__ import annotations - -from typing import Any, List, Tuple - -from .path_fallback_service import find_paths_from_query, to_retrieval_results - - -def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]: - """Deduplicate results by hash/content while preserving at least one entry.""" - if not results: - return [], 0 - - unique_results: List[Any] = [] - seen_hashes = set() - seen_contents = set() - - for item in results: - content = str(getattr(item, "content", "") or "").strip() - if not content: - continue - - hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content)) - if hash_value in seen_hashes: - continue - - is_dup = False - for seen in seen_contents: - if content in seen or seen in content: - is_dup = True - break - if is_dup: - continue - - seen_hashes.add(hash_value) - seen_contents.add(content) - unique_results.append(item) - - if not unique_results: - unique_results.append(results[0]) - - removed_count = max(0, len(results) - len(unique_results)) - return unique_results, removed_count - - -def maybe_apply_smart_path_fallback( - *, - query: str, - results: List[Any], - graph_store: Any, - metadata_store: Any, - enabled: bool, - threshold: float, - max_depth: int = 3, - max_paths: int = 5, -) -> Tuple[List[Any], bool, int]: - """Append indirect relation paths when semantic results are weak.""" - if not enabled or not str(query or "").strip(): - return results, False, 0 - if graph_store is None or metadata_store is None: - return results, False, 0 - - max_score = 0.0 - if results: - try: - max_score = float(getattr(results[0], "score", 0.0) or 0.0) - except Exception: - max_score = 0.0 - - if max_score >= float(threshold): - return results, False, 0 - - paths = find_paths_from_query( - query=query, - graph_store=graph_store, - metadata_store=metadata_store, - max_depth=max_depth, - max_paths=max_paths, - ) - if not paths: - return results, False, 0 - - path_results = to_retrieval_results(paths) - if not path_results: - return results, False, 0 - - merged = list(path_results) + list(results) - return merged, True, len(path_results) - diff --git a/plugins/A_memorix/core/utils/summary_importer.py b/plugins/A_memorix/core/utils/summary_importer.py deleted file mode 100644 index b6271db4..00000000 --- a/plugins/A_memorix/core/utils/summary_importer.py +++ /dev/null @@ -1,425 +0,0 @@ -""" -聊天总结与知识导入工具 - -该模块负责从聊天记录中提取信息,生成总结,并将总结内容及提取的实体/关系 -导入到 A_memorix 的存储组件中。 -""" - -import time -import json -import re -import traceback -from typing import List, Dict, Any, Tuple, Optional -from pathlib import Path - -from src.common.logger import get_logger -from src.services import llm_service as llm_api -from src.services import message_service as message_api -from src.config.config import global_config, model_config as host_model_config -from src.config.model_configs import TaskConfig - -from ..storage import ( - KnowledgeType, - VectorStore, - GraphStore, - MetadataStore, - resolve_stored_knowledge_type, -) -from ..embedding import EmbeddingAPIAdapter -from .relation_write_service import RelationWriteService -from .runtime_self_check import ensure_runtime_self_check, run_embedding_runtime_self_check - -logger = get_logger("A_Memorix.SummaryImporter") - -# 默认总结提示词模版 -SUMMARY_PROMPT_TEMPLATE = """ -你是 {bot_name}。{personality_context} -现在你需要对以下一段聊天记录进行总结,并提取其中的重要知识。 - -聊天记录内容: -{chat_history} - -请完成以下任务: -1. **生成总结**:以第三人称或机器人的视角,简洁明了地总结这段对话的主要内容、发生的事件或讨论的主题。 -2. **提取实体与关系**:识别并提取对话中提到的重要实体以及它们之间的关系。 - -请严格以 JSON 格式输出,格式如下: -{{ - "summary": "总结文本内容", - "entities": ["张三", "李四"], - "relations": [ - {{"subject": "张三", "predicate": "认识", "object": "李四"}} - ] -}} - -注意:总结应具有叙事性,能够作为长程记忆的一部分。直接使用实体的实际名称,不要使用 e1/e2 等代号。 -""" - -class SummaryImporter: - """总结并导入知识的工具类""" - - def __init__( - self, - vector_store: VectorStore, - graph_store: GraphStore, - metadata_store: MetadataStore, - embedding_manager: EmbeddingAPIAdapter, - plugin_config: dict - ): - self.vector_store = vector_store - self.graph_store = graph_store - self.metadata_store = metadata_store - self.embedding_manager = embedding_manager - self.plugin_config = plugin_config - self.relation_write_service: Optional[RelationWriteService] = ( - plugin_config.get("relation_write_service") - if isinstance(plugin_config, dict) - else None - ) - - def _normalize_summary_model_selectors(self, raw_value: Any) -> List[str]: - """标准化 summarization.model_name 配置(vNext 仅接受字符串数组)。""" - if raw_value is None: - return ["auto"] - if isinstance(raw_value, list): - selectors = [str(x).strip() for x in raw_value if str(x).strip()] - return selectors or ["auto"] - raise ValueError( - "summarization.model_name 在 vNext 必须为 List[str]。" - " 请执行 scripts/release_vnext_migrate.py migrate。" - ) - - def _pick_default_summary_task(self, available_tasks: Dict[str, TaskConfig]) -> Tuple[Optional[str], Optional[TaskConfig]]: - """ - 选择总结默认任务,避免错误落到 embedding 任务。 - 优先级:replyer > utils > planner > tool_use > 其他非 embedding。 - """ - preferred = ("replyer", "utils", "planner", "tool_use") - for name in preferred: - cfg = available_tasks.get(name) - if cfg and cfg.model_list: - return name, cfg - - for name, cfg in available_tasks.items(): - if name != "embedding" and cfg.model_list: - return name, cfg - - for name, cfg in available_tasks.items(): - if cfg.model_list: - return name, cfg - - return None, None - - def _resolve_summary_model_config(self) -> Optional[TaskConfig]: - """ - 解析 summarization.model_name 为 TaskConfig。 - 支持: - - "auto" - - "replyer"(任务名) - - "some-model-name"(具体模型名) - - ["utils:model1", "utils:model2", "replyer"](数组混合语法) - """ - available_tasks = llm_api.get_available_models() - if not available_tasks: - return None - - raw_cfg = self.plugin_config.get("summarization", {}).get("model_name", "auto") - selectors = self._normalize_summary_model_selectors(raw_cfg) - default_task_name, default_task_cfg = self._pick_default_summary_task(available_tasks) - - selected_models: List[str] = [] - base_cfg: Optional[TaskConfig] = None - model_dict = getattr(host_model_config, "models_dict", {}) - - def _append_models(models: List[str]): - for model_name in models: - if model_name and model_name not in selected_models: - selected_models.append(model_name) - - for raw_selector in selectors: - selector = raw_selector.strip() - if not selector: - continue - - if selector.lower() == "auto": - if default_task_cfg: - _append_models(default_task_cfg.model_list) - if base_cfg is None: - base_cfg = default_task_cfg - continue - - if ":" in selector: - task_name, model_name = selector.split(":", 1) - task_name = task_name.strip() - model_name = model_name.strip() - task_cfg = available_tasks.get(task_name) - if not task_cfg: - logger.warning(f"总结模型选择器 '{selector}' 的任务 '{task_name}' 不存在,已跳过") - continue - - if base_cfg is None: - base_cfg = task_cfg - - if not model_name or model_name.lower() == "auto": - _append_models(task_cfg.model_list) - continue - - if model_name in model_dict or model_name in task_cfg.model_list: - _append_models([model_name]) - else: - logger.warning(f"总结模型选择器 '{selector}' 的模型 '{model_name}' 不存在,已跳过") - continue - - task_cfg = available_tasks.get(selector) - if task_cfg: - _append_models(task_cfg.model_list) - if base_cfg is None: - base_cfg = task_cfg - continue - - if selector in model_dict: - _append_models([selector]) - continue - - logger.warning(f"总结模型选择器 '{selector}' 无法识别,已跳过") - - if not selected_models: - if default_task_cfg: - _append_models(default_task_cfg.model_list) - if base_cfg is None: - base_cfg = default_task_cfg - else: - first_cfg = next(iter(available_tasks.values())) - _append_models(first_cfg.model_list) - if base_cfg is None: - base_cfg = first_cfg - - if not selected_models: - return None - - template_cfg = base_cfg or default_task_cfg or next(iter(available_tasks.values())) - return TaskConfig( - model_list=selected_models, - max_tokens=template_cfg.max_tokens, - temperature=template_cfg.temperature, - slow_threshold=template_cfg.slow_threshold, - selection_strategy=template_cfg.selection_strategy, - ) - - async def import_from_stream( - self, - stream_id: str, - context_length: Optional[int] = None, - include_personality: Optional[bool] = None - ) -> Tuple[bool, str]: - """ - 从指定的聊天流中提取记录并执行总结导入 - - Args: - stream_id: 聊天流 ID - context_length: 总结的历史消息条数 - include_personality: 是否包含人设 - - Returns: - Tuple[bool, str]: (是否成功, 结果消息) - """ - try: - self_check_ok, self_check_msg = await self._ensure_runtime_self_check() - if not self_check_ok: - return False, f"导入前自检失败: {self_check_msg}" - - # 1. 获取配置 - if context_length is None: - context_length = self.plugin_config.get("summarization", {}).get("context_length", 50) - - if include_personality is None: - include_personality = self.plugin_config.get("summarization", {}).get("include_personality", True) - - # 2. 获取历史消息 - # 获取当前时间之前的消息 - now = time.time() - messages = message_api.get_messages_before_time_in_chat( - chat_id=stream_id, - timestamp=now, - limit=context_length - ) - - if not messages: - return False, "未找到有效的聊天记录进行总结" - - # 转换为可读文本 - chat_history_text = message_api.build_readable_messages(messages) - - # 3. 准备提示词内容 - bot_name = global_config.bot.nickname or "机器人" - personality_context = "" - if include_personality: - personality = getattr(global_config.bot, "personality", "") - if personality: - personality_context = f"你的性格设定是:{personality}" - - # 4. 调用 LLM - prompt = SUMMARY_PROMPT_TEMPLATE.format( - bot_name=bot_name, - personality_context=personality_context, - chat_history=chat_history_text - ) - - model_config_to_use = self._resolve_summary_model_config() - if model_config_to_use is None: - return False, "未找到可用的总结模型配置" - - logger.info(f"正在为流 {stream_id} 执行总结,消息条数: {len(messages)}") - logger.info(f"总结模型候选列表: {model_config_to_use.model_list}") - - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config_to_use, - request_type="A_Memorix.ChatSummarization" - ) - - if not success or not response: - return False, "LLM 生成总结失败" - - # 5. 解析结果 - data = self._parse_llm_response(response) - if not data or "summary" not in data: - return False, "解析 LLM 响应失败或总结为空" - - summary_text = data["summary"] - entities = data.get("entities", []) - relations = data.get("relations", []) - msg_times = [ - float(getattr(getattr(msg, "timestamp", None), "timestamp", lambda: 0.0)()) - for msg in messages - if getattr(msg, "time", None) is not None - ] - time_meta = {} - if msg_times: - time_meta = { - "event_time_start": min(msg_times), - "event_time_end": max(msg_times), - "time_granularity": "minute", - "time_confidence": 0.95, - } - - # 6. 执行导入 - await self._execute_import(summary_text, entities, relations, stream_id, time_meta=time_meta) - - # 7. 持久化 - self.vector_store.save() - self.graph_store.save() - - result_msg = ( - f"✅ 总结导入成功\n" - f"📝 总结长度: {len(summary_text)}\n" - f"📌 提取实体: {len(entities)}\n" - f"🔗 提取关系: {len(relations)}" - ) - return True, result_msg - - except Exception as e: - logger.error(f"总结导入过程中出错: {e}\n{traceback.format_exc()}") - return False, f"错误: {str(e)}" - - async def _ensure_runtime_self_check(self) -> Tuple[bool, str]: - plugin_instance = self.plugin_config.get("plugin_instance") if isinstance(self.plugin_config, dict) else None - if plugin_instance is not None: - report = await ensure_runtime_self_check(plugin_instance) - else: - report = await run_embedding_runtime_self_check( - config=self.plugin_config, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - if bool(report.get("ok", False)): - return True, "" - return ( - False, - f"{report.get('message', 'unknown')} " - f"(configured={report.get('configured_dimension', 0)}, " - f"store={report.get('vector_store_dimension', 0)}, " - f"encoded={report.get('encoded_dimension', 0)})", - ) - - def _parse_llm_response(self, response: str) -> Dict[str, Any]: - """解析 LLM 返回的 JSON""" - try: - # 尝试查找 JSON - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - return {} - except Exception as e: - logger.warning(f"解析总结 JSON 失败: {e}") - return {} - - async def _execute_import( - self, - summary: str, - entities: List[str], - relations: List[Dict[str, str]], - stream_id: str, - time_meta: Optional[Dict[str, Any]] = None, - ): - """将数据写入存储""" - # 获取默认知识类型 - type_str = self.plugin_config.get("summarization", {}).get("default_knowledge_type", "narrative") - try: - knowledge_type = resolve_stored_knowledge_type(type_str, content=summary) - except ValueError: - logger.warning(f"非法 summarization.default_knowledge_type={type_str},回退 narrative") - knowledge_type = KnowledgeType.NARRATIVE - - # 导入总结文本 - hash_value = self.metadata_store.add_paragraph( - content=summary, - source=f"chat_summary:{stream_id}", - knowledge_type=knowledge_type.value, - time_meta=time_meta, - ) - - embedding = await self.embedding_manager.encode(summary) - self.vector_store.add( - vectors=embedding.reshape(1, -1), - ids=[hash_value] - ) - - # 导入实体 - if entities: - self.graph_store.add_nodes(entities) - - # 导入关系 - rv_cfg = self.plugin_config.get("retrieval", {}).get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - rv_cfg = {} - write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - for rel in relations: - s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") - if all([s, p, o]): - if self.relation_write_service is not None: - await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, - confidence=1.0, - source_paragraph=summary, - write_vector=write_vector, - ) - else: - # 写入元数据 - rel_hash = self.metadata_store.add_relation( - subject=s, - predicate=p, - obj=o, - confidence=1.0, - source_paragraph=summary - ) - # 写入图数据库(写入 relation_hashes,确保后续可按关系精确修剪) - self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - - logger.info(f"总结导入完成: hash={hash_value[:8]}") diff --git a/plugins/A_memorix/core/utils/time_parser.py b/plugins/A_memorix/core/utils/time_parser.py deleted file mode 100644 index 8e577974..00000000 --- a/plugins/A_memorix/core/utils/time_parser.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -时间解析工具。 - -约束: -1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间: - - YYYY/MM/DD - - YYYY/MM/DD HH:mm -2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。 -""" - -from __future__ import annotations - -import re -from datetime import datetime -from typing import Any, Dict, Optional, Tuple - - -_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") -_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$") -_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$") - -_INGEST_FORMATS = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%dT%H:%M", - "%Y/%m/%d", - "%Y-%m-%d", -] - -_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"} - - -def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float: - """解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。""" - text = str(value).strip() - if not text: - raise ValueError("时间不能为空") - - if _QUERY_DATE_RE.fullmatch(text): - dt = datetime.strptime(text, "%Y/%m/%d") - if is_end: - dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) - return dt.timestamp() - - if _QUERY_MINUTE_RE.fullmatch(text): - dt = datetime.strptime(text, "%Y/%m/%d %H:%M") - return dt.timestamp() - - raise ValueError( - f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm" - ) - - -def parse_query_time_range( - time_from: Optional[str], - time_to: Optional[str], -) -> Tuple[Optional[float], Optional[float]]: - """解析查询窗口并验证区间。""" - ts_from = ( - parse_query_datetime_to_timestamp(time_from, is_end=False) - if time_from - else None - ) - ts_to = ( - parse_query_datetime_to_timestamp(time_to, is_end=True) - if time_to - else None - ) - - if ts_from is not None and ts_to is not None and ts_from > ts_to: - raise ValueError("time_from 不能晚于 time_to") - - return ts_from, ts_to - - -def parse_ingest_datetime_to_timestamp( - value: Any, - is_end: bool = False, -) -> Optional[float]: - """解析入库时间,允许 timestamp/常见字符串格式。""" - if value is None: - return None - - if isinstance(value, (int, float)): - return float(value) - - text = str(value).strip() - if not text: - return None - - if _NUMERIC_RE.fullmatch(text): - return float(text) - - for fmt in _INGEST_FORMATS: - try: - dt = datetime.strptime(text, fmt) - if fmt in _INGEST_DATE_FORMATS and is_end: - dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) - return dt.timestamp() - except ValueError: - continue - - raise ValueError(f"无法解析时间: {text}") - - -def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: - """归一化 time_meta 到存储层字段。""" - if not time_meta: - return {} - - normalized: Dict[str, Any] = {} - - event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time")) - event_start = parse_ingest_datetime_to_timestamp( - time_meta.get("event_time_start"), - is_end=False, - ) - event_end = parse_ingest_datetime_to_timestamp( - time_meta.get("event_time_end"), - is_end=True, - ) - - time_range = time_meta.get("time_range") - if ( - isinstance(time_range, (list, tuple)) - and len(time_range) == 2 - ): - if event_start is None: - event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False) - if event_end is None: - event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True) - - if event_start is not None and event_end is not None and event_start > event_end: - raise ValueError("event_time_start 不能晚于 event_time_end") - - if event_time is not None: - normalized["event_time"] = event_time - if event_start is not None: - normalized["event_time_start"] = event_start - if event_end is not None: - normalized["event_time_end"] = event_end - - granularity = time_meta.get("time_granularity") - if granularity: - normalized["time_granularity"] = str(granularity) - else: - raw_time_values = [ - time_meta.get("event_time"), - time_meta.get("event_time_start"), - time_meta.get("event_time_end"), - ] - has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None) - normalized["time_granularity"] = "minute" if has_minute else "day" - - confidence = time_meta.get("time_confidence") - if confidence is not None: - normalized["time_confidence"] = float(confidence) - - return normalized - - -def format_timestamp(ts: Optional[float]) -> Optional[str]: - """将 timestamp 格式化为 YYYY/MM/DD HH:mm。""" - if ts is None: - return None - return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M") - diff --git a/plugins/A_memorix/core/utils/web_import_manager.py b/plugins/A_memorix/core/utils/web_import_manager.py deleted file mode 100644 index b088be1f..00000000 --- a/plugins/A_memorix/core/utils/web_import_manager.py +++ /dev/null @@ -1,3522 +0,0 @@ -""" -Web Import Task Manager - -为 A_Memorix WebUI 提供导入任务队列、状态管理、并发调度与取消/重试能力。 -""" - -from __future__ import annotations - -import asyncio -import hashlib -import json -import os -import shutil -import sys -import time -import traceback -import uuid -from collections import deque -from dataclasses import dataclass, field -from datetime import datetime -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple - -from src.common.logger import get_logger -from src.services import llm_service as llm_api - -from ..storage import ( - parse_import_strategy, - resolve_stored_knowledge_type, - select_import_strategy, - KnowledgeType, - MetadataStore, -) -from ..storage.type_detection import looks_like_quote_text -from ..utils.import_payloads import normalize_paragraph_import_item -from ..utils.runtime_self_check import ensure_runtime_self_check -from ..utils.time_parser import normalize_time_meta -from ..storage.knowledge_types import ImportStrategy -from ..strategies.base import ProcessedChunk, KnowledgeType as StrategyKnowledgeType -from ..strategies.narrative import NarrativeStrategy -from ..strategies.factual import FactualStrategy -from ..strategies.quote import QuoteStrategy - -logger = get_logger("A_Memorix.WebImportManager") - - -TASK_STATUS = { - "queued", - "preparing", - "running", - "cancel_requested", - "cancelled", - "completed", - "completed_with_errors", - "failed", -} - -FILE_STATUS = { - "queued", - "preparing", - "splitting", - "extracting", - "writing", - "saving", - "completed", - "failed", - "cancelled", -} - -CHUNK_STATUS = { - "queued", - "extracting", - "writing", - "completed", - "failed", - "cancelled", -} - - -def _now() -> float: - return time.time() - - -def _coerce_int(value: Any, default: int) -> int: - try: - return int(value) - except Exception: - return default - - -def _coerce_bool(value: Any, default: bool) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - text = str(value).strip().lower() - if text in {"1", "true", "yes", "y", "on"}: - return True - if text in {"0", "false", "no", "n", "off", ""}: - return False - return default - - -def _clamp(value: int, min_value: int, max_value: int) -> int: - return max(min_value, min(max_value, value)) - - -def _coerce_list(value: Any) -> List[str]: - if value is None: - return [] - if isinstance(value, list): - raw_items = value - else: - text = str(value or "").replace("\r", "\n") - raw_items = [] - for seg in text.split("\n"): - raw_items.extend(seg.split(",")) - - out: List[str] = [] - seen = set() - for item in raw_items: - v = str(item or "").strip() - if not v: - continue - key = v.lower() - if key in seen: - continue - seen.add(key) - out.append(v) - return out - - -def _parse_optional_positive_int(value: Any, field_name: str) -> Optional[int]: - if value is None: - return None - text = str(value).strip() - if text == "": - return None - try: - parsed = int(text) - except Exception: - raise ValueError(f"{field_name} 必须为整数") - if parsed <= 0: - raise ValueError(f"{field_name} 必须 > 0") - return parsed - - -def _safe_filename(name: str) -> str: - base = os.path.basename(str(name or "").strip()) - if not base: - return f"unnamed_{uuid.uuid4().hex[:8]}.txt" - return base - - -def _storage_type_from_strategy(strategy_type: StrategyKnowledgeType) -> str: - if strategy_type == StrategyKnowledgeType.NARRATIVE: - return KnowledgeType.NARRATIVE.value - if strategy_type == StrategyKnowledgeType.FACTUAL: - return KnowledgeType.FACTUAL.value - if strategy_type == StrategyKnowledgeType.QUOTE: - return KnowledgeType.QUOTE.value - return KnowledgeType.MIXED.value - - -@dataclass -class ImportChunkRecord: - chunk_id: str - index: int - chunk_type: str - status: str = "queued" - step: str = "queued" - failed_at: str = "" - retryable: bool = False - error: str = "" - progress: float = 0.0 - content_preview: str = "" - updated_at: float = field(default_factory=_now) - - def to_dict(self) -> Dict[str, Any]: - return { - "chunk_id": self.chunk_id, - "index": self.index, - "chunk_type": self.chunk_type, - "status": self.status, - "step": self.step, - "failed_at": self.failed_at, - "retryable": self.retryable, - "error": self.error, - "progress": self.progress, - "content_preview": self.content_preview, - "updated_at": self.updated_at, - } - - -@dataclass -class ImportFileRecord: - file_id: str - name: str - source_kind: str - input_mode: str - status: str = "queued" - current_step: str = "queued" - detected_strategy_type: str = "unknown" - total_chunks: int = 0 - done_chunks: int = 0 - failed_chunks: int = 0 - cancelled_chunks: int = 0 - progress: float = 0.0 - error: str = "" - chunks: List[ImportChunkRecord] = field(default_factory=list) - created_at: float = field(default_factory=_now) - updated_at: float = field(default_factory=_now) - temp_path: Optional[str] = None - source_path: Optional[str] = None - inline_content: Optional[str] = None - content_hash: str = "" - retry_chunk_indexes: List[int] = field(default_factory=list) - retry_mode: str = "" - - def to_dict(self, include_chunks: bool = False) -> Dict[str, Any]: - payload = { - "file_id": self.file_id, - "name": self.name, - "source_kind": self.source_kind, - "input_mode": self.input_mode, - "status": self.status, - "current_step": self.current_step, - "detected_strategy_type": self.detected_strategy_type, - "total_chunks": self.total_chunks, - "done_chunks": self.done_chunks, - "failed_chunks": self.failed_chunks, - "cancelled_chunks": self.cancelled_chunks, - "progress": self.progress, - "error": self.error, - "created_at": self.created_at, - "updated_at": self.updated_at, - "source_path": self.source_path or "", - "content_hash": self.content_hash or "", - "retry_chunk_indexes": list(self.retry_chunk_indexes or []), - "retry_mode": self.retry_mode or "", - } - if include_chunks: - payload["chunks"] = [chunk.to_dict() for chunk in self.chunks] - return payload - - -@dataclass -class ImportTaskRecord: - task_id: str - source: str - params: Dict[str, Any] - status: str = "queued" - current_step: str = "queued" - total_chunks: int = 0 - done_chunks: int = 0 - failed_chunks: int = 0 - cancelled_chunks: int = 0 - progress: float = 0.0 - error: str = "" - files: List[ImportFileRecord] = field(default_factory=list) - created_at: float = field(default_factory=_now) - started_at: Optional[float] = None - finished_at: Optional[float] = None - updated_at: float = field(default_factory=_now) - schema_detected: str = "" - artifact_paths: Dict[str, str] = field(default_factory=dict) - rollback_info: Dict[str, Any] = field(default_factory=dict) - retry_parent_task_id: str = "" - retry_summary: Dict[str, Any] = field(default_factory=dict) - - def to_summary(self) -> Dict[str, Any]: - return { - "task_id": self.task_id, - "source": self.source, - "status": self.status, - "current_step": self.current_step, - "total_chunks": self.total_chunks, - "done_chunks": self.done_chunks, - "failed_chunks": self.failed_chunks, - "cancelled_chunks": self.cancelled_chunks, - "progress": self.progress, - "error": self.error, - "file_count": len(self.files), - "created_at": self.created_at, - "started_at": self.started_at, - "finished_at": self.finished_at, - "updated_at": self.updated_at, - "task_kind": str(self.params.get("task_kind") or self.source), - "schema_detected": self.schema_detected, - "artifact_paths": dict(self.artifact_paths), - "rollback_info": dict(self.rollback_info), - "retry_parent_task_id": self.retry_parent_task_id or "", - "retry_summary": dict(self.retry_summary), - } - - def to_detail(self, include_chunks: bool = False) -> Dict[str, Any]: - payload = self.to_summary() - payload["params"] = self.params - payload["files"] = [f.to_dict(include_chunks=include_chunks) for f in self.files] - return payload - - -class ImportTaskManager: - def __init__(self, plugin: Any): - self.plugin = plugin - self._lock = asyncio.Lock() - self._storage_lock = asyncio.Lock() - - self._tasks: Dict[str, ImportTaskRecord] = {} - self._task_order: deque[str] = deque() - self._queue: deque[str] = deque() - self._active_task_id: Optional[str] = None - - self._worker_task: Optional[asyncio.Task] = None - self._stopping = False - - self._temp_root = self._resolve_temp_root() - self._temp_root.mkdir(parents=True, exist_ok=True) - self._reports_root = self._resolve_reports_root() - self._reports_root.mkdir(parents=True, exist_ok=True) - self._manifest_path = self._resolve_manifest_path() - self._manifest_cache: Optional[Dict[str, Any]] = None - self._write_changed_callback: Optional[Callable[[Dict[str, Any]], Any]] = None - - def set_write_changed_callback(self, callback: Optional[Callable[[Dict[str, Any]], Any]]) -> None: - self._write_changed_callback = callback - - async def _notify_write_changed(self, payload: Dict[str, Any]) -> None: - callback = self._write_changed_callback - if callback is None: - return - try: - maybe_awaitable = callback(payload) - if asyncio.iscoroutine(maybe_awaitable): - await maybe_awaitable - except Exception as e: - logger.warning(f"写入变更回调执行失败: {e}") - - def _resolve_temp_root(self) -> Path: - data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) - if str(data_dir).startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir).resolve() - return data_dir / "web_import_tmp" - - def _resolve_reports_root(self) -> Path: - return self._resolve_data_dir() / "web_import_reports" - - def _resolve_manifest_path(self) -> Path: - return self._resolve_data_dir() / "import_manifest.json" - - def _resolve_staging_root(self) -> Path: - return self._resolve_data_dir() / "import_staging" - - def _resolve_backup_root(self) -> Path: - return self._resolve_data_dir() / "import_backup" - - def _resolve_repo_root(self) -> Path: - return Path(__file__).resolve().parents[3] - - def _resolve_data_dir(self) -> Path: - data_dir = Path(self.plugin.get_config("storage.data_dir", "./data")) - if str(data_dir).startswith("."): - plugin_dir = Path(__file__).resolve().parents[2] - data_dir = (plugin_dir / data_dir).resolve() - return data_dir.resolve() - - def _resolve_migration_script(self) -> Path: - return Path(__file__).resolve().parents[2] / "scripts" / "migrate_maibot_memory.py" - - def _default_maibot_source_db(self) -> Path: - # A_memorix/core/utils -> workspace root - return self._resolve_repo_root() / "MaiBot" / "data" / "MaiBot.db" - - def _cfg(self, key: str, default: Any) -> Any: - return self.plugin.get_config(key, default) - - def _cfg_int(self, key: str, default: int) -> int: - return _coerce_int(self._cfg(key, default), default) - - def _is_enabled(self) -> bool: - return bool(self._cfg("web.import.enabled", True)) - - def _queue_limit(self) -> int: - return max(1, self._cfg_int("web.import.max_queue_size", 20)) - - def _max_files_per_task(self) -> int: - return max(1, self._cfg_int("web.import.max_files_per_task", 200)) - - def _max_file_size_bytes(self) -> int: - mb = max(1, self._cfg_int("web.import.max_file_size_mb", 20)) - return mb * 1024 * 1024 - - def _max_paste_chars(self) -> int: - return max(1000, self._cfg_int("web.import.max_paste_chars", 200000)) - - def _default_file_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.default_file_concurrency", 2)) - - def _default_chunk_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.default_chunk_concurrency", 4)) - - def _max_file_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.max_file_concurrency", 6)) - - def _max_chunk_concurrency(self) -> int: - return max(1, self._cfg_int("web.import.max_chunk_concurrency", 12)) - - def _llm_retry_config(self) -> Dict[str, float]: - retries = max(0, self._cfg_int("web.import.llm_retry.max_attempts", 4)) - min_wait = max(0.1, float(self._cfg("web.import.llm_retry.min_wait_seconds", 3) or 3)) - max_wait = max(min_wait, float(self._cfg("web.import.llm_retry.max_wait_seconds", 40) or 40)) - mult = max(1.0, float(self._cfg("web.import.llm_retry.backoff_multiplier", 3) or 3)) - return { - "retries": retries, - "min_wait": min_wait, - "max_wait": max_wait, - "multiplier": mult, - } - - def _default_path_aliases(self) -> Dict[str, str]: - plugin_dir = Path(__file__).resolve().parents[2] - repo_root = self._resolve_repo_root() - return { - "raw": str((plugin_dir / "data" / "raw").resolve()), - "lpmm": str((repo_root / "data" / "lpmm_storage").resolve()), - "plugin_data": str((plugin_dir / "data").resolve()), - } - - def get_path_aliases(self) -> Dict[str, str]: - configured = self._cfg("web.import.path_aliases", self._default_path_aliases()) - if not isinstance(configured, dict): - configured = self._default_path_aliases() - - repo_root = self._resolve_repo_root() - result: Dict[str, str] = {} - for alias, raw_path in configured.items(): - key = str(alias or "").strip() - if not key: - continue - text = str(raw_path or "").strip() - if not text: - continue - if text.startswith("\\\\"): - continue - p = Path(text) - if not p.is_absolute(): - p = (repo_root / p).resolve() - else: - p = p.resolve() - result[key] = str(p) - - defaults = self._default_path_aliases() - for key, path in defaults.items(): - result.setdefault(key, path) - return result - - def resolve_path_alias( - self, - alias: str, - relative_path: str = "", - *, - must_exist: bool = False, - ) -> Path: - alias_key = str(alias or "").strip() - aliases = self.get_path_aliases() - if alias_key not in aliases: - raise ValueError(f"未知路径别名: {alias_key}") - - root = Path(aliases[alias_key]).resolve() - rel = str(relative_path or "").strip().replace("\\", "/") - if rel.startswith("/") or rel.startswith("\\") or rel.startswith("//"): - raise ValueError("relative_path 不能为绝对路径") - if ":" in rel: - raise ValueError("relative_path 不允许包含盘符") - - candidate = (root / rel).resolve() if rel else root - try: - candidate.relative_to(root) - except ValueError: - raise ValueError("路径越界:relative_path 超出白名单目录") - if must_exist and not candidate.exists(): - raise ValueError(f"路径不存在: {candidate}") - return candidate - - async def resolve_path_request(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "").strip() - relative_path = str(payload.get("relative_path") or "").strip() - must_exist = _coerce_bool(payload.get("must_exist"), True) - resolved = self.resolve_path_alias(alias, relative_path, must_exist=must_exist) - return { - "alias": alias, - "relative_path": relative_path, - "resolved_path": str(resolved), - "exists": resolved.exists(), - "is_file": resolved.is_file(), - "is_dir": resolved.is_dir(), - } - - def _load_manifest(self) -> Dict[str, Any]: - if self._manifest_cache is not None: - return self._manifest_cache - path = self._manifest_path - if not path.exists(): - self._manifest_cache = {} - return self._manifest_cache - try: - payload = json.loads(path.read_text(encoding="utf-8")) - if isinstance(payload, dict): - self._manifest_cache = payload - else: - self._manifest_cache = {} - except Exception: - self._manifest_cache = {} - return self._manifest_cache - - def _save_manifest(self, payload: Dict[str, Any]) -> None: - path = self._manifest_path - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - self._manifest_cache = payload - - def _clear_manifest(self) -> None: - self._save_manifest({}) - - def _normalize_manifest_path(self, raw_path: str) -> str: - text = str(raw_path or "").strip() - if not text: - return "" - return text.replace("\\", "/").strip().lower() - - def _match_manifest_item_for_source(self, source: str, item: Dict[str, Any]) -> bool: - source_text = str(source or "").strip() - if not source_text or ":" not in source_text: - return False - prefix, tail = source_text.split(":", 1) - source_kind = prefix.strip().lower() - source_value = tail.strip() - if not source_value: - return False - - item_kind = str(item.get("source_kind") or "").strip().lower() - item_name = str(item.get("name") or "").strip() - item_path_norm = self._normalize_manifest_path(item.get("source_path") or "") - - if source_kind in {"raw_scan", "lpmm_openie"}: - source_path_norm = self._normalize_manifest_path(source_value) - if source_path_norm and item_path_norm and source_path_norm == item_path_norm and item_kind == source_kind: - return True - - if source_kind == "web_import": - return item_kind in {"upload", "paste"} and item_name == source_value - - if source_kind == "lpmm_openie": - source_name = Path(source_value).name - return item_kind == "lpmm_openie" and item_name == source_name - - return False - - async def invalidate_manifest_for_sources(self, sources: List[str]) -> Dict[str, Any]: - requested_sources: List[str] = [] - seen_sources = set() - for raw in sources or []: - source = str(raw or "").strip() - if not source: - continue - key = source.lower() - if key in seen_sources: - continue - seen_sources.add(key) - requested_sources.append(source) - - result: Dict[str, Any] = { - "requested_sources": requested_sources, - "removed_count": 0, - "removed_keys": [], - "remaining_count": 0, - "unmatched_sources": [], - "warnings": [], - } - - async with self._lock: - manifest = self._load_manifest() - if not isinstance(manifest, dict): - manifest = {} - - valid_items: List[Tuple[str, Dict[str, Any]]] = [] - malformed_keys: List[str] = [] - for key, item in manifest.items(): - if isinstance(item, dict): - valid_items.append((str(key), item)) - else: - malformed_keys.append(str(key)) - - keys_to_remove = set() - for source in requested_sources: - matched = False - for key, item in valid_items: - if self._match_manifest_item_for_source(source, item): - keys_to_remove.add(key) - matched = True - if not matched: - result["unmatched_sources"].append(source) - - if keys_to_remove: - for key in keys_to_remove: - manifest.pop(key, None) - self._save_manifest(manifest) - - result["removed_keys"] = sorted(keys_to_remove) - result["removed_count"] = len(keys_to_remove) - result["remaining_count"] = len(manifest) - - if malformed_keys: - preview = ", ".join(malformed_keys[:5]) - extra = "" if len(malformed_keys) <= 5 else f" ... (+{len(malformed_keys) - 5})" - result["warnings"].append( - f"manifest 条目结构异常,已跳过 {len(malformed_keys)} 项: {preview}{extra}" - ) - - return result - - def _manifest_key_for_file(self, file_record: ImportFileRecord, content_hash: str, dedupe_policy: str) -> str: - if dedupe_policy == "content_hash": - return f"hash:{content_hash}" - if file_record.source_path: - return f"path:{Path(file_record.source_path).as_posix().lower()}" - return f"hash:{content_hash}" - - def _is_manifest_hit( - self, - file_record: ImportFileRecord, - content_hash: str, - dedupe_policy: str, - ) -> bool: - key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) - manifest = self._load_manifest() - item = manifest.get(key) - if not isinstance(item, dict): - return False - return str(item.get("hash") or "") == content_hash and bool(item.get("imported")) - - def _record_manifest_import( - self, - file_record: ImportFileRecord, - content_hash: str, - dedupe_policy: str, - task_id: str, - ) -> None: - key = self._manifest_key_for_file(file_record, content_hash, dedupe_policy) - manifest = self._load_manifest() - manifest[key] = { - "hash": content_hash, - "imported": True, - "timestamp": _now(), - "task_id": task_id, - "name": file_record.name, - "source_path": file_record.source_path or "", - "source_kind": file_record.source_kind, - } - self._save_manifest(manifest) - - def _normalize_common_import_params(self, payload: Dict[str, Any], *, default_dedupe: str) -> Dict[str, Any]: - input_mode = str(payload.get("input_mode", "text") or "text").strip().lower() - if input_mode not in {"text", "json"}: - raise ValueError("input_mode 必须为 text 或 json") - - file_concurrency = _coerce_int( - payload.get("file_concurrency", self._default_file_concurrency()), - self._default_file_concurrency(), - ) - chunk_concurrency = _coerce_int( - payload.get("chunk_concurrency", self._default_chunk_concurrency()), - self._default_chunk_concurrency(), - ) - file_concurrency = _clamp(file_concurrency, 1, self._max_file_concurrency()) - chunk_concurrency = _clamp(chunk_concurrency, 1, self._max_chunk_concurrency()) - - llm_enabled = _coerce_bool(payload.get("llm_enabled", True), True) - strategy_override = parse_import_strategy( - payload.get("strategy_override", "auto"), - default=ImportStrategy.AUTO, - ).value - - dedupe_policy = str(payload.get("dedupe_policy", default_dedupe) or default_dedupe).strip().lower() - if dedupe_policy not in {"content_hash", "manifest", "none"}: - raise ValueError("dedupe_policy 必须为 content_hash/manifest/none") - - chat_log = _coerce_bool(payload.get("chat_log"), False) - chat_reference_time = str(payload.get("chat_reference_time") or "").strip() or None - force = _coerce_bool(payload.get("force"), False) - clear_manifest = _coerce_bool(payload.get("clear_manifest"), False) - - return { - "input_mode": input_mode, - "file_concurrency": file_concurrency, - "chunk_concurrency": chunk_concurrency, - "llm_enabled": llm_enabled, - "strategy_override": strategy_override, - "chat_log": chat_log, - "chat_reference_time": chat_reference_time, - "force": force, - "clear_manifest": clear_manifest, - "dedupe_policy": dedupe_policy, - } - - def _normalize_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="content_hash") - params["task_kind"] = "upload" - return params - - def _normalize_raw_scan_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="manifest") - alias = str(payload.get("alias") or "raw").strip() - relative_path = str(payload.get("relative_path") or "").strip() - glob_pattern = str(payload.get("glob") or "*").strip() or "*" - recursive = _coerce_bool(payload.get("recursive"), True) - if ".." in relative_path.replace("\\", "/").split("/"): - raise ValueError("relative_path 不允许包含 ..") - params.update( - { - "task_kind": "raw_scan", - "alias": alias, - "relative_path": relative_path, - "glob": glob_pattern, - "recursive": recursive, - } - ) - return params - - def _normalize_lpmm_openie_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - params = self._normalize_common_import_params(payload, default_dedupe="manifest") - alias = str(payload.get("alias") or "lpmm").strip() - relative_path = str(payload.get("relative_path") or "").strip() - include_all_json = _coerce_bool(payload.get("include_all_json"), False) - params.update( - { - "task_kind": "lpmm_openie", - "alias": alias, - "relative_path": relative_path, - "include_all_json": include_all_json, - "input_mode": "json", - } - ) - return params - - def _normalize_temporal_backfill_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "plugin_data").strip() - relative_path = str(payload.get("relative_path") or "").strip() - dry_run = _coerce_bool(payload.get("dry_run"), False) - no_created_fallback = _coerce_bool(payload.get("no_created_fallback"), False) - limit = _parse_optional_positive_int(payload.get("limit"), "limit") or 100000 - return { - "task_kind": "temporal_backfill", - "alias": alias, - "relative_path": relative_path, - "dry_run": dry_run, - "no_created_fallback": no_created_fallback, - "limit": limit, - } - - def _normalize_lpmm_convert_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - alias = str(payload.get("alias") or "lpmm").strip() - relative_path = str(payload.get("relative_path") or "").strip() - target_alias = str(payload.get("target_alias") or "plugin_data").strip() - target_relative_path = str(payload.get("target_relative_path") or "").strip() - dimension = _parse_optional_positive_int(payload.get("dimension"), "dimension") or _coerce_int( - self._cfg("embedding.dimension", 384), - 384, - ) - batch_size = _parse_optional_positive_int(payload.get("batch_size"), "batch_size") or 1024 - return { - "task_kind": "lpmm_convert", - "alias": alias, - "relative_path": relative_path, - "target_alias": target_alias, - "target_relative_path": target_relative_path, - "dimension": dimension, - "batch_size": batch_size, - } - - def _normalize_by_task_kind(self, task_kind: str, payload: Dict[str, Any]) -> Dict[str, Any]: - kind = str(task_kind or "").strip().lower() - if kind in {"upload", "paste"}: - params = self._normalize_params(payload) - params["task_kind"] = kind - return params - if kind == "maibot_migration": - return self._normalize_migration_params(payload) - if kind == "raw_scan": - return self._normalize_raw_scan_params(payload) - if kind == "lpmm_openie": - return self._normalize_lpmm_openie_params(payload) - if kind == "temporal_backfill": - return self._normalize_temporal_backfill_params(payload) - if kind == "lpmm_convert": - return self._normalize_lpmm_convert_params(payload) - # upload/paste 默认走通用文本导入参数 - return self._normalize_params(payload) - - def _normalize_migration_params(self, payload: Dict[str, Any]) -> Dict[str, Any]: - source_db = str(payload.get("source_db") or "").strip() - if not source_db: - source_db = str(self._default_maibot_source_db()) - - time_from = str(payload.get("time_from") or "").strip() or None - time_to = str(payload.get("time_to") or "").strip() or None - - stream_ids = _coerce_list(payload.get("stream_ids")) - group_ids = _coerce_list(payload.get("group_ids")) - user_ids = _coerce_list(payload.get("user_ids")) - - start_id = _parse_optional_positive_int(payload.get("start_id"), "start_id") - end_id = _parse_optional_positive_int(payload.get("end_id"), "end_id") - if start_id is not None and end_id is not None and start_id > end_id: - raise ValueError("start_id 不能大于 end_id") - - read_batch_size = _parse_optional_positive_int(payload.get("read_batch_size"), "read_batch_size") or 2000 - commit_window_rows = _parse_optional_positive_int(payload.get("commit_window_rows"), "commit_window_rows") or 20000 - embed_batch_size = _parse_optional_positive_int(payload.get("embed_batch_size"), "embed_batch_size") or 256 - entity_embed_batch_size = ( - _parse_optional_positive_int(payload.get("entity_embed_batch_size"), "entity_embed_batch_size") or 512 - ) - embed_workers = _parse_optional_positive_int(payload.get("embed_workers"), "embed_workers") - max_errors = _parse_optional_positive_int(payload.get("max_errors"), "max_errors") or 500 - log_every = _parse_optional_positive_int(payload.get("log_every"), "log_every") or 5000 - preview_limit = _parse_optional_positive_int(payload.get("preview_limit"), "preview_limit") or 20 - - no_resume = _coerce_bool(payload.get("no_resume"), False) - reset_state = _coerce_bool(payload.get("reset_state"), False) - dry_run = _coerce_bool(payload.get("dry_run"), False) - verify_only = _coerce_bool(payload.get("verify_only"), False) - - return { - "task_kind": "maibot_migration", - "source_db": source_db, - "target_data_dir": str(self._resolve_data_dir()), - "time_from": time_from, - "time_to": time_to, - "stream_ids": stream_ids, - "group_ids": group_ids, - "user_ids": user_ids, - "start_id": start_id, - "end_id": end_id, - "read_batch_size": read_batch_size, - "commit_window_rows": commit_window_rows, - "embed_batch_size": embed_batch_size, - "entity_embed_batch_size": entity_embed_batch_size, - "embed_workers": embed_workers, - "max_errors": max_errors, - "log_every": log_every, - "preview_limit": preview_limit, - "no_resume": no_resume, - "reset_state": reset_state, - "dry_run": dry_run, - "verify_only": verify_only, - } - - def _pending_task_count(self) -> int: - pending = 0 - for task in self._tasks.values(): - if task.status in {"queued", "preparing", "running", "cancel_requested"}: - pending += 1 - return pending - - async def _ensure_worker(self) -> None: - async with self._lock: - if self._worker_task and not self._worker_task.done(): - return - self._stopping = False - self._worker_task = asyncio.create_task(self._worker_loop()) - - async def get_runtime_settings(self) -> Dict[str, Any]: - llm_retry = self._llm_retry_config() - return { - "max_queue_size": self._queue_limit(), - "max_files_per_task": self._max_files_per_task(), - "max_file_size_mb": self._cfg_int("web.import.max_file_size_mb", 20), - "max_paste_chars": self._max_paste_chars(), - "default_file_concurrency": self._default_file_concurrency(), - "default_chunk_concurrency": self._default_chunk_concurrency(), - "max_file_concurrency": self._max_file_concurrency(), - "max_chunk_concurrency": self._max_chunk_concurrency(), - "poll_interval_ms": max(200, self._cfg_int("web.import.poll_interval_ms", 1000)), - "maibot_source_db_default": str(self._default_maibot_source_db()), - "maibot_target_data_dir": str(self._resolve_data_dir()), - "path_aliases": self.get_path_aliases(), - "llm_retry": llm_retry, - "convert_enable_staging_switch": _coerce_bool( - self._cfg("web.import.convert.enable_staging_switch", True), True - ), - "convert_keep_backup_count": max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)), - } - - def is_write_blocked(self) -> bool: - task_id = self._active_task_id - if not task_id: - return False - task = self._tasks.get(task_id) - if not task: - return False - return task.status in {"preparing", "running", "cancel_requested"} - - def _ensure_ready(self) -> None: - required_attrs = ("metadata_store", "vector_store", "graph_store", "embedding_manager") - - def _collect_missing() -> List[str]: - missing_local: List[str] = [] - for attr in required_attrs: - if getattr(self.plugin, attr, None) is None: - missing_local.append(attr) - return missing_local - - missing = _collect_missing() - if missing: - raise ValueError(f"导入依赖未初始化: {', '.join(missing)}") - ready_checker = getattr(self.plugin, "is_runtime_ready", None) - if callable(ready_checker) and not ready_checker(): - raise ValueError("插件运行时未就绪,请先完成 on_enable 初始化") - - def _scan_files( - self, - base_path: Path, - *, - recursive: bool, - glob_pattern: str, - allowed_exts: Optional[set[str]] = None, - ) -> List[Path]: - if base_path.is_file(): - candidates = [base_path] - else: - if recursive: - candidates = list(base_path.rglob(glob_pattern)) - else: - candidates = list(base_path.glob(glob_pattern)) - out: List[Path] = [] - for p in candidates: - if not p.is_file(): - continue - ext = p.suffix.lower() - if allowed_exts and ext not in allowed_exts: - continue - out.append(p.resolve()) - out.sort(key=lambda x: x.as_posix().lower()) - return out - - async def create_upload_task(self, files: List[Any], payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - if not files: - raise ValueError("至少需要上传一个文件") - - params = self._normalize_params(payload) - max_files = self._max_files_per_task() - if len(files) > max_files: - raise ValueError(f"单任务文件数超过上限: {max_files}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="upload", - params=params, - status="queued", - current_step="queued", - ) - task_dir = self._temp_root / task.task_id - task_dir.mkdir(parents=True, exist_ok=True) - - max_size = self._max_file_size_bytes() - for idx, uploaded in enumerate(files): - file_id = uuid.uuid4().hex - if isinstance(uploaded, dict): - staged_path_raw = uploaded.get("staged_path") or uploaded.get("path") or "" - staged_path = Path(str(staged_path_raw or "")).expanduser().resolve() - if not staged_path.is_file(): - raise ValueError(f"上传暂存文件不存在: {staged_path}") - name = _safe_filename(uploaded.get("filename") or uploaded.get("name") or staged_path.name) - ext = Path(name).suffix.lower() - if ext not in {".txt", ".md", ".json"}: - raise ValueError(f"不支持的文件类型: {name}") - if staged_path.stat().st_size > max_size: - raise ValueError(f"文件超过大小限制: {name}") - temp_path = task_dir / f"{file_id}_{name}" - shutil.copy2(staged_path, temp_path) - else: - name = _safe_filename(getattr(uploaded, "filename", f"file_{idx}.txt")) - ext = Path(name).suffix.lower() - if ext not in {".txt", ".md", ".json"}: - raise ValueError(f"不支持的文件类型: {name}") - content = await uploaded.read() - if len(content) > max_size: - raise ValueError(f"文件超过大小限制: {name}") - temp_path = task_dir / f"{file_id}_{name}" - temp_path.write_bytes(content) - file_mode = "json" if ext == ".json" else params["input_mode"] - task.files.append( - ImportFileRecord( - file_id=file_id, - name=name, - source_kind="upload", - input_mode=file_mode, - temp_path=str(temp_path), - ) - ) - - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_paste_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - - params = self._normalize_params(payload) - params["task_kind"] = "paste" - content = str(payload.get("content", "") or "") - if not content.strip(): - raise ValueError("content 不能为空") - if len(content) > self._max_paste_chars(): - raise ValueError(f"粘贴内容超过限制: {self._max_paste_chars()} 字符") - - name = _safe_filename(payload.get("name") or f"paste_{int(_now())}.txt") - if params["input_mode"] == "json" and Path(name).suffix.lower() != ".json": - name = f"{Path(name).stem}.json" - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="paste", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=name, - source_kind="paste", - input_mode=params["input_mode"], - inline_content=content, - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_raw_scan_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - params = self._normalize_raw_scan_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - files = self._scan_files( - source_path, - recursive=bool(params["recursive"]), - glob_pattern=str(params["glob"] or "*"), - allowed_exts={".txt", ".md", ".json"}, - ) - if not files: - raise ValueError("未找到可导入文件") - if len(files) > self._max_files_per_task(): - raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="raw_scan", - params=params, - status="queued", - current_step="queued", - ) - for path in files: - mode = "json" if path.suffix.lower() == ".json" else params["input_mode"] - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=path.name, - source_kind="raw_scan", - input_mode=mode, - source_path=str(path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_lpmm_openie_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - params = self._normalize_lpmm_openie_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - files: List[Path] = [] - if source_path.is_file(): - files = [source_path] - else: - files = self._scan_files( - source_path, - recursive=True, - glob_pattern="*-openie.json", - allowed_exts={".json"}, - ) - if not files and params.get("include_all_json"): - files = self._scan_files( - source_path, - recursive=True, - glob_pattern="*.json", - allowed_exts={".json"}, - ) - if not files: - raise ValueError("未找到 LPMM OpenIE JSON 文件") - if len(files) > self._max_files_per_task(): - raise ValueError(f"单任务文件数超过上限: {self._max_files_per_task()}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="lpmm_openie", - params=params, - status="queued", - current_step="queued", - schema_detected="lpmm_openie", - ) - for path in files: - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=path.name, - source_kind="lpmm_openie", - input_mode="json", - source_path=str(path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_temporal_backfill_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - params = self._normalize_temporal_backfill_params(payload) - target_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - if not target_path.is_dir(): - raise ValueError("temporal_backfill 目标路径必须为目录") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="temporal_backfill", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"temporal_backfill_{int(_now())}", - source_kind="temporal_backfill", - input_mode="json", - source_path=str(target_path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_lpmm_convert_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - params = self._normalize_lpmm_convert_params(payload) - source_path = self.resolve_path_alias( - params["alias"], - params["relative_path"], - must_exist=True, - ) - if not source_path.is_dir(): - raise ValueError("lpmm_convert 输入路径必须为目录") - target_path = self.resolve_path_alias( - params["target_alias"], - params["target_relative_path"], - must_exist=False, - ) - target_path.mkdir(parents=True, exist_ok=True) - if not target_path.is_dir(): - raise ValueError("lpmm_convert 目标路径必须为目录") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="lpmm_convert", - params={**params, "source_path": str(source_path), "target_path": str(target_path)}, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"lpmm_convert_{int(_now())}", - source_kind="lpmm_convert", - input_mode="json", - source_path=str(source_path), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def create_maibot_migration_task(self, payload: Dict[str, Any]) -> Dict[str, Any]: - if not self._is_enabled(): - raise ValueError("导入功能已禁用") - self._ensure_ready() - - params = self._normalize_migration_params(payload) - script_path = self._resolve_migration_script() - if not script_path.exists(): - raise ValueError(f"迁移脚本不存在: {script_path}") - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - - task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source="maibot_migration", - params=params, - status="queued", - current_step="queued", - ) - task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=f"maibot_migration_{int(_now())}", - source_kind="maibot_migration", - input_mode="text", - inline_content=json.dumps(params, ensure_ascii=False), - ) - ) - self._tasks[task.task_id] = task - self._task_order.appendleft(task.task_id) - self._queue.append(task.task_id) - - await self._ensure_worker() - return task.to_summary() - - async def list_tasks(self, limit: int = 50) -> List[Dict[str, Any]]: - async with self._lock: - task_ids = list(self._task_order)[: max(1, int(limit))] - return [self._tasks[task_id].to_summary() for task_id in task_ids if task_id in self._tasks] - - async def get_task(self, task_id: str, include_chunks: bool = False) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - return task.to_detail(include_chunks=include_chunks) - - async def get_chunks(self, task_id: str, file_id: str, offset: int = 0, limit: int = 50) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - file_obj = self._find_file(task, file_id) - if not file_obj: - return None - start = max(0, int(offset)) - size = max(1, min(500, int(limit))) - items = file_obj.chunks[start : start + size] - return { - "task_id": task_id, - "file_id": file_id, - "offset": start, - "limit": size, - "total": len(file_obj.chunks), - "items": [x.to_dict() for x in items], - } - - async def cancel_task(self, task_id: str) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - if task.status == "queued": - self._mark_task_cancelled_locked(task, "任务已取消") - self._queue = deque([x for x in self._queue if x != task_id]) - elif task.status in {"preparing", "running"}: - task.status = "cancel_requested" - task.current_step = "cancel_requested" - task.updated_at = _now() - return task.to_summary() - - def _build_retry_plan(self, task: ImportTaskRecord) -> Dict[str, Any]: - chunk_retry_candidates: List[Tuple[ImportFileRecord, List[int]]] = [] - file_fallback_candidates: List[ImportFileRecord] = [] - skipped: List[Dict[str, str]] = [] - - for file_obj in task.files: - if file_obj.status == "cancelled": - continue - - failed_chunks = [c for c in file_obj.chunks if c.status == "failed"] - has_file_level_failure = file_obj.status == "failed" and not failed_chunks - if has_file_level_failure: - file_fallback_candidates.append(file_obj) - continue - - if not failed_chunks: - continue - - retry_indexes: List[int] = [] - has_non_retryable = False - for chunk in failed_chunks: - failed_at = str(chunk.failed_at or "").strip().lower() - retryable = bool(chunk.retryable) or ( - file_obj.input_mode == "text" and failed_at == "extracting" - ) - if retryable: - try: - retry_indexes.append(int(chunk.index)) - except Exception: - has_non_retryable = True - else: - has_non_retryable = True - - if has_non_retryable: - file_fallback_candidates.append(file_obj) - continue - - retry_indexes = sorted(set(retry_indexes)) - if retry_indexes: - chunk_retry_candidates.append((file_obj, retry_indexes)) - else: - skipped.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": "no_retryable_failed_chunks", - } - ) - - unique_fallback: List[ImportFileRecord] = [] - fallback_seen = set() - for file_obj in file_fallback_candidates: - if file_obj.file_id in fallback_seen: - continue - fallback_seen.add(file_obj.file_id) - unique_fallback.append(file_obj) - - return { - "chunk_retry_candidates": chunk_retry_candidates, - "file_fallback_candidates": unique_fallback, - "skipped": skipped, - } - - def _clone_failed_file_for_retry( - self, - retry_task: ImportTaskRecord, - failed_file: ImportFileRecord, - task_dir: Path, - *, - retry_mode: str, - retry_chunk_indexes: Optional[List[int]] = None, - ) -> Tuple[bool, str]: - source_kind = str(failed_file.source_kind or "").strip().lower() - retry_chunk_indexes = list(retry_chunk_indexes or []) - - if source_kind == "upload": - candidate_paths: List[Path] = [] - if failed_file.temp_path: - candidate_paths.append(Path(failed_file.temp_path)) - if failed_file.source_path: - candidate_paths.append(Path(failed_file.source_path)) - src_path = next((p for p in candidate_paths if p.exists() and p.is_file()), None) - if src_path is None: - return False, "upload_source_missing" - data = src_path.read_bytes() - file_id = uuid.uuid4().hex - name = _safe_filename(failed_file.name) - dst = task_dir / f"{file_id}_{name}" - dst.write_bytes(data) - retry_task.files.append( - ImportFileRecord( - file_id=file_id, - name=name, - source_kind="upload", - input_mode=failed_file.input_mode, - temp_path=str(dst), - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - if source_kind == "paste": - if failed_file.inline_content is None: - return False, "paste_content_missing" - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind="paste", - input_mode=failed_file.input_mode, - inline_content=failed_file.inline_content, - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - if source_kind == "maibot_migration": - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind="maibot_migration", - input_mode="text", - inline_content=failed_file.inline_content, - retry_mode="file_fallback", - retry_chunk_indexes=[], - ) - ) - return True, "" - - if source_kind in {"raw_scan", "lpmm_openie", "lpmm_convert", "temporal_backfill"}: - retry_task.files.append( - ImportFileRecord( - file_id=uuid.uuid4().hex, - name=_safe_filename(failed_file.name), - source_kind=source_kind, - input_mode=failed_file.input_mode, - source_path=failed_file.source_path, - inline_content=failed_file.inline_content, - retry_mode=retry_mode, - retry_chunk_indexes=retry_chunk_indexes, - ) - ) - return True, "" - - return False, f"unsupported_source_kind:{source_kind or 'unknown'}" - - async def retry_failed(self, task_id: str, overrides: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return None - retry_plan = self._build_retry_plan(task) - chunk_retry_candidates = list(retry_plan["chunk_retry_candidates"]) - file_fallback_candidates = list(retry_plan["file_fallback_candidates"]) - skipped_candidates = list(retry_plan["skipped"]) - if not chunk_retry_candidates and not file_fallback_candidates: - raise ValueError("当前任务没有可重试失败项") - base_params = dict(task.params) - task_kind = str(task.params.get("task_kind") or "").strip().lower() - - if overrides: - base_params.update(overrides) - params = self._normalize_by_task_kind(task_kind, base_params) - params["retry_parent_task_id"] = task_id - params["retry_strategy"] = "chunk_first_auto_file_fallback" - - async with self._lock: - if self._pending_task_count() >= self._queue_limit(): - raise ValueError("任务队列已满,请稍后重试") - retry_task = ImportTaskRecord( - task_id=uuid.uuid4().hex, - source=task.source, - params=params, - status="queued", - current_step="queued", - schema_detected=task.schema_detected, - retry_parent_task_id=task_id, - ) - - task_dir = self._temp_root / retry_task.task_id - task_dir.mkdir(parents=True, exist_ok=True) - - retry_summary = { - "chunk_retry_files": 0, - "chunk_retry_chunks": 0, - "file_fallback_files": 0, - "skipped_files": 0, - "parent_task_id": task_id, - } - skipped_details = list(skipped_candidates) - - for file_obj, chunk_indexes in chunk_retry_candidates: - ok, reason = self._clone_failed_file_for_retry( - retry_task, - file_obj, - task_dir, - retry_mode="chunk", - retry_chunk_indexes=chunk_indexes, - ) - if ok: - retry_summary["chunk_retry_files"] += 1 - retry_summary["chunk_retry_chunks"] += len(chunk_indexes) - else: - skipped_details.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": reason, - } - ) - - for file_obj in file_fallback_candidates: - ok, reason = self._clone_failed_file_for_retry( - retry_task, - file_obj, - task_dir, - retry_mode="file_fallback", - retry_chunk_indexes=[], - ) - if ok: - retry_summary["file_fallback_files"] += 1 - else: - skipped_details.append( - { - "file_name": file_obj.name, - "source_kind": file_obj.source_kind, - "reason": reason, - } - ) - - retry_summary["skipped_files"] = len(skipped_details) - if skipped_details: - retry_summary["skipped_details"] = skipped_details - retry_task.retry_summary = retry_summary - - if not retry_task.files: - raise ValueError("无可执行的重试输入:失败项均无法构建重试任务") - - self._tasks[retry_task.task_id] = retry_task - self._task_order.appendleft(retry_task.task_id) - self._queue.append(retry_task.task_id) - logger.info( - "重试任务已创建 " - f"parent={task_id} retry={retry_task.task_id} " - f"chunk_files={retry_summary['chunk_retry_files']} " - f"chunk_chunks={retry_summary['chunk_retry_chunks']} " - f"file_fallback={retry_summary['file_fallback_files']} " - f"skipped={retry_summary['skipped_files']}" - ) - - await self._ensure_worker() - return retry_task.to_summary() - - async def shutdown(self) -> None: - async with self._lock: - self._stopping = True - for task in self._tasks.values(): - if task.status in {"queued", "preparing", "running", "cancel_requested"}: - self._mark_task_cancelled_locked(task, "服务关闭") - self._queue.clear() - worker = self._worker_task - self._worker_task = None - - if worker: - worker.cancel() - try: - await worker - except asyncio.CancelledError: - pass - except Exception: - pass - - self._cleanup_temp_root() - - def _cleanup_temp_root(self) -> None: - try: - if not self._temp_root.exists(): - return - for child in self._temp_root.rglob("*"): - if child.is_file(): - child.unlink(missing_ok=True) - for child in sorted(self._temp_root.rglob("*"), reverse=True): - if child.is_dir(): - child.rmdir() - self._temp_root.rmdir() - except Exception as e: - logger.warning(f"清理临时导入目录失败: {e}") - - async def _worker_loop(self) -> None: - logger.info("Web 导入任务 worker 已启动") - while True: - if self._stopping: - break - - task_id: Optional[str] = None - async with self._lock: - while self._queue: - candidate = self._queue.popleft() - t = self._tasks.get(candidate) - if not t: - continue - if t.status == "cancelled": - continue - task_id = candidate - self._active_task_id = candidate - break - - if not task_id: - await asyncio.sleep(0.2) - continue - - try: - await self._run_task(task_id) - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"导入任务执行失败 task={task_id}: {e}\n{traceback.format_exc()}") - async with self._lock: - task = self._tasks.get(task_id) - if task and task.status not in {"cancelled", "completed", "completed_with_errors"}: - task.status = "failed" - task.current_step = "failed" - task.error = str(e) - task.finished_at = _now() - task.updated_at = _now() - finally: - should_cleanup = await self._should_cleanup_task_temp(task_id) - async with self._lock: - if self._active_task_id == task_id: - self._active_task_id = None - if should_cleanup: - await self._cleanup_task_temp_files(task_id) - - logger.info("Web 导入任务 worker 已停止") - - async def _cleanup_task_temp_files(self, task_id: str) -> None: - task_dir = self._temp_root / task_id - if not task_dir.exists(): - return - try: - for child in task_dir.rglob("*"): - if child.is_file(): - child.unlink(missing_ok=True) - for child in sorted(task_dir.rglob("*"), reverse=True): - if child.is_dir(): - child.rmdir() - task_dir.rmdir() - except Exception as e: - logger.warning(f"清理任务临时文件失败 task={task_id}: {e}") - - def _task_report_path(self, task_id: str) -> Path: - self._reports_root.mkdir(parents=True, exist_ok=True) - return self._reports_root / f"{task_id}_summary.json" - - def _write_task_report(self, task: ImportTaskRecord) -> None: - path = self._task_report_path(task.task_id) - payload = task.to_detail(include_chunks=False) - payload["generated_at"] = _now() - path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - task.artifact_paths["summary"] = str(path) - - async def _run_task(self, task_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - task.status = "preparing" - task.current_step = "preparing" - task.started_at = _now() - task.updated_at = _now() - if task.params.get("clear_manifest"): - self._clear_manifest() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - if task.status == "cancel_requested": - task.status = "cancelled" - task.current_step = "cancelled" - task.finished_at = _now() - task.updated_at = _now() - return - task.status = "running" - task.current_step = "running" - task.updated_at = _now() - - task_kind = str(task.params.get("task_kind") or task.source).strip().lower() - if task_kind == "maibot_migration": - if not task.files: - raise RuntimeError("迁移任务缺少文件记录") - await self._process_maibot_migration(task_id, task.files[0]) - elif task_kind == "temporal_backfill": - if not task.files: - raise RuntimeError("回填任务缺少文件记录") - await self._process_temporal_backfill(task_id, task.files[0]) - elif task_kind == "lpmm_convert": - if not task.files: - raise RuntimeError("转换任务缺少文件记录") - await self._process_lpmm_convert(task_id, task.files[0]) - else: - file_semaphore = asyncio.Semaphore(task.params["file_concurrency"]) - chunk_semaphore = asyncio.Semaphore(task.params["chunk_concurrency"]) - jobs = [ - asyncio.create_task(self._process_file(task_id, f, file_semaphore, chunk_semaphore)) - for f in task.files - ] - await asyncio.gather(*jobs, return_exceptions=True) - - write_changed_payload: Optional[Dict[str, Any]] = None - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - self._recompute_task_progress(task) - has_failed = any( - (f.status == "failed") - or (f.failed_chunks > 0) - or bool(str(f.error or "").strip()) - for f in task.files - ) - has_cancelled = any(f.status == "cancelled" for f in task.files) - has_completed = any(f.status == "completed" for f in task.files) - - # 统一按文件真实终态收敛任务状态,避免出现“任务已取消但文件已完成”的矛盾结果。 - if has_failed and not has_cancelled: - task.status = "completed_with_errors" - task.current_step = "completed_with_errors" - elif has_cancelled and not has_completed: - task.status = "cancelled" - task.current_step = "cancelled" - elif has_cancelled and has_completed: - task.status = "cancelled" - task.current_step = "cancelled" - else: - task.status = "completed" - task.current_step = "completed" - task.finished_at = _now() - task.updated_at = _now() - try: - self._write_task_report(task) - except Exception as report_err: - logger.warning(f"写入任务报告失败 task={task_id}: {report_err}") - task_kind = str(task.params.get("task_kind") or task.source).strip().lower() - write_task_kinds = {"upload", "paste", "raw_scan", "lpmm_openie", "maibot_migration", "lpmm_convert"} - has_written_chunks = (task.done_chunks > 0) or any(f.done_chunks > 0 for f in task.files) - if task_kind in write_task_kinds and has_written_chunks: - write_changed_payload = { - "task_id": task.task_id, - "task_kind": task_kind, - "status": task.status, - "done_chunks": task.done_chunks, - "finished_at": task.finished_at, - } - - if write_changed_payload: - await self._notify_write_changed(write_changed_payload) - - def _build_maibot_migration_command(self, params: Dict[str, Any]) -> List[str]: - script_path = self._resolve_migration_script() - if not script_path.exists(): - raise RuntimeError(f"迁移脚本不存在: {script_path}") - - cmd = [ - sys.executable, - str(script_path), - "--source-db", - str(params["source_db"]), - "--target-data-dir", - str(params["target_data_dir"]), - "--read-batch-size", - str(params["read_batch_size"]), - "--commit-window-rows", - str(params["commit_window_rows"]), - "--embed-batch-size", - str(params["embed_batch_size"]), - "--entity-embed-batch-size", - str(params["entity_embed_batch_size"]), - "--max-errors", - str(params["max_errors"]), - "--log-every", - str(params["log_every"]), - "--preview-limit", - str(params["preview_limit"]), - "--yes", - ] - - if params.get("embed_workers") is not None: - cmd.extend(["--embed-workers", str(params["embed_workers"])]) - if params.get("start_id") is not None: - cmd.extend(["--start-id", str(params["start_id"])]) - if params.get("end_id") is not None: - cmd.extend(["--end-id", str(params["end_id"])]) - if params.get("time_from"): - cmd.extend(["--time-from", str(params["time_from"])]) - if params.get("time_to"): - cmd.extend(["--time-to", str(params["time_to"])]) - - for sid in params.get("stream_ids") or []: - cmd.extend(["--stream-id", str(sid)]) - for gid in params.get("group_ids") or []: - cmd.extend(["--group-id", str(gid)]) - for uid in params.get("user_ids") or []: - cmd.extend(["--user-id", str(uid)]) - - if params.get("reset_state"): - cmd.append("--reset-state") - if params.get("no_resume"): - cmd.append("--no-resume") - if params.get("dry_run"): - cmd.append("--dry-run") - if params.get("verify_only"): - cmd.append("--verify-only") - - return cmd - - async def _ensure_maibot_migration_chunk( - self, - task_id: str, - file_id: str, - *, - chunk_type: str = "maibot_migration", - preview: str = "MaiBot chat_history 迁移任务", - ) -> str: - chunk_id = f"{file_id}_{chunk_type}" - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return chunk_id - f = self._find_file(task, file_id) - if not f: - return chunk_id - if not f.chunks: - f.chunks = [ - ImportChunkRecord( - chunk_id=chunk_id, - index=0, - chunk_type=chunk_type, - status="queued", - step="queued", - progress=0.0, - content_preview=preview, - ) - ] - f.total_chunks = 1 - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 - f.updated_at = _now() - self._recompute_task_progress(task) - else: - chunk_id = f.chunks[0].chunk_id - return chunk_id - - async def _refresh_maibot_progress_from_state( - self, - task_id: str, - file_id: str, - chunk_id: str, - state_path: Path, - ) -> None: - if not state_path.exists(): - return - try: - payload = json.loads(state_path.read_text(encoding="utf-8")) - except Exception: - return - - stats = payload.get("stats", {}) if isinstance(payload, dict) else {} - if not isinstance(stats, dict): - stats = {} - - total = max(0, _coerce_int(stats.get("source_matched_total", 0), 0)) - scanned = max(0, _coerce_int(stats.get("scanned_rows", 0), 0)) - bad = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) - done = max(0, scanned - bad) - migrated = max(0, _coerce_int(stats.get("migrated_rows", 0), 0)) - last_id = max(0, _coerce_int(stats.get("last_committed_id", 0), 0)) - - if total <= 0: - total = max(1, scanned) - - progress = max(0.0, min(1.0, float(scanned) / float(total))) if total > 0 else 0.0 - preview = f"scanned={scanned}/{total}, migrated={migrated}, bad={bad}, last_id={last_id}" - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if c: - if c.status not in {"completed", "failed", "cancelled"}: - c.status = "writing" - c.step = "migrating" - c.progress = progress - c.content_preview = preview - c.updated_at = _now() - f.total_chunks = total - f.done_chunks = done - f.failed_chunks = bad - f.cancelled_chunks = 0 - f.progress = progress - if f.status not in {"failed", "cancelled"}: - f.status = "writing" - f.current_step = "migrating" - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _terminate_process(self, process: asyncio.subprocess.Process) -> None: - if process.returncode is not None: - return - try: - process.terminate() - await asyncio.wait_for(process.wait(), timeout=5.0) - except Exception: - try: - process.kill() - await asyncio.wait_for(process.wait(), timeout=3.0) - except Exception: - pass - - async def _reload_stores_after_external_migration(self) -> None: - async with self._storage_lock: - try: - if self.plugin.vector_store and self.plugin.vector_store.has_data(): - self.plugin.vector_store.load() - except Exception as e: - logger.warning(f"迁移后重载 VectorStore 失败: {e}") - try: - if self.plugin.graph_store and self.plugin.graph_store.has_data(): - self.plugin.graph_store.load() - except Exception as e: - logger.warning(f"迁移后重载 GraphStore 失败: {e}") - - async def _process_maibot_migration(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "maibot_migration") - await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="maibot_migration", - preview="MaiBot chat_history 迁移任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "migrating", 0.0) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - - command = self._build_maibot_migration_command(params) - project_root = self._resolve_repo_root() - state_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_resume.json" - report_path = Path(params["target_data_dir"]) / "migration_state" / "chat_history_report.json" - - logger.info(f"开始执行 MaiBot 迁移任务: {' '.join(command)}") - process = await asyncio.create_subprocess_exec( - *command, - cwd=str(project_root), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdout_lines: List[str] = [] - stderr_lines: List[str] = [] - - async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: - if stream is None: - return - while True: - line = await stream.readline() - if not line: - break - text = line.decode("utf-8", errors="replace").strip() - if not text: - continue - target.append(text) - if len(target) > 120: - del target[:-120] - - drain_tasks = [ - asyncio.create_task(_drain(process.stdout, stdout_lines)), - asyncio.create_task(_drain(process.stderr, stderr_lines)), - ] - - cancelled = False - return_code: Optional[int] = None - try: - while True: - if await self._is_cancel_requested(task_id): - cancelled = True - await self._terminate_process(process) - break - - await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) - try: - return_code = await asyncio.wait_for(process.wait(), timeout=1.0) - break - except asyncio.TimeoutError: - continue - finally: - await asyncio.gather(*drain_tasks, return_exceptions=True) - - if cancelled: - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._refresh_maibot_progress_from_state(task_id, file_record.file_id, chunk_id, state_path) - - report: Dict[str, Any] = {} - if report_path.exists(): - try: - report = json.loads(report_path.read_text(encoding="utf-8")) - except Exception: - report = {} - - stats = report.get("stats", {}) if isinstance(report, dict) else {} - if not isinstance(stats, dict): - stats = {} - bad_rows = max(0, _coerce_int(stats.get("bad_rows", 0), 0)) - - if return_code in {0, 2}: - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - await self._reload_stores_after_external_migration() - - async with self._lock: - task2 = self._tasks.get(task_id) - if not task2: - return - f = self._find_file(task2, file_record.file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if c and c.status not in {"cancelled", "failed"}: - c.status = "completed" - c.step = "completed" - c.progress = 1.0 - c.updated_at = _now() - if f.total_chunks <= 0: - f.total_chunks = 1 - if f.done_chunks + f.failed_chunks <= 0: - f.done_chunks = f.total_chunks - bad_rows - f.failed_chunks = bad_rows - f.done_chunks = max(0, min(f.done_chunks, f.total_chunks)) - f.failed_chunks = max(0, min(f.failed_chunks, f.total_chunks)) - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - if bad_rows > 0 and not f.error: - f.error = f"迁移完成,但存在坏行: {bad_rows}" - f.updated_at = _now() - self._recompute_task_progress(task2) - return - - fail_reason = "" - if isinstance(report, dict): - fail_reason = str(report.get("fail_reason") or "").strip() - tail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") - detail = fail_reason or tail or f"迁移进程退出码: {return_code}" - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail) - await self._set_file_failed(task_id, file_record.file_id, detail) - - def _resolve_convert_script(self) -> Path: - return Path(__file__).resolve().parents[2] / "scripts" / "convert_lpmm.py" - - def _cleanup_old_backups(self) -> None: - keep = max(0, self._cfg_int("web.import.convert.keep_backup_count", 3)) - backup_root = self._resolve_backup_root() - if not backup_root.exists() or keep <= 0: - return - dirs = [p for p in backup_root.iterdir() if p.is_dir() and p.name.startswith("lpmm_convert_")] - dirs.sort(key=lambda p: p.stat().st_mtime, reverse=True) - for old in dirs[keep:]: - try: - shutil.rmtree(old, ignore_errors=True) - except Exception: - pass - - def _verify_convert_output(self, output_dir: Path) -> Dict[str, Any]: - vectors = output_dir / "vectors" - graph = output_dir / "graph" - metadata = output_dir / "metadata" - checks = { - "vectors_exists": vectors.exists(), - "graph_exists": graph.exists(), - "metadata_exists": metadata.exists(), - "vectors_nonempty": vectors.exists() and any(vectors.iterdir()), - "graph_nonempty": graph.exists() and any(graph.iterdir()), - "metadata_nonempty": metadata.exists() and any(metadata.iterdir()), - } - checks["ok"] = checks["vectors_exists"] and checks["graph_exists"] and checks["metadata_exists"] - return checks - - async def _preflight_convert_runtime(self) -> Tuple[bool, str]: - """使用当前服务解释器做 convert 依赖预检,避免子进程报错信息不透明。""" - probe_code = ( - "import importlib\n" - "mods=['networkx','scipy','pyarrow']\n" - "failed=[]\n" - "for m in mods:\n" - " try:\n" - " importlib.import_module(m)\n" - " except Exception as e:\n" - " failed.append(f'{m}:{e.__class__.__name__}:{e}')\n" - "print('OK' if not failed else ';'.join(failed))\n" - ) - try: - probe = await asyncio.create_subprocess_exec( - sys.executable, - "-c", - probe_code, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout, stderr = await asyncio.wait_for(probe.communicate(), timeout=20.0) - except Exception as e: - return False, f"依赖预检执行失败: {e}" - - out = (stdout or b"").decode("utf-8", errors="replace").strip() - err = (stderr or b"").decode("utf-8", errors="replace").strip() - if probe.returncode != 0: - detail = err or out or f"return_code={probe.returncode}" - return False, f"依赖预检失败 (python={sys.executable}): {detail}" - if out != "OK": - return False, f"依赖预检失败 (python={sys.executable}): {out}" - return True, "" - - async def _process_lpmm_convert(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "lpmm_convert") - await self._set_file_state(task_id, file_record.file_id, "preparing", "preflight") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="lpmm_convert", - preview="LPMM 二进制转换任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "converting", 0.05) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - source_dir = Path(params.get("source_path") or "") - target_dir = Path(params.get("target_path") or "") - if not source_dir.exists() or not source_dir.is_dir(): - await self._set_file_failed(task_id, file_record.file_id, f"输入目录无效: {source_dir}") - return - if not target_dir.exists() or not target_dir.is_dir(): - await self._set_file_failed(task_id, file_record.file_id, f"目标目录无效: {target_dir}") - return - - script_path = self._resolve_convert_script() - if not script_path.exists(): - await self._set_file_failed(task_id, file_record.file_id, f"转换脚本不存在: {script_path}") - return - - runtime_ok, runtime_detail = await self._preflight_convert_runtime() - if not runtime_ok: - await self._set_file_failed(task_id, file_record.file_id, runtime_detail) - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, runtime_detail) - return - - required_inputs = ["paragraph.parquet", "entity.parquet"] - if not any((source_dir / name).exists() for name in required_inputs): - await self._set_file_failed( - task_id, - file_record.file_id, - f"输入目录缺少必要文件,至少需要其一: {', '.join(required_inputs)}", - ) - return - - staging_root = self._resolve_staging_root() - staging_root.mkdir(parents=True, exist_ok=True) - staging_dir = staging_root / f"lpmm_convert_{task_id}" - if staging_dir.exists(): - shutil.rmtree(staging_dir, ignore_errors=True) - staging_dir.mkdir(parents=True, exist_ok=True) - - # 简单空间预检:至少保留 512MB - usage = shutil.disk_usage(str(target_dir)) - if usage.free < 512 * 1024 * 1024: - await self._set_file_failed(task_id, file_record.file_id, "磁盘剩余空间不足(<512MB)") - return - - cmd = [ - sys.executable, - str(script_path), - "--input", - str(source_dir), - "--output", - str(staging_dir), - "--dim", - str(params.get("dimension", 384)), - "--batch-size", - str(params.get("batch_size", 1024)), - ] - process = await asyncio.create_subprocess_exec( - *cmd, - cwd=str(self._resolve_repo_root()), - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - stdout_lines: List[str] = [] - stderr_lines: List[str] = [] - - async def _drain(stream: Optional[asyncio.StreamReader], target: List[str]) -> None: - if stream is None: - return - while True: - line = await stream.readline() - if not line: - break - text = line.decode("utf-8", errors="replace").strip() - if text: - target.append(text) - if len(target) > 120: - del target[:-120] - - drain_tasks = [ - asyncio.create_task(_drain(process.stdout, stdout_lines)), - asyncio.create_task(_drain(process.stderr, stderr_lines)), - ] - - cancelled = False - return_code: Optional[int] = None - try: - while True: - if await self._is_cancel_requested(task_id): - cancelled = True - await self._terminate_process(process) - break - try: - return_code = await asyncio.wait_for(process.wait(), timeout=1.0) - break - except asyncio.TimeoutError: - continue - finally: - await asyncio.gather(*drain_tasks, return_exceptions=True) - - if cancelled: - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - if return_code != 0: - detail = (stderr_lines[-1] if stderr_lines else "") or (stdout_lines[-1] if stdout_lines else "") - await self._set_file_failed(task_id, file_record.file_id, detail or f"转换失败,退出码: {return_code}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, detail or f"退出码: {return_code}") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "verifying", 0.65) - verify = self._verify_convert_output(staging_dir) - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.artifact_paths["staging_dir"] = str(staging_dir) - t.artifact_paths["verify"] = json.dumps(verify, ensure_ascii=False) - if not verify.get("ok"): - await self._set_file_failed(task_id, file_record.file_id, f"校验失败: {verify}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"校验失败: {verify}") - return - - enable_switch = _coerce_bool(self._cfg("web.import.convert.enable_staging_switch", True), True) - if not enable_switch: - await self._set_file_failed(task_id, file_record.file_id, "未启用 staging 切换") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, "未启用 staging 切换") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "switching", 0.85) - backup_root = self._resolve_backup_root() - backup_root.mkdir(parents=True, exist_ok=True) - backup_dir = backup_root / f"lpmm_convert_{task_id}_{int(_now())}" - backup_dir.mkdir(parents=True, exist_ok=True) - - switched = False - rollback_info: Dict[str, Any] = {"attempted": True, "restored": False, "error": ""} - moved_items: List[Tuple[Path, Path]] = [] - try: - for name in ("vectors", "graph", "metadata"): - src_current = target_dir / name - src_new = staging_dir / name - if not src_new.exists(): - raise RuntimeError(f"staging 缺少目录: {src_new}") - if src_current.exists(): - dst_backup = backup_dir / name - shutil.move(str(src_current), str(dst_backup)) - moved_items.append((dst_backup, src_current)) - shutil.move(str(src_new), str(src_current)) - switched = True - except Exception as switch_err: - rollback_info["error"] = str(switch_err) - # 尝试回滚 - for src_backup, dst_original in moved_items: - if src_backup.exists() and not dst_original.exists(): - try: - shutil.move(str(src_backup), str(dst_original)) - except Exception: - pass - rollback_info["restored"] = True - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.rollback_info = rollback_info - await self._set_file_failed(task_id, file_record.file_id, f"切换失败并回滚: {switch_err}") - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"switch failed: {switch_err}") - return - - if switched: - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.rollback_info = rollback_info - t.artifact_paths["backup_dir"] = str(backup_dir) - self._cleanup_old_backups() - try: - await self._reload_stores_after_external_migration() - except Exception as reload_err: - logger.warning(f"转换后重载存储失败: {reload_err}") - - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - async with self._lock: - t = self._tasks.get(task_id) - if not t: - return - f = self._find_file(t, file_record.file_id) - if not f: - return - f.total_chunks = 1 - f.done_chunks = 1 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - f.updated_at = _now() - self._recompute_task_progress(t) - - async def _process_temporal_backfill(self, task_id: str, file_record: ImportFileRecord) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "temporal_backfill") - await self._set_file_state(task_id, file_record.file_id, "preparing", "backfilling") - chunk_id = await self._ensure_maibot_migration_chunk( - task_id, - file_record.file_id, - chunk_type="temporal_backfill", - preview="时序字段回填任务", - ) - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "backfilling", 0.2) - - task = self._tasks.get(task_id) - if not task: - await self._set_file_failed(task_id, file_record.file_id, "任务不存在") - return - params = dict(task.params) - target_dir = Path(file_record.source_path or "") - metadata_dir = target_dir / "metadata" - if not metadata_dir.exists(): - await self._set_file_failed(task_id, file_record.file_id, f"metadata 目录不存在: {metadata_dir}") - return - - dry_run = bool(params.get("dry_run")) - no_created_fallback = bool(params.get("no_created_fallback")) - limit = max(1, _coerce_int(params.get("limit"), 100000)) - - store = MetadataStore(data_dir=metadata_dir) - updated = 0 - candidates = 0 - try: - store.connect() - summary = store.backfill_temporal_metadata_from_created_at( - limit=limit, - dry_run=dry_run, - no_created_fallback=no_created_fallback, - ) - candidates = int(summary.get("candidates", 0)) - updated = int(summary.get("updated", 0)) - finally: - try: - store.close() - except Exception: - pass - - async with self._lock: - t = self._tasks.get(task_id) - if t: - t.artifact_paths["temporal_backfill"] = json.dumps( - { - "target_dir": str(target_dir), - "dry_run": dry_run, - "no_created_fallback": no_created_fallback, - "limit": limit, - "candidates": candidates, - "updated": updated, - }, - ensure_ascii=False, - ) - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - async with self._lock: - t = self._tasks.get(task_id) - if not t: - return - f = self._find_file(t, file_record.file_id) - if not f: - return - f.total_chunks = 1 - f.done_chunks = 1 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 1.0 - f.status = "completed" - f.current_step = "completed" - f.updated_at = _now() - self._recompute_task_progress(t) - - async def _process_file( - self, - task_id: str, - file_record: ImportFileRecord, - file_semaphore: asyncio.Semaphore, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - async with file_semaphore: - await self._set_file_state(task_id, file_record.file_id, "preparing", "preparing") - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - try: - content = await self._read_file_content(file_record) - content_hash = hashlib.md5(content.encode("utf-8", errors="ignore")).hexdigest() - file_record.content_hash = content_hash - task = self._tasks.get(task_id) - if task: - dedupe_policy = str(task.params.get("dedupe_policy") or "none") - force = bool(task.params.get("force")) - if dedupe_policy != "none" and not force: - async with self._lock: - if self._is_manifest_hit(file_record, content_hash, dedupe_policy): - task2 = self._tasks.get(task_id) - if task2: - f = self._find_file(task2, file_record.file_id) - if f: - f.status = "completed" - f.current_step = "skipped" - f.progress = 1.0 - f.total_chunks = 0 - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.detected_strategy_type = "skipped" - f.error = "" - f.updated_at = _now() - self._recompute_task_progress(task2) - return - if file_record.input_mode == "json": - await self._process_json_file(task_id, file_record, content, chunk_semaphore) - else: - await self._process_text_file(task_id, file_record, content, chunk_semaphore) - task3 = self._tasks.get(task_id) - if task3: - dedupe_policy = str(task3.params.get("dedupe_policy") or "none") - f3 = self._find_file(task3, file_record.file_id) - if dedupe_policy != "none" and f3 and f3.status == "completed": - async with self._lock: - self._record_manifest_import(file_record, content_hash, dedupe_policy, task_id) - except Exception as e: - await self._set_file_failed(task_id, file_record.file_id, str(e)) - - async def _read_file_content(self, file_record: ImportFileRecord) -> str: - if file_record.inline_content is not None: - return file_record.inline_content - if file_record.source_path and Path(file_record.source_path).exists(): - data = Path(file_record.source_path).read_bytes() - try: - return data.decode("utf-8") - except UnicodeDecodeError: - return data.decode("utf-8", errors="replace") - if file_record.temp_path and Path(file_record.temp_path).exists(): - data = Path(file_record.temp_path).read_bytes() - try: - return data.decode("utf-8") - except UnicodeDecodeError: - return data.decode("utf-8", errors="replace") - raise RuntimeError("读取文件失败:输入内容缺失") - - async def _process_text_file( - self, - task_id: str, - file_record: ImportFileRecord, - content: str, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - task = self._tasks[task_id] - async with self._lock: - t = self._tasks.get(task_id) - if t and not t.schema_detected: - t.schema_detected = "plain_text" - strategy = self._determine_strategy( - file_record.name, - content, - task.params["strategy_override"], - chat_log=bool(task.params.get("chat_log")), - ) - await self._set_file_strategy(task_id, file_record.file_id, strategy) - await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") - await self._ensure_embedding_runtime_ready() - - chunks = strategy.split(content) - selected_chunks = list(chunks) - if file_record.retry_mode == "chunk": - retry_index_set = set() - for idx in file_record.retry_chunk_indexes: - try: - retry_index_set.add(int(idx)) - except Exception: - continue - selected_chunks = [chunk for chunk in chunks if int(chunk.chunk.index) in retry_index_set] - if not selected_chunks: - raise RuntimeError("失败分块重试索引无效,未匹配到可执行分块") - logger.info( - "重试任务按失败分块执行: " - f"file={file_record.name} " - f"selected={len(selected_chunks)} " - f"total={len(chunks)}" - ) - - await self._register_chunks(task_id, file_record.file_id, selected_chunks) - - await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") - model_cfg = None - if task.params["llm_enabled"]: - model_cfg = await self._select_model() - - jobs = [] - for chunk in selected_chunks: - jobs.append( - asyncio.create_task( - self._process_text_chunk( - task_id=task_id, - file_record=file_record, - chunk=chunk, - strategy=strategy, - llm_enabled=task.params["llm_enabled"], - model_cfg=model_cfg, - chunk_semaphore=chunk_semaphore, - chat_log=bool(task.params.get("chat_log")), - chat_reference_time=str(task.params.get("chat_reference_time") or "").strip() or None, - ) - ) - ) - await asyncio.gather(*jobs, return_exceptions=True) - - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - async with self._storage_lock: - self.plugin.vector_store.save() - self.plugin.graph_store.save() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_record.file_id) - if not f: - return - if f.failed_chunks > 0: - f.status = "failed" - f.current_step = "failed" - if not f.error: - f.error = f"存在失败分块: {f.failed_chunks}" - elif task.status == "cancel_requested": - f.status = "cancelled" - f.current_step = "cancelled" - else: - f.status = "completed" - f.current_step = "completed" - f.progress = 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - async def _process_text_chunk( - self, - task_id: str, - file_record: ImportFileRecord, - chunk: ProcessedChunk, - strategy: Any, - llm_enabled: bool, - model_cfg: Any, - chunk_semaphore: asyncio.Semaphore, - chat_log: bool = False, - chat_reference_time: Optional[str] = None, - ) -> None: - async with chunk_semaphore: - chunk_id = chunk.chunk.chunk_id - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "extracting", "extracting", 0.25) - - processed = chunk - rescue_strategy = self._chunk_rescue(chunk, file_record.name) - current_strategy = strategy - if rescue_strategy: - chunk.type = StrategyKnowledgeType.QUOTE - chunk.flags.verbatim = True - chunk.flags.requires_llm = False - current_strategy = rescue_strategy - try: - if llm_enabled and chunk.flags.requires_llm: - processed = await current_strategy.extract( - chunk, - lambda prompt: self._llm_call(prompt, model_cfg), - ) - elif chunk.type == StrategyKnowledgeType.QUOTE: - processed = await current_strategy.extract(chunk) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"抽取失败: {e}") - return - - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) - try: - time_meta = None - if chat_log and llm_enabled and model_cfg is not None: - time_meta = await self._extract_chat_time_meta_with_llm( - processed.chunk.text, - model_cfg, - reference_time=chat_reference_time, - ) - async with self._storage_lock: - await self._persist_processed_chunk(file_record, processed, time_meta=time_meta) - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") - - async def _process_json_file( - self, - task_id: str, - file_record: ImportFileRecord, - content: str, - chunk_semaphore: asyncio.Semaphore, - ) -> None: - await self._set_file_strategy(task_id, file_record.file_id, "json") - await self._set_file_state(task_id, file_record.file_id, "splitting", "splitting") - await self._ensure_embedding_runtime_ready() - - try: - data = json.loads(content) - except Exception as e: - raise RuntimeError(f"JSON 解析失败: {e}") - - schema = self._detect_json_schema(data) - async with self._lock: - task = self._tasks.get(task_id) - if task: - task.schema_detected = schema - task.updated_at = _now() - units = self._build_json_units(data, file_record.file_id, file_record.name, schema) - await self._register_json_units(task_id, file_record.file_id, units) - - await self._set_file_state(task_id, file_record.file_id, "extracting", "extracting") - jobs = [ - asyncio.create_task(self._process_json_unit(task_id, file_record, unit, chunk_semaphore)) - for unit in units - ] - await asyncio.gather(*jobs, return_exceptions=True) - - if await self._is_cancel_requested(task_id): - await self._set_file_cancelled(task_id, file_record.file_id, "任务已取消") - return - - await self._set_file_state(task_id, file_record.file_id, "saving", "saving") - async with self._storage_lock: - self.plugin.vector_store.save() - self.plugin.graph_store.save() - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_record.file_id) - if not f: - return - if f.failed_chunks > 0: - f.status = "failed" - f.current_step = "failed" - if not f.error: - f.error = f"存在失败分块: {f.failed_chunks}" - elif task.status == "cancel_requested": - f.status = "cancelled" - f.current_step = "cancelled" - else: - f.status = "completed" - f.current_step = "completed" - f.progress = 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - def _detect_json_schema(self, data: Any) -> str: - if isinstance(data, dict) and isinstance(data.get("docs"), list): - return "lpmm_openie" - if isinstance(data, dict) and isinstance(data.get("paragraphs"), list): - paragraphs = data.get("paragraphs", []) - for p in paragraphs: - if isinstance(p, dict) and any( - key in p for key in ("entities", "relations", "time_meta", "source", "type", "knowledge_type") - ): - return "script_json" - return "web_json" - raise RuntimeError("不支持的 JSON 格式:需要 paragraphs 或 docs") - - def _build_json_units(self, data: Any, file_id: str, filename: str, schema: str) -> List[Dict[str, Any]]: - units: List[Dict[str, Any]] = [] - paragraphs: List[Any] = [] - entities: List[Any] = [] - relations: List[Any] = [] - - if schema in {"web_json", "script_json"}: - paragraphs = data.get("paragraphs", []) - entities = data.get("entities", []) - relations = data.get("relations", []) - elif schema == "lpmm_openie": - docs = data.get("docs", []) - for d in docs: - if not isinstance(d, dict): - continue - content = str(d.get("passage", "") or "").strip() - if not content: - continue - triples = d.get("extracted_triples", []) or [] - rels = [] - for t in triples: - if isinstance(t, list) and len(t) == 3: - rels.append( - { - "subject": str(t[0]), - "predicate": str(t[1]), - "object": str(t[2]), - } - ) - para_item = { - "content": content, - "source": f"lpmm_openie:{filename}", - "entities": d.get("extracted_entities", []) or [], - "relations": rels, - "knowledge_type": "factual", - } - paragraphs.append(para_item) - - for p in paragraphs: - paragraph = normalize_paragraph_import_item( - p, - default_source=f"web_import:{filename}", - ) - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "paragraph", - "content": paragraph["content"], - "time_meta": paragraph["time_meta"], - "knowledge_type": paragraph["knowledge_type"], - "chunk_type": paragraph["knowledge_type"], - "source": paragraph["source"], - "entities": paragraph["entities"], - "relations": paragraph["relations"], - "preview": paragraph["content"][:120], - } - ) - - for e in entities: - name = str(e or "").strip() - if name: - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "entity", - "name": name, - "chunk_type": "entity", - "preview": name[:120], - } - ) - - for r in relations: - if not isinstance(r, dict): - continue - s = str(r.get("subject", "")).strip() - p = str(r.get("predicate", "")).strip() - o = str(r.get("object", "")).strip() - if s and p and o: - units.append( - { - "chunk_id": f"{file_id}_json_{len(units)}", - "kind": "relation", - "subject": s, - "predicate": p, - "object": o, - "chunk_type": "relation", - "preview": f"{s} {p} {o}"[:120], - } - ) - return units - - async def _register_json_units(self, task_id: str, file_id: str, units: List[Dict[str, Any]]) -> None: - records = [ - ImportChunkRecord( - chunk_id=u["chunk_id"], - index=i, - chunk_type=u.get("chunk_type", "json"), - status="queued", - step="queued", - progress=0.0, - content_preview=str(u.get("preview", "")), - ) - for i, u in enumerate(units) - ] - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.chunks = records - f.total_chunks = len(records) - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 if records else 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _process_json_unit( - self, - task_id: str, - file_record: ImportFileRecord, - unit: Dict[str, Any], - chunk_semaphore: asyncio.Semaphore, - ) -> None: - chunk_id = unit["chunk_id"] - async with chunk_semaphore: - if await self._is_cancel_requested(task_id): - await self._set_chunk_cancelled(task_id, file_record.file_id, chunk_id, "任务已取消") - return - - await self._set_chunk_state(task_id, file_record.file_id, chunk_id, "writing", "writing", 0.7) - try: - async with self._storage_lock: - kind = unit["kind"] - if kind == "paragraph": - content = str(unit.get("content", "")) - k_type = resolve_stored_knowledge_type( - unit.get("knowledge_type"), - content=content, - ).value - source = str(unit.get("source") or f"web_import:{file_record.name}") - para_hash = self.plugin.metadata_store.add_paragraph( - content=content, - source=source, - knowledge_type=k_type, - time_meta=unit.get("time_meta"), - ) - emb = await self.plugin.embedding_manager.encode(content) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) - except ValueError: - pass - for name in unit.get("entities", []) or []: - n = str(name or "").strip() - if n: - await self._add_entity_with_vector(n, source_paragraph=para_hash) - for rel in unit.get("relations", []) or []: - if not isinstance(rel, dict): - continue - s = str(rel.get("subject", "")).strip() - p = str(rel.get("predicate", "")).strip() - o = str(rel.get("object", "")).strip() - if s and p and o: - await self._add_relation(s, p, o, source_paragraph=para_hash) - elif kind == "entity": - await self._add_entity_with_vector(unit["name"]) - elif kind == "relation": - await self._add_relation(unit["subject"], unit["predicate"], unit["object"]) - else: - raise RuntimeError(f"未知 JSON 导入单元类型: {kind}") - await self._set_chunk_completed(task_id, file_record.file_id, chunk_id) - except Exception as e: - await self._set_chunk_failed(task_id, file_record.file_id, chunk_id, f"写入失败: {e}") - - def _source_label(self, file_record: ImportFileRecord) -> str: - if file_record.source_path: - return f"{file_record.source_kind}:{file_record.source_path}" - return f"web_import:{file_record.name}" - - async def _ensure_embedding_runtime_ready(self) -> None: - report = await ensure_runtime_self_check(self.plugin) - if bool(report.get("ok", False)): - return - raise RuntimeError( - "embedding runtime self-check failed: " - f"{report.get('message', 'unknown')} " - f"(configured={report.get('configured_dimension', 0)}, " - f"store={report.get('vector_store_dimension', 0)}, " - f"encoded={report.get('encoded_dimension', 0)})" - ) - - async def _persist_processed_chunk( - self, - file_record: ImportFileRecord, - processed: ProcessedChunk, - *, - time_meta: Optional[Dict[str, Any]] = None, - ) -> None: - content = processed.chunk.text - para_hash = self.plugin.metadata_store.add_paragraph( - content=content, - source=self._source_label(file_record), - knowledge_type=_storage_type_from_strategy(processed.type), - time_meta=time_meta, - ) - - emb = await self.plugin.embedding_manager.encode(content) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [para_hash]) - except ValueError: - pass - - data = processed.data or {} - entities: List[str] = [] - relations: List[Tuple[str, str, str]] = [] - - for triple in data.get("triples", []): - s = str(triple.get("subject", "")).strip() - p = str(triple.get("predicate", "")).strip() - o = str(triple.get("object", "")).strip() - if s and p and o: - relations.append((s, p, o)) - entities.extend([s, o]) - - for rel in data.get("relations", []): - s = str(rel.get("subject", "")).strip() - p = str(rel.get("predicate", "")).strip() - o = str(rel.get("object", "")).strip() - if s and p and o: - relations.append((s, p, o)) - entities.extend([s, o]) - - for k in ("entities", "events", "verbatim_entities"): - for e in data.get(k, []): - name = str(e or "").strip() - if name: - entities.append(name) - - uniq_entities = list({x.strip().lower(): x.strip() for x in entities if str(x).strip()}.values()) - for name in uniq_entities: - await self._add_entity_with_vector(name, source_paragraph=para_hash) - - for s, p, o in relations: - await self._add_relation(s, p, o, source_paragraph=para_hash) - - async def _add_entity_with_vector(self, name: str, source_paragraph: str = "") -> str: - hash_value = self.plugin.metadata_store.add_entity(name=name, source_paragraph=source_paragraph) - self.plugin.graph_store.add_nodes([name]) - if hash_value not in self.plugin.vector_store: - emb = await self.plugin.embedding_manager.encode(name) - try: - self.plugin.vector_store.add(emb.reshape(1, -1), [hash_value]) - except ValueError: - pass - return hash_value - - async def _add_relation(self, subject: str, predicate: str, obj: str, source_paragraph: str = "") -> str: - await self._add_entity_with_vector(subject, source_paragraph=source_paragraph) - await self._add_entity_with_vector(obj, source_paragraph=source_paragraph) - rv_cfg = self.plugin.get_config("retrieval.relation_vectorization", {}) or {} - if not isinstance(rv_cfg, dict): - rv_cfg = {} - write_vector = bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - relation_service = getattr(self.plugin, "relation_write_service", None) - if relation_service is not None: - result = await relation_service.upsert_relation_with_vector( - subject=subject, - predicate=predicate, - obj=obj, - confidence=1.0, - source_paragraph=source_paragraph, - write_vector=write_vector, - ) - return result.hash_value - - rel_hash = self.plugin.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - source_paragraph=source_paragraph, - confidence=1.0, - ) - self.plugin.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - try: - self.plugin.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - return rel_hash - async def _select_model(self) -> Any: - models = llm_api.get_available_models() - if not models: - raise RuntimeError("没有可用 LLM 模型") - - config_model = str(self._cfg("advanced.extraction_model", "auto") or "auto").strip() - if config_model.lower() != "auto" and config_model in models: - return models[config_model] - - for task_name in [ - "lpmm_entity_extract", - "lpmm_rdf_build", - "embedding", - "replyer", - "utils", - "planner", - "tool_use", - ]: - if task_name in models: - return models[task_name] - - return models[next(iter(models))] - - async def _llm_call(self, prompt: str, model_config: Any) -> Dict[str, Any]: - cfg = self._llm_retry_config() - retries = int(cfg["retries"]) - last_error: Optional[Exception] = None - for attempt in range(retries + 1): - try: - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="A_Memorix.WebImport", - ) - if not success or not response: - raise RuntimeError("LLM 生成失败") - - txt = str(response or "").strip() - if "```" in txt: - txt = txt.split("```json")[-1].split("```")[0].strip() - if txt.startswith("json"): - txt = txt[4:].strip() - - try: - return json.loads(txt) - except Exception: - s = txt.find("{") - e = txt.rfind("}") - if s >= 0 and e > s: - return json.loads(txt[s : e + 1]) - raise - except Exception as err: - last_error = err - if attempt >= retries: - break - delay = min(cfg["max_wait"], cfg["min_wait"] * (cfg["multiplier"] ** attempt)) - await asyncio.sleep(max(0.0, float(delay))) - raise RuntimeError(f"LLM 抽取失败: {last_error}") - - def _parse_reference_time(self, value: Optional[str]) -> datetime: - if not value: - return datetime.now() - text = str(value).strip() - formats = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y/%m/%d", - "%Y-%m-%d", - ] - for fmt in formats: - try: - return datetime.strptime(text, fmt) - except ValueError: - continue - return datetime.now() - - async def _extract_chat_time_meta_with_llm( - self, - text: str, - model_config: Any, - *, - reference_time: Optional[str] = None, - ) -> Optional[Dict[str, Any]]: - if not str(text or "").strip(): - return None - ref_dt = self._parse_reference_time(reference_time) - reference_now = ref_dt.strftime("%Y/%m/%d %H:%M") - prompt = f"""You are a time extraction engine for chat logs. -Extract temporal information from the following chat paragraph. - -Rules: -1. Use semantic understanding, not regex matching. -2. Convert relative expressions to absolute local datetime using reference_now. -3. If a time span exists, return event_time_start/event_time_end. -4. If only one point in time exists, return event_time. -5. If no reliable time info exists, keep all event_time fields null. -6. Return JSON only. - -reference_now: {reference_now} -text: -{text} - -JSON schema: -{{ - "event_time": null, - "event_time_start": null, - "event_time_end": null, - "time_range": null, - "time_granularity": null, - "time_confidence": 0.0 -}} -""" - try: - result = await self._llm_call(prompt, model_config) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取失败: {e}") - return None - - raw_time_meta = { - "event_time": result.get("event_time"), - "event_time_start": result.get("event_time_start"), - "event_time_end": result.get("event_time_end"), - "time_range": result.get("time_range"), - "time_granularity": result.get("time_granularity"), - "time_confidence": result.get("time_confidence"), - } - try: - normalized = normalize_time_meta(raw_time_meta) - except Exception: - return None - has_effective = any(k in normalized for k in ("event_time", "event_time_start", "event_time_end")) - if not has_effective: - return None - return normalized - - def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[Any]: - if chunk.type == StrategyKnowledgeType.QUOTE: - return None - if looks_like_quote_text(chunk.chunk.text): - return QuoteStrategy(filename) - return None - - def _instantiate_strategy(self, filename: str, strategy: ImportStrategy) -> Any: - if strategy == ImportStrategy.FACTUAL: - return FactualStrategy(filename) - if strategy == ImportStrategy.QUOTE: - return QuoteStrategy(filename) - return NarrativeStrategy(filename) - - def _determine_strategy(self, filename: str, content: str, override: str, *, chat_log: bool = False) -> Any: - strategy = select_import_strategy( - content, - override=override, - chat_log=chat_log, - ) - return self._instantiate_strategy(filename, strategy) - - async def _set_file_strategy(self, task_id: str, file_id: str, strategy: Any) -> None: - if isinstance(strategy, str): - strategy_type = strategy - elif isinstance(strategy, NarrativeStrategy): - strategy_type = "narrative" - elif isinstance(strategy, FactualStrategy): - strategy_type = "factual" - elif isinstance(strategy, QuoteStrategy): - strategy_type = "quote" - else: - strategy_type = "unknown" - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.detected_strategy_type = strategy_type - f.updated_at = _now() - task.updated_at = _now() - - async def _register_chunks(self, task_id: str, file_id: str, chunks: List[ProcessedChunk]) -> None: - records = [ - ImportChunkRecord( - chunk_id=chunk.chunk.chunk_id, - index=index, - chunk_type=chunk.type.value, - status="queued", - step="queued", - progress=0.0, - content_preview=str(chunk.chunk.text or "")[:120], - ) - for index, chunk in enumerate(chunks) - ] - - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.chunks = records - f.total_chunks = len(records) - f.done_chunks = 0 - f.failed_chunks = 0 - f.cancelled_chunks = 0 - f.progress = 0.0 if records else 1.0 - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_file_state(self, task_id: str, file_id: str, status: str, step: str) -> None: - if status not in FILE_STATUS: - return - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = status - f.current_step = step - f.updated_at = _now() - task.updated_at = _now() - if step in {"preparing", "splitting", "extracting", "writing", "saving"} and task.status in {"queued", "preparing"}: - task.status = "running" - task.current_step = "running" - - async def _set_file_failed(self, task_id: str, file_id: str, error: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = "failed" - f.current_step = "failed" - f.error = str(error) - f.updated_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_file_cancelled(self, task_id: str, file_id: str, reason: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - f.status = "cancelled" - f.current_step = "cancelled" - f.error = reason - additional_cancelled = 0 - for chunk in f.chunks: - if chunk.status in {"completed", "failed", "cancelled"}: - continue - chunk.status = "cancelled" - chunk.step = "cancelled" - chunk.retryable = False - chunk.error = reason - chunk.progress = 1.0 - chunk.updated_at = _now() - additional_cancelled += 1 - if additional_cancelled > 0: - f.cancelled_chunks += additional_cancelled - f.progress = self._compute_ratio( - f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks - ) - f.updated_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_state( - self, - task_id: str, - file_id: str, - chunk_id: str, - status: str, - step: str, - progress: float, - ) -> None: - if status not in CHUNK_STATUS: - return - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c: - return - c.status = status - c.step = step - if status in {"queued", "extracting", "writing"}: - c.error = "" - c.failed_at = "" - c.retryable = False - c.progress = max(0.0, min(1.0, float(progress))) - c.updated_at = _now() - if f.status not in {"failed", "cancelled"}: - f.status = "extracting" if status == "extracting" else "writing" - f.current_step = step - f.updated_at = _now() - task.updated_at = _now() - - async def _set_chunk_completed(self, task_id: str, file_id: str, chunk_id: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "completed": - return - c.status = "completed" - c.step = "completed" - c.failed_at = "" - c.retryable = False - c.progress = 1.0 - c.updated_at = _now() - f.done_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_failed(self, task_id: str, file_id: str, chunk_id: str, error: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "failed": - return - failed_stage = str(c.step or "").strip().lower() - if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: - failed_stage = str(f.current_step or "").strip().lower() - if failed_stage in {"", "queued", "failed", "completed", "cancelled"}: - failed_stage = "unknown" - c.status = "failed" - c.step = "failed" - c.failed_at = failed_stage - c.retryable = bool(f.input_mode == "text" and failed_stage == "extracting") - c.error = str(error) - c.progress = 1.0 - c.updated_at = _now() - f.failed_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - if not f.error: - f.error = str(error) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _set_chunk_cancelled(self, task_id: str, file_id: str, chunk_id: str, reason: str) -> None: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return - f = self._find_file(task, file_id) - if not f: - return - c = self._find_chunk(f, chunk_id) - if not c or c.status == "cancelled": - return - c.status = "cancelled" - c.step = "cancelled" - c.retryable = False - c.error = reason - c.progress = 1.0 - c.updated_at = _now() - f.cancelled_chunks += 1 - f.progress = self._compute_ratio(f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks) - f.updated_at = _now() - self._recompute_task_progress(task) - - async def _is_cancel_requested(self, task_id: str) -> bool: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return True - return task.status == "cancel_requested" - - def _find_file(self, task: ImportTaskRecord, file_id: str) -> Optional[ImportFileRecord]: - for f in task.files: - if f.file_id == file_id: - return f - return None - - def _find_chunk(self, file_record: ImportFileRecord, chunk_id: str) -> Optional[ImportChunkRecord]: - for c in file_record.chunks: - if c.chunk_id == chunk_id: - return c - return None - - def _compute_ratio(self, done: int, total: int) -> float: - if total <= 0: - return 1.0 - return max(0.0, min(1.0, float(done) / float(total))) - - def _recompute_task_progress(self, task: ImportTaskRecord) -> None: - total = 0 - done = 0 - failed = 0 - cancelled = 0 - for f in task.files: - total += f.total_chunks - done += f.done_chunks - failed += f.failed_chunks - cancelled += f.cancelled_chunks - task.total_chunks = total - task.done_chunks = done - task.failed_chunks = failed - task.cancelled_chunks = cancelled - task.progress = self._compute_ratio(done + failed + cancelled, total) - task.updated_at = _now() - - async def _should_cleanup_task_temp(self, task_id: str) -> bool: - async with self._lock: - task = self._tasks.get(task_id) - if not task: - return True - for f in task.files: - if f.status == "failed": - return False - return True - - def _mark_task_cancelled_locked(self, task: ImportTaskRecord, reason: str) -> None: - for f in task.files: - if f.status in {"completed", "failed", "cancelled"}: - continue - f.status = "cancelled" - f.current_step = "cancelled" - f.error = reason - additional_cancelled = 0 - for c in f.chunks: - if c.status in {"completed", "failed", "cancelled"}: - continue - c.status = "cancelled" - c.step = "cancelled" - c.retryable = False - c.error = reason - c.progress = 1.0 - c.updated_at = _now() - additional_cancelled += 1 - if additional_cancelled > 0: - f.cancelled_chunks += additional_cancelled - f.progress = self._compute_ratio( - f.done_chunks + f.failed_chunks + f.cancelled_chunks, f.total_chunks - ) - f.updated_at = _now() - task.status = "cancelled" - task.current_step = "cancelled" - task.finished_at = _now() - task.updated_at = _now() - self._recompute_task_progress(task) diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py deleted file mode 100644 index 841106a4..00000000 --- a/plugins/A_memorix/plugin.py +++ /dev/null @@ -1,273 +0,0 @@ -"""A_Memorix SDK plugin entry.""" - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, List, Optional - -from maibot_sdk import MaiBotPlugin, Tool -from maibot_sdk.types import ToolParameterInfo, ToolParamType - -from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel - - -def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo: - return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) - - -_ADMIN_TOOL_PARAMS = [ - _tool_param("action", ToolParamType.STRING, "管理动作", True), - _tool_param("target", ToolParamType.STRING, "可选目标标识", False), -] - - -class AMemorixPlugin(MaiBotPlugin): - def __init__(self) -> None: - super().__init__() - self._plugin_root = Path(__file__).resolve().parent - self._plugin_config: Dict[str, Any] = {} - self._kernel: Optional[SDKMemoryKernel] = None - - def set_plugin_config(self, config: Dict[str, Any]) -> None: - self._plugin_config = config or {} - if self._kernel is not None: - self._kernel.close() - self._kernel = None - - async def on_load(self): - await self._get_kernel() - - async def on_unload(self): - if self._kernel is not None: - shutdown = getattr(self._kernel, "shutdown", None) - if callable(shutdown): - await shutdown() - else: - self._kernel.close() - self._kernel = None - - async def _get_kernel(self) -> SDKMemoryKernel: - if self._kernel is None: - self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config) - await self._kernel.initialize() - return self._kernel - - async def _dispatch_admin_tool(self, method_name: str, action: str, **kwargs): - kernel = await self._get_kernel() - handler = getattr(kernel, method_name) - return await handler(action=action, **kwargs) - - @Tool( - "search_memory", - description="搜索长期记忆", - parameters=[ - _tool_param("query", ToolParamType.STRING, "查询文本", False), - _tool_param("limit", ToolParamType.INTEGER, "返回条数", False), - _tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), - _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), - _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_search_memory( - self, - query: str = "", - limit: int = 5, - mode: str = "search", - chat_id: str = "", - person_id: str = "", - time_start: str | float | None = None, - time_end: str | float | None = None, - respect_filter: bool = True, - **kwargs, - ): - kernel = await self._get_kernel() - return await kernel.search_memory( - KernelSearchRequest( - query=query, - limit=limit, - mode=mode, - chat_id=chat_id, - person_id=person_id, - time_start=time_start, - time_end=time_end, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - ) - - @Tool( - "ingest_summary", - description="写入聊天摘要到长期记忆", - parameters=[ - _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True), - _tool_param("text", ToolParamType.STRING, "摘要文本", True), - _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), - _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_ingest_summary( - self, - external_id: str, - chat_id: str, - text: str, - participants: Optional[List[str]] = None, - time_start: float | None = None, - time_end: float | None = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - **kwargs, - ): - kernel = await self._get_kernel() - return await kernel.ingest_summary( - external_id=external_id, - chat_id=chat_id, - text=text, - participants=participants, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=metadata, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - - @Tool( - "ingest_text", - description="写入普通长期记忆文本", - parameters=[ - _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), - _tool_param("source_type", ToolParamType.STRING, "来源类型", True), - _tool_param("text", ToolParamType.STRING, "原始文本", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), - _tool_param("respect_filter", ToolParamType.BOOLEAN, "是否应用聊天过滤配置", False), - ], - ) - async def handle_ingest_text( - self, - external_id: str, - source_type: str, - text: str, - chat_id: str = "", - person_ids: Optional[List[str]] = None, - participants: Optional[List[str]] = None, - timestamp: float | None = None, - time_start: float | None = None, - time_end: float | None = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - respect_filter: bool = True, - **kwargs, - ): - relations = kwargs.get("relations") - entities = kwargs.get("entities") - kernel = await self._get_kernel() - return await kernel.ingest_text( - external_id=external_id, - source_type=source_type, - text=text, - chat_id=chat_id, - person_ids=person_ids, - participants=participants, - timestamp=timestamp, - time_start=time_start, - time_end=time_end, - tags=tags, - metadata=metadata, - entities=entities, - relations=relations, - respect_filter=respect_filter, - user_id=str(kwargs.get("user_id", "") or "").strip(), - group_id=str(kwargs.get("group_id", "") or "").strip(), - ) - - @Tool( - "get_person_profile", - description="获取人物画像", - parameters=[ - _tool_param("person_id", ToolParamType.STRING, "人物 ID", True), - _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), - _tool_param("limit", ToolParamType.INTEGER, "证据条数", False), - ], - ) - async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs): - _ = kwargs - kernel = await self._get_kernel() - return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit) - - @Tool( - "maintain_memory", - description="维护长期记忆关系状态", - parameters=[ - _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore/freeze/recycle_bin", True), - _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", False), - _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), - _tool_param("limit", ToolParamType.INTEGER, "查询条数(用于 recycle_bin)", False), - ], - ) - async def handle_maintain_memory( - self, - action: str, - target: str = "", - hours: float | None = None, - reason: str = "", - limit: int = 50, - **kwargs, - ): - _ = kwargs - kernel = await self._get_kernel() - return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason, limit=limit) - - @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) - async def handle_memory_stats(self, **kwargs): - _ = kwargs - kernel = await self._get_kernel() - return kernel.memory_stats() - - @Tool("memory_graph_admin", description="长期记忆图谱管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_graph_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_graph_admin", action=action, **kwargs) - - @Tool("memory_source_admin", description="长期记忆来源管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_source_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_source_admin", action=action, **kwargs) - - @Tool("memory_episode_admin", description="Episode 管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_episode_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_episode_admin", action=action, **kwargs) - - @Tool("memory_profile_admin", description="人物画像管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_profile_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_profile_admin", action=action, **kwargs) - - @Tool("memory_runtime_admin", description="长期记忆运行时管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_runtime_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_runtime_admin", action=action, **kwargs) - - @Tool("memory_import_admin", description="长期记忆导入管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_import_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_import_admin", action=action, **kwargs) - - @Tool("memory_tuning_admin", description="长期记忆调优管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_tuning_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_tuning_admin", action=action, **kwargs) - - @Tool("memory_v5_admin", description="长期记忆 V5 管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_v5_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_v5_admin", action=action, **kwargs) - - @Tool("memory_delete_admin", description="长期记忆删除管理接口", parameters=_ADMIN_TOOL_PARAMS) - async def handle_memory_delete_admin(self, action: str, **kwargs): - return await self._dispatch_admin_tool("memory_delete_admin", action=action, **kwargs) - - -def create_plugin(): - return AMemorixPlugin() diff --git a/plugins/A_memorix/requirements.txt b/plugins/A_memorix/requirements.txt deleted file mode 100644 index f737fdf4..00000000 --- a/plugins/A_memorix/requirements.txt +++ /dev/null @@ -1,52 +0,0 @@ -# A_Memorix 插件依赖 -# -# 核心依赖 (必需) -# ================== - -# 数值计算 - 用于向量操作、矩阵计算 -numpy>=1.20.0 - -# 稀疏矩阵 - 用于图存储的邻接矩阵 -scipy>=1.7.0 - -# 图结构处理(LPMM 转换) -networkx>=3.0.0 - -# Parquet 读取(LPMM 转换) -pyarrow>=10.0.0 - -# DataFrame 处理(LPMM 转换) -pandas>=1.5.0 - -# 异步事件循环嵌套 - 用于插件初始化时的异步操作 -nest-asyncio>=1.5.0 - -# 向量索引 - 用于向量存储和检索 -faiss-cpu>=1.7.0 - -# Web 服务器依赖 (可视化功能需要) -# ================== - -# ASGI 服务器 -uvicorn>=0.20.0 - -# Web 框架 -fastapi>=0.100.0 - -# 数据验证 -pydantic>=2.0.0 -python-multipart>=0.0.9 - -# 注意事项 -# ================== -# -# 1. sqlite3 是 Python 标准库,无需安装 -# 2. json, re, time, pathlib 等都是标准库 -# 3. sentence-transformers 不需要(使用主程序 Embedding API) - -# UI 交互 -rich>=14.0.0 -tenacity>=8.0.0 - -# 稀疏检索中文分词(可选,未安装时自动回退 char n-gram) -jieba>=0.42.1 diff --git a/plugins/A_memorix/scripts/audit_vector_consistency.py b/plugins/A_memorix/scripts/audit_vector_consistency.py deleted file mode 100644 index c97806dc..00000000 --- a/plugins/A_memorix/scripts/audit_vector_consistency.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 -""" -A_Memorix 一致性审计脚本。 - -输出内容: -1. paragraph/entity/relation 向量覆盖率 -2. relation vector_state 分布 -3. 孤儿向量数量(向量存在但 metadata 不存在) -4. 状态与向量文件不一致统计 -""" - -from __future__ import annotations - -import argparse -import json -import pickle -import sys -from pathlib import Path -from typing import Any, Dict, Set - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="审计 A_Memorix 向量一致性") - parser.add_argument( - "--data-dir", - default=str(PLUGIN_ROOT / "data"), - help="A_Memorix 数据目录(默认: plugins/A_memorix/data)", - ) - parser.add_argument("--json-out", default="", help="可选:输出 JSON 文件路径") - parser.add_argument( - "--strict", - action="store_true", - help="若发现一致性异常则返回非 0 退出码", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - -try: - from core.storage.vector_store import VectorStore - from core.storage.metadata_store import MetadataStore - from core.storage import QuantizationType -except Exception as e: # pragma: no cover - print(f"❌ 导入核心模块失败: {e}") - sys.exit(1) - - -def _safe_ratio(numerator: int, denominator: int) -> float: - if denominator <= 0: - return 0.0 - return float(numerator) / float(denominator) - - -def _load_vector_store(data_dir: Path) -> VectorStore: - meta_path = data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - raise FileNotFoundError(f"未找到向量元数据文件: {meta_path}") - - with open(meta_path, "rb") as f: - meta = pickle.load(f) - dimension = int(meta.get("dimension", 1024)) - - store = VectorStore( - dimension=max(1, dimension), - quantization_type=QuantizationType.INT8, - data_dir=data_dir / "vectors", - ) - if store.has_data(): - store.load() - return store - - -def _load_metadata_store(data_dir: Path) -> MetadataStore: - store = MetadataStore(data_dir=data_dir / "metadata") - store.connect() - return store - - -def _hash_set(metadata_store: MetadataStore, table: str) -> Set[str]: - return {str(h) for h in metadata_store.list_hashes(table)} - - -def _relation_state_stats(metadata_store: MetadataStore) -> Dict[str, int]: - return metadata_store.count_relations_by_vector_state() - - -def run_audit(data_dir: Path) -> Dict[str, Any]: - vector_store = _load_vector_store(data_dir) - metadata_store = _load_metadata_store(data_dir) - try: - paragraph_hashes = _hash_set(metadata_store, "paragraphs") - entity_hashes = _hash_set(metadata_store, "entities") - relation_hashes = _hash_set(metadata_store, "relations") - - known_hashes = set(getattr(vector_store, "_known_hashes", set())) - live_vector_hashes = {h for h in known_hashes if h in vector_store} - - para_vector_hits = len(paragraph_hashes & live_vector_hashes) - ent_vector_hits = len(entity_hashes & live_vector_hashes) - rel_vector_hits = len(relation_hashes & live_vector_hashes) - - orphan_vector_hashes = sorted( - live_vector_hashes - paragraph_hashes - entity_hashes - relation_hashes - ) - - relation_rows = metadata_store.get_relations() - ready_but_missing = 0 - not_ready_but_present = 0 - for row in relation_rows: - h = str(row.get("hash") or "") - state = str(row.get("vector_state") or "none").lower() - in_vector = h in live_vector_hashes - if state == "ready" and not in_vector: - ready_but_missing += 1 - if state != "ready" and in_vector: - not_ready_but_present += 1 - - relation_states = _relation_state_stats(metadata_store) - rel_total = max(0, int(relation_states.get("total", len(relation_hashes)))) - ready_count = max(0, int(relation_states.get("ready", 0))) - - result = { - "counts": { - "paragraphs": len(paragraph_hashes), - "entities": len(entity_hashes), - "relations": len(relation_hashes), - "vectors_live": len(live_vector_hashes), - }, - "coverage": { - "paragraph_vector_coverage": _safe_ratio(para_vector_hits, len(paragraph_hashes)), - "entity_vector_coverage": _safe_ratio(ent_vector_hits, len(entity_hashes)), - "relation_vector_coverage": _safe_ratio(rel_vector_hits, len(relation_hashes)), - "relation_ready_coverage": _safe_ratio(ready_count, rel_total), - }, - "relation_states": relation_states, - "orphans": { - "vector_only_count": len(orphan_vector_hashes), - "vector_only_sample": orphan_vector_hashes[:30], - }, - "consistency_checks": { - "ready_but_missing_vector": ready_but_missing, - "not_ready_but_vector_present": not_ready_but_present, - }, - } - return result - finally: - metadata_store.close() - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - - data_dir = Path(args.data_dir).resolve() - if not data_dir.exists(): - print(f"❌ 数据目录不存在: {data_dir}") - return 2 - - try: - result = run_audit(data_dir) - except Exception as e: - print(f"❌ 审计失败: {e}") - return 2 - - print("=== A_Memorix Vector Consistency Audit ===") - print(f"data_dir: {data_dir}") - print(f"paragraphs: {result['counts']['paragraphs']}") - print(f"entities: {result['counts']['entities']}") - print(f"relations: {result['counts']['relations']}") - print(f"vectors_live: {result['counts']['vectors_live']}") - print( - "coverage: " - f"paragraph={result['coverage']['paragraph_vector_coverage']:.3f}, " - f"entity={result['coverage']['entity_vector_coverage']:.3f}, " - f"relation={result['coverage']['relation_vector_coverage']:.3f}, " - f"relation_ready={result['coverage']['relation_ready_coverage']:.3f}" - ) - print(f"relation_states: {result['relation_states']}") - print( - "consistency_checks: " - f"ready_but_missing_vector={result['consistency_checks']['ready_but_missing_vector']}, " - f"not_ready_but_vector_present={result['consistency_checks']['not_ready_but_vector_present']}" - ) - print(f"orphan_vectors: {result['orphans']['vector_only_count']}") - - if args.json_out: - out_path = Path(args.json_out).resolve() - out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"json_out: {out_path}") - - has_anomaly = ( - result["orphans"]["vector_only_count"] > 0 - or result["consistency_checks"]["ready_but_missing_vector"] > 0 - ) - if args.strict and has_anomaly: - return 1 - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/backfill_relation_vectors.py b/plugins/A_memorix/scripts/backfill_relation_vectors.py deleted file mode 100644 index 7ba0ade0..00000000 --- a/plugins/A_memorix/scripts/backfill_relation_vectors.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python3 -""" -关系向量一次性回填脚本(灰度/离线执行)。 - -用途: -1. 对 relations 中 vector_state in (none, failed, pending) 的记录补齐向量。 -2. 支持并发控制,降低总耗时。 -3. 可作为灰度阶段验证工具,与 audit_vector_consistency.py 配合使用。 -4. 可选自动纳入“ready 但向量缺失”的漂移记录进行修复。 -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import time -from pathlib import Path -from typing import Any, Dict, List - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="关系向量一次性回填") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="配置文件路径(默认 plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default=str(PLUGIN_ROOT / "data"), - help="数据目录(默认 plugins/A_memorix/data)", - ) - parser.add_argument( - "--states", - default="none,failed,pending", - help="待处理状态列表,逗号分隔", - ) - parser.add_argument("--limit", type=int, default=50000, help="最大处理数量") - parser.add_argument("--concurrency", type=int, default=8, help="并发数") - parser.add_argument("--max-retry", type=int, default=None, help="最大重试次数过滤") - parser.add_argument( - "--include-ready-missing", - action="store_true", - help="额外纳入 vector_state=ready 但向量缺失的关系", - ) - parser.add_argument("--dry-run", action="store_true", help="仅统计候选,不写入") - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -from core.storage import ( - VectorStore, - GraphStore, - MetadataStore, - QuantizationType, - SparseMatrixFormat, -) -from core.embedding import create_embedding_api_adapter -from core.utils.relation_write_service import RelationWriteService - - -def _load_config(config_path: Path) -> Dict[str, Any]: - with open(config_path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - return dict(raw) if isinstance(raw, dict) else {} - - -def _build_vector_store(data_dir: Path, emb_cfg: Dict[str, Any]) -> VectorStore: - q_type = str(emb_cfg.get("quantization_type", "int8")).lower() - if q_type != "int8": - raise ValueError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - dim = int(emb_cfg.get("dimension", 1024)) - store = VectorStore( - dimension=max(1, dim), - quantization_type=QuantizationType.INT8, - data_dir=data_dir / "vectors", - ) - if store.has_data(): - store.load() - return store - - -def _build_graph_store(data_dir: Path, graph_cfg: Dict[str, Any]) -> GraphStore: - fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() - fmt_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - store = GraphStore( - matrix_format=fmt_map.get(fmt, SparseMatrixFormat.CSR), - data_dir=data_dir / "graph", - ) - if store.has_data(): - store.load() - return store - - -def _build_metadata_store(data_dir: Path) -> MetadataStore: - store = MetadataStore(data_dir=data_dir / "metadata") - store.connect() - return store - - -def _build_embedding_manager(emb_cfg: Dict[str, Any]): - retry_cfg = emb_cfg.get("retry", {}) - if not isinstance(retry_cfg, dict): - retry_cfg = {} - return create_embedding_api_adapter( - batch_size=int(emb_cfg.get("batch_size", 32)), - max_concurrent=int(emb_cfg.get("max_concurrent", 5)), - default_dimension=int(emb_cfg.get("dimension", 1024)), - model_name=str(emb_cfg.get("model_name", "auto")), - retry_config=retry_cfg, - ) - - -async def _process_rows( - service: RelationWriteService, - rows: List[Dict[str, Any]], - concurrency: int, -) -> Dict[str, int]: - semaphore = asyncio.Semaphore(max(1, int(concurrency))) - stat = {"success": 0, "failed": 0, "skipped": 0} - - async def _worker(row: Dict[str, Any]) -> None: - async with semaphore: - result = await service.ensure_relation_vector( - hash_value=str(row["hash"]), - subject=str(row.get("subject", "")), - predicate=str(row.get("predicate", "")), - obj=str(row.get("object", "")), - ) - if result.vector_state == "ready": - if result.vector_written: - stat["success"] += 1 - else: - stat["skipped"] += 1 - else: - stat["failed"] += 1 - - await asyncio.gather(*[_worker(row) for row in rows]) - return stat - - -async def main_async(args: argparse.Namespace) -> int: - config_path = Path(args.config).resolve() - if not config_path.exists(): - print(f"❌ 配置文件不存在: {config_path}") - return 2 - - cfg = _load_config(config_path) - emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} - graph_cfg = cfg.get("graph", {}) if isinstance(cfg, dict) else {} - retrieval_cfg = cfg.get("retrieval", {}) if isinstance(cfg, dict) else {} - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) if isinstance(retrieval_cfg, dict) else {} - if not isinstance(emb_cfg, dict): - emb_cfg = {} - if not isinstance(graph_cfg, dict): - graph_cfg = {} - if not isinstance(rv_cfg, dict): - rv_cfg = {} - - data_dir = Path(args.data_dir).resolve() - if not data_dir.exists(): - print(f"❌ 数据目录不存在: {data_dir}") - return 2 - - print(f"data_dir: {data_dir}") - print(f"config: {config_path}") - - vector_store = _build_vector_store(data_dir, emb_cfg) - graph_store = _build_graph_store(data_dir, graph_cfg) - metadata_store = _build_metadata_store(data_dir) - embedding_manager = _build_embedding_manager(emb_cfg) - service = RelationWriteService( - metadata_store=metadata_store, - graph_store=graph_store, - vector_store=vector_store, - embedding_manager=embedding_manager, - ) - - try: - states = [s.strip() for s in str(args.states).split(",") if s.strip()] - if not states: - states = ["none", "failed", "pending"] - max_retry = int(args.max_retry) if args.max_retry is not None else int(rv_cfg.get("max_retry", 3)) - limit = int(args.limit) - - rows = metadata_store.list_relations_by_vector_state( - states=states, - limit=max(1, limit), - max_retry=max(1, max_retry), - ) - added_ready_missing = 0 - if args.include_ready_missing: - ready_rows = metadata_store.list_relations_by_vector_state( - states=["ready"], - limit=max(1, limit), - max_retry=max(1, max_retry), - ) - ready_missing_rows = [ - row for row in ready_rows if str(row.get("hash", "")) not in vector_store - ] - added_ready_missing = len(ready_missing_rows) - if ready_missing_rows: - dedup: Dict[str, Dict[str, Any]] = {} - for row in rows: - dedup[str(row.get("hash", ""))] = row - for row in ready_missing_rows: - dedup.setdefault(str(row.get("hash", "")), row) - rows = list(dedup.values())[: max(1, limit)] - print(f"candidates: {len(rows)} (states={states}, max_retry={max_retry})") - if args.include_ready_missing: - print(f"ready_missing_candidates_added: {added_ready_missing}") - if not rows: - return 0 - - if args.dry_run: - print("dry_run=true,未执行写入。") - return 0 - - started = time.time() - stat = await _process_rows( - service=service, - rows=rows, - concurrency=int(args.concurrency), - ) - elapsed = (time.time() - started) * 1000.0 - - vector_store.save() - graph_store.save() - state_stats = metadata_store.count_relations_by_vector_state() - output = { - "processed": len(rows), - "success": int(stat["success"]), - "failed": int(stat["failed"]), - "skipped": int(stat["skipped"]), - "elapsed_ms": elapsed, - "state_stats": state_stats, - } - print(json.dumps(output, ensure_ascii=False, indent=2)) - return 0 if stat["failed"] == 0 else 1 - finally: - metadata_store.close() - - -def parse_args() -> argparse.Namespace: - return _build_arg_parser().parse_args() - - -if __name__ == "__main__": - arguments = parse_args() - raise SystemExit(asyncio.run(main_async(arguments))) diff --git a/plugins/A_memorix/scripts/backfill_temporal_metadata.py b/plugins/A_memorix/scripts/backfill_temporal_metadata.py deleted file mode 100644 index b68820cd..00000000 --- a/plugins/A_memorix/scripts/backfill_temporal_metadata.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -""" -回填段落时序字段。 - -默认策略: -1. 若段落缺失 event_time/event_time_start/event_time_end -2. 且存在 created_at -3. 写入 event_time=created_at, time_granularity=day, time_confidence=0.2 -""" - -from __future__ import annotations - -import argparse -from pathlib import Path -import sys - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) - -from plugins.A_memorix.core.storage import MetadataStore # noqa: E402 - - -def backfill( - data_dir: Path, - dry_run: bool, - limit: int, - no_created_fallback: bool, -) -> int: - store = MetadataStore(data_dir=data_dir) - store.connect() - summary = store.backfill_temporal_metadata_from_created_at( - limit=limit, - dry_run=dry_run, - no_created_fallback=no_created_fallback, - ) - store.close() - if dry_run: - print(f"[dry-run] candidates={summary['candidates']}") - return int(summary["candidates"]) - if no_created_fallback: - print(f"skip update (no-created-fallback), candidates={summary['candidates']}") - return 0 - print(f"updated={summary['updated']}") - return int(summary["updated"]) - - -def main() -> int: - parser = argparse.ArgumentParser(description="Backfill temporal metadata for A_Memorix paragraphs") - parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="数据目录") - parser.add_argument("--dry-run", action="store_true", help="仅统计,不写入") - parser.add_argument("--limit", type=int, default=100000, help="最大处理条数") - parser.add_argument( - "--no-created-fallback", - action="store_true", - help="不使用 created_at 回填,仅输出候选数量", - ) - args = parser.parse_args() - - backfill( - data_dir=Path(args.data_dir), - dry_run=args.dry_run, - limit=max(1, int(args.limit)), - no_created_fallback=args.no_created_fallback, - ) - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) - diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py deleted file mode 100644 index 2ef0b396..00000000 --- a/plugins/A_memorix/scripts/convert_lpmm.py +++ /dev/null @@ -1,540 +0,0 @@ -#!/usr/bin/env python3 -""" -LPMM 到 A_memorix 存储转换器 - -功能: -1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet) -2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl) -3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore -4. 绕过 Embedding 生成以节省 Token -""" - -import sys -import os -import json -import argparse -import asyncio -import pickle -import logging -from pathlib import Path -from typing import Dict, Any, List, Tuple -import numpy as np -import tomlkit - -# 设置路径 -current_dir = Path(__file__).resolve().parent -plugin_root = current_dir.parent -project_root = plugin_root.parent.parent -sys.path.insert(0, str(project_root)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式") - parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)") - parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录") - parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)") - parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)") - parser.add_argument( - "--skip-relation-vector-rebuild", - action="store_true", - help="跳过按关系元数据重建关系向量(默认开启)", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - -# 设置日志:优先复用 MaiBot 统一日志体系,失败时回退到标准 logging。 -try: - from src.common.logger import get_logger - - logger = get_logger("A_Memorix.LPMMConverter") -except Exception: - logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") - logger = logging.getLogger("A_Memorix.LPMMConverter") - -try: - import networkx as nx - from scipy import sparse - import pyarrow.parquet as pq -except ImportError as e: - logger.error(f"缺少依赖: {e}") - logger.error("请安装: pip install pandas pyarrow networkx scipy") - sys.exit(1) - -try: - # 优先采取相对导入 (将插件根目录加入路径) - # 这样可以避免硬编码插件名称 (plugins.A_memorix) - if str(plugin_root) not in sys.path: - sys.path.insert(0, str(plugin_root)) - - from core.storage.vector_store import VectorStore - from core.storage.graph_store import GraphStore - from core.storage.metadata_store import MetadataStore - from core.storage import QuantizationType, SparseMatrixFormat - from core.embedding import create_embedding_api_adapter - from core.utils.relation_write_service import RelationWriteService - -except ImportError as e: - logger.error(f"无法导入 A_memorix 核心模块: {e}") - logger.error("请确保在正确的环境中运行,且已安装所有依赖。") - sys.exit(1) - - -class LPMMConverter: - def __init__( - self, - lpmm_data_dir: Path, - output_dir: Path, - dimension: int = 384, - batch_size: int = 1024, - rebuild_relation_vectors: bool = True, - ): - self.lpmm_dir = lpmm_data_dir - self.output_dir = output_dir - self.dimension = dimension - self.batch_size = max(1, int(batch_size)) - self.rebuild_relation_vectors = bool(rebuild_relation_vectors) - - self.vector_dir = output_dir / "vectors" - self.graph_dir = output_dir / "graph" - self.metadata_dir = output_dir / "metadata" - - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.relation_write_service = None - # LPMM 原 ID -> A_memorix ID 映射(用于图重写) - self.id_mapping: Dict[str, str] = {} - - def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None: - """记录 ID 映射,兼容带/不带类型前缀两种格式。""" - if raw_id is None: - return - - raw = str(raw_id).strip() - if not raw: - return - - self.id_mapping[raw] = mapped_id - - prefix = f"{p_type}-" - if raw.startswith(prefix): - self.id_mapping[raw[len(prefix):]] = mapped_id - else: - self.id_mapping[prefix + raw] = mapped_id - - def _map_node_id(self, node: Any) -> str: - """将图节点 ID 映射到转换后的 A_memorix ID。""" - node_key = str(node) - return self.id_mapping.get(node_key, node_key) - - def initialize_stores(self): - """初始化空的 A_memorix 存储""" - logger.info(f"正在初始化存储于 {self.output_dir}...") - - # 初始化 VectorStore (A_memorix 默认使用 INT8 量化) - self.vector_store = VectorStore( - dimension=self.dimension, - quantization_type=QuantizationType.INT8, - data_dir=self.vector_dir - ) - self.vector_store.clear() # 清空旧数据 - - # 初始化 GraphStore (使用 CSR 格式) - self.graph_store = GraphStore( - matrix_format=SparseMatrixFormat.CSR, - data_dir=self.graph_dir - ) - self.graph_store.clear() - - # 初始化 MetadataStore - self.metadata_store = MetadataStore(data_dir=self.metadata_dir) - self.metadata_store.connect() - # 清空元数据表?理想情况下是的,但要小心。 - # 对于转换,我们假设是全新的开始或覆盖。 - # A_memorix 中的 MetadataStore 通常使用 SQLite。 - # 如果目录是新的,我们会依赖它创建新文件。 - if self.rebuild_relation_vectors: - self._init_relation_vector_service() - - def _load_plugin_config(self) -> Dict[str, Any]: - config_path = plugin_root / "config.toml" - if not config_path.exists(): - return {} - try: - with open(config_path, "r", encoding="utf-8") as f: - parsed = tomlkit.load(f) - return dict(parsed) if isinstance(parsed, dict) else {} - except Exception as e: - logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}") - return {} - - def _init_relation_vector_service(self) -> None: - if not self.rebuild_relation_vectors: - return - cfg = self._load_plugin_config() - emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} - if not isinstance(emb_cfg, dict): - emb_cfg = {} - try: - self.embedding_manager = create_embedding_api_adapter( - batch_size=int(emb_cfg.get("batch_size", 32)), - max_concurrent=int(emb_cfg.get("max_concurrent", 5)), - default_dimension=int(emb_cfg.get("dimension", self.dimension)), - model_name=str(emb_cfg.get("model_name", "auto")), - retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {}, - ) - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - except Exception as e: - self.embedding_manager = None - self.relation_write_service = None - logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}") - - async def _rebuild_relation_vectors(self) -> None: - if not self.rebuild_relation_vectors: - return - if self.relation_write_service is None: - logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。") - return - - rows = self.metadata_store.get_relations() - if not rows: - logger.info("未发现关系元数据,无需重建关系向量。") - return - - success = 0 - failed = 0 - skipped = 0 - for row in rows: - result = await self.relation_write_service.ensure_relation_vector( - hash_value=str(row["hash"]), - subject=str(row.get("subject", "")), - predicate=str(row.get("predicate", "")), - obj=str(row.get("object", "")), - ) - if result.vector_state == "ready": - if result.vector_written: - success += 1 - else: - skipped += 1 - else: - failed += 1 - - logger.info( - "关系向量重建完成: " - f"total={len(rows)} " - f"success={success} " - f"skipped={skipped} " - f"failed={failed}" - ) - - @staticmethod - def _parse_relation_text(text: str) -> Tuple[str, str, str]: - raw = str(text or "").strip() - if not raw: - return "", "", "" - if "|" in raw: - parts = [p.strip() for p in raw.split("|") if p.strip()] - if len(parts) >= 3: - return parts[0], parts[1], parts[2] - if "->" in raw: - parts = [p.strip() for p in raw.split("->") if p.strip()] - if len(parts) >= 3: - return parts[0], parts[1], parts[2] - pieces = raw.split() - if len(pieces) >= 3: - return pieces[0], pieces[1], " ".join(pieces[2:]) - return "", "", "" - - def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int: - if not relation_path.exists(): - return 0 - - try: - parquet_file = pq.ParquetFile(relation_path) - except Exception as e: - logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}") - return 0 - - cols = set(parquet_file.schema_arrow.names) - has_triple_cols = {"subject", "predicate", "object"}.issubset(cols) - content_col = "str" if "str" in cols else ("content" if "content" in cols else "") - - imported_hashes = set() - imported = 0 - for record_batch in parquet_file.iter_batches(batch_size=self.batch_size): - df_batch = record_batch.to_pandas() - for _, row in df_batch.iterrows(): - subject = "" - predicate = "" - obj = "" - if has_triple_cols: - subject = str(row.get("subject", "") or "").strip() - predicate = str(row.get("predicate", "") or "").strip() - obj = str(row.get("object", "") or "").strip() - elif content_col: - subject, predicate, obj = self._parse_relation_text(row.get(content_col, "")) - - if not (subject and predicate and obj): - continue - - rel_hash = self.metadata_store.add_relation( - subject=subject, - predicate=predicate, - obj=obj, - source_paragraph=None, - ) - if rel_hash in imported_hashes: - continue - imported_hashes.add(rel_hash) - self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - imported += 1 - - return imported - - def convert_vectors(self): - """将 Parquet 向量转换为 VectorStore""" - # LPMM 默认文件名 - parquet_files = { - "paragraph": self.lpmm_dir / "paragraph.parquet", - "entity": self.lpmm_dir / "entity.parquet", - "relation": self.lpmm_dir / "relation.parquet" - } - - total_vectors = 0 - - for p_type, p_path in parquet_files.items(): - # 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应, - # 直接导入会污染召回结果(命中后无法反查 relation 元数据)。 - if p_type == "relation": - relation_count = self._import_relation_metadata_from_parquet(p_path) - logger.warning( - "跳过 relation.parquet 向量导入(保持一致性);" - f"已导入关系元数据: {relation_count}" - ) - continue - - if not p_path.exists(): - logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。") - continue - - logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...") - try: - parquet_file = pq.ParquetFile(p_path) - total_rows = parquet_file.metadata.num_rows - if total_rows == 0: - logger.info(f"{p_path} 为空,跳过。") - continue - - # LPMM Schema: 'hash', 'embedding', 'str' - cols = parquet_file.schema_arrow.names - # 兼容性检查 - content_col = 'str' if 'str' in cols else 'content' - emb_col = 'embedding' - hash_col = 'hash' - - if content_col not in cols or emb_col not in cols: - logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}") - continue - - batch_columns = [content_col, emb_col] - if hash_col in cols: - batch_columns.append(hash_col) - - processed_rows = 0 - added_for_type = 0 - batch_idx = 0 - - for record_batch in parquet_file.iter_batches( - batch_size=self.batch_size, - columns=batch_columns, - ): - batch_idx += 1 - df_batch = record_batch.to_pandas() - - embeddings_list = [] - ids_list = [] - - # 同时处理元数据映射 - for _, row in df_batch.iterrows(): - processed_rows += 1 - content = row[content_col] - emb = row[emb_col] - - if content is None or (isinstance(content, float) and np.isnan(content)): - continue - content = str(content).strip() - if not content: - continue - - if emb is None or len(emb) == 0: - continue - - # 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID - # 保证检索返回 ID 可以直接反查元数据。 - store_id = None - if p_type == "paragraph": - store_id = self.metadata_store.add_paragraph( - content=content, - source="lpmm_import", - knowledge_type="factual", - ) - elif p_type == "entity": - store_id = self.metadata_store.add_entity(name=content) - else: - continue - - raw_hash = row[hash_col] if hash_col in df_batch.columns else None - if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)): - self._register_id_mapping(raw_hash, store_id, p_type) - - # 确保 embedding 是 numpy 数组 - emb_np = np.array(emb, dtype=np.float32) - if emb_np.shape[0] != self.dimension: - logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}") - continue - - embeddings_list.append(emb_np) - ids_list.append(store_id) - - if embeddings_list: - # 分批添加到向量存储 - vectors_np = np.stack(embeddings_list) - count = self.vector_store.add(vectors_np, ids_list) - added_for_type += count - total_vectors += count - - if batch_idx == 1 or batch_idx % 10 == 0: - logger.info( - f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}" - ) - - logger.info( - f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}" - ) - - except Exception as e: - logger.error(f"处理 {p_path} 时出错: {e}") - - # 提交向量存储 - self.vector_store.save() - logger.info(f"向量转换完成。总向量数: {total_vectors}") - - def convert_graph(self): - """将 LPMM 图转换为 GraphStore""" - # LPMM 默认文件名是 rag-graph.graphml - graph_files = [ - self.lpmm_dir / "rag-graph.graphml", - self.lpmm_dir / "graph.graphml", - self.lpmm_dir / "graph_structure.pkl" - ] - - nx_graph = None - - for g_path in graph_files: - if g_path.exists(): - logger.info(f"发现图文件: {g_path}") - try: - if g_path.suffix == ".graphml": - nx_graph = nx.read_graphml(g_path) - elif g_path.suffix == ".pkl": - with open(g_path, "rb") as f: - data = pickle.load(f) - # LPMM 可能会将图存储在包装类中 - if hasattr(data, "graph") and isinstance(data.graph, nx.Graph): - nx_graph = data.graph - elif isinstance(data, nx.Graph): - nx_graph = data - break - except Exception as e: - logger.error(f"加载 {g_path} 失败: {e}") - - if nx_graph is None: - logger.warning("未找到有效的图文件。跳过图转换。") - return - - logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。") - - # 1. 添加节点 - # LPMM 节点通常是哈希或带前缀的字符串。 - # 我们需要将它们映射到 A_memorix 格式。 - # 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。 - - nodes_to_add = [] - node_attrs = {} - - for node, attrs in nx_graph.nodes(data=True): - # 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..." - mapped_node = self._map_node_id(node) - nodes_to_add.append(mapped_node) - if attrs: - node_attrs[mapped_node] = attrs - - self.graph_store.add_nodes(nodes_to_add, node_attrs) - - # 2. 添加边 - edges_to_add = [] - weights = [] - - for u, v, data in nx_graph.edges(data=True): - weight = data.get("weight", 1.0) - edges_to_add.append((self._map_node_id(u), self._map_node_id(v))) - weights.append(float(weight)) - - # 如果可能,将关系同步到 MetadataStore - # 但图的边并不总是包含关系谓词 - # 如果 LPMM 边数据有 'predicate',我们可以添加到元数据 - # 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失 - - if edges_to_add: - self.graph_store.add_edges(edges_to_add, weights) - - self.graph_store.save() - logger.info("图转换完成。") - - def run(self): - self.initialize_stores() - self.convert_vectors() - self.convert_graph() - asyncio.run(self._rebuild_relation_vectors()) - self.vector_store.save() - self.graph_store.save() - self.metadata_store.close() - logger.info("所有转换成功完成。") - - -def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - input_path = Path(args.input) - output_path = Path(args.output) - - if not input_path.exists(): - logger.error(f"输入目录不存在: {input_path}") - sys.exit(1) - - converter = LPMMConverter( - input_path, - output_path, - dimension=args.dim, - batch_size=args.batch_size, - rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild), - ) - converter.run() - -if __name__ == "__main__": - main() diff --git a/plugins/A_memorix/scripts/import_lpmm_json.py b/plugins/A_memorix/scripts/import_lpmm_json.py deleted file mode 100644 index 2e458e16..00000000 --- a/plugins/A_memorix/scripts/import_lpmm_json.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3 -""" -LPMM OpenIE JSON 导入工具。 - -功能: -1. 读取符合 LPMM 规范的 OpenIE JSON 文件 -2. 转换为 A_Memorix 的统一导入格式 -3. 复用 `process_knowledge.py` 中的 `AutoImporter` 直接入库 -""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import traceback -from pathlib import Path -from typing import Any, Dict, List - -from rich.console import Console -from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn - -console = Console() - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -for path in (CURRENT_DIR, WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="将 LPMM OpenIE JSON 导入 A_Memorix") - parser.add_argument("path", help="LPMM JSON 文件路径或目录") - parser.add_argument("--force", action="store_true", help="强制重新导入") - parser.add_argument("--concurrency", "-c", type=int, default=5, help="并发数") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - - -try: - from process_knowledge import AutoImporter - from A_memorix.core.utils.hash import compute_paragraph_hash - from src.common.logger import get_logger -except ImportError as exc: # pragma: no cover - script bootstrap - print(f"导入模块失败,请确认 PYTHONPATH 与工作区结构: {exc}") - raise SystemExit(1) - - -logger = get_logger("A_Memorix.LPMMImport") - - -class LPMMConverter: - def convert_lpmm_to_memorix(self, lpmm_data: Dict[str, Any], filename: str) -> Dict[str, Any]: - memorix_data = {"paragraphs": [], "entities": []} - docs = lpmm_data.get("docs", []) or [] - if not docs: - logger.warning(f"文件中未找到 docs 字段: {filename}") - return memorix_data - - all_entities = set() - for doc in docs: - content = str(doc.get("passage", "") or "").strip() - if not content: - continue - - relations: List[Dict[str, str]] = [] - for triple in doc.get("extracted_triples", []) or []: - if isinstance(triple, list) and len(triple) == 3: - relations.append( - { - "subject": str(triple[0] or "").strip(), - "predicate": str(triple[1] or "").strip(), - "object": str(triple[2] or "").strip(), - } - ) - - entities = [str(item or "").strip() for item in doc.get("extracted_entities", []) or [] if str(item or "").strip()] - all_entities.update(entities) - for relation in relations: - if relation["subject"]: - all_entities.add(relation["subject"]) - if relation["object"]: - all_entities.add(relation["object"]) - - memorix_data["paragraphs"].append( - { - "hash": compute_paragraph_hash(content), - "content": content, - "source": filename, - "entities": entities, - "relations": relations, - } - ) - - memorix_data["entities"] = sorted(all_entities) - return memorix_data - - -async def main() -> None: - parser = _build_arg_parser() - args = parser.parse_args() - - target_path = Path(args.path) - if not target_path.exists(): - logger.error(f"路径不存在: {target_path}") - return - - if target_path.is_dir(): - files_to_process = list(target_path.glob("*-openie.json")) or list(target_path.glob("*.json")) - else: - files_to_process = [target_path] - - if not files_to_process: - logger.error("未找到可处理的 JSON 文件") - return - - importer = AutoImporter(force=bool(args.force), concurrency=int(args.concurrency)) - if not await importer.initialize(): - logger.error("初始化存储失败") - return - - converter = LPMMConverter() - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), - console=console, - transient=False, - ) as progress: - for json_file in files_to_process: - logger.info(f"正在转换并导入: {json_file.name}") - try: - with open(json_file, "r", encoding="utf-8") as handle: - lpmm_data = json.load(handle) - memorix_data = converter.convert_lpmm_to_memorix(lpmm_data, json_file.name) - total_items = len(memorix_data.get("paragraphs", [])) - if total_items <= 0: - logger.warning(f"转换结果为空: {json_file.name}") - continue - - task_id = progress.add_task(f"Importing {json_file.name}", total=total_items) - - def update_progress(step: int = 1) -> None: - progress.advance(task_id, advance=step) - - await importer.import_json_data( - memorix_data, - filename=f"lpmm_{json_file.name}", - progress_callback=update_progress, - ) - except Exception as exc: - logger.error(f"处理文件 {json_file.name} 失败: {exc}\n{traceback.format_exc()}") - - await importer.close() - logger.info("全部处理完成") - - -if __name__ == "__main__": - if sys.platform == "win32": # pragma: no cover - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/migrate_chat_history.py b/plugins/A_memorix/scripts/migrate_chat_history.py deleted file mode 100644 index 0fb0bfe1..00000000 --- a/plugins/A_memorix/scripts/migrate_chat_history.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import asyncio -import json -import sqlite3 -import sys -from datetime import datetime -from pathlib import Path -from typing import Any, Dict - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" - -if str(WORKSPACE_ROOT) not in sys.path: - sys.path.insert(0, str(WORKSPACE_ROOT)) -if str(MAIBOT_ROOT) not in sys.path: - sys.path.insert(0, str(MAIBOT_ROOT)) - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix") - parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") - parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") - parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部") - parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") - return parser.parse_args() - - -def _to_timestamp(value: Any) -> float | None: - if value is None: - return None - if isinstance(value, (int, float)): - return float(value) - text = str(value).strip() - if not text: - return None - try: - return datetime.fromisoformat(text).timestamp() - except ValueError: - return None - - -async def _main() -> int: - args = _parse_args() - db_path = Path(args.db_path).resolve() - if not db_path.exists(): - print(f"数据库不存在: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - sql = """ - SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary - FROM chat_history - ORDER BY id ASC - """ - if int(args.limit or 0) > 0: - sql += " LIMIT ?" - rows = conn.execute(sql, (int(args.limit),)).fetchall() - else: - rows = conn.execute(sql).fetchall() - conn.close() - - print(f"chat_history 待处理: {len(rows)}") - if args.dry_run: - for row in rows[:5]: - print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}") - return 0 - - kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) - await kernel.initialize() - migrated = 0 - skipped = 0 - for row in rows: - participants = json.loads(row["participants"]) if row["participants"] else [] - keywords = json.loads(row["keywords"]) if row["keywords"] else [] - theme = str(row["theme"] or "").strip() - summary = str(row["summary"] or "").strip() - text = f"主题:{theme}\n概括:{summary}".strip() - result: Dict[str, Any] = await kernel.ingest_summary( - external_id=f"chat_history:{row['id']}", - chat_id=str(row["session_id"] or ""), - text=text, - participants=participants, - time_start=_to_timestamp(row["start_timestamp"]), - time_end=_to_timestamp(row["end_timestamp"]), - tags=keywords, - metadata={"theme": theme, "source_row_id": int(row["id"])}, - ) - if result.get("stored_ids"): - migrated += 1 - else: - skipped += 1 - - print(f"迁移完成: migrated={migrated} skipped={skipped}") - print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) - kernel.close() - return 0 - - -if __name__ == "__main__": - raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/migrate_maibot_memory.py b/plugins/A_memorix/scripts/migrate_maibot_memory.py deleted file mode 100644 index 0b26a9cd..00000000 --- a/plugins/A_memorix/scripts/migrate_maibot_memory.py +++ /dev/null @@ -1,1714 +0,0 @@ -#!/usr/bin/env python3 -""" -MaiBot 记忆迁移脚本(chat_history -> A_memorix) - -特性: -1. 高性能:分页读取 + 批量 embedding + 批量写入 -2. 断点续传:基于 last_committed_id 的窗口提交 -3. 精确一次语义:稳定哈希 + 幂等写入 + 向量存在性检查 -4. 可确认筛选:支持时间区间、聊天流(stream/group/user)筛选,并先预览后确认 -""" - -from __future__ import annotations - -import argparse -import asyncio -import hashlib -import importlib -import json -import logging -import os -import pickle -import sqlite3 -import sys -import time -import traceback -import types -from collections import defaultdict -from dataclasses import dataclass -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, Tuple - -import numpy as np -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -RUNTIME_CORE_PACKAGE = "_a_memorix_runtime_core" - -VectorStore = None -GraphStore = None -MetadataStore = None -create_embedding_api_adapter = None -KnowledgeType = None -QuantizationType = None -SparseMatrixFormat = None -compute_hash = None -normalize_text = None -atomic_write = None -model_config = None -RelationWriteService = None - - -def _create_bootstrap_logger(): - fallback = logging.getLogger("A_Memorix.MaiBotMigration") - if not fallback.handlers: - fallback.addHandler(logging.NullHandler()) - try: - for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - from src.common.logger import get_logger - - return get_logger("A_Memorix.MaiBotMigration") - except Exception: - return fallback - - -logger = _create_bootstrap_logger() - - -def _ensure_import_paths() -> None: - for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - - -def _ensure_runtime_core_package() -> str: - existing = sys.modules.get(RUNTIME_CORE_PACKAGE) - if existing is not None and hasattr(existing, "__path__"): - return RUNTIME_CORE_PACKAGE - - pkg = types.ModuleType(RUNTIME_CORE_PACKAGE) - pkg.__path__ = [str(PLUGIN_ROOT / "core")] - pkg.__package__ = RUNTIME_CORE_PACKAGE - sys.modules[RUNTIME_CORE_PACKAGE] = pkg - return RUNTIME_CORE_PACKAGE - - -def _disable_unavailable_gemini_provider() -> None: - global model_config - try: - from google import genai # type: ignore # noqa: F401 - return - except Exception: - pass - - from src.config.config import model_config as loaded_model_config - - providers = list(getattr(loaded_model_config, "api_providers", [])) - if not providers: - model_config = loaded_model_config - return - - kept_providers = [p for p in providers if str(getattr(p, "client_type", "")).lower() != "gemini"] - if len(kept_providers) == len(providers): - model_config = loaded_model_config - return - - loaded_model_config.api_providers = kept_providers - loaded_model_config.api_providers_dict = {p.name: p for p in kept_providers} - - models = list(getattr(loaded_model_config, "models", [])) - kept_models = [m for m in models if m.api_provider in loaded_model_config.api_providers_dict] - loaded_model_config.models = kept_models - loaded_model_config.models_dict = {m.name: m for m in kept_models} - - task_cfg = loaded_model_config.model_task_config - for field_name in task_cfg.__dataclass_fields__.keys(): - task = getattr(task_cfg, field_name, None) - if task is None or not hasattr(task, "model_list"): - continue - task.model_list = [m for m in list(task.model_list) if m in loaded_model_config.models_dict] - - model_config = loaded_model_config - logger.warning("检测到缺少 google.genai,已临时禁用 gemini provider 以保证脚本可运行。") - - -def _bootstrap_runtime_symbols() -> None: - global VectorStore - global GraphStore - global MetadataStore - global KnowledgeType - global QuantizationType - global SparseMatrixFormat - global compute_hash - global normalize_text - global atomic_write - global RelationWriteService - global logger - - if VectorStore is not None and compute_hash is not None and atomic_write is not None: - return - - _ensure_import_paths() - - import src # noqa: F401 - from src.common.logger import get_logger - - logger = get_logger("A_Memorix.MaiBotMigration") - - pkg = _ensure_runtime_core_package() - - vector_store_module = importlib.import_module(f"{pkg}.storage.vector_store") - graph_store_module = importlib.import_module(f"{pkg}.storage.graph_store") - metadata_store_module = importlib.import_module(f"{pkg}.storage.metadata_store") - knowledge_types_module = importlib.import_module(f"{pkg}.storage.knowledge_types") - hash_module = importlib.import_module(f"{pkg}.utils.hash") - io_module = importlib.import_module(f"{pkg}.utils.io") - relation_write_service_module = importlib.import_module(f"{pkg}.utils.relation_write_service") - - VectorStore = vector_store_module.VectorStore - GraphStore = graph_store_module.GraphStore - MetadataStore = metadata_store_module.MetadataStore - KnowledgeType = knowledge_types_module.KnowledgeType - QuantizationType = vector_store_module.QuantizationType - SparseMatrixFormat = graph_store_module.SparseMatrixFormat - compute_hash = hash_module.compute_hash - normalize_text = hash_module.normalize_text - atomic_write = io_module.atomic_write - RelationWriteService = relation_write_service_module.RelationWriteService - - -def _load_embedding_adapter_factory() -> None: - global create_embedding_api_adapter - global model_config - - if create_embedding_api_adapter is not None: - return - - _ensure_import_paths() - - from src.config.config import model_config as loaded_model_config - - model_config = loaded_model_config - _disable_unavailable_gemini_provider() - - pkg = _ensure_runtime_core_package() - api_adapter_module = importlib.import_module(f"{pkg}.embedding.api_adapter") - create_embedding_api_adapter = api_adapter_module.create_embedding_api_adapter - - -DEFAULT_SOURCE_DB = MAIBOT_ROOT / "data" / "MaiBot.db" -DEFAULT_TARGET_DATA_DIR = PLUGIN_ROOT / "data" -DEFAULT_CONFIG_PATH = PLUGIN_ROOT / "config.toml" - -MIGRATION_STATE_DIRNAME = "migration_state" -STATE_FILENAME = "chat_history_resume.json" -BAD_ROWS_FILENAME = "chat_history_bad_rows.jsonl" -REPORT_FILENAME = "chat_history_report.json" - - -class MigrationError(Exception): - """迁移流程错误。""" - - -@dataclass -class SelectionFilter: - time_from_ts: Optional[float] - time_to_ts: Optional[float] - stream_ids: List[str] - stream_filter_requested: bool - start_id: Optional[int] - end_id: Optional[int] - time_from_raw: Optional[str] - time_to_raw: Optional[str] - - def fingerprint_payload(self) -> Dict[str, Any]: - return { - "time_from_ts": self.time_from_ts, - "time_to_ts": self.time_to_ts, - "time_from_raw": self.time_from_raw, - "time_to_raw": self.time_to_raw, - "stream_ids": sorted(self.stream_ids), - "stream_filter_requested": self.stream_filter_requested, - "start_id": self.start_id, - "end_id": self.end_id, - } - - -@dataclass -class PreviewResult: - total: int - distribution: List[Tuple[str, int]] - samples: List[Dict[str, Any]] - - -@dataclass -class MappedRow: - row_id: int - chat_id: str - paragraph_hash: str - content: str - source: str - time_meta: Dict[str, Any] - entities: List[str] - relations: List[Tuple[str, str, str]] - existing_paragraph_vector: bool - - -def _safe_int(value: Any, default: int) -> int: - try: - return int(value) - except Exception: - return default - - -def _safe_float(value: Any, default: float) -> float: - try: - return float(value) - except Exception: - return default - - -def _normalize_name(value: Any) -> str: - return str(value or "").strip() - - -def _canonical_name(value: Any) -> str: - return _normalize_name(value).lower() - - -def _dedup_keep_order(items: Iterable[str]) -> List[str]: - out: List[str] = [] - seen: set[str] = set() - for raw in items: - v = _normalize_name(raw) - if not v: - continue - k = v.lower() - if k in seen: - continue - seen.add(k) - out.append(v) - return out - - -def _format_ts(ts: Optional[float]) -> str: - if ts is None: - return "-" - try: - return datetime.fromtimestamp(float(ts)).strftime("%Y-%m-%d %H:%M:%S") - except Exception: - return str(ts) - - -def _parse_cli_datetime(text: str, is_end: bool = False) -> float: - value = str(text or "").strip() - if not value: - raise ValueError("时间不能为空") - - formats = [ - ("%Y-%m-%d %H:%M:%S", False), - ("%Y/%m/%d %H:%M:%S", False), - ("%Y-%m-%d %H:%M", False), - ("%Y/%m/%d %H:%M", False), - ("%Y-%m-%d", True), - ("%Y/%m/%d", True), - ] - - for fmt, is_date_only in formats: - try: - dt = datetime.strptime(value, fmt) - if is_date_only and is_end: - dt = dt.replace(hour=23, minute=59, second=59, microsecond=0) - return dt.timestamp() - except ValueError: - continue - - raise ValueError( - f"时间格式错误: {value},仅支持 YYYY-MM-DD、YYYY/MM/DD、YYYY-MM-DD HH:mm[:ss]、YYYY/MM/DD HH:mm[:ss]" - ) - - -def _json_hash(payload: Dict[str, Any]) -> str: - data = json.dumps(payload, ensure_ascii=False, sort_keys=True) - return hashlib.sha1(data.encode("utf-8")).hexdigest() - - -def _deep_merge_dict(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]: - out = dict(base) - for key, value in override.items(): - if isinstance(value, dict) and isinstance(out.get(key), dict): - out[key] = _deep_merge_dict(out[key], value) - else: - out[key] = value - return out - - -def _extract_schema_defaults(schema_obj: Dict[str, Any]) -> Dict[str, Any]: - defaults: Dict[str, Any] = {} - if not isinstance(schema_obj, dict): - return defaults - - for key, spec in schema_obj.items(): - if not isinstance(spec, dict): - continue - if "default" in spec: - defaults[key] = spec.get("default") - continue - props = spec.get("properties") - if isinstance(props, dict): - defaults[key] = _extract_schema_defaults(props) - return defaults - - -def _load_manifest_defaults() -> Dict[str, Any]: - manifest_path = PLUGIN_ROOT / "_manifest.json" - if not manifest_path.exists(): - return {} - try: - with open(manifest_path, "r", encoding="utf-8") as f: - payload = json.load(f) - schema = payload.get("config_schema") - if isinstance(schema, dict): - return _extract_schema_defaults(schema) - except Exception as e: - logger.warning(f"读取 manifest 默认配置失败,已回退空配置: {e}") - return {} - - -def _build_source_db_fingerprint(db_path: Path) -> Dict[str, Any]: - stat = db_path.stat() - payload = { - "path": str(db_path.resolve()), - "size": stat.st_size, - "mtime": stat.st_mtime, - } - payload["sha1"] = _json_hash(payload) - return payload - - -def _state_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / STATE_FILENAME - - -def _bad_rows_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / BAD_ROWS_FILENAME - - -def _report_path(target_data_dir: Path) -> Path: - return target_data_dir / MIGRATION_STATE_DIRNAME / REPORT_FILENAME - - -def _dump_json_atomic(path: Path, payload: Dict[str, Any]) -> None: - if atomic_write is None: - path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(path.suffix + ".tmp") - with open(tmp, "w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - f.write("\n") - f.flush() - os.fsync(f.fileno()) - os.replace(tmp, path) - return - - with atomic_write(path, mode="w", encoding="utf-8") as f: - json.dump(payload, f, ensure_ascii=False, indent=2) - f.write("\n") - - -class SourceDB: - def __init__(self, db_path: Path): - self.db_path = db_path - self.conn: Optional[sqlite3.Connection] = None - - def connect(self) -> None: - if not self.db_path.exists(): - raise MigrationError(f"源数据库不存在: {self.db_path}") - - uri = f"file:{self.db_path.resolve().as_posix()}?mode=ro" - try: - self.conn = sqlite3.connect(uri, uri=True, check_same_thread=False) - except sqlite3.OperationalError: - self.conn = sqlite3.connect(str(self.db_path.resolve()), check_same_thread=False) - - self.conn.row_factory = sqlite3.Row - pragmas = [ - "PRAGMA query_only = ON", - "PRAGMA cache_size = -128000", - "PRAGMA temp_store = MEMORY", - "PRAGMA synchronous = OFF", - "PRAGMA journal_mode = WAL", - ] - for sql in pragmas: - try: - self.conn.execute(sql) - except sqlite3.OperationalError: - # 部分 PRAGMA 在 mode=ro 下会失败,不影响只读扫描能力 - continue - - def close(self) -> None: - if self.conn is not None: - self.conn.close() - self.conn = None - - def _require_conn(self) -> sqlite3.Connection: - if self.conn is None: - raise MigrationError("源数据库尚未连接") - return self.conn - - def resolve_stream_ids( - self, - stream_ids: Sequence[str], - group_ids: Sequence[str], - user_ids: Sequence[str], - ) -> List[str]: - conn = self._require_conn() - resolved: set[str] = set(_normalize_name(x) for x in stream_ids if _normalize_name(x)) - has_group_or_user = any(_normalize_name(x) for x in group_ids) or any(_normalize_name(x) for x in user_ids) - if not has_group_or_user: - return sorted(resolved) - - table_exists = conn.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name='chat_streams' LIMIT 1" - ).fetchone() - if table_exists is None: - raise MigrationError("源库缺少 chat_streams 表,无法根据 --group-id/--user-id 映射 stream_id") - - def _select_by_field(field: str, values: Sequence[str]) -> None: - values_norm = [_normalize_name(v) for v in values if _normalize_name(v)] - if not values_norm: - return - placeholders = ",".join("?" for _ in values_norm) - sql = f"SELECT DISTINCT stream_id FROM chat_streams WHERE {field} IN ({placeholders})" - cur = conn.execute(sql, tuple(values_norm)) - for row in cur.fetchall(): - sid = _normalize_name(row["stream_id"]) - if sid: - resolved.add(sid) - - _select_by_field("group_id", group_ids) - _select_by_field("user_id", user_ids) - return sorted(resolved) - - @staticmethod - def _build_where( - selection: SelectionFilter, - start_after_id: Optional[int] = None, - ) -> Tuple[str, List[Any]]: - conditions: List[str] = [] - params: List[Any] = [] - - if selection.start_id is not None: - conditions.append("id >= ?") - params.append(selection.start_id) - if selection.end_id is not None: - conditions.append("id <= ?") - params.append(selection.end_id) - if start_after_id is not None: - conditions.append("id > ?") - params.append(start_after_id) - - if selection.stream_ids: - placeholders = ",".join("?" for _ in selection.stream_ids) - conditions.append(f"chat_id IN ({placeholders})") - params.extend(selection.stream_ids) - elif selection.stream_filter_requested: - conditions.append("1=0") - - if selection.time_from_ts is not None and selection.time_to_ts is not None: - conditions.append("(end_time >= ? AND start_time <= ?)") - params.extend([selection.time_from_ts, selection.time_to_ts]) - elif selection.time_from_ts is not None: - conditions.append("(end_time >= ?)") - params.append(selection.time_from_ts) - elif selection.time_to_ts is not None: - conditions.append("(start_time <= ?)") - params.append(selection.time_to_ts) - - where_sql = "WHERE " + " AND ".join(conditions) if conditions else "" - return where_sql, params - - def count_candidates(self, selection: SelectionFilter) -> int: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" - cur = conn.execute(sql, tuple(params)) - return int(cur.fetchone()["c"]) - - def preview(self, selection: SelectionFilter, preview_limit: int) -> PreviewResult: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - - total_sql = f"SELECT COUNT(*) AS c FROM chat_history {where_sql}" - total = int(conn.execute(total_sql, tuple(params)).fetchone()["c"]) - - dist_sql = ( - f"SELECT chat_id, COUNT(*) AS c FROM chat_history {where_sql} " - "GROUP BY chat_id ORDER BY c DESC LIMIT 30" - ) - distribution = [ - (_normalize_name(row["chat_id"]), int(row["c"])) - for row in conn.execute(dist_sql, tuple(params)).fetchall() - ] - - sample_sql = ( - "SELECT id, chat_id, start_time, end_time, theme, summary " - f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" - ) - sample_params = list(params) - sample_params.append(max(1, int(preview_limit))) - samples = [dict(row) for row in conn.execute(sample_sql, tuple(sample_params)).fetchall()] - - return PreviewResult(total=total, distribution=distribution, samples=samples) - - def iter_rows( - self, - selection: SelectionFilter, - batch_size: int, - start_after_id: int, - ) -> Generator[List[sqlite3.Row], None, None]: - conn = self._require_conn() - cursor = int(start_after_id) - while True: - where_sql, params = self._build_where(selection, start_after_id=cursor) - sql = ( - "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " - f"FROM chat_history {where_sql} ORDER BY id ASC LIMIT ?" - ) - bind = list(params) - bind.append(max(1, int(batch_size))) - rows = conn.execute(sql, tuple(bind)).fetchall() - if not rows: - break - yield rows - cursor = int(rows[-1]["id"]) - - def sample_rows_for_verify( - self, - selection: SelectionFilter, - sample_size: int, - ) -> List[sqlite3.Row]: - conn = self._require_conn() - where_sql, params = self._build_where(selection, start_after_id=None) - sql = ( - "SELECT id, chat_id, start_time, end_time, participants, theme, keywords, summary " - f"FROM chat_history {where_sql} ORDER BY RANDOM() LIMIT ?" - ) - bind = list(params) - bind.append(max(1, int(sample_size))) - return conn.execute(sql, tuple(bind)).fetchall() - - -class MigrationRunner: - def __init__(self, args: argparse.Namespace): - self.args = args - self.source_db_path = Path(args.source_db).resolve() - self.target_data_dir = Path(args.target_data_dir).resolve() - self.state_file = _state_path(self.target_data_dir) - self.bad_rows_file = _bad_rows_path(self.target_data_dir) - self.report_file = _report_path(self.target_data_dir) - - self.source_db = SourceDB(self.source_db_path) - - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.relation_write_service = None - self.plugin_config: Dict[str, Any] = {} - self.embed_workers: int = 5 - - self.selection: Optional[SelectionFilter] = None - self.filter_fingerprint: str = "" - self.source_db_fingerprint: Dict[str, Any] = {} - self.source_db_fingerprint_hash: str = "" - self.state: Dict[str, Any] = {} - - self.started_at = time.time() - self.exit_code = 0 - self.failed = False - self.fail_reason: Optional[str] = None - - self.stats: Dict[str, Any] = { - "source_matched_total": 0, - "scanned_rows": 0, - "valid_rows": 0, - "migrated_rows": 0, - "skipped_existing_rows": 0, - "bad_rows": 0, - "paragraph_vectors_added": 0, - "entity_vectors_added": 0, - "relations_written": 0, - "relation_vectors_written": 0, - "relation_vectors_failed": 0, - "relation_vectors_skipped": 0, - "graph_edges_written": 0, - "windows_committed": 0, - "last_committed_id": 0, - "verify_sample_size": 0, - "verify_paragraph_missing": 0, - "verify_vector_missing": 0, - "verify_relation_missing": 0, - "verify_edge_missing": 0, - "verify_passed": False, - } - - async def run(self) -> int: - try: - _bootstrap_runtime_symbols() - self._prepare_paths() - - self.source_db.connect() - self.selection = self._build_selection_filter() - self.filter_fingerprint = _json_hash(self.selection.fingerprint_payload()) - - self.source_db_fingerprint = _build_source_db_fingerprint(self.source_db_path) - self.source_db_fingerprint_hash = str(self.source_db_fingerprint.get("sha1", "")) - - preview = self.source_db.preview(self.selection, preview_limit=self.args.preview_limit) - self.stats["source_matched_total"] = int(preview.total) - self._print_preview(preview) - - if preview.total <= 0: - logger.info("筛选后无数据,退出。") - self.stats["verify_passed"] = True - if self.args.verify_only: - self._load_plugin_config() - await self._init_target_stores(require_embedding=False) - await self._verify(strict=True) - return self._finalize() - - if self.args.verify_only: - self._load_plugin_config() - await self._init_target_stores(require_embedding=False) - await self._verify(strict=True) - return self._finalize() - - if self.args.dry_run: - logger.info("dry-run 模式:仅预览,不写入。") - return self._finalize() - - if not self.args.yes: - if not self._confirm(): - logger.info("用户取消执行。") - return self._finalize() - - self._load_plugin_config() - await self._init_target_stores(require_embedding=True) - self._load_or_init_state() - - start_after_id = self._resolve_start_after_id() - await self._migrate(start_after_id=start_after_id) - await self._verify(strict=True) - return self._finalize() - except Exception as e: - self.failed = True - self.fail_reason = str(e) - logger.error(f"迁移失败: {e}\n{traceback.format_exc()}") - return self._finalize() - finally: - self._close() - - def _prepare_paths(self) -> None: - (self.target_data_dir / MIGRATION_STATE_DIRNAME).mkdir(parents=True, exist_ok=True) - if self.args.reset_state and self.state_file.exists(): - self.state_file.unlink() - if self.args.reset_state and self.bad_rows_file.exists(): - self.bad_rows_file.unlink() - - def _load_plugin_config(self) -> None: - merged = _load_manifest_defaults() - - config_path = DEFAULT_CONFIG_PATH - if config_path.exists(): - try: - with open(config_path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - if isinstance(raw, dict): - merged = _deep_merge_dict(merged, dict(raw)) - except Exception as e: - logger.warning(f"读取插件配置失败,继续使用默认配置: {e}") - - self.plugin_config = merged - - def _read_existing_vector_dimension(self, fallback_dimension: int) -> int: - meta_path = self.target_data_dir / "vectors" / "vectors_metadata.pkl" - if not meta_path.exists(): - return fallback_dimension - try: - with open(meta_path, "rb") as f: - payload = pickle.load(f) - value = _safe_int(payload.get("dimension"), fallback_dimension) - return max(1, value) - except Exception: - return fallback_dimension - - async def _init_target_stores(self, require_embedding: bool) -> None: - if VectorStore is None or GraphStore is None or MetadataStore is None: - raise MigrationError("运行时初始化失败:存储组件不可用") - - emb_cfg = self.plugin_config.get("embedding", {}) if isinstance(self.plugin_config, dict) else {} - graph_cfg = self.plugin_config.get("graph", {}) if isinstance(self.plugin_config, dict) else {} - - self.embed_workers = max(1, _safe_int(self.args.embed_workers, _safe_int(emb_cfg.get("max_concurrent"), 5))) - emb_batch_size = max(1, _safe_int(emb_cfg.get("batch_size"), 32)) - emb_default_dim = max(1, _safe_int(emb_cfg.get("dimension"), 1024)) - emb_model_name = str(emb_cfg.get("model_name", "auto")) - emb_retry = emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {} - - if require_embedding: - _load_embedding_adapter_factory() - if create_embedding_api_adapter is None: - raise MigrationError("运行时初始化失败:embedding 适配器不可用") - - if model_config is not None: - embedding_task = getattr(getattr(model_config, "model_task_config", None), "embedding", None) - if embedding_task is not None and hasattr(embedding_task, "model_list"): - if not list(embedding_task.model_list): - raise MigrationError( - "当前配置没有可用 embedding 模型。若你使用 gemini provider,请先安装 `google-genai` " - "或切换到可用的 embedding provider。" - ) - - self.embedding_manager = create_embedding_api_adapter( - batch_size=emb_batch_size, - max_concurrent=self.embed_workers, - default_dimension=emb_default_dim, - model_name=emb_model_name, - retry_config=emb_retry, - ) - - try: - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - has_existing_vectors = (self.target_data_dir / "vectors" / "vectors_metadata.pkl").exists() - if not has_existing_vectors: - detected_dim = await self.embedding_manager._detect_dimension() - except Exception as e: - logger.warning(f"嵌入维度探测失败,回退配置维度: {e}") - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - else: - detected_dim = self._read_existing_vector_dimension(emb_default_dim) - self.embedding_manager = None - - q_type = str(emb_cfg.get("quantization_type", "int8")).lower() - if q_type != "int8": - raise MigrationError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - quantization = QuantizationType.INT8 - - matrix_fmt = str(graph_cfg.get("sparse_matrix_format", "csr")).lower() - fmt_map = { - "csr": SparseMatrixFormat.CSR, - "csc": SparseMatrixFormat.CSC, - } - sparse_fmt = fmt_map.get(matrix_fmt, SparseMatrixFormat.CSR) - - self.vector_store = VectorStore( - dimension=detected_dim, - quantization_type=quantization, - data_dir=self.target_data_dir / "vectors", - ) - self.graph_store = GraphStore( - matrix_format=sparse_fmt, - data_dir=self.target_data_dir / "graph", - ) - self.metadata_store = MetadataStore(data_dir=self.target_data_dir / "metadata") - self.metadata_store.connect() - - if self.vector_store.has_data(): - self.vector_store.load() - if self.graph_store.has_data(): - self.graph_store.load() - - self.relation_write_service = None - if require_embedding and RelationWriteService is not None and self.embedding_manager is not None: - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - logger.info( - f"目标存储初始化完成: dim={self.vector_store.dimension}, quant={q_type}, graph_fmt={matrix_fmt}, " - f"embed_workers={self.embed_workers}" - ) - - def _should_write_relation_vectors(self) -> bool: - retrieval_cfg = self.plugin_config.get("retrieval", {}) if isinstance(self.plugin_config, dict) else {} - if not isinstance(retrieval_cfg, dict): - return False - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - return False - return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - async def _ensure_relation_vectors_for_records( - self, - relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]], - ) -> None: - if not relation_records: - return - if self.relation_write_service is None: - return - - success = 0 - failed = 0 - skipped = 0 - for relation_hash, rel in relation_records.items(): - result = await self.relation_write_service.ensure_relation_vector( - hash_value=relation_hash, - subject=str(rel[0]), - predicate=str(rel[1]), - obj=str(rel[2]), - ) - if result.vector_state == "ready": - if result.vector_written: - success += 1 - else: - skipped += 1 - else: - failed += 1 - - self.stats["relation_vectors_written"] += success - self.stats["relation_vectors_failed"] += failed - self.stats["relation_vectors_skipped"] += skipped - - def _build_selection_filter(self) -> SelectionFilter: - if self.args.start_id is not None and self.args.start_id <= 0: - raise MigrationError("--start-id 必须 > 0") - if self.args.end_id is not None and self.args.end_id <= 0: - raise MigrationError("--end-id 必须 > 0") - if self.args.start_id is not None and self.args.end_id is not None and self.args.start_id > self.args.end_id: - raise MigrationError("--start-id 不能大于 --end-id") - - time_from_ts = _parse_cli_datetime(self.args.time_from, is_end=False) if self.args.time_from else None - time_to_ts = _parse_cli_datetime(self.args.time_to, is_end=True) if self.args.time_to else None - if time_from_ts is not None and time_to_ts is not None and time_from_ts > time_to_ts: - raise MigrationError("--time-from 不能晚于 --time-to") - - stream_filter_requested = bool( - (self.args.stream_id or []) or (self.args.group_id or []) or (self.args.user_id or []) - ) - stream_ids = self.source_db.resolve_stream_ids( - stream_ids=self.args.stream_id or [], - group_ids=self.args.group_id or [], - user_ids=self.args.user_id or [], - ) - if stream_filter_requested and not stream_ids: - logger.warning("已指定 stream/group/user 筛选,但未解析到任何 stream_id,结果将为空。") - - logger.info( - f"筛选条件: time_from={self.args.time_from or '-'}, time_to={self.args.time_to or '-'}, " - f"stream_ids={len(stream_ids)}, stream_filter_requested={stream_filter_requested}" - ) - - return SelectionFilter( - time_from_ts=time_from_ts, - time_to_ts=time_to_ts, - stream_ids=stream_ids, - stream_filter_requested=stream_filter_requested, - start_id=self.args.start_id, - end_id=self.args.end_id, - time_from_raw=self.args.time_from, - time_to_raw=self.args.time_to, - ) - - def _load_or_init_state(self) -> None: - if self.args.start_id is not None: - logger.info("检测到 --start-id,已按用户指定起点覆盖断点状态。") - self.state = self._new_state(last_committed_id=int(self.args.start_id) - 1) - return - - if self.args.no_resume: - self.state = self._new_state(last_committed_id=0) - return - - if not self.state_file.exists(): - self.state = self._new_state(last_committed_id=0) - return - - with open(self.state_file, "r", encoding="utf-8") as f: - loaded = json.load(f) - - loaded_filter_fp = str(loaded.get("filter_fingerprint", "")) - loaded_source_fp = str(loaded.get("source_db_fingerprint", "")) - - if loaded_filter_fp != self.filter_fingerprint or loaded_source_fp != self.source_db_fingerprint_hash: - if self.args.dry_run or self.args.verify_only: - logger.info("检测到断点与当前筛选不一致;当前为只读模式,将忽略旧断点。") - self.state = self._new_state(last_committed_id=0) - return - raise MigrationError( - "检测到筛选条件或源库指纹变化,已拒绝继续续传。请使用 --reset-state 或调整参数后重试。" - ) - - self.state = loaded - stored_stats = loaded.get("stats", {}) - if isinstance(stored_stats, dict): - for k, v in stored_stats.items(): - if k in self.stats and isinstance(v, (int, float, bool)): - self.stats[k] = v - - def _new_state(self, last_committed_id: int) -> Dict[str, Any]: - return { - "version": 1, - "updated_at": time.time(), - "last_committed_id": int(last_committed_id), - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint_hash, - "source_db_meta": self.source_db_fingerprint, - "stats": dict(self.stats), - } - - def _flush_state(self, last_committed_id: int) -> None: - self.stats["last_committed_id"] = int(last_committed_id) - self.state = { - "version": 1, - "updated_at": time.time(), - "last_committed_id": int(last_committed_id), - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint_hash, - "source_db_meta": self.source_db_fingerprint, - "stats": dict(self.stats), - } - _dump_json_atomic(self.state_file, self.state) - - def _resolve_start_after_id(self) -> int: - if self.selection is None: - raise MigrationError("selection 未初始化") - - if self.args.start_id is not None: - return int(self.args.start_id) - 1 - - if self.args.no_resume: - return 0 - - state_last = _safe_int(self.state.get("last_committed_id"), 0) if self.state else 0 - return max(0, state_last) - - def _print_preview(self, preview: PreviewResult) -> None: - print("\n=== Migration Preview ===") - print(f"source_db: {self.source_db_path}") - print(f"target_data_dir: {self.target_data_dir}") - if self.selection: - print( - f"time_window: [{self.selection.time_from_raw or '-'} ~ {self.selection.time_to_raw or '-'}] " - f"(ts: {_format_ts(self.selection.time_from_ts)} ~ {_format_ts(self.selection.time_to_ts)})" - ) - print( - f"id_window: [{self.selection.start_id or '-'} ~ {self.selection.end_id or '-'}], " - f"selected_streams={len(self.selection.stream_ids)}" - ) - print(f"matched_rows: {preview.total}") - - if preview.distribution: - print("top_chat_distribution:") - for cid, cnt in preview.distribution[:10]: - print(f" - {cid}: {cnt}") - else: - print("top_chat_distribution: (none)") - - if preview.samples: - print(f"samples (first {len(preview.samples)}):") - for row in preview.samples: - summary_preview = _normalize_name(row.get("summary", ""))[:60] - theme_preview = _normalize_name(row.get("theme", ""))[:30] - print( - f" - id={row.get('id')} chat_id={row.get('chat_id')} " - f"[{_format_ts(row.get('start_time'))} ~ {_format_ts(row.get('end_time'))}] " - f"theme={theme_preview!r} summary={summary_preview!r}" - ) - print("=========================\n") - - def _confirm(self) -> bool: - answer = input("确认按以上筛选执行迁移?输入 y 继续 [y/N]: ").strip().lower() - return answer in {"y", "yes"} - - def _parse_json_list_field(self, raw: Any, field_name: str, row_id: int) -> List[str]: - if raw is None: - return [] - if isinstance(raw, list): - data = raw - elif isinstance(raw, str): - try: - parsed = json.loads(raw) - except Exception as e: - raise ValueError(f"{field_name} JSON 解析失败: {e}") from e - if not isinstance(parsed, list): - raise ValueError(f"{field_name} JSON 必须是 list,当前为 {type(parsed).__name__}") - data = parsed - else: - raise ValueError(f"{field_name} 字段类型不支持: {type(raw).__name__}") - return _dedup_keep_order(str(x) for x in data if _normalize_name(x)) - - def _map_row(self, row: sqlite3.Row) -> MappedRow: - row_id = int(row["id"]) - chat_id = _normalize_name(row["chat_id"]) - theme = _normalize_name(row["theme"]) - summary = _normalize_name(row["summary"]) - - participants = self._parse_json_list_field(row["participants"], "participants", row_id) - keywords = self._parse_json_list_field(row["keywords"], "keywords", row_id) - keywords_top = keywords[:8] - - participants_text = "、".join(participants) if participants else "" - keywords_text = "、".join(keywords_top) if keywords_top else "" - - content = ( - f"话题:{theme}\n" - f"概括:{summary}\n" - f"参与者:{participants_text}\n" - f"关键词:{keywords_text}" - ).strip() - - paragraph_hash = compute_hash(normalize_text(content)) - source = f"maibot.chat_history:{chat_id}" - - start_time = _safe_float(row["start_time"], 0.0) - end_time = _safe_float(row["end_time"], start_time) - time_meta = { - "event_time_start": start_time, - "event_time_end": end_time, - "time_granularity": "minute", - "time_confidence": 0.95, - } - - entities = _dedup_keep_order([*participants, theme, *keywords_top]) - relations: List[Tuple[str, str, str]] = [] - if theme: - for participant in participants: - relations.append((participant, "参与话题", theme)) - for keyword in keywords_top: - relations.append((theme, "关键词", keyword)) - - existing_vector = paragraph_hash in self.vector_store - return MappedRow( - row_id=row_id, - chat_id=chat_id, - paragraph_hash=paragraph_hash, - content=content, - source=source, - time_meta=time_meta, - entities=entities, - relations=relations, - existing_paragraph_vector=existing_vector, - ) - - def _append_bad_row(self, row: sqlite3.Row, reason: str) -> None: - payload = { - "id": int(row["id"]), - "chat_id": _normalize_name(row["chat_id"]), - "start_time": row["start_time"], - "end_time": row["end_time"], - "participants": row["participants"], - "theme": _normalize_name(row["theme"]), - "keywords": row["keywords"], - "summary": row["summary"], - "error": reason, - "timestamp": time.time(), - } - self.bad_rows_file.parent.mkdir(parents=True, exist_ok=True) - with open(self.bad_rows_file, "a", encoding="utf-8") as f: - f.write(json.dumps(payload, ensure_ascii=False)) - f.write("\n") - - async def _migrate(self, start_after_id: int) -> None: - if self.selection is None: - raise MigrationError("selection 未初始化") - - read_batch_size = max(1, int(self.args.read_batch_size)) - commit_window_rows = max(1, int(self.args.commit_window_rows)) - log_every = max(1, int(self.args.log_every)) - - window_rows: List[MappedRow] = [] - window_scanned = 0 - last_seen_id = start_after_id - - logger.info( - f"开始迁移: start_after_id={start_after_id}, read_batch_size={read_batch_size}, " - f"commit_window_rows={commit_window_rows}" - ) - - for batch in self.source_db.iter_rows(self.selection, read_batch_size, start_after_id): - for row in batch: - row_id = int(row["id"]) - last_seen_id = row_id - self.stats["scanned_rows"] += 1 - window_scanned += 1 - - try: - mapped = self._map_row(row) - except Exception as e: - self.stats["bad_rows"] += 1 - self._append_bad_row(row, str(e)) - if self.stats["bad_rows"] > int(self.args.max_errors): - raise MigrationError( - f"坏行数量超过上限 max_errors={self.args.max_errors},已中止。" - ) - continue - - self.stats["valid_rows"] += 1 - if mapped.existing_paragraph_vector: - self.stats["skipped_existing_rows"] += 1 - else: - self.stats["migrated_rows"] += 1 - window_rows.append(mapped) - - if window_scanned >= commit_window_rows: - await self._commit_window(window_rows, last_seen_id) - window_rows = [] - window_scanned = 0 - - if self.stats["scanned_rows"] % log_every == 0: - logger.info( - f"迁移进度: scanned={self.stats['scanned_rows']}/{self.stats['source_matched_total']}, " - f"valid={self.stats['valid_rows']}, bad={self.stats['bad_rows']}, " - f"last_id={last_seen_id}" - ) - - if window_scanned > 0 or window_rows: - await self._commit_window(window_rows, last_seen_id) - - logger.info( - f"迁移主流程完成: scanned={self.stats['scanned_rows']}, valid={self.stats['valid_rows']}, " - f"bad={self.stats['bad_rows']}, last_committed_id={self.stats['last_committed_id']}" - ) - - async def _commit_window(self, rows: List[MappedRow], last_seen_id: int) -> None: - if not rows: - self._flush_state(last_seen_id) - self.stats["windows_committed"] += 1 - return - - now_ts = time.time() - empty_meta_blob = pickle.dumps({}) - - conn = self.metadata_store.get_connection() - - cursor = conn.cursor() - - # 批量查询本窗口内已存在的段落,保证重跑时 entity/mention 不重复累计 - existing_paragraph_hashes: set[str] = set() - all_hashes = [item.paragraph_hash for item in rows] - for i in range(0, len(all_hashes), 800): - batch_hashes = all_hashes[i : i + 800] - if not batch_hashes: - continue - placeholders = ",".join("?" for _ in batch_hashes) - existing_rows = cursor.execute( - f"SELECT hash FROM paragraphs WHERE hash IN ({placeholders})", - tuple(batch_hashes), - ).fetchall() - for row in existing_rows: - existing_paragraph_hashes.add(str(row["hash"])) - - paragraph_records: List[Tuple[Any, ...]] = [] - paragraph_embed_map: Dict[str, str] = {} - - entity_display: Dict[str, str] = {} - entity_counts: Dict[str, int] = defaultdict(int) - paragraph_entity_mentions: Dict[Tuple[str, str], int] = defaultdict(int) - entity_embed_map: Dict[str, str] = {} - - relation_records: Dict[str, Tuple[str, str, str, float, Optional[str], bytes]] = {} - paragraph_relation_links: set[Tuple[str, str]] = set() - - for item in rows: - is_new_paragraph = item.paragraph_hash not in existing_paragraph_hashes - - start_ts = _safe_float(item.time_meta.get("event_time_start"), 0.0) - end_ts = _safe_float(item.time_meta.get("event_time_end"), start_ts) - confidence = _safe_float(item.time_meta.get("time_confidence"), 0.95) - granularity = _normalize_name(item.time_meta.get("time_granularity")) or "minute" - - if is_new_paragraph: - paragraph_records.append( - ( - item.paragraph_hash, - item.content, - None, - now_ts, - now_ts, - empty_meta_blob, - item.source, - len(normalize_text(item.content).split()), - None, - start_ts, - end_ts, - granularity, - confidence, - KnowledgeType.NARRATIVE.value, - ) - ) - - if item.paragraph_hash not in self.vector_store: - paragraph_embed_map[item.paragraph_hash] = item.content - - for entity in item.entities: - name = _normalize_name(entity) - if not name: - continue - canon = _canonical_name(name) - if not canon: - continue - entity_hash = compute_hash(canon) - entity_display.setdefault(entity_hash, name) - if is_new_paragraph: - entity_counts[entity_hash] += 1 - paragraph_entity_mentions[(item.paragraph_hash, entity_hash)] += 1 - if entity_hash not in self.vector_store: - entity_embed_map.setdefault(entity_hash, name) - - for subject, predicate, obj in item.relations: - s = _normalize_name(subject) - p = _normalize_name(predicate) - o = _normalize_name(obj) - if not (s and p and o): - continue - - s_canon = _canonical_name(s) - p_canon = _canonical_name(p) - o_canon = _canonical_name(o) - relation_hash = compute_hash(f"{s_canon}|{p_canon}|{o_canon}") - - if is_new_paragraph: - relation_records.setdefault( - relation_hash, - (s, p, o, 1.0, item.paragraph_hash, empty_meta_blob), - ) - paragraph_relation_links.add((item.paragraph_hash, relation_hash)) - - for relation_entity in (s, o): - e_canon = _canonical_name(relation_entity) - if not e_canon: - continue - e_hash = compute_hash(e_canon) - entity_display.setdefault(e_hash, relation_entity) - if is_new_paragraph: - entity_counts[e_hash] += 1 - paragraph_entity_mentions[(item.paragraph_hash, e_hash)] += 1 - if e_hash not in self.vector_store: - entity_embed_map.setdefault(e_hash, relation_entity) - - try: - cursor.execute("BEGIN") - - if paragraph_records: - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraphs - ( - hash, content, vector_index, created_at, updated_at, metadata, source, word_count, - event_time, event_time_start, event_time_end, time_granularity, time_confidence, knowledge_type - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - paragraph_records, - ) - - if entity_counts: - entity_rows = [ - ( - entity_hash, - entity_display[entity_hash], - None, - int(count), - now_ts, - empty_meta_blob, - ) - for entity_hash, count in entity_counts.items() - ] - try: - cursor.executemany( - """ - INSERT INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(hash) DO UPDATE SET - appearance_count = entities.appearance_count + excluded.appearance_count - """, - entity_rows, - ) - except sqlite3.OperationalError: - cursor.executemany( - """ - INSERT OR IGNORE INTO entities - (hash, name, vector_index, appearance_count, created_at, metadata) - VALUES (?, ?, ?, ?, ?, ?) - """, - entity_rows, - ) - cursor.executemany( - "UPDATE entities SET appearance_count = appearance_count + ? WHERE hash = ?", - [(int(count), entity_hash) for entity_hash, count in entity_counts.items()], - ) - - if paragraph_entity_mentions: - pe_rows = [ - (paragraph_hash, entity_hash, int(mentions)) - for (paragraph_hash, entity_hash), mentions in paragraph_entity_mentions.items() - ] - try: - cursor.executemany( - """ - INSERT INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - ON CONFLICT(paragraph_hash, entity_hash) DO UPDATE SET - mention_count = paragraph_entities.mention_count + excluded.mention_count - """, - pe_rows, - ) - except sqlite3.OperationalError: - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraph_entities - (paragraph_hash, entity_hash, mention_count) - VALUES (?, ?, ?) - """, - pe_rows, - ) - cursor.executemany( - """ - UPDATE paragraph_entities - SET mention_count = mention_count + ? - WHERE paragraph_hash = ? AND entity_hash = ? - """, - [(m, p, e) for (p, e, m) in pe_rows], - ) - - if relation_records: - relation_rows = [ - ( - relation_hash, - rel[0], - rel[1], - rel[2], - None, - rel[3], - now_ts, - rel[4], - rel[5], - ) - for relation_hash, rel in relation_records.items() - ] - cursor.executemany( - """ - INSERT OR IGNORE INTO relations - (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, - relation_rows, - ) - - if paragraph_relation_links: - pr_rows = [(p_hash, r_hash) for p_hash, r_hash in paragraph_relation_links] - cursor.executemany( - """ - INSERT OR IGNORE INTO paragraph_relations - (paragraph_hash, relation_hash) - VALUES (?, ?) - """, - pr_rows, - ) - - conn.commit() - except Exception: - conn.rollback() - raise - - self.stats["relations_written"] += len(relation_records) - - if relation_records: - edge_pairs = [] - relation_hashes = [] - for relation_hash, rel in relation_records.items(): - edge_pairs.append((rel[0], rel[2])) - relation_hashes.append(relation_hash) - - with self.graph_store.batch_update(): - self.graph_store.add_edges(edge_pairs, relation_hashes=relation_hashes) - self.stats["graph_edges_written"] += len(edge_pairs) - - if self._should_write_relation_vectors(): - await self._ensure_relation_vectors_for_records(relation_records) - - para_added = await self._embed_and_add_vectors( - id_to_text=paragraph_embed_map, - batch_size=max(1, int(self.args.embed_batch_size)), - workers=self.embed_workers, - ) - ent_added = await self._embed_and_add_vectors( - id_to_text=entity_embed_map, - batch_size=max(1, int(self.args.entity_embed_batch_size)), - workers=self.embed_workers, - ) - self.stats["paragraph_vectors_added"] += para_added - self.stats["entity_vectors_added"] += ent_added - - self.vector_store.save() - self.graph_store.save() - - self.stats["windows_committed"] += 1 - self._flush_state(last_seen_id) - - async def _embed_and_add_vectors( - self, - id_to_text: Dict[str, str], - batch_size: int, - workers: int, - ) -> int: - if not id_to_text: - return 0 - if self.embedding_manager is None: - raise MigrationError("embedding_manager 未初始化,无法写入向量") - - ids = [] - texts = [] - for hash_id, text in id_to_text.items(): - if hash_id in self.vector_store: - continue - ids.append(hash_id) - texts.append(text) - - if not ids: - return 0 - - total_added = 0 - chunk_size = max(1, int(batch_size)) - for i in range(0, len(ids), chunk_size): - chunk_ids = ids[i : i + chunk_size] - chunk_texts = texts[i : i + chunk_size] - - embeddings = await self.embedding_manager.encode_batch( - chunk_texts, - batch_size=chunk_size, - num_workers=max(1, int(workers)), - ) - - emb_arr = np.asarray(embeddings, dtype=np.float32) - if emb_arr.ndim == 1: - emb_arr = emb_arr.reshape(1, -1) - if emb_arr.shape[0] != len(chunk_ids): - logger.warning( - f"embedding 返回数量异常: expected={len(chunk_ids)}, got={emb_arr.shape[0]},跳过该批次" - ) - continue - - valid_vectors = [] - valid_ids = [] - for idx, vec in enumerate(emb_arr): - if vec.ndim != 1: - continue - if vec.shape[0] != self.vector_store.dimension: - logger.warning( - f"向量维度不匹配,跳过: id={chunk_ids[idx]}, got={vec.shape[0]}, expected={self.vector_store.dimension}" - ) - continue - if not np.all(np.isfinite(vec)): - logger.warning(f"向量含 NaN/Inf,跳过: id={chunk_ids[idx]}") - continue - if chunk_ids[idx] in self.vector_store: - continue - valid_vectors.append(vec) - valid_ids.append(chunk_ids[idx]) - - if valid_vectors: - batch_vectors = np.stack(valid_vectors).astype(np.float32, copy=False) - added = self.vector_store.add(batch_vectors, valid_ids) - total_added += int(added) - - return total_added - - async def _verify(self, strict: bool) -> None: - if self.selection is None: - raise MigrationError("selection 未初始化") - - sample_size = min(2000, max(0, int(self.stats.get("source_matched_total", 0)))) - self.stats["verify_sample_size"] = sample_size - - if sample_size <= 0: - self.stats["verify_passed"] = True - return - - sample_rows = self.source_db.sample_rows_for_verify(self.selection, sample_size) - para_missing = 0 - vec_missing = 0 - rel_missing = 0 - edge_missing = 0 - - for row in sample_rows: - try: - mapped = self._map_row(row) - except Exception: - continue - - paragraph = self.metadata_store.get_paragraph(mapped.paragraph_hash) - if paragraph is None: - para_missing += 1 - if mapped.paragraph_hash not in self.vector_store: - vec_missing += 1 - - for s, p, o in mapped.relations: - relation_hash = compute_hash(f"{_canonical_name(s)}|{_canonical_name(p)}|{_canonical_name(o)}") - relation = self.metadata_store.get_relation(relation_hash) - if relation is None: - rel_missing += 1 - if self.graph_store.get_edge_weight(s, o) <= 0.0: - edge_missing += 1 - - self.stats["verify_paragraph_missing"] = para_missing - self.stats["verify_vector_missing"] = vec_missing - self.stats["verify_relation_missing"] = rel_missing - self.stats["verify_edge_missing"] = edge_missing - - verify_passed = all(x == 0 for x in [para_missing, vec_missing, rel_missing, edge_missing]) - if strict and not verify_passed: - self.failed = True - self.fail_reason = ( - "严格校验失败: " - f"paragraph_missing={para_missing}, vector_missing={vec_missing}, " - f"relation_missing={rel_missing}, edge_missing={edge_missing}" - ) - - self.stats["verify_passed"] = verify_passed - - def _finalize(self) -> int: - elapsed = time.time() - self.started_at - self.stats["elapsed_seconds"] = elapsed - - report = { - "success": not self.failed, - "fail_reason": self.fail_reason, - "args": vars(self.args), - "source_db": str(self.source_db_path), - "target_data_dir": str(self.target_data_dir), - "selection": self.selection.fingerprint_payload() if self.selection else {}, - "filter_fingerprint": self.filter_fingerprint, - "source_db_fingerprint": self.source_db_fingerprint, - "state_file": str(self.state_file), - "bad_rows_file": str(self.bad_rows_file), - "stats": dict(self.stats), - "timestamp": time.time(), - } - - _dump_json_atomic(self.report_file, report) - - if self.failed: - self.exit_code = 1 - elif self.stats.get("bad_rows", 0) > 0: - self.exit_code = 2 - else: - self.exit_code = 0 - - print("\n=== Migration Report ===") - print(f"success: {not self.failed}") - if self.fail_reason: - print(f"fail_reason: {self.fail_reason}") - print(f"elapsed: {elapsed:.2f}s") - print(f"source_matched_total: {self.stats['source_matched_total']}") - print(f"scanned_rows: {self.stats['scanned_rows']}") - print(f"valid_rows: {self.stats['valid_rows']}") - print(f"migrated_rows: {self.stats['migrated_rows']}") - print(f"skipped_existing_rows: {self.stats['skipped_existing_rows']}") - print(f"bad_rows: {self.stats['bad_rows']}") - print(f"paragraph_vectors_added: {self.stats['paragraph_vectors_added']}") - print(f"entity_vectors_added: {self.stats['entity_vectors_added']}") - print(f"relations_written: {self.stats['relations_written']}") - print( - "relation_vectors: " - f"written={self.stats['relation_vectors_written']}, " - f"failed={self.stats['relation_vectors_failed']}, " - f"skipped={self.stats['relation_vectors_skipped']}" - ) - print(f"graph_edges_written: {self.stats['graph_edges_written']}") - print(f"windows_committed: {self.stats['windows_committed']}") - print(f"last_committed_id: {self.stats['last_committed_id']}") - print( - "verify: " - f"sample={self.stats['verify_sample_size']}, " - f"paragraph_missing={self.stats['verify_paragraph_missing']}, " - f"vector_missing={self.stats['verify_vector_missing']}, " - f"relation_missing={self.stats['verify_relation_missing']}, " - f"edge_missing={self.stats['verify_edge_missing']}, " - f"passed={self.stats['verify_passed']}" - ) - print(f"report_file: {self.report_file}") - print("========================\n") - - return self.exit_code - - def _close(self) -> None: - try: - if self.metadata_store is not None: - self.metadata_store.close() - except Exception: - pass - self.source_db.close() - - -def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser( - description="迁移 MaiBot chat_history 到 A_memorix(高性能 + 可断点续传 + 可确认筛选)" - ) - - parser.add_argument("--source-db", default=str(DEFAULT_SOURCE_DB), help="源数据库路径(默认 data/MaiBot.db)") - parser.add_argument( - "--target-data-dir", - default=str(DEFAULT_TARGET_DATA_DIR), - help="A_memorix 数据目录(默认 plugins/A_memorix/data)", - ) - - resume_group = parser.add_mutually_exclusive_group() - resume_group.add_argument("--resume", dest="no_resume", action="store_false", help="启用断点续传(默认)") - resume_group.add_argument("--no-resume", dest="no_resume", action="store_true", help="禁用断点续传") - parser.set_defaults(no_resume=False) - - parser.add_argument("--reset-state", action="store_true", help="清空迁移状态文件后执行") - parser.add_argument("--start-id", type=int, default=None, help="从指定 chat_history.id 开始迁移(覆盖断点)") - parser.add_argument("--end-id", type=int, default=None, help="迁移到指定 chat_history.id") - - parser.add_argument("--read-batch-size", type=int, default=2000, help="源库分页读取大小(默认 2000)") - parser.add_argument("--commit-window-rows", type=int, default=20000, help="每窗口提交行数(默认 20000)") - parser.add_argument("--embed-batch-size", type=int, default=256, help="段落 embedding 批次大小(默认 256)") - parser.add_argument( - "--entity-embed-batch-size", - type=int, - default=512, - help="实体 embedding 批次大小(默认 512)", - ) - parser.add_argument("--embed-workers", type=int, default=None, help="embedding 并发数(默认读取配置)") - parser.add_argument("--max-errors", type=int, default=500, help="坏行上限(默认 500)") - parser.add_argument("--log-every", type=int, default=5000, help="日志输出步长(默认 5000)") - - parser.add_argument("--dry-run", action="store_true", help="仅预览不写入") - parser.add_argument("--verify-only", action="store_true", help="仅执行严格校验") - - parser.add_argument("--time-from", default=None, help="开始时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") - parser.add_argument("--time-to", default=None, help="结束时间:YYYY-MM-DD / YYYY/MM/DD / YYYY-MM-DD HH:mm[:ss]") - parser.add_argument("--stream-id", action="append", default=[], help="聊天流 stream_id(可重复)") - parser.add_argument("--group-id", action="append", default=[], help="群号(可重复,自动映射 stream_id)") - parser.add_argument("--user-id", action="append", default=[], help="用户号(可重复,自动映射 stream_id)") - parser.add_argument("--yes", action="store_true", help="跳过交互确认") - parser.add_argument("--preview-limit", type=int, default=20, help="预览样本条数(默认 20)") - - return parser - - -async def async_main() -> int: - parser = build_parser() - args = parser.parse_args() - - runner = MigrationRunner(args) - return await runner.run() - - -def main() -> int: - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - return asyncio.run(async_main()) - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/migrate_person_memory_points.py b/plugins/A_memorix/scripts/migrate_person_memory_points.py deleted file mode 100644 index a03a8914..00000000 --- a/plugins/A_memorix/scripts/migrate_person_memory_points.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import asyncio -import json -import sqlite3 -import sys -from pathlib import Path -from typing import Any, Dict, List - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" - -if str(WORKSPACE_ROOT) not in sys.path: - sys.path.insert(0, str(WORKSPACE_ROOT)) -if str(MAIBOT_ROOT) not in sys.path: - sys.path.insert(0, str(MAIBOT_ROOT)) - -from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 - - -def _parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix") - parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") - parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") - parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部") - parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") - return parser.parse_args() - - -def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]: - try: - values = json.loads(raw_value) if raw_value else [] - except Exception: - values = [] - items: List[Dict[str, Any]] = [] - for index, item in enumerate(values): - text = str(item or "").strip() - if not text: - continue - parts = text.split(":") - if len(parts) >= 3: - category = parts[0].strip() - content = ":".join(parts[1:-1]).strip() - weight = parts[-1].strip() - else: - category = "其他" - content = text - weight = "1.0" - if content: - items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"}) - return items - - -async def _main() -> int: - args = _parse_args() - db_path = Path(args.db_path).resolve() - if not db_path.exists(): - print(f"数据库不存在: {db_path}") - return 1 - - conn = sqlite3.connect(str(db_path)) - conn.row_factory = sqlite3.Row - sql = """ - SELECT person_id, person_name, user_nickname, memory_points - FROM person_info - WHERE memory_points IS NOT NULL AND memory_points != '' - ORDER BY id ASC - """ - if int(args.limit or 0) > 0: - sql += " LIMIT ?" - rows = conn.execute(sql, (int(args.limit),)).fetchall() - else: - rows = conn.execute(sql).fetchall() - conn.close() - - preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows) - print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}") - if args.dry_run: - for row in rows[:5]: - print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}") - return 0 - - kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) - await kernel.initialize() - migrated = 0 - skipped = 0 - for row in rows: - person_id = str(row["person_id"] or "").strip() - if not person_id: - continue - display_name = str(row["person_name"] or row["user_nickname"] or "").strip() - for item in _parse_memory_points(row["memory_points"]): - result: Dict[str, Any] = await kernel.ingest_text( - external_id=f"person_memory:{person_id}:{item['index']}", - source_type="person_fact", - text=f"[{item['category']}] {item['content']}", - person_ids=[person_id], - tags=[item["category"]], - entities=[person_id, display_name] if display_name else [person_id], - metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name}, - ) - if result.get("stored_ids"): - migrated += 1 - else: - skipped += 1 - - print(f"迁移完成: migrated={migrated} skipped={skipped}") - print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) - kernel.close() - return 0 - - -if __name__ == "__main__": - raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/process_knowledge.py b/plugins/A_memorix/scripts/process_knowledge.py deleted file mode 100644 index d9e6fe32..00000000 --- a/plugins/A_memorix/scripts/process_knowledge.py +++ /dev/null @@ -1,728 +0,0 @@ -#!/usr/bin/env python3 -""" -知识库自动导入脚本 (Strategy-Aware Version) - -功能: -1. 扫描 plugins/A_memorix/data/raw 下的 .txt 文件 -2. 检查 data/import_manifest.json 确认是否已导入 -3. 使用 Strategy 模式处理文件 (Narrative/Factual/Quote) -4. 将生成的数据直接存入 VectorStore/GraphStore/MetadataStore -5. 更新 manifest -""" - -import sys -import os -import json -import asyncio -import time -import random -import hashlib -import tomlkit -import argparse -from pathlib import Path -from datetime import datetime -from typing import List, Dict, Any, Optional -from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeElapsedColumn -from rich.console import Console -from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type - -console = Console() - -class LLMGenerationError(Exception): - pass - -# 路径设置 -current_dir = Path(__file__).resolve().parent -plugin_root = current_dir.parent -workspace_root = plugin_root.parent -maibot_root = workspace_root / "MaiBot" -for path in (workspace_root, maibot_root, plugin_root): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - -# 数据目录 -DATA_DIR = plugin_root / "data" -RAW_DIR = DATA_DIR / "raw" -PROCESSED_DIR = DATA_DIR / "processed" -MANIFEST_PATH = DATA_DIR / "import_manifest.json" - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix Knowledge Importer (Strategy-Aware)") - parser.add_argument("--force", action="store_true", help="Force re-import") - parser.add_argument("--clear-manifest", action="store_true", help="Clear manifest") - parser.add_argument( - "--type", - "-t", - default="auto", - help="Target import strategy override (auto/narrative/factual/quote)", - ) - parser.add_argument("--concurrency", "-c", type=int, default=5) - parser.add_argument( - "--chat-log", - action="store_true", - help="聊天记录导入模式:强制 narrative 策略,并使用 LLM 语义抽取 event_time/event_time_range", - ) - parser.add_argument( - "--chat-reference-time", - default=None, - help="chat_log 模式的相对时间参考点(如 2026/02/12 10:30);不传则使用当前本地时间", - ) - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - sys.exit(0) - - -try: - import A_memorix.core as core_module - import A_memorix.core.storage as storage_module - from src.common.logger import get_logger - from src.services import llm_service as llm_api - from src.config.config import global_config, model_config - - VectorStore = core_module.VectorStore - GraphStore = core_module.GraphStore - MetadataStore = core_module.MetadataStore - ImportStrategy = core_module.ImportStrategy - create_embedding_api_adapter = core_module.create_embedding_api_adapter - RelationWriteService = getattr(core_module, "RelationWriteService", None) - - looks_like_quote_text = storage_module.looks_like_quote_text - parse_import_strategy = storage_module.parse_import_strategy - resolve_stored_knowledge_type = storage_module.resolve_stored_knowledge_type - select_import_strategy = storage_module.select_import_strategy - - from A_memorix.core.utils.time_parser import normalize_time_meta - from A_memorix.core.utils.import_payloads import normalize_paragraph_import_item - from A_memorix.core.strategies.base import BaseStrategy, ProcessedChunk, KnowledgeType as StratKnowledgeType - from A_memorix.core.strategies.narrative import NarrativeStrategy - from A_memorix.core.strategies.factual import FactualStrategy - from A_memorix.core.strategies.quote import QuoteStrategy - -except ImportError as e: - print(f"❌ 无法导入模块: {e}") - import traceback - traceback.print_exc() - sys.exit(1) - -logger = get_logger("A_Memorix.AutoImport") - - -def _log_before_retry(retry_state) -> None: - """使用项目统一日志风格记录重试信息。""" - exc = None - if getattr(retry_state, "outcome", None) is not None and retry_state.outcome.failed: - exc = retry_state.outcome.exception() - next_sleep = getattr(getattr(retry_state, "next_action", None), "sleep", None) - logger.warning( - "LLM 调用即将重试: " - f"attempt={getattr(retry_state, 'attempt_number', '?')} " - f"next_sleep={next_sleep} " - f"error={exc}" - ) - -class AutoImporter: - def __init__( - self, - force: bool = False, - clear_manifest: bool = False, - target_type: str = "auto", - concurrency: int = 5, - chat_log: bool = False, - chat_reference_time: Optional[str] = None, - ): - self.vector_store: Optional[VectorStore] = None - self.graph_store: Optional[GraphStore] = None - self.metadata_store: Optional[MetadataStore] = None - self.embedding_manager = None - self.relation_write_service = None - self.plugin_config = {} - self.manifest = {} - self.force = force - self.clear_manifest = clear_manifest - self.chat_log = chat_log - parsed_target_type = parse_import_strategy(target_type, default=ImportStrategy.AUTO) - self.target_type = ImportStrategy.NARRATIVE.value if chat_log else parsed_target_type.value - self.chat_reference_dt = self._parse_reference_time(chat_reference_time) - if self.chat_log and parsed_target_type not in {ImportStrategy.AUTO, ImportStrategy.NARRATIVE}: - logger.warning( - f"chat_log 模式已启用,target_type={target_type} 将被覆盖为 narrative" - ) - self.concurrency_limit = concurrency - self.semaphore = None - self.storage_lock = None - - async def initialize(self): - logger.info(f"正在初始化... (并发数: {self.concurrency_limit})") - self.semaphore = asyncio.Semaphore(self.concurrency_limit) - self.storage_lock = asyncio.Lock() - - RAW_DIR.mkdir(parents=True, exist_ok=True) - PROCESSED_DIR.mkdir(parents=True, exist_ok=True) - - if self.clear_manifest: - logger.info("🧹 清理 Mainfest") - self.manifest = {} - self._save_manifest() - elif MANIFEST_PATH.exists(): - try: - with open(MANIFEST_PATH, "r", encoding="utf-8") as f: - self.manifest = json.load(f) - except Exception: - self.manifest = {} - - config_path = plugin_root / "config.toml" - try: - with open(config_path, "r", encoding="utf-8") as f: - self.plugin_config = tomlkit.load(f) - except Exception as e: - logger.error(f"加载插件配置失败: {e}") - return False - - try: - await self._init_stores() - except Exception as e: - logger.error(f"初始化存储失败: {e}") - return False - - return True - - async def _init_stores(self): - # ... (Same as original) - self.embedding_manager = create_embedding_api_adapter( - batch_size=self.plugin_config.get("embedding", {}).get("batch_size", 32), - default_dimension=self.plugin_config.get("embedding", {}).get("dimension", 384), - model_name=self.plugin_config.get("embedding", {}).get("model_name", "auto"), - retry_config=self.plugin_config.get("embedding", {}).get("retry", {}), - ) - try: - dim = await self.embedding_manager._detect_dimension() - except: - dim = self.embedding_manager.default_dimension - - q_type_str = str(self.plugin_config.get("embedding", {}).get("quantization_type", "int8") or "int8").lower() - # Need to access QuantizationType from storage_module if not imported globally - QuantizationType = storage_module.QuantizationType - if q_type_str != "int8": - raise ValueError( - "embedding.quantization_type 在 vNext 仅允许 int8(SQ8)。" - " 请先执行 scripts/release_vnext_migrate.py migrate。" - ) - - self.vector_store = VectorStore( - dimension=dim, - quantization_type=QuantizationType.INT8, - data_dir=DATA_DIR / "vectors" - ) - - SparseMatrixFormat = storage_module.SparseMatrixFormat - m_fmt_str = self.plugin_config.get("graph", {}).get("sparse_matrix_format", "csr") - m_map = {"csr": SparseMatrixFormat.CSR, "csc": SparseMatrixFormat.CSC} - - self.graph_store = GraphStore( - matrix_format=m_map.get(m_fmt_str, SparseMatrixFormat.CSR), - data_dir=DATA_DIR / "graph" - ) - - self.metadata_store = MetadataStore(data_dir=DATA_DIR / "metadata") - self.metadata_store.connect() - - if RelationWriteService is not None: - self.relation_write_service = RelationWriteService( - metadata_store=self.metadata_store, - graph_store=self.graph_store, - vector_store=self.vector_store, - embedding_manager=self.embedding_manager, - ) - - if self.vector_store.has_data(): self.vector_store.load() - if self.graph_store.has_data(): self.graph_store.load() - - def _should_write_relation_vectors(self) -> bool: - retrieval_cfg = self.plugin_config.get("retrieval", {}) - if not isinstance(retrieval_cfg, dict): - return False - rv_cfg = retrieval_cfg.get("relation_vectorization", {}) - if not isinstance(rv_cfg, dict): - return False - return bool(rv_cfg.get("enabled", False)) and bool(rv_cfg.get("write_on_import", True)) - - def load_file(self, file_path: Path) -> str: - with open(file_path, "r", encoding="utf-8") as f: - return f.read() - - def get_file_hash(self, content: str) -> str: - return hashlib.md5(content.encode("utf-8")).hexdigest() - - def _parse_reference_time(self, value: Optional[str]) -> datetime: - """解析 chat_log 模式的参考时间(用于相对时间语义解析)。""" - if not value: - return datetime.now() - formats = [ - "%Y/%m/%d %H:%M:%S", - "%Y/%m/%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%d %H:%M", - "%Y/%m/%d", - "%Y-%m-%d", - ] - text = str(value).strip() - for fmt in formats: - try: - return datetime.strptime(text, fmt) - except ValueError: - continue - logger.warning( - f"无法解析 chat_reference_time={value},将回退为当前本地时间" - ) - return datetime.now() - - async def _extract_chat_time_meta_with_llm( - self, - text: str, - model_config: Any, - ) -> Optional[Dict[str, Any]]: - """ - 使用 LLM 从聊天文本语义中抽取时间信息。 - 支持将相对时间表达转换为绝对时间。 - """ - if not text.strip(): - return None - - reference_now = self.chat_reference_dt.strftime("%Y/%m/%d %H:%M") - prompt = f"""You are a time extraction engine for chat logs. -Extract temporal information from the following chat paragraph. - -Rules: -1. Use semantic understanding, not regex matching. -2. Convert relative expressions (e.g., yesterday evening, last Friday morning) to absolute local datetime using reference_now. -3. If a time span exists, return event_time_start/event_time_end. -4. If only one point in time exists, return event_time. -5. If no reliable time can be inferred, return all time fields as null. -6. Output ONLY valid JSON. No markdown, no explanation. - -reference_now: {reference_now} -timezone: local system timezone - -Allowed output formats for time values: -- "YYYY/MM/DD" -- "YYYY/MM/DD HH:mm" - -JSON schema: -{{ - "event_time": null, - "event_time_start": null, - "event_time_end": null, - "time_range": null, - "time_granularity": "day", - "time_confidence": 0.0 -}} - -Chat paragraph: -\"\"\"{text}\"\"\" -""" - try: - result = await self._llm_call(prompt, model_config) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取失败: {e}") - return None - - if not isinstance(result, dict): - return None - - raw_time_meta = { - "event_time": result.get("event_time"), - "event_time_start": result.get("event_time_start"), - "event_time_end": result.get("event_time_end"), - "time_range": result.get("time_range"), - "time_granularity": result.get("time_granularity"), - "time_confidence": result.get("time_confidence"), - } - try: - normalized = normalize_time_meta(raw_time_meta) - except Exception as e: - logger.warning(f"chat_log 时间语义抽取结果不可用,已忽略: {e}") - return None - - has_effective_time = any( - key in normalized - for key in ("event_time", "event_time_start", "event_time_end") - ) - if not has_effective_time: - return None - - return normalized - - def _determine_strategy(self, filename: str, content: str) -> BaseStrategy: - """Layer 1: Global Strategy Routing""" - strategy = select_import_strategy( - content, - override=self.target_type, - chat_log=self.chat_log, - ) - if self.chat_log: - logger.info(f"chat_log 模式: {filename} 强制使用 NarrativeStrategy") - elif strategy == ImportStrategy.QUOTE: - logger.info(f"Auto-detected Quote/Lyric type for {filename}") - - if strategy == ImportStrategy.FACTUAL: - return FactualStrategy(filename) - if strategy == ImportStrategy.QUOTE: - return QuoteStrategy(filename) - return NarrativeStrategy(filename) - - def _chunk_rescue(self, chunk: ProcessedChunk, filename: str) -> Optional[BaseStrategy]: - """Layer 2: Chunk-level rescue strategies""" - # If we are already in Quote strategy, no need to rescue - if chunk.type == StratKnowledgeType.QUOTE: - return None - - if looks_like_quote_text(chunk.chunk.text): - logger.info(f" > Rescuing chunk {chunk.chunk.index} as Quote") - return QuoteStrategy(filename) - - return None - - async def process_and_import(self): - if not await self.initialize(): return - - files = list(RAW_DIR.glob("*.txt")) - logger.info(f"扫描到 {len(files)} 个文件 in {RAW_DIR}") - - if not files: return - - tasks = [] - for file_path in files: - tasks.append(asyncio.create_task(self._process_single_file(file_path))) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - success_count = sum(1 for r in results if r is True) - logger.info(f"本次主处理完成,共成功处理 {success_count}/{len(files)} 个文件") - - if self.vector_store: self.vector_store.save() - if self.graph_store: self.graph_store.save() - - async def _process_single_file(self, file_path: Path) -> bool: - filename = file_path.name - async with self.semaphore: - try: - content = self.load_file(file_path) - file_hash = self.get_file_hash(content) - - if not self.force and filename in self.manifest: - record = self.manifest[filename] - if record.get("hash") == file_hash and record.get("imported"): - logger.info(f"跳过已导入文件: {filename}") - return False - - logger.info(f">>> 开始处理: {filename}") - - # 1. Strategy Selection - strategy = self._determine_strategy(filename, content) - logger.info(f" 策略: {strategy.__class__.__name__}") - - # 2. Split (Strategy-Aware) - initial_chunks = strategy.split(content) - logger.info(f" 初步分块: {len(initial_chunks)}") - - processed_data = {"paragraphs": [], "entities": [], "relations": []} - - # 3. Extract Loop - model_config = await self._select_model() - - for i, chunk in enumerate(initial_chunks): - current_strategy = strategy - # Layer 2: Chunk Rescue - rescue_strategy = self._chunk_rescue(chunk, filename) - if rescue_strategy: - # Re-split? No, just re-process this text as a single chunk using the rescue strategy - # But rescue strategy might want to split it further? - # Simplification: Treat the whole chunk text as one block for the rescue strategy - # OR create a single chunk object for it. - # Creating a new chunk using rescue strategy logic might be complex if split behavior differs. - # Let's just instantiate a chunk of the new type manually - chunk.type = StratKnowledgeType.QUOTE - chunk.flags.verbatim = True - chunk.flags.requires_llm = False # Quotes don't usually need LLM - current_strategy = rescue_strategy - - # Extraction - if chunk.flags.requires_llm: - result_chunk = await current_strategy.extract(chunk, lambda p: self._llm_call(p, model_config)) - else: - # For quotes, extract might be just pass through or regex - result_chunk = await current_strategy.extract(chunk) - - time_meta = None - if self.chat_log: - time_meta = await self._extract_chat_time_meta_with_llm( - result_chunk.chunk.text, - model_config, - ) - - # Normalize Data - self._normalize_and_aggregate( - result_chunk, - processed_data, - time_meta=time_meta, - ) - - logger.info(f" 已处理块 {i+1}/{len(initial_chunks)}") - - # 4. Save Json - json_path = PROCESSED_DIR / f"{file_path.stem}.json" - with open(json_path, "w", encoding="utf-8") as f: - json.dump(processed_data, f, ensure_ascii=False, indent=2) - - # 5. Import to DB - async with self.storage_lock: - await self._import_to_db(processed_data) - - self.manifest[filename] = { - "hash": file_hash, - "timestamp": time.time(), - "imported": True - } - self._save_manifest() - self.vector_store.save() - self.graph_store.save() - logger.info(f"✅ 文件 {filename} 处理并导入完成") - return True - - except Exception as e: - logger.error(f"❌ 处理失败 {filename}: {e}") - import traceback - traceback.print_exc() - return False - - def _normalize_and_aggregate( - self, - chunk: ProcessedChunk, - all_data: Dict, - time_meta: Optional[Dict[str, Any]] = None, - ): - """Convert strategy-specific data to unified generic format for storage.""" - # Generic fields - para_item = { - "content": chunk.chunk.text, - "source": chunk.source.file, - "knowledge_type": resolve_stored_knowledge_type( - chunk.type.value, - content=chunk.chunk.text, - ).value, - "entities": [], - "relations": [] - } - - data = chunk.data - - # 1. Triples (Factual) - if "triples" in data: - for t in data["triples"]: - para_item["relations"].append({ - "subject": t.get("subject"), - "predicate": t.get("predicate"), - "object": t.get("object") - }) - # Auto-add entities from triples - para_item["entities"].extend([t.get("subject"), t.get("object")]) - - # 2. Events & Relations (Narrative) - if "events" in data: - # Store events as content/metadata? Or entities? - # For now maybe just keep them in logic, or add as 'Event' entities? - # Creating entities for events is good. - para_item["entities"].extend(data["events"]) - - if "relations" in data: # Narrative also outputs relations list - para_item["relations"].extend(data["relations"]) - for r in data["relations"]: - para_item["entities"].extend([r.get("subject"), r.get("object")]) - - # 3. Verbatim Entities (Quote) - if "verbatim_entities" in data: - para_item["entities"].extend(data["verbatim_entities"]) - - # Dedupe per paragraph - para_item["entities"] = list(set([e for e in para_item["entities"] if e])) - - if time_meta: - para_item["time_meta"] = time_meta - - all_data["paragraphs"].append(para_item) - all_data["entities"].extend(para_item["entities"]) - if "relations" in para_item: - all_data["relations"].extend(para_item["relations"]) - - @retry( - retry=retry_if_exception_type((LLMGenerationError, json.JSONDecodeError)), - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=2, max=10), - before_sleep=_log_before_retry - ) - async def _llm_call(self, prompt: str, model_config: Any) -> Dict: - """Generic LLM Caller""" - success, response, _, _ = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="Script.ProcessKnowledge" - ) - if success: - txt = response.strip() - if "```" in txt: - txt = txt.split("```json")[-1].split("```")[0].strip() - try: - return json.loads(txt) - except json.JSONDecodeError: - # Fallback: try to find first { and last } - start = txt.find('{') - end = txt.rfind('}') - if start != -1 and end != -1: - return json.loads(txt[start:end+1]) - raise - else: - raise LLMGenerationError("LLM generation failed") - - async def _select_model(self) -> Any: - models = llm_api.get_available_models() - if not models: raise ValueError("No LLM models") - - config_model = self.plugin_config.get("advanced", {}).get("extraction_model", "auto") - if config_model != "auto" and config_model in models: - return models[config_model] - - for task_key in ["lpmm_entity_extract", "lpmm_rdf_build", "embedding"]: - if task_key in models: return models[task_key] - - return models[list(models.keys())[0]] - - # Re-use existing methods - async def _add_entity_with_vector(self, name: str, source_paragraph: Optional[str] = None) -> str: - # Same as before - hash_value = self.metadata_store.add_entity(name, source_paragraph=source_paragraph) - self.graph_store.add_nodes([name]) - try: - emb = await self.embedding_manager.encode(name) - try: - self.vector_store.add(emb.reshape(1, -1), [hash_value]) - except ValueError: pass - except Exception: pass - return hash_value - - async def import_json_data(self, data: Dict, filename: str = "script_import", progress_callback=None): - """Public import entrypoint for pre-processed JSON payloads.""" - if not self.storage_lock: - raise RuntimeError("Importer is not initialized. Call initialize() first.") - - async with self.storage_lock: - await self._import_to_db(data, progress_callback=progress_callback) - self.manifest[filename] = { - "hash": self.get_file_hash(json.dumps(data, ensure_ascii=False, sort_keys=True)), - "timestamp": time.time(), - "imported": True, - } - self._save_manifest() - self.vector_store.save() - self.graph_store.save() - - async def _import_to_db(self, data: Dict, progress_callback=None): - # Same logic, but ensure robust - with self.graph_store.batch_update(): - for item in data.get("paragraphs", []): - paragraph = normalize_paragraph_import_item( - item, - default_source="script", - ) - content = paragraph["content"] - source = paragraph["source"] - k_type_val = paragraph["knowledge_type"] - - h_val = self.metadata_store.add_paragraph( - content=content, - source=source, - knowledge_type=k_type_val, - time_meta=paragraph["time_meta"], - ) - - if h_val not in self.vector_store: - try: - emb = await self.embedding_manager.encode(content) - self.vector_store.add(emb.reshape(1, -1), [h_val]) - except Exception as e: - logger.error(f" Vector fail: {e}") - - para_entities = paragraph["entities"] - for entity in para_entities: - if entity: - await self._add_entity_with_vector(entity, source_paragraph=h_val) - - para_relations = paragraph["relations"] - for rel in para_relations: - s, p, o = rel.get("subject"), rel.get("predicate"), rel.get("object") - if s and p and o: - await self._add_entity_with_vector(s, source_paragraph=h_val) - await self._add_entity_with_vector(o, source_paragraph=h_val) - confidence = float(rel.get("confidence", 1.0) or 1.0) - rel_meta = rel.get("metadata", {}) - write_vector = self._should_write_relation_vectors() - if self.relation_write_service is not None: - await self.relation_write_service.upsert_relation_with_vector( - subject=s, - predicate=p, - obj=o, - confidence=confidence, - source_paragraph=h_val, - metadata=rel_meta if isinstance(rel_meta, dict) else {}, - write_vector=write_vector, - ) - else: - rel_hash = self.metadata_store.add_relation( - s, - p, - o, - confidence=confidence, - source_paragraph=h_val, - metadata=rel_meta if isinstance(rel_meta, dict) else {}, - ) - self.graph_store.add_edges([(s, o)], relation_hashes=[rel_hash]) - try: - self.metadata_store.set_relation_vector_state(rel_hash, "none") - except Exception: - pass - - if progress_callback: progress_callback(1) - - async def close(self): - if self.metadata_store: self.metadata_store.close() - - def _save_manifest(self): - with open(MANIFEST_PATH, "w", encoding="utf-8") as f: - json.dump(self.manifest, f, ensure_ascii=False, indent=2) - -async def main(): - parser = _build_arg_parser() - args = parser.parse_args() - - if not global_config: return - - importer = AutoImporter( - force=args.force, - clear_manifest=args.clear_manifest, - target_type=args.type, - concurrency=args.concurrency, - chat_log=args.chat_log, - chat_reference_time=args.chat_reference_time, - ) - await importer.process_and_import() - await importer.close() - -if __name__ == "__main__": - if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - asyncio.run(main()) diff --git a/plugins/A_memorix/scripts/rebuild_episodes.py b/plugins/A_memorix/scripts/rebuild_episodes.py deleted file mode 100644 index b6adaa21..00000000 --- a/plugins/A_memorix/scripts/rebuild_episodes.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -"""Episode source 级重建工具。""" - -from __future__ import annotations - -import argparse -import asyncio -import sys -from pathlib import Path -from typing import Any, Dict, List - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -WORKSPACE_ROOT = PLUGIN_ROOT.parent -MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" -for path in (WORKSPACE_ROOT, MAIBOT_ROOT, PLUGIN_ROOT): - path_str = str(path) - if path_str not in sys.path: - sys.path.insert(0, path_str) - -try: - import tomlkit # type: ignore -except Exception: # pragma: no cover - tomlkit = None - -from A_memorix.core.storage import MetadataStore -from A_memorix.core.utils.episode_service import EpisodeService - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Rebuild A_Memorix episodes by source") - parser.add_argument("--data-dir", default=str(PLUGIN_ROOT / "data"), help="插件数据目录") - parser.add_argument("--source", type=str, help="指定单个 source 入队/重建") - parser.add_argument("--all", action="store_true", help="对所有 source 入队/重建") - parser.add_argument("--wait", action="store_true", help="在脚本内同步执行重建") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - - -def _load_plugin_config() -> Dict[str, Any]: - config_path = PLUGIN_ROOT / "config.toml" - if tomlkit is None or not config_path.exists(): - return {} - try: - with open(config_path, "r", encoding="utf-8") as handle: - parsed = tomlkit.load(handle) - return dict(parsed) if isinstance(parsed, dict) else {} - except Exception: - return {} - - -def _resolve_sources(store: MetadataStore, *, source: str | None, rebuild_all: bool) -> List[str]: - if rebuild_all: - return list(store.list_episode_sources_for_rebuild()) - token = str(source or "").strip() - if not token: - raise ValueError("必须提供 --source 或 --all") - return [token] - - -async def _run_rebuilds(store: MetadataStore, plugin_config: Dict[str, Any], sources: List[str]) -> int: - service = EpisodeService(metadata_store=store, plugin_config=plugin_config) - failures: List[str] = [] - for source in sources: - started = store.mark_episode_source_running(source) - if not started: - failures.append(f"{source}: unable_to_mark_running") - continue - try: - result = await service.rebuild_source(source) - store.mark_episode_source_done(source) - print( - "rebuilt" - f" source={source}" - f" paragraphs={int(result.get('paragraph_count') or 0)}" - f" groups={int(result.get('group_count') or 0)}" - f" episodes={int(result.get('episode_count') or 0)}" - f" fallback={int(result.get('fallback_count') or 0)}" - ) - except Exception as exc: - err = str(exc)[:500] - store.mark_episode_source_failed(source, err) - failures.append(f"{source}: {err}") - print(f"failed source={source} error={err}") - - if failures: - for item in failures: - print(item) - return 1 - return 0 - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - if bool(args.all) == bool(args.source): - parser.error("必须且只能选择一个:--source 或 --all") - - store = MetadataStore(data_dir=Path(args.data_dir) / "metadata") - store.connect() - try: - sources = _resolve_sources(store, source=args.source, rebuild_all=bool(args.all)) - if not sources: - print("no sources to rebuild") - return 0 - - enqueued = 0 - reason = "script_rebuild_all" if args.all else "script_rebuild_source" - for source in sources: - enqueued += int(store.enqueue_episode_source_rebuild(source, reason=reason)) - print(f"enqueued={enqueued} sources={len(sources)}") - - if not args.wait: - return 0 - - plugin_config = _load_plugin_config() - return asyncio.run(_run_rebuilds(store, plugin_config, sources)) - finally: - store.close() - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/release_vnext_migrate.py b/plugins/A_memorix/scripts/release_vnext_migrate.py deleted file mode 100644 index 0922fd0b..00000000 --- a/plugins/A_memorix/scripts/release_vnext_migrate.py +++ /dev/null @@ -1,731 +0,0 @@ -#!/usr/bin/env python3 -""" -vNext release migration entrypoint for A_Memorix. - -Subcommands: -- preflight: detect legacy config/data/schema risks -- migrate: offline migrate config + vectors + metadata schema + graph edge hash map -- verify: strict post-migration consistency checks -""" - -from __future__ import annotations - -import argparse -import json -import pickle -import sqlite3 -import sys -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix vNext release migration tool") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="config.toml path (default: plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default="", - help="optional data dir override; default resolved from config.storage.data_dir", - ) - parser.add_argument("--json-out", default="", help="optional JSON report output path") - - sub = parser.add_subparsers(dest="command", required=True) - - p_preflight = sub.add_parser("preflight", help="scan legacy risks") - p_preflight.add_argument("--strict", action="store_true", help="return 1 if any error check exists") - - p_migrate = sub.add_parser("migrate", help="run offline migration") - p_migrate.add_argument("--dry-run", action="store_true", help="only print planned changes") - p_migrate.add_argument( - "--verify-after", - action="store_true", - help="run verify automatically after migrate", - ) - - p_verify = sub.add_parser("verify", help="post-migration verification") - p_verify.add_argument("--strict", action="store_true", help="return 1 if any error check exists") - return parser - - -# --help/-h fast path: avoid heavy host/plugin bootstrap -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -try: - from core.storage import GraphStore, KnowledgeType, MetadataStore, QuantizationType, VectorStore - from core.storage.metadata_store import SCHEMA_VERSION -except Exception as e: # pragma: no cover - print(f"❌ failed to import storage modules: {e}") - raise SystemExit(2) - - -@dataclass -class CheckItem: - code: str - level: str - message: str - details: Optional[Dict[str, Any]] = None - - def to_dict(self) -> Dict[str, Any]: - out = { - "code": self.code, - "level": self.level, - "message": self.message, - } - if self.details: - out["details"] = self.details - return out - - -def _read_toml(path: Path) -> Dict[str, Any]: - text = path.read_text(encoding="utf-8") - return tomlkit.parse(text) - - -def _write_toml(path: Path, data: Dict[str, Any]) -> None: - path.write_text(tomlkit.dumps(data), encoding="utf-8") - - -def _get_nested(obj: Dict[str, Any], keys: Sequence[str], default: Any = None) -> Any: - cur: Any = obj - for k in keys: - if not isinstance(cur, dict) or k not in cur: - return default - cur = cur[k] - return cur - - -def _ensure_table(obj: Dict[str, Any], key: str) -> Dict[str, Any]: - if key not in obj or not isinstance(obj[key], dict): - obj[key] = tomlkit.table() - return obj[key] - - -def _resolve_data_dir(config_doc: Dict[str, Any], explicit_data_dir: Optional[str]) -> Path: - if explicit_data_dir: - return Path(explicit_data_dir).expanduser().resolve() - raw = str(_get_nested(config_doc, ("storage", "data_dir"), "./data") or "./data").strip() - if raw.startswith("."): - return (PLUGIN_ROOT / raw).resolve() - return Path(raw).expanduser().resolve() - - -def _sqlite_table_exists(conn: sqlite3.Connection, table: str) -> bool: - row = conn.execute( - "SELECT 1 FROM sqlite_master WHERE type='table' AND name=? LIMIT 1", - (table,), - ).fetchone() - return row is not None - - -def _collect_hash_alias_conflicts(conn: sqlite3.Connection) -> Dict[str, List[str]]: - hashes: List[str] = [] - if _sqlite_table_exists(conn, "relations"): - rows = conn.execute("SELECT hash FROM relations").fetchall() - hashes.extend(str(r[0]) for r in rows if r and r[0]) - if _sqlite_table_exists(conn, "deleted_relations"): - rows = conn.execute("SELECT hash FROM deleted_relations").fetchall() - hashes.extend(str(r[0]) for r in rows if r and r[0]) - - alias_map: Dict[str, str] = {} - conflicts: Dict[str, set[str]] = {} - for h in hashes: - if len(h) != 64: - continue - alias = h[:32] - old = alias_map.get(alias) - if old is None: - alias_map[alias] = h - continue - if old != h: - conflicts.setdefault(alias, set()).update({old, h}) - return {k: sorted(v) for k, v in conflicts.items()} - - -def _collect_invalid_knowledge_types(conn: sqlite3.Connection) -> List[str]: - if not _sqlite_table_exists(conn, "paragraphs"): - return [] - - allowed = {item.value for item in KnowledgeType} - rows = conn.execute("SELECT DISTINCT knowledge_type FROM paragraphs").fetchall() - invalid: List[str] = [] - for row in rows: - raw = row[0] - value = str(raw).strip().lower() if raw is not None else "" - if value not in allowed: - invalid.append(str(raw) if raw is not None else "") - return sorted(set(invalid)) - - -def _guess_vector_dimension(config_doc: Dict[str, Any], vectors_dir: Path) -> int: - meta_path = vectors_dir / "vectors_metadata.pkl" - if meta_path.exists(): - try: - with open(meta_path, "rb") as f: - meta = pickle.load(f) - dim = int(meta.get("dimension", 0)) - if dim > 0: - return dim - except Exception: - pass - try: - dim_cfg = int(_get_nested(config_doc, ("embedding", "dimension"), 1024)) - if dim_cfg > 0: - return dim_cfg - except Exception: - pass - return 1024 - - -def _preflight_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: - checks: List[CheckItem] = [] - facts: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - } - - if not config_path.exists(): - checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) - return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} - - config_doc = _read_toml(config_path) - tool_mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() - summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) - summary_knowledge_type = str( - _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" - ).strip().lower() - quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() - - facts["routing.tool_search_mode"] = tool_mode - facts["summarization.model_name_type"] = type(summary_model).__name__ - facts["summarization.default_knowledge_type"] = summary_knowledge_type - facts["embedding.quantization_type"] = quantization - - if tool_mode == "legacy": - checks.append( - CheckItem( - "CP-04", - "error", - "routing.tool_search_mode=legacy is no longer accepted at runtime", - ) - ) - elif tool_mode not in {"forward", "disabled"}: - checks.append( - CheckItem( - "CP-04", - "error", - f"routing.tool_search_mode invalid value: {tool_mode}", - ) - ) - - if isinstance(summary_model, str): - checks.append( - CheckItem( - "CP-11", - "error", - "summarization.model_name must be List[str], string legacy format detected", - ) - ) - elif not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): - checks.append( - CheckItem( - "CP-11", - "error", - "summarization.model_name must be List[str]", - ) - ) - - if summary_knowledge_type not in {item.value for item in KnowledgeType}: - checks.append( - CheckItem( - "CP-13", - "error", - f"invalid summarization.default_knowledge_type: {summary_knowledge_type}", - ) - ) - - if quantization != "int8": - checks.append( - CheckItem( - "UG-07", - "error", - "embedding.quantization_type must be int8 in vNext", - ) - ) - - vectors_dir = data_dir / "vectors" - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - facts["vectors.npy_exists"] = npy_path.exists() - facts["vectors.bin_exists"] = bin_path.exists() - facts["vectors_ids.bin_exists"] = ids_bin_path.exists() - - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - checks.append( - CheckItem( - "CP-07", - "error", - "legacy vectors.npy detected; offline migrate required", - {"npy_path": str(npy_path)}, - ) - ) - - metadata_db = data_dir / "metadata" / "metadata.db" - facts["metadata_db_exists"] = metadata_db.exists() - relation_count = 0 - if metadata_db.exists(): - conn = sqlite3.connect(str(metadata_db)) - try: - has_schema_table = _sqlite_table_exists(conn, "schema_migrations") - facts["schema_migrations_exists"] = has_schema_table - if not has_schema_table: - checks.append( - CheckItem( - "CP-08", - "error", - "schema_migrations table missing (legacy metadata schema)", - ) - ) - else: - row = conn.execute("SELECT MAX(version) FROM schema_migrations").fetchone() - version = int(row[0]) if row and row[0] is not None else 0 - facts["schema_version"] = version - if version != SCHEMA_VERSION: - checks.append( - CheckItem( - "CP-08", - "error", - f"schema version mismatch: current={version}, expected={SCHEMA_VERSION}", - ) - ) - - if _sqlite_table_exists(conn, "relations"): - row = conn.execute("SELECT COUNT(*) FROM relations").fetchone() - relation_count = int(row[0]) if row and row[0] is not None else 0 - facts["relations_count"] = relation_count - - conflicts = _collect_hash_alias_conflicts(conn) - facts["alias_conflict_count"] = len(conflicts) - if conflicts: - checks.append( - CheckItem( - "CP-05", - "error", - "32-bit relation hash alias conflict detected", - {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, - ) - ) - - invalid_knowledge_types = _collect_invalid_knowledge_types(conn) - facts["invalid_knowledge_type_values"] = invalid_knowledge_types - if invalid_knowledge_types: - checks.append( - CheckItem( - "CP-12", - "error", - "invalid paragraph knowledge_type values detected", - {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, - ) - ) - finally: - conn.close() - else: - checks.append( - CheckItem( - "META-00", - "warning", - "metadata.db not found, schema checks skipped", - ) - ) - - graph_meta_path = data_dir / "graph" / "graph_metadata.pkl" - facts["graph_metadata_exists"] = graph_meta_path.exists() - if relation_count > 0: - if not graph_meta_path.exists(): - checks.append( - CheckItem( - "CP-06", - "error", - "relations exist but graph metadata missing", - ) - ) - else: - try: - with open(graph_meta_path, "rb") as f: - graph_meta = pickle.load(f) - edge_hash_map = graph_meta.get("edge_hash_map", {}) - edge_hash_map_size = len(edge_hash_map) if isinstance(edge_hash_map, dict) else 0 - facts["edge_hash_map_size"] = edge_hash_map_size - if edge_hash_map_size <= 0: - checks.append( - CheckItem( - "CP-06", - "error", - "edge_hash_map missing/empty while relations exist", - ) - ) - except Exception as e: - checks.append( - CheckItem( - "CP-06", - "error", - f"failed to read graph metadata: {e}", - ) - ) - - has_error = any(c.level == "error" for c in checks) - return { - "ok": not has_error, - "checks": [c.to_dict() for c in checks], - "facts": facts, - } - - -def _migrate_config(config_doc: Dict[str, Any]) -> Dict[str, Any]: - changes: Dict[str, Any] = {} - - routing = _ensure_table(config_doc, "routing") - mode_raw = str(routing.get("tool_search_mode", "forward") or "").strip().lower() - mode_new = mode_raw - if mode_raw == "legacy" or mode_raw not in {"forward", "disabled"}: - mode_new = "forward" - if mode_new != mode_raw: - routing["tool_search_mode"] = mode_new - changes["routing.tool_search_mode"] = {"old": mode_raw, "new": mode_new} - - summary = _ensure_table(config_doc, "summarization") - summary_model = summary.get("model_name", ["auto"]) - if isinstance(summary_model, str): - normalized = [summary_model.strip() or "auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": summary_model, "new": normalized} - elif not isinstance(summary_model, list): - normalized = ["auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": str(type(summary_model)), "new": normalized} - elif any(not isinstance(x, str) for x in summary_model): - normalized = [str(x).strip() for x in summary_model if str(x).strip()] - if not normalized: - normalized = ["auto"] - summary["model_name"] = normalized - changes["summarization.model_name"] = {"old": summary_model, "new": normalized} - - default_knowledge_type = str(summary.get("default_knowledge_type", "narrative") or "").strip().lower() - allowed_knowledge_types = {item.value for item in KnowledgeType} - if default_knowledge_type not in allowed_knowledge_types: - summary["default_knowledge_type"] = "narrative" - changes["summarization.default_knowledge_type"] = { - "old": default_knowledge_type, - "new": "narrative", - } - - embedding = _ensure_table(config_doc, "embedding") - quantization = str(embedding.get("quantization_type", "int8") or "").strip().lower() - if quantization != "int8": - embedding["quantization_type"] = "int8" - changes["embedding.quantization_type"] = {"old": quantization, "new": "int8"} - - return changes - - -def _migrate_impl(config_path: Path, data_dir: Path, dry_run: bool) -> Dict[str, Any]: - config_doc = _read_toml(config_path) - result: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - "dry_run": bool(dry_run), - "steps": {}, - } - - config_changes = _migrate_config(config_doc) - result["steps"]["config"] = {"changed": bool(config_changes), "changes": config_changes} - if config_changes and not dry_run: - _write_toml(config_path, config_doc) - - vectors_dir = data_dir / "vectors" - vectors_dir.mkdir(parents=True, exist_ok=True) - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - if dry_run: - result["steps"]["vector"] = {"migrated": False, "reason": "dry_run"} - else: - dim = _guess_vector_dimension(config_doc, vectors_dir) - store = VectorStore( - dimension=max(1, int(dim)), - quantization_type=QuantizationType.INT8, - data_dir=vectors_dir, - ) - result["steps"]["vector"] = store.migrate_legacy_npy(vectors_dir) - else: - result["steps"]["vector"] = {"migrated": False, "reason": "not_required"} - - metadata_dir = data_dir / "metadata" - metadata_dir.mkdir(parents=True, exist_ok=True) - metadata_db = metadata_dir / "metadata.db" - triples: List[Tuple[str, str, str, str]] = [] - relation_count = 0 - - metadata_result: Dict[str, Any] = {"migrated": False, "reason": "not_required"} - if metadata_db.exists(): - store = MetadataStore(data_dir=metadata_dir) - store.connect(enforce_schema=False) - try: - if dry_run: - metadata_result = {"migrated": False, "reason": "dry_run"} - else: - metadata_result = store.run_legacy_migration_for_vnext() - relation_count = int(store.count_relations()) - if relation_count > 0: - triples = [(str(s), str(p), str(o), str(h)) for s, p, o, h in store.get_all_triples()] - finally: - store.close() - result["steps"]["metadata"] = metadata_result - - graph_dir = data_dir / "graph" - graph_dir.mkdir(parents=True, exist_ok=True) - graph_matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") - graph_store = GraphStore(matrix_format=graph_matrix_format, data_dir=graph_dir) - graph_step: Dict[str, Any] = { - "rebuilt": False, - "mapped_hashes": 0, - "relation_count": relation_count, - "topology_rebuilt_from_relations": False, - } - if relation_count > 0: - if dry_run: - graph_step["reason"] = "dry_run" - else: - if graph_store.has_data(): - graph_store.load() - - mapped = graph_store.rebuild_edge_hash_map(triples) - - # 兜底:历史数据里 graph 节点/边与 relations 脱节时,直接从 relations 重建图。 - if mapped <= 0 or not graph_store.has_edge_hash_map(): - nodes = sorted({s for s, _, o, _ in triples} | {o for _, _, o, _ in triples}) - edges = [(s, o) for s, _, o, _ in triples] - hashes = [h for _, _, _, h in triples] - - graph_store.clear() - if nodes: - graph_store.add_nodes(nodes) - if edges: - mapped = graph_store.add_edges(edges, relation_hashes=hashes) - else: - mapped = 0 - graph_step.update( - { - "topology_rebuilt_from_relations": True, - "rebuilt_nodes": len(nodes), - "rebuilt_edges": int(graph_store.num_edges), - } - ) - - graph_store.save() - graph_step.update({"rebuilt": True, "mapped_hashes": int(mapped)}) - else: - graph_step["reason"] = "no_relations" - result["steps"]["graph"] = graph_step - - return result - - -def _verify_impl(config_path: Path, data_dir: Path) -> Dict[str, Any]: - checks: List[CheckItem] = [] - facts: Dict[str, Any] = { - "config_path": str(config_path), - "data_dir": str(data_dir), - } - - if not config_path.exists(): - checks.append(CheckItem("CFG-00", "error", f"config not found: {config_path}")) - return {"ok": False, "checks": [c.to_dict() for c in checks], "facts": facts} - - config_doc = _read_toml(config_path) - mode = str(_get_nested(config_doc, ("routing", "tool_search_mode"), "forward") or "").strip().lower() - if mode not in {"forward", "disabled"}: - checks.append(CheckItem("CP-04", "error", f"invalid routing.tool_search_mode: {mode}")) - - summary_model = _get_nested(config_doc, ("summarization", "model_name"), ["auto"]) - if not isinstance(summary_model, list) or any(not isinstance(x, str) for x in summary_model): - checks.append(CheckItem("CP-11", "error", "summarization.model_name must be List[str]")) - summary_knowledge_type = str( - _get_nested(config_doc, ("summarization", "default_knowledge_type"), "narrative") or "narrative" - ).strip().lower() - if summary_knowledge_type not in {item.value for item in KnowledgeType}: - checks.append( - CheckItem("CP-13", "error", f"invalid summarization.default_knowledge_type: {summary_knowledge_type}") - ) - - quantization = str(_get_nested(config_doc, ("embedding", "quantization_type"), "int8") or "").strip().lower() - if quantization != "int8": - checks.append(CheckItem("UG-07", "error", "embedding.quantization_type must be int8")) - - vectors_dir = data_dir / "vectors" - npy_path = vectors_dir / "vectors.npy" - bin_path = vectors_dir / "vectors.bin" - ids_bin_path = vectors_dir / "vectors_ids.bin" - if npy_path.exists() and not (bin_path.exists() and ids_bin_path.exists()): - checks.append(CheckItem("CP-07", "error", "legacy vectors.npy still exists without bin migration")) - - metadata_dir = data_dir / "metadata" - store = MetadataStore(data_dir=metadata_dir) - try: - store.connect(enforce_schema=True) - schema_version = store.get_schema_version() - facts["schema_version"] = schema_version - if schema_version != SCHEMA_VERSION: - checks.append(CheckItem("CP-08", "error", f"schema version mismatch: {schema_version}")) - - relation_count = int(store.count_relations()) - facts["relations_count"] = relation_count - - conflicts = {} - invalid_knowledge_types: List[str] = [] - db_path = metadata_dir / "metadata.db" - if db_path.exists(): - conn = sqlite3.connect(str(db_path)) - try: - conflicts = _collect_hash_alias_conflicts(conn) - invalid_knowledge_types = _collect_invalid_knowledge_types(conn) - finally: - conn.close() - if conflicts: - checks.append( - CheckItem( - "CP-05", - "error", - "alias conflicts still exist after migration", - {"aliases": sorted(conflicts.keys())[:20], "total": len(conflicts)}, - ) - ) - if invalid_knowledge_types: - checks.append( - CheckItem( - "CP-12", - "error", - "invalid paragraph knowledge_type values remain after migration", - {"values": invalid_knowledge_types[:20], "total": len(invalid_knowledge_types)}, - ) - ) - - if relation_count > 0: - graph_dir = data_dir / "graph" - if not (graph_dir / "graph_metadata.pkl").exists(): - checks.append(CheckItem("CP-06", "error", "graph metadata missing while relations exist")) - else: - matrix_format = str(_get_nested(config_doc, ("graph", "sparse_matrix_format"), "csr") or "csr") - graph_store = GraphStore(matrix_format=matrix_format, data_dir=graph_dir) - graph_store.load() - if not graph_store.has_edge_hash_map(): - checks.append(CheckItem("CP-06", "error", "edge_hash_map is empty")) - except Exception as e: - checks.append(CheckItem("CP-08", "error", f"metadata strict connect failed: {e}")) - finally: - try: - store.close() - except Exception: - pass - - has_error = any(c.level == "error" for c in checks) - return { - "ok": not has_error, - "checks": [c.to_dict() for c in checks], - "facts": facts, - } - - -def _print_report(title: str, report: Dict[str, Any]) -> None: - print(f"=== {title} ===") - print(f"ok: {bool(report.get('ok', True))}") - facts = report.get("facts", {}) - if facts: - print("facts:") - for k in sorted(facts.keys()): - print(f" - {k}: {facts[k]}") - checks = report.get("checks", []) - if checks: - print("checks:") - for item in checks: - print(f" - [{item.get('level')}] {item.get('code')}: {item.get('message')}") - else: - print("checks: none") - - -def _write_json_if_needed(path: str, payload: Dict[str, Any]) -> None: - if not path: - return - out = Path(path).expanduser().resolve() - out.parent.mkdir(parents=True, exist_ok=True) - out.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") - print(f"json_out: {out}") - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - config_path = Path(args.config).expanduser().resolve() - if not config_path.exists(): - print(f"❌ config not found: {config_path}") - return 2 - config_doc = _read_toml(config_path) - data_dir = _resolve_data_dir(config_doc, args.data_dir) - - if args.command == "preflight": - report = _preflight_impl(config_path, data_dir) - _print_report("vNext Preflight", report) - _write_json_if_needed(args.json_out, report) - has_error = any(item.get("level") == "error" for item in report.get("checks", [])) - if args.strict and has_error: - return 1 - return 0 - - if args.command == "migrate": - payload = _migrate_impl(config_path, data_dir, dry_run=bool(args.dry_run)) - print("=== vNext Migrate ===") - print(json.dumps(payload, ensure_ascii=False, indent=2)) - - verify_report = None - if args.verify_after and not args.dry_run: - verify_report = _verify_impl(config_path, data_dir) - _print_report("vNext Verify (after migrate)", verify_report) - payload["verify_after"] = verify_report - - _write_json_if_needed(args.json_out, payload) - if verify_report is not None: - has_error = any(item.get("level") == "error" for item in verify_report.get("checks", [])) - if has_error: - return 1 - return 0 - - if args.command == "verify": - report = _verify_impl(config_path, data_dir) - _print_report("vNext Verify", report) - _write_json_if_needed(args.json_out, report) - has_error = any(item.get("level") == "error" for item in report.get("checks", [])) - if args.strict and has_error: - return 1 - return 0 - - return 2 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/plugins/A_memorix/scripts/runtime_self_check.py b/plugins/A_memorix/scripts/runtime_self_check.py deleted file mode 100644 index 70c423ac..00000000 --- a/plugins/A_memorix/scripts/runtime_self_check.py +++ /dev/null @@ -1,152 +0,0 @@ -#!/usr/bin/env python3 -"""Run A_Memorix runtime self-check against real embedding/runtime configuration.""" - -from __future__ import annotations - -import argparse -import asyncio -import json -import sys -import tempfile -from pathlib import Path -from typing import Any - -import tomlkit - - -CURRENT_DIR = Path(__file__).resolve().parent -PLUGIN_ROOT = CURRENT_DIR.parent -PROJECT_ROOT = PLUGIN_ROOT.parent.parent -sys.path.insert(0, str(PROJECT_ROOT)) -sys.path.insert(0, str(PLUGIN_ROOT)) - - -def _build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="A_Memorix runtime self-check") - parser.add_argument( - "--config", - default=str(PLUGIN_ROOT / "config.toml"), - help="config.toml path (default: plugins/A_memorix/config.toml)", - ) - parser.add_argument( - "--data-dir", - default="", - help="optional data dir override; default resolved from config.storage.data_dir", - ) - parser.add_argument( - "--use-config-data-dir", - action="store_true", - help="use config.storage.data_dir directly instead of an isolated temp dir", - ) - parser.add_argument( - "--sample-text", - default="A_Memorix runtime self check", - help="sample text used for real embedding probe", - ) - parser.add_argument("--json", action="store_true", help="print JSON report") - return parser - - -if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): - _build_arg_parser().print_help() - raise SystemExit(0) - -from core.runtime.lifecycle_orchestrator import initialize_storage_async -from core.utils.runtime_self_check import run_embedding_runtime_self_check - - -def _load_config(path: Path) -> dict[str, Any]: - with open(path, "r", encoding="utf-8") as f: - raw = tomlkit.load(f) - return dict(raw) if isinstance(raw, dict) else {} - - -def _nested_get(config: dict[str, Any], key: str, default: Any = None) -> Any: - current: Any = config - for part in key.split("."): - if isinstance(current, dict) and part in current: - current = current[part] - else: - return default - return current - - -class _PluginStub: - def __init__(self, config: dict[str, Any]): - self.config = config - self.vector_store = None - self.graph_store = None - self.metadata_store = None - self.embedding_manager = None - self.sparse_index = None - self.relation_write_service = None - - def get_config(self, key: str, default: Any = None) -> Any: - return _nested_get(self.config, key, default) - - -async def _main_async(args: argparse.Namespace) -> int: - config_path = Path(args.config).resolve() - if not config_path.exists(): - print(f"❌ 配置文件不存在: {config_path}") - return 2 - - config = _load_config(config_path) - temp_dir_ctx = None - if args.data_dir: - storage_dir = str(Path(args.data_dir).resolve()) - elif args.use_config_data_dir: - raw_data_dir = str(_nested_get(config, "storage.data_dir", "./data") or "./data").strip() - if raw_data_dir.startswith("."): - storage_dir = str((config_path.parent / raw_data_dir).resolve()) - else: - storage_dir = str(Path(raw_data_dir).resolve()) - else: - temp_dir_ctx = tempfile.TemporaryDirectory(prefix="memorix-runtime-self-check-") - storage_dir = temp_dir_ctx.name - - storage_cfg = config.setdefault("storage", {}) - storage_cfg["data_dir"] = storage_dir - - plugin = _PluginStub(config) - try: - await initialize_storage_async(plugin) - report = await run_embedding_runtime_self_check( - config=config, - vector_store=plugin.vector_store, - embedding_manager=plugin.embedding_manager, - sample_text=str(args.sample_text or "A_Memorix runtime self check"), - ) - report["data_dir"] = storage_dir - report["isolated_data_dir"] = temp_dir_ctx is not None - if args.json: - print(json.dumps(report, ensure_ascii=False, indent=2)) - else: - print("A_Memorix Runtime Self-Check") - print(f"ok: {report.get('ok')}") - print(f"code: {report.get('code')}") - print(f"message: {report.get('message')}") - print(f"configured_dimension: {report.get('configured_dimension')}") - print(f"vector_store_dimension: {report.get('vector_store_dimension')}") - print(f"detected_dimension: {report.get('detected_dimension')}") - print(f"encoded_dimension: {report.get('encoded_dimension')}") - print(f"elapsed_ms: {float(report.get('elapsed_ms', 0.0)):.2f}") - return 0 if bool(report.get("ok")) else 1 - finally: - if plugin.metadata_store is not None: - try: - plugin.metadata_store.close() - except Exception: - pass - if temp_dir_ctx is not None: - temp_dir_ctx.cleanup() - - -def main() -> int: - parser = _build_arg_parser() - args = parser.parse_args() - return asyncio.run(_main_async(args)) - - -if __name__ == "__main__": - raise SystemExit(main()) From a004c59e1e01064cf8fab9dc15c34737521672e6 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 31 Mar 2026 15:49:49 +0800 Subject: [PATCH 07/14] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=AE=B0?= =?UTF-8?q?=E5=BF=86=E8=87=AA=E5=8A=A8=E5=8C=96=E9=92=A9=E5=AD=90=E4=B8=8E?= =?UTF-8?q?=E5=9B=9E=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在接收和发送消息时注册记忆自动化,并重构人物记忆回写逻辑以使用 memory_service.ingest_text。主要改动如下: 在接收消息时调用 memory_automation_service.on_incoming_message(bot 侧),在发送消息时调用 on_message_sent(send_service 侧),并加入安全的错误处理。 在 person_info 中,用 memory_service.ingest_text 替换手动操作 person.memory_points 的方式;新增 resolve_person_id_for_memory 辅助方法,并为回写计算一个 external_id 指纹。 扩展插件运行时的记忆搜索能力,使其支持 mode、chat_id、person_id、user_id、group_id、时间范围以及 respect_filter 选项。 改进 find_messages 的数据库会话处理,改为使用单一 session,并修复排序和过滤逻辑。 从 KnowledgeFetcher 中移除未使用的 LLMRequest 导入和初始化。 更新术语解释器(jargon explainer)的导入路径,使用新的模块位置。 更新 .gitignore 例外规则,允许特定的 pytest 数据文件被纳入版本控制。 文档小调整:明确人物事实提取规则(将直接使用的 “you” 改写为第三人称)。 --- .gitignore | 6 + .../brain_chat/PFC/pfc_KnowledgeFetcher.py | 5 - src/chat/message_receive/bot.py | 7 + src/common/message_repository.py | 54 ++++--- src/learners/jargon_explainer_old.py | 2 +- src/person_info/person_info.py | 145 +++++++++++------- src/plugin_runtime/capabilities/data.py | 22 ++- src/services/memory_flow_service.py | 1 + src/services/send_service.py | 16 ++ 9 files changed, 170 insertions(+), 88 deletions(-) diff --git a/.gitignore b/.gitignore index 6d8b249f..c5a687ca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,10 @@ data/ +!pytests/A_memorix_test/data/ +!pytests/A_memorix_test/data/benchmarks/ +!pytests/A_memorix_test/data/benchmarks/long_novel_memory_benchmark.json +!pytests/A_memorix_test/data/real_dialogues/ +!pytests/A_memorix_test/data/real_dialogues/private_alice_weekend.json +pytests/A_memorix_test/data/benchmarks/results/ data1/ mongodb/ NapCat.Framework.Windows.Once/ diff --git a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py index 3136f8be..fe875540 100644 --- a/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py +++ b/src/chat/brain_chat/PFC/pfc_KnowledgeFetcher.py @@ -3,10 +3,6 @@ from typing import Any, Dict, List, Tuple from src.chat.message_receive.chat_manager import chat_manager as _chat_manager from src.common.logger import get_logger -# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned -# from src.plugins.memory_system.Hippocampus import HippocampusManager -from src.config.config import model_config -from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import resolve_person_id_for_memory from src.services.memory_service import memory_service @@ -17,7 +13,6 @@ class KnowledgeFetcher: """知识调取器""" def __init__(self, private_name: str, stream_id: str): - self.llm = LLMRequest(model_set=model_config.model_task_config.utils) self.private_name = private_name self.stream_id = stream_id diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 33a66ffc..8bc06ed0 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -325,6 +325,13 @@ class ChatBot: scope=scope, ) # 确保会话存在 + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_incoming_message(message) + except Exception as exc: + logger.warning(f"[{session_id}] 长期记忆自动摘要注册失败: {exc}") + # message.update_chat_stream(chat) # 命令处理 - 使用新插件系统检查并处理命令。 diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 94d7bfea..98799738 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -189,39 +189,37 @@ def find_messages( conditions.append(Messages.is_command == False) # noqa: E712 statement = select(Messages).where(*conditions) - if limit > 0: - if limit_mode == "earliest": - statement = statement.order_by(col(Messages.timestamp)).limit(limit) - with get_db_session() as session: + with get_db_session(auto_commit=False) as session: + if limit > 0: + if limit_mode == "earliest": + statement = statement.order_by(col(Messages.timestamp)).limit(limit) results = list(session.exec(statement).all()) + else: + statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit) + results = list(session.exec(statement).all()) + results = list(reversed(results)) else: - statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit) - with get_db_session() as session: - results = list(session.exec(statement).all()) - results = list(reversed(results)) - else: - if sort: - order_terms: list[Any] = [] - for field_name, direction in sort: - sort_field = _resolve_field(field_name) - if sort_field is None: - logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") - continue - order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc()) - if order_terms: - statement = statement.order_by(*order_terms) - with get_db_session() as session: + if sort: + order_terms: list[Any] = [] + for field_name, direction in sort: + sort_field = _resolve_field(field_name) + if sort_field is None: + logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") + continue + order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc()) + if order_terms: + statement = statement.order_by(*order_terms) results = list(session.exec(statement).all()) - if filter_intercept_message_level is not None: - filtered_results = [] - for msg in results: - config = _parse_additional_config(msg) - if config.get("intercept_message_level", 0) <= filter_intercept_message_level: - filtered_results.append(msg) - results = filtered_results + if filter_intercept_message_level is not None: + filtered_results = [] + for msg in results: + config = _parse_additional_config(msg) + if config.get("intercept_message_level", 0) <= filter_intercept_message_level: + filtered_results.append(msg) + results = filtered_results - return [_message_to_instance(msg) for msg in results] + return [_message_to_instance(msg) for msg in results] except Exception as e: log_message = ( "使用 SQLModel 查找消息失败 " diff --git a/src/learners/jargon_explainer_old.py b/src/learners/jargon_explainer_old.py index 876b4539..330da8cb 100644 --- a/src/learners/jargon_explainer_old.py +++ b/src/learners/jargon_explainer_old.py @@ -8,7 +8,7 @@ from src.common.data_models.llm_service_data_models import LLMGenerationOptions from src.services.llm_service import LLMServiceClient from src.config.config import global_config from src.prompt.prompt_manager import prompt_manager -from src.learners.jargon_miner_old import search_jargon +from src.learners.jargon_explainer import search_jargon from src.learners.learner_utils_old import ( is_bot_message, contains_bot_self_name, diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index c603f4b7..0838156a 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -18,6 +18,7 @@ from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.common.logger import get_logger from src.config.config import global_config +from src.services.memory_service import memory_service from src.services.llm_service import LLMServiceClient @@ -66,15 +67,45 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" try: - with get_db_session() as session: - statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1) - record = session.exec(statement).first() - return record.person_id if record else "" + with get_db_session(auto_commit=False) as session: + statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == person_name).limit(1) + person_id = session.exec(statement).first() + return str(person_id) if person_id else "" except Exception as e: logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" +def resolve_person_id_for_memory( + *, + person_name: str = "", + platform: str = "", + user_id: Union[int, str, None] = None, + strict_known: bool = False, +) -> str: + """解析长期记忆检索/写入使用的人物 ID。 + + 解析顺序: + 1. 优先按 `person_name` 映射数据库中的 `person_id` + 2. 回退到 `platform + user_id` 生成稳定 `person_id` + 3. 若 `strict_known=True`,则要求该 `person_id` 已被认识 + """ + clean_name = str(person_name or "").strip() + if clean_name: + if by_name := get_person_id_by_person_name(clean_name): + return by_name + + clean_platform = str(platform or "").strip() + clean_user_id = str(user_id or "").strip() + if clean_platform and clean_user_id: + candidate = get_person_id(clean_platform, clean_user_id) + if strict_known and not is_person_known(person_id=candidate): + return "" + return candidate + + return "" + + def is_person_known( person_id: Optional[str] = None, user_id: Optional[str] = None, @@ -800,75 +831,83 @@ person_info_manager = PersonInfoManager() async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None: - """将人物信息存入person_info的memory_points + """将人物事实写入长期记忆系统。 Args: person_name: 人物名称 memory_content: 记忆内容 chat_id: 聊天ID """ + clean_content = str(memory_content or "").strip() + if not clean_content: + logger.debug("人物事实写回跳过:memory_content 为空") + return + + clean_chat_id = str(chat_id or "").strip() + if not clean_chat_id: + logger.warning("人物事实写回失败:chat_id 为空") + return + + clean_person_name = str(person_name or "").strip() try: # 从 chat_id 获取 session - session = _chat_manager.get_session_by_session_id(chat_id) + session = _chat_manager.get_session_by_session_id(clean_chat_id) if not session: - logger.warning(f"无法获取session for chat_id: {chat_id}") + logger.warning(f"无法获取session for chat_id: {clean_chat_id}") return - platform = session.platform - - # 尝试从person_name查找person_id - # 首先尝试通过person_name查找 - person_id = get_person_id_by_person_name(person_name) + session_platform = str(getattr(session, "platform", "") or "").strip() + session_user_id = str(getattr(session, "user_id", "") or "").strip() + session_group_id = str(getattr(session, "group_id", "") or "").strip() + person_id = resolve_person_id_for_memory( + person_name=clean_person_name, + platform=session_platform, + user_id=session_user_id, + ) if not person_id: - # 如果通过person_name找不到,尝试从 session 获取 user_id - if platform and session.user_id: - user_id = session.user_id - person_id = get_person_id(platform, user_id) - else: - logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}") - return - - # 创建或获取Person对象 - person = Person(person_id=person_id) - - if not person.is_known: - logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆") + logger.warning(f"无法确定person_id for person_name: {clean_person_name}, chat_id: {clean_chat_id}") return - # 确定记忆分类(可以根据memory_content判断,这里使用通用分类) - category = "其他" # 默认分类,可以根据需要调整 + person = Person(person_id=person_id) + if not person.is_known: + logger.warning(f"用户 {clean_person_name or person_id} (person_id: {person_id}) 尚未认识,跳过写回") + return - # 记忆点格式:category:content:weight - weight = "1.0" # 默认权重 - memory_point = f"{category}:{memory_content}:{weight}" + participant_name = str(getattr(person, "person_name", "") or getattr(person, "nickname", "") or "").strip() + if not participant_name: + participant_name = clean_person_name or person_id - # 添加到memory_points - if not person.memory_points: - person.memory_points = [] + payload_fingerprint = hashlib.md5(f"{person_id}|{clean_chat_id}|{clean_content}".encode()).hexdigest() + external_id = f"person_fact:{person_id}:{payload_fingerprint}" - # 检查是否已存在相似的记忆点(避免重复) - is_duplicate = False - for existing_point in person.memory_points: - if existing_point and isinstance(existing_point, str): - parts = existing_point.split(":", 2) - if len(parts) >= 2: - existing_content = parts[1].strip() - # 简单相似度检查(如果内容相同或非常相似,则跳过) - if ( - existing_content == memory_content - or memory_content in existing_content - or existing_content in memory_content - ): - is_duplicate = True - break + result = await memory_service.ingest_text( + external_id=external_id, + source_type="person_fact", + text=clean_content, + chat_id=clean_chat_id, + person_ids=[person_id], + participants=[participant_name], + tags=["person_fact"], + metadata={ + "person_id": person_id, + "person_name": participant_name, + "writeback_source": "memory_flow_service", + }, + respect_filter=True, + user_id=session_user_id, + group_id=session_group_id, + ) - if not is_duplicate: - person.memory_points.append(memory_point) - person.sync_to_database() - logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}") + if getattr(result, "success", False): + logger.info( + f"成功写回人物事实到长期记忆: person={participant_name} person_id={person_id} chat_id={clean_chat_id}" + ) else: - logger.debug(f"记忆点已存在,跳过: {memory_point}") + logger.warning( + f"人物事实写回长期记忆失败: person={participant_name} person_id={person_id} " + f"chat_id={clean_chat_id} detail={getattr(result, 'detail', '')}" + ) except Exception as e: logger.error(f"存储人物记忆失败: {e}") diff --git a/src/plugin_runtime/capabilities/data.py b/src/plugin_runtime/capabilities/data.py index 32843d09..1acd33d3 100644 --- a/src/plugin_runtime/capabilities/data.py +++ b/src/plugin_runtime/capabilities/data.py @@ -671,10 +671,30 @@ class RuntimeDataCapabilityMixin: except (TypeError, ValueError): limit_value = 5 + mode = str(args.get("mode", "search") or "search").strip() or "search" + chat_id = str(args.get("chat_id", "") or "").strip() + person_id = str(args.get("person_id", "") or "").strip() + user_id = str(args.get("user_id", "") or "").strip() + group_id = str(args.get("group_id", "") or "").strip() + respect_filter = bool(args.get("respect_filter", True)) + time_start = args.get("time_start") + time_end = args.get("time_end") + try: from src.services.memory_service import memory_service - result = await memory_service.search(query, limit=limit_value) + result = await memory_service.search( + query, + limit=limit_value, + mode=mode, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + respect_filter=respect_filter, + user_id=user_id, + group_id=group_id, + ) if not result.success: return {"success": False, "error": result.error or "长期记忆检索失败"} knowledge_info = result.to_text(limit=limit_value) diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index 96062eb6..c95bcc69 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -178,6 +178,7 @@ class PersonFactWritebackService: 1. 明确是关于目标人物本人的信息。 2. 具有相对稳定性,可以作为长期记忆保存。 3. 用简洁中文陈述句表达。 +4. 如果回复是在直接对目标人物说话,出现“你/你的/你自己”时,默认都指目标人物,请先改写成关于目标人物的第三人称事实再输出。 不要提取: - 机器人的情绪、计划、临时动作、客套话 diff --git a/src/services/send_service.py b/src/services/send_service.py index d7f17563..5b0d60f8 100644 --- a/src/services/send_service.py +++ b/src/services/send_service.py @@ -434,6 +434,21 @@ def _store_sent_message(message: SessionMessage) -> None: MessageUtils.store_message_to_db(message) +async def _notify_memory_automation_on_message_sent(message: SessionMessage) -> None: + """在发送成功后通知长期记忆自动化服务。 + + Args: + message: 已成功发送的内部消息对象。 + """ + try: + from src.services.memory_flow_service import memory_automation_service + + await memory_automation_service.on_message_sent(message) + except Exception as exc: + session_id = message.session_id or "unknown-session" + logger.warning(f"[{session_id}] 长期记忆人物事实写回注册失败: {exc}") + + def _log_platform_io_failures(delivery_batch: DeliveryBatch) -> None: """输出 Platform IO 批量发送失败详情。 @@ -503,6 +518,7 @@ async def _send_via_platform_io( if delivery_batch.has_success: if storage_message: _store_sent_message(message) + await _notify_memory_automation_on_message_sent(message) if show_log: successful_driver_ids = [ receipt.driver_id or "unknown" From bcf2bfcd6363ab5eafc97f294d359d97f71bb466 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 31 Mar 2026 16:40:47 +0800 Subject: [PATCH 08/14] =?UTF-8?q?fix:=E5=AF=B9=E9=BD=90=E6=8F=92=E4=BB=B6I?= =?UTF-8?q?D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/services/memory_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/services/memory_service.py b/src/services/memory_service.py index 04f08ff6..81dfb960 100644 --- a/src/services/memory_service.py +++ b/src/services/memory_service.py @@ -9,7 +9,7 @@ from src.plugin_runtime.integration import get_plugin_runtime_manager logger = get_logger("memory_service") -PLUGIN_ID = "A_Memorix" +PLUGIN_ID = "a-dawn.a-memorix" @dataclass From 683cdf4a13faf8e49afd395a8a5e549237773ae5 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 31 Mar 2026 19:22:10 +0800 Subject: [PATCH 09/14] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20A=5Fmemorix=20?= =?UTF-8?q?=E5=AD=90=E6=A8=A1=E5=9D=97=E6=8C=87=E9=92=88=E8=87=B3=20f03092?= =?UTF-8?q?a?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/A_memorix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/A_memorix b/plugins/A_memorix index 5fc5026a..f03092a6 160000 --- a/plugins/A_memorix +++ b/plugins/A_memorix @@ -1 +1 @@ -Subproject commit 5fc5026a540c1cfd55a7b824b43aaeef867e3228 +Subproject commit f03092a63a08cbfd261f22b9c0c6b4f4a5aaeabe From d56a8ba030dff6f97b664f33f725a62442966807 Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Tue, 31 Mar 2026 20:37:04 +0800 Subject: [PATCH 10/14] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20A=5Fmemorix=20?= =?UTF-8?q?=E5=AD=90=E6=A8=A1=E5=9D=97=E6=8C=87=E9=92=88=E8=87=B3=2060224c?= =?UTF-8?q?e?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plugins/A_memorix | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/A_memorix b/plugins/A_memorix index f03092a6..60224ce2 160000 --- a/plugins/A_memorix +++ b/plugins/A_memorix @@ -1 +1 @@ -Subproject commit f03092a63a08cbfd261f22b9c0c6b4f4a5aaeabe +Subproject commit 60224ce27552065841c4acc9f5927712e8e3deb3 From 02262e21126c295e9273540846ecc48df4aa8bff Mon Sep 17 00:00:00 2001 From: DawnARC Date: Tue, 31 Mar 2026 21:31:27 +0800 Subject: [PATCH 11/14] =?UTF-8?q?fix:=E4=BD=BF=E7=94=A8=20LLMServiceClient?= =?UTF-8?q?=20=E8=BF=9B=E8=A1=8C=E4=BA=BA=E7=89=A9=E4=BA=8B=E5=AE=9E?= =?UTF-8?q?=E6=8F=90=E5=8F=96(=E6=8D=A2=E8=AF=B7=E6=B1=82)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/services/memory_flow_service.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/services/memory_flow_service.py b/src/services/memory_flow_service.py index c95bcc69..75ff0ca9 100644 --- a/src/services/memory_flow_service.py +++ b/src/services/memory_flow_service.py @@ -9,10 +9,10 @@ from json_repair import repair_json from src.chat.utils.utils import is_bot_self from src.common.message_repository import find_messages from src.common.logger import get_logger -from src.config.config import global_config, model_config -from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config from src.memory_system.chat_history_summarizer import ChatHistorySummarizer from src.person_info.person_info import Person, get_person_id, store_person_memory_from_answer +from src.services.llm_service import LLMServiceClient logger = get_logger("memory_flow_service") @@ -55,10 +55,7 @@ class PersonFactWritebackService: self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=256) self._worker_task: Optional[asyncio.Task] = None self._stopping = False - self._extractor = LLMRequest( - model_set=model_config.model_task_config.utils, - request_type="person_fact_writeback", - ) + self._extractor = LLMServiceClient(task_name="utils", request_type="person_fact_writeback") async def start(self) -> None: if self._worker_task is not None and not self._worker_task.done(): @@ -190,11 +187,11 @@ class PersonFactWritebackService: ["他喜欢深夜打游戏", "他养了一只猫"] 如果没有可写入的事实,输出 []""" try: - response, _ = await self._extractor.generate_response_async(prompt) + response_result = await self._extractor.generate_response(prompt) except Exception as exc: logger.debug("人物事实提取模型调用失败: %s", exc) return [] - return self._parse_fact_list(response) + return self._parse_fact_list(response_result.response) @staticmethod def _parse_fact_list(raw: str) -> List[str]: From 15d436b3a1219126afad46d7af6f5f32bd4ba43a Mon Sep 17 00:00:00 2001 From: A-Dawn <67786671+A-Dawn@users.noreply.github.com> Date: Fri, 3 Apr 2026 08:08:24 +0800 Subject: [PATCH 12/14] =?UTF-8?q?refactor:=20=E5=B0=86=20A=5FMemorix=20?= =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=B8=BA=E4=B8=BB=E7=BA=BF=E9=95=BF=E6=9C=9F?= =?UTF-8?q?=E8=AE=B0=E5=BF=86=E5=AD=90=E7=B3=BB=E7=BB=9F=E5=B9=B6=E9=87=8D?= =?UTF-8?q?=E5=BB=BA=E7=AE=A1=E7=90=86=E7=95=8C=E9=9D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 A_Memorix 从旧 submodule / 插件形态迁入主线源码,主体落到 src/A_memorix - 调整主程序接入方式,使 A_Memorix 作为源码内长期记忆子系统运行 - 回收父项目插件体系中针对 A_Memorix 的特判,减少对 plugin 通用层的侵入 - 将长期记忆配置、运行时、自检、导入、调优等能力收口到 memory 路由与主线服务层 - 重做长期记忆控制台与图谱页面,按 MaiBot 现有 dashboard 风格接入 - 补充实体关系图与证据视图双视图能力,支持查看节点、关系、段落及其证据链路 - 新增长期记忆配置编辑器与 memory-api,支持主线内配置管理 - 补齐删除管理能力:删除预览、混合删除、来源批量删除、删除操作恢复 - 优化删除预览与删除操作详情的前端展示,支持分页、检索,并以实体名/关系内容/段落摘要替代单纯 hash 展示 - 修复图谱与控制台相关前端问题,包括证据视图切换、查询触发时机、删除弹层空值保护等 - 新增或更新 A_Memorix 相关测试、WebUI 路由测试、前端 vitest 测试与辅助验证脚本 - 移除旧 plugins/A_memorix、.gitmodules 及相关历史维护文档 --- .gitmodules | 4 - .../scripts/a_memorix_electron_validate.cjs | 406 ++ dashboard/src/components/CodeEditor.tsx | 3 +- .../components/memory/MemoryConfigEditor.tsx | 311 + .../components/memory/MemoryDeleteDialog.tsx | 281 + dashboard/src/i18n/locales/en.json | 4 +- dashboard/src/i18n/locales/ja.json | 4 +- dashboard/src/i18n/locales/ko.json | 4 +- dashboard/src/i18n/locales/zh.json | 4 +- dashboard/src/lib/memory-api.ts | 509 ++ .../routes/__tests__/plugin-config.test.tsx | 80 + dashboard/src/routes/plugin-config.tsx | 62 +- .../__tests__/knowledge-base.test.tsx | 270 + .../__tests__/knowledge-graph.test.tsx | 341 + .../src/routes/resource/knowledge-base.tsx | 1538 ++++- .../resource/knowledge-graph/GraphDialogs.tsx | 463 +- .../knowledge-graph/GraphVisualization.tsx | 110 +- .../routes/resource/knowledge-graph/index.tsx | 1018 ++- .../routes/resource/knowledge-graph/types.ts | 11 +- docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md | 38 - docs/a_memorix_sync.md | 45 + plugins/A_memorix | 1 - .../A_memorix_test/data/benchmarks/README.md | 65 + .../group_chat_stream_memory_benchmark.json | 728 ++ ...oup_chat_stream_memory_benchmark_hard.json | 862 +++ .../long_novel_memory_benchmark.json | 120 + .../real_dialogues/private_alice_weekend.json | 90 + .../test_embedding_dimension_control.py | 166 + .../test_group_chat_stream_fixture_schema.py | 86 + ...test_group_chat_stream_memory_benchmark.py | 1187 ++++ .../A_memorix_test/test_knowledge_fetcher.py | 3 - .../test_long_novel_memory_benchmark.py | 26 +- .../test_long_novel_memory_benchmark_live.py | 5 +- pytests/A_memorix_test/test_memory_service.py | 2 +- .../test_query_long_term_memory_tool.py | 4 +- ...real_dialogue_business_flow_integration.py | 17 +- .../test_real_dialogue_business_flow_live.py | 17 +- pytests/conftest.py | 6 +- pytests/webui/test_memory_routes.py | 290 +- .../webui/test_plugin_management_routes.py | 49 + scripts/run.sh | 9 +- scripts/run_a_memorix_webui_backend.py | 25 + scripts/sync_a_memorix_subtree.sh | 21 + scripts/verify_a_memorix_webui.sh | 83 + src/A_memorix/.gitattributes | 2 + src/A_memorix/.gitignore | 245 + src/A_memorix/CHANGELOG.md | 718 ++ src/A_memorix/CONFIG_REFERENCE.md | 359 + src/A_memorix/IMPORT_GUIDE.md | 335 + src/A_memorix/LICENSE | 661 ++ src/A_memorix/LICENSE-MAIBOT-GPL.md | 22 + src/A_memorix/MODIFICATION_POLICY.md | 97 + src/A_memorix/QUICK_START.md | 313 + src/A_memorix/README.md | 271 + src/A_memorix/RELEASE_SUMMARY_1.0.0.md | 46 + src/A_memorix/__init__.py | 5 + src/A_memorix/config_schema.json | 1384 ++++ src/A_memorix/core/__init__.py | 84 + src/A_memorix/core/embedding/__init__.py | 18 + src/A_memorix/core/embedding/api_adapter.py | 434 ++ src/A_memorix/core/embedding/manager.py | 510 ++ src/A_memorix/core/embedding/presets.py | 72 + src/A_memorix/core/retrieval/__init__.py | 54 + src/A_memorix/core/retrieval/dual_path.py | 1871 ++++++ .../core/retrieval/graph_relation_recall.py | 272 + src/A_memorix/core/retrieval/pagerank.py | 482 ++ src/A_memorix/core/retrieval/sparse_bm25.py | 401 ++ src/A_memorix/core/retrieval/threshold.py | 450 ++ src/A_memorix/core/runtime/__init__.py | 16 + .../core/runtime/lifecycle_orchestrator.py | 265 + .../core/runtime/sdk_memory_kernel.py | 4243 ++++++++++++ .../runtime/search_runtime_initializer.py | 240 + src/A_memorix/core/storage/__init__.py | 53 + src/A_memorix/core/storage/graph_store.py | 1448 ++++ src/A_memorix/core/storage/knowledge_types.py | 183 + src/A_memorix/core/storage/metadata_store.py | 5959 +++++++++++++++++ src/A_memorix/core/storage/type_detection.py | 137 + src/A_memorix/core/storage/vector_store.py | 776 +++ src/A_memorix/core/strategies/__init__.py | 0 src/A_memorix/core/strategies/base.py | 89 + src/A_memorix/core/strategies/factual.py | 98 + src/A_memorix/core/strategies/narrative.py | 126 + src/A_memorix/core/strategies/quote.py | 52 + src/A_memorix/core/utils/__init__.py | 33 + .../core/utils/aggregate_query_service.py | 360 + .../core/utils/episode_retrieval_service.py | 182 + .../utils/episode_segmentation_service.py | 304 + src/A_memorix/core/utils/episode_service.py | 558 ++ src/A_memorix/core/utils/hash.py | 129 + src/A_memorix/core/utils/import_payloads.py | 110 + src/A_memorix/core/utils/io.py | 84 + src/A_memorix/core/utils/matcher.py | 89 + src/A_memorix/core/utils/monitor.py | 189 + .../core/utils/path_fallback_service.py | 165 + .../core/utils/person_profile_service.py | 599 ++ src/A_memorix/core/utils/plugin_id_policy.py | 27 + src/A_memorix/core/utils/quantization.py | 344 + src/A_memorix/core/utils/relation_query.py | 121 + .../core/utils/relation_write_service.py | 166 + .../core/utils/retrieval_tuning_manager.py | 1858 +++++ .../core/utils/runtime_self_check.py | 240 + .../core/utils/search_execution_service.py | 439 ++ .../core/utils/search_postprocess.py | 90 + src/A_memorix/core/utils/summary_importer.py | 463 ++ src/A_memorix/core/utils/time_parser.py | 170 + .../core/utils/web_import_manager.py | 3606 ++++++++++ src/A_memorix/host_service.py | 260 + src/A_memorix/paths.py | 56 + src/A_memorix/plugin.py | 290 + src/A_memorix/requirements.txt | 52 + src/A_memorix/runtime_registry.py | 27 + src/A_memorix/scripts/_bootstrap.py | 22 + .../scripts/audit_vector_consistency.py | 208 + .../scripts/backfill_relation_vectors.py | 265 + .../scripts/backfill_temporal_metadata.py | 65 + src/A_memorix/scripts/convert_lpmm.py | 530 ++ src/A_memorix/scripts/import_lpmm_json.py | 165 + src/A_memorix/scripts/migrate_chat_history.py | 99 + .../scripts/migrate_maibot_memory.py | 1743 +++++ .../scripts/migrate_person_memory_points.py | 109 + src/A_memorix/scripts/process_knowledge.py | 720 ++ src/A_memorix/scripts/rebuild_episodes.py | 119 + .../scripts/release_vnext_migrate.py | 744 ++ src/A_memorix/scripts/runtime_self_check.py | 144 + src/A_memorix/web/import.html | 913 +++ src/A_memorix/web/index.html | 3136 +++++++++ src/A_memorix/web/index.html.scratch | 13 + src/A_memorix/web/tuning.html | 722 ++ src/main.py | 3 + src/plugin_runtime/integration.py | 11 +- src/plugin_runtime/runner/plugin_loader.py | 58 +- src/plugin_runtime/runner/runner_main.py | 5 +- src/services/memory_service.py | 18 +- src/webui/routers/memory.py | 143 +- src/webui/routers/plugin/config_routes.py | 34 +- src/webui/routers/plugin/support.py | 3 + 136 files changed, 52533 insertions(+), 629 deletions(-) delete mode 100644 .gitmodules create mode 100644 dashboard/scripts/a_memorix_electron_validate.cjs create mode 100644 dashboard/src/components/memory/MemoryConfigEditor.tsx create mode 100644 dashboard/src/components/memory/MemoryDeleteDialog.tsx create mode 100644 dashboard/src/lib/memory-api.ts create mode 100644 dashboard/src/routes/__tests__/plugin-config.test.tsx create mode 100644 dashboard/src/routes/resource/__tests__/knowledge-base.test.tsx create mode 100644 dashboard/src/routes/resource/__tests__/knowledge-graph.test.tsx delete mode 100644 docs-src/MAINTAIN_A_MEMORIX_SUBMODULE.md create mode 100644 docs/a_memorix_sync.md delete mode 160000 plugins/A_memorix create mode 100644 pytests/A_memorix_test/data/benchmarks/README.md create mode 100644 pytests/A_memorix_test/data/benchmarks/group_chat_stream_memory_benchmark.json create mode 100644 pytests/A_memorix_test/data/benchmarks/group_chat_stream_memory_benchmark_hard.json create mode 100644 pytests/A_memorix_test/data/benchmarks/long_novel_memory_benchmark.json create mode 100644 pytests/A_memorix_test/data/real_dialogues/private_alice_weekend.json create mode 100644 pytests/A_memorix_test/test_embedding_dimension_control.py create mode 100644 pytests/A_memorix_test/test_group_chat_stream_fixture_schema.py create mode 100644 pytests/A_memorix_test/test_group_chat_stream_memory_benchmark.py create mode 100644 pytests/webui/test_plugin_management_routes.py create mode 100644 scripts/run_a_memorix_webui_backend.py create mode 100755 scripts/sync_a_memorix_subtree.sh create mode 100755 scripts/verify_a_memorix_webui.sh create mode 100644 src/A_memorix/.gitattributes create mode 100644 src/A_memorix/.gitignore create mode 100644 src/A_memorix/CHANGELOG.md create mode 100644 src/A_memorix/CONFIG_REFERENCE.md create mode 100644 src/A_memorix/IMPORT_GUIDE.md create mode 100644 src/A_memorix/LICENSE create mode 100644 src/A_memorix/LICENSE-MAIBOT-GPL.md create mode 100644 src/A_memorix/MODIFICATION_POLICY.md create mode 100644 src/A_memorix/QUICK_START.md create mode 100644 src/A_memorix/README.md create mode 100644 src/A_memorix/RELEASE_SUMMARY_1.0.0.md create mode 100644 src/A_memorix/__init__.py create mode 100644 src/A_memorix/config_schema.json create mode 100644 src/A_memorix/core/__init__.py create mode 100644 src/A_memorix/core/embedding/__init__.py create mode 100644 src/A_memorix/core/embedding/api_adapter.py create mode 100644 src/A_memorix/core/embedding/manager.py create mode 100644 src/A_memorix/core/embedding/presets.py create mode 100644 src/A_memorix/core/retrieval/__init__.py create mode 100644 src/A_memorix/core/retrieval/dual_path.py create mode 100644 src/A_memorix/core/retrieval/graph_relation_recall.py create mode 100644 src/A_memorix/core/retrieval/pagerank.py create mode 100644 src/A_memorix/core/retrieval/sparse_bm25.py create mode 100644 src/A_memorix/core/retrieval/threshold.py create mode 100644 src/A_memorix/core/runtime/__init__.py create mode 100644 src/A_memorix/core/runtime/lifecycle_orchestrator.py create mode 100644 src/A_memorix/core/runtime/sdk_memory_kernel.py create mode 100644 src/A_memorix/core/runtime/search_runtime_initializer.py create mode 100644 src/A_memorix/core/storage/__init__.py create mode 100644 src/A_memorix/core/storage/graph_store.py create mode 100644 src/A_memorix/core/storage/knowledge_types.py create mode 100644 src/A_memorix/core/storage/metadata_store.py create mode 100644 src/A_memorix/core/storage/type_detection.py create mode 100644 src/A_memorix/core/storage/vector_store.py create mode 100644 src/A_memorix/core/strategies/__init__.py create mode 100644 src/A_memorix/core/strategies/base.py create mode 100644 src/A_memorix/core/strategies/factual.py create mode 100644 src/A_memorix/core/strategies/narrative.py create mode 100644 src/A_memorix/core/strategies/quote.py create mode 100644 src/A_memorix/core/utils/__init__.py create mode 100644 src/A_memorix/core/utils/aggregate_query_service.py create mode 100644 src/A_memorix/core/utils/episode_retrieval_service.py create mode 100644 src/A_memorix/core/utils/episode_segmentation_service.py create mode 100644 src/A_memorix/core/utils/episode_service.py create mode 100644 src/A_memorix/core/utils/hash.py create mode 100644 src/A_memorix/core/utils/import_payloads.py create mode 100644 src/A_memorix/core/utils/io.py create mode 100644 src/A_memorix/core/utils/matcher.py create mode 100644 src/A_memorix/core/utils/monitor.py create mode 100644 src/A_memorix/core/utils/path_fallback_service.py create mode 100644 src/A_memorix/core/utils/person_profile_service.py create mode 100644 src/A_memorix/core/utils/plugin_id_policy.py create mode 100644 src/A_memorix/core/utils/quantization.py create mode 100644 src/A_memorix/core/utils/relation_query.py create mode 100644 src/A_memorix/core/utils/relation_write_service.py create mode 100644 src/A_memorix/core/utils/retrieval_tuning_manager.py create mode 100644 src/A_memorix/core/utils/runtime_self_check.py create mode 100644 src/A_memorix/core/utils/search_execution_service.py create mode 100644 src/A_memorix/core/utils/search_postprocess.py create mode 100644 src/A_memorix/core/utils/summary_importer.py create mode 100644 src/A_memorix/core/utils/time_parser.py create mode 100644 src/A_memorix/core/utils/web_import_manager.py create mode 100644 src/A_memorix/host_service.py create mode 100644 src/A_memorix/paths.py create mode 100644 src/A_memorix/plugin.py create mode 100644 src/A_memorix/requirements.txt create mode 100644 src/A_memorix/runtime_registry.py create mode 100644 src/A_memorix/scripts/_bootstrap.py create mode 100644 src/A_memorix/scripts/audit_vector_consistency.py create mode 100644 src/A_memorix/scripts/backfill_relation_vectors.py create mode 100644 src/A_memorix/scripts/backfill_temporal_metadata.py create mode 100644 src/A_memorix/scripts/convert_lpmm.py create mode 100644 src/A_memorix/scripts/import_lpmm_json.py create mode 100644 src/A_memorix/scripts/migrate_chat_history.py create mode 100644 src/A_memorix/scripts/migrate_maibot_memory.py create mode 100644 src/A_memorix/scripts/migrate_person_memory_points.py create mode 100644 src/A_memorix/scripts/process_knowledge.py create mode 100644 src/A_memorix/scripts/rebuild_episodes.py create mode 100644 src/A_memorix/scripts/release_vnext_migrate.py create mode 100644 src/A_memorix/scripts/runtime_self_check.py create mode 100644 src/A_memorix/web/import.html create mode 100644 src/A_memorix/web/index.html create mode 100644 src/A_memorix/web/index.html.scratch create mode 100644 src/A_memorix/web/tuning.html diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 3ddc6a14..00000000 --- a/.gitmodules +++ /dev/null @@ -1,4 +0,0 @@ -[submodule "plugins/A_memorix"] - path = plugins/A_memorix - url = https://github.com/A-Dawn/A_memorix.git - branch = MaiBot_branch diff --git a/dashboard/scripts/a_memorix_electron_validate.cjs b/dashboard/scripts/a_memorix_electron_validate.cjs new file mode 100644 index 00000000..1d5de90e --- /dev/null +++ b/dashboard/scripts/a_memorix_electron_validate.cjs @@ -0,0 +1,406 @@ +const { app, BrowserWindow } = require('electron') +const fs = require('fs') +const path = require('path') + +const DASHBOARD_URL = process.env.MAIBOT_DASHBOARD_URL || 'http://127.0.0.1:7999' +const OUTPUT_DIR = process.env.MAIBOT_UI_SNAPSHOT_DIR + || path.resolve(__dirname, '..', '..', 'tmp', 'ui-snapshots', 'a_memorix-electron') +const TOKEN_PATH = process.env.MAIBOT_WEBUI_TOKEN_PATH + || path.resolve(__dirname, '..', '..', 'data', 'webui.json') +const sampleStamp = String(Date.now()) +const sampleSource = process.env.MAIBOT_UI_SAMPLE_SOURCE || `webui-demo:a_memorix-json-${sampleStamp}` +const sampleName = process.env.MAIBOT_UI_SAMPLE_NAME || `webui-json-validation-${sampleStamp}.json` + +const DEFAULT_SAMPLE = { + paragraphs: [ + { + content: 'Alice 在杭州西湖与 Bob 讨论 A_Memorix 的前端接入与 embedding 调优方案。', + source: sampleSource, + entities: ['Alice', 'Bob', '杭州西湖', 'A_Memorix'], + relations: [ + { subject: 'Alice', predicate: '在', object: '杭州西湖' }, + { subject: 'Alice', predicate: '讨论', object: 'A_Memorix' }, + { subject: 'Bob', predicate: '讨论', object: 'A_Memorix' }, + { subject: 'Bob', predicate: '负责', object: 'embedding 调优' }, + ], + knowledge_type: 'factual', + }, + ], + entities: ['Alice', 'Bob', '杭州西湖', 'A_Memorix', 'embedding 调优'], + relations: [{ subject: 'Alice', predicate: '认识', object: 'Bob' }], +} + +function loadSampleJson() { + const customPath = String(process.env.MAIBOT_UI_IMPORT_JSON_PATH || '').trim() + if (!customPath) { + return JSON.stringify(DEFAULT_SAMPLE, null, 2) + } + return fs.readFileSync(customPath, 'utf8') +} + +const sampleJson = loadSampleJson() + +fs.mkdirSync(OUTPUT_DIR, { recursive: true }) + +function wait(ms) { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +async function exec(win, code) { + return win.webContents.executeJavaScript(code, true) +} + +async function waitFor(win, predicateCode, label, timeout = 30000, interval = 300) { + const start = Date.now() + while (Date.now() - start < timeout) { + try { + const ok = await exec(win, predicateCode) + if (ok) { + return ok + } + } catch { + // keep polling + } + await wait(interval) + } + throw new Error(`Timeout waiting for ${label}`) +} + +async function sendClick(win, x, y) { + win.webContents.sendInputEvent({ type: 'mouseMove', x, y, movementX: 0, movementY: 0 }) + win.webContents.sendInputEvent({ type: 'mouseDown', x, y, button: 'left', clickCount: 1 }) + win.webContents.sendInputEvent({ type: 'mouseUp', x, y, button: 'left', clickCount: 1 }) +} + +async function capture(win, name) { + const image = await win.webContents.capturePage() + fs.writeFileSync(path.join(OUTPUT_DIR, name), image.toPNG()) + const text = await exec(win, 'document.body ? document.body.innerText : ""') + fs.writeFileSync(path.join(OUTPUT_DIR, name.replace(/\.png$/, '.txt')), text || '') +} + +async function getJson(win, relativePath) { + return exec( + win, + `fetch(${JSON.stringify(relativePath)}, { credentials: 'include' }).then((r) => r.json())`, + ) +} + +async function setSessionCookie(win) { + const raw = fs.readFileSync(TOKEN_PATH, 'utf8') + const config = JSON.parse(raw) + const token = String(config.access_token || '').trim() + if (!token) { + throw new Error(`No access token found in ${TOKEN_PATH}`) + } + const payload = await exec( + win, + `fetch('/api/webui/auth/verify', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + credentials: 'include', + body: JSON.stringify({ token: ${JSON.stringify(token)} }), + }).then(async (response) => ({ + ok: response.ok, + status: response.status, + body: await response.json(), + }))`, + ) + if (!payload?.ok || !payload?.body?.valid) { + throw new Error(`Failed to authenticate WebUI token via /auth/verify: ${JSON.stringify(payload)}`) + } +} + +async function openImportTab(win) { + await exec(win, `(() => { + const tab = Array.from(document.querySelectorAll('[role="tab"]')).find((el) => (el.textContent || '').trim() === '导入') + if (!tab) return false + tab.dispatchEvent(new PointerEvent('pointerdown', { bubbles: true, cancelable: true, pointerId: 1, button: 0, pointerType: 'mouse', isPrimary: true })) + tab.dispatchEvent(new MouseEvent('mousedown', { bubbles: true, cancelable: true, button: 0 })) + tab.dispatchEvent(new MouseEvent('mouseup', { bubbles: true, cancelable: true, button: 0 })) + tab.dispatchEvent(new MouseEvent('click', { bubbles: true, cancelable: true, button: 0 })) + return true + })()`) + await waitFor( + win, + `document.body && document.body.innerText.includes('粘贴导入') && document.body.innerText.includes('创建导入任务')`, + 'import panel', + ) +} + +async function setJsonMode(win) { + const trigger = await exec(win, `(() => { + const label = Array.from(document.querySelectorAll('label')).find((node) => (node.textContent || '').includes('输入模式')) + const root = label?.closest('div')?.parentElement || label?.parentElement + const button = root?.querySelector('button') + if (!button) return null + const rect = button.getBoundingClientRect() + return { x: Math.round(rect.left + rect.width / 2), y: Math.round(rect.top + rect.height / 2) } + })()`) + if (!trigger) { + throw new Error('select trigger not found') + } + await sendClick(win, trigger.x, trigger.y) + await waitFor(win, `document.querySelectorAll('[role="option"]').length > 0`, 'select options', 5000, 200) + + const option = await exec(win, `(() => { + const item = Array.from(document.querySelectorAll('[role="option"]')).find((el) => (el.textContent || '').trim() === 'json') + if (!item) return null + const rect = item.getBoundingClientRect() + return { x: Math.round(rect.left + rect.width / 2), y: Math.round(rect.top + rect.height / 2) } + })()`) + if (!option) { + throw new Error('json option not found') + } + await sendClick(win, option.x, option.y) + await waitFor( + win, + `(() => { + const label = Array.from(document.querySelectorAll('label')).find((node) => (node.textContent || '').includes('输入模式')) + const root = label?.closest('div')?.parentElement || label?.parentElement + const button = root?.querySelector('button') + return (button?.textContent || '').trim() === 'json' + })()`, + 'json mode selected', + 8000, + 300, + ) +} + +async function typeIntoLabeled(win, labelText, selector, text) { + const rect = await exec(win, `(() => { + const label = Array.from(document.querySelectorAll('label')).find((node) => (node.textContent || '').includes(${JSON.stringify(labelText)})) + const root = label?.closest('div')?.parentElement || label?.parentElement + const el = root?.querySelector(${JSON.stringify(selector)}) + if (!el) return null + const r = el.getBoundingClientRect() + return { x: Math.round(r.left + 20), y: Math.round(r.top + 20) } + })()`) + if (!rect) { + throw new Error(`field not found: ${labelText}`) + } + await sendClick(win, rect.x, rect.y) + await wait(150) + await win.webContents.insertText(text) + await wait(250) +} + +async function clickButton(win, text) { + const ok = await exec(win, `(() => { + const target = Array.from(document.querySelectorAll('button')).find((el) => (el.textContent || '').includes(${JSON.stringify(text)})) + if (!target) return false + target.scrollIntoView({ block: 'center' }) + target.dispatchEvent(new PointerEvent('pointerdown', { bubbles: true, cancelable: true, pointerId: 1, button: 0, pointerType: 'mouse', isPrimary: true })) + target.dispatchEvent(new MouseEvent('mousedown', { bubbles: true, cancelable: true, button: 0 })) + target.dispatchEvent(new MouseEvent('mouseup', { bubbles: true, cancelable: true, button: 0 })) + target.dispatchEvent(new MouseEvent('click', { bubbles: true, cancelable: true, button: 0 })) + return true + })()`) + if (!ok) { + throw new Error(`button not found: ${text}`) + } +} + +async function clickTab(win, text) { + const ok = await exec(win, `(() => { + const target = Array.from(document.querySelectorAll('[role="tab"]')).find((el) => (el.textContent || '').includes(${JSON.stringify(text)})) + if (!target) return false + target.scrollIntoView({ block: 'center' }) + target.dispatchEvent(new PointerEvent('pointerdown', { bubbles: true, cancelable: true, pointerId: 1, button: 0, pointerType: 'mouse', isPrimary: true })) + target.dispatchEvent(new MouseEvent('mousedown', { bubbles: true, cancelable: true, button: 0 })) + target.dispatchEvent(new MouseEvent('mouseup', { bubbles: true, cancelable: true, button: 0 })) + target.dispatchEvent(new MouseEvent('click', { bubbles: true, cancelable: true, button: 0 })) + return true + })()`) + if (!ok) { + throw new Error(`tab not found: ${text}`) + } +} + +async function clickGraphElement(win, selector, index = 0) { + const rect = await exec(win, `(() => { + const targets = Array.from(document.querySelectorAll(${JSON.stringify(selector)})) + const target = targets[${index}] + if (!target) return null + target.scrollIntoView({ block: 'center', inline: 'center' }) + const r = target.getBoundingClientRect() + return { x: Math.round(r.left + r.width / 2), y: Math.round(r.top + r.height / 2) } + })()`) + if (!rect) { + throw new Error(`graph element not found: ${selector}[${index}]`) + } + await sendClick(win, rect.x, rect.y) +} + +async function capturePluginFilterState(win) { + await win.loadURL(`${DASHBOARD_URL}/plugin-config`) + await waitFor( + win, + `document.body && document.body.innerText.includes('插件配置') && document.querySelector('input[placeholder="搜索插件..."]')`, + 'plugin config page', + 30000, + 400, + ) + await exec(win, `(() => { + const input = document.querySelector('input[placeholder="搜索插件..."]') + if (!input) return false + const setter = Object.getOwnPropertyDescriptor(HTMLInputElement.prototype, 'value')?.set + setter?.call(input, 'memorix') + input.dispatchEvent(new Event('input', { bubbles: true })) + input.dispatchEvent(new Event('change', { bubbles: true })) + return true + })()`) + await wait(500) + await capture(win, '01-plugin-config-filtered.png') +} + +app.whenReady().then(async () => { + const win = new BrowserWindow({ + width: 1600, + height: 1200, + show: false, + webPreferences: { + contextIsolation: true, + nodeIntegration: false, + }, + }) + + await win.loadURL(`${DASHBOARD_URL}/auth`) + await waitFor(win, `document.readyState === 'complete'`, 'auth page') + await capture(win, '00-auth-login.png') + await setSessionCookie(win) + + await capturePluginFilterState(win) + + await win.loadURL(`${DASHBOARD_URL}/resource/knowledge-base`) + await waitFor( + win, + `document.body && document.body.innerText.includes('运行时自检') && document.body.innerText.includes('刷新数据')`, + 'memory console ready', + 30000, + 500, + ) + await capture(win, '02-memory-console-before-import.png') + + const beforeGraph = await getJson(win, '/api/webui/memory/graph?limit=120') + const beforeTasks = await getJson(win, '/api/webui/memory/import/tasks?limit=20') + const knownTaskIds = new Set( + Array.isArray(beforeTasks.items) + ? beforeTasks.items.map((item) => String(item.task_id || item.taskId || '')) + : [], + ) + + await openImportTab(win) + await setJsonMode(win) + await typeIntoLabeled(win, '名称', 'input', sampleName) + await typeIntoLabeled(win, '粘贴内容', 'textarea', sampleJson) + await capture(win, '03-memory-import-json-filled.png') + + await clickButton(win, '创建导入任务') + + let taskId = null + let taskStatus = null + const start = Date.now() + while (Date.now() - start < 120000) { + const payload = await getJson(win, '/api/webui/memory/import/tasks?limit=20') + fs.writeFileSync(path.join(OUTPUT_DIR, 'tasks-last.json'), JSON.stringify(payload, null, 2)) + const items = Array.isArray(payload.items) ? payload.items : [] + const task = items.find((item) => !knownTaskIds.has(String(item.task_id || item.taskId || ''))) + if (task) { + taskId = task.task_id || task.taskId || null + taskStatus = task.status || null + if (['completed', 'failed', 'cancelled'].includes(String(taskStatus))) { + break + } + } + await wait(1500) + } + + if (!taskId) { + throw new Error('new json import task not observed') + } + + const detail = await getJson( + win, + `/api/webui/memory/import/tasks/${encodeURIComponent(taskId)}?include_chunks=true`, + ) + fs.writeFileSync(path.join(OUTPUT_DIR, 'task-detail.json'), JSON.stringify(detail, null, 2)) + fs.writeFileSync( + path.join(OUTPUT_DIR, 'task-status.txt'), + `taskId=${taskId}\nstatus=${taskStatus}\nsource=${sampleSource}\n`, + ) + + await clickButton(win, '刷新数据') + await wait(2000) + await capture(win, '04-memory-console-after-import.png') + + await win.loadURL(`${DASHBOARD_URL}/resource/knowledge-graph`) + await waitFor( + win, + `document.body && document.body.innerText.includes('长期记忆图谱') && document.body.innerText.includes('实体关系图') && document.body.innerText.includes('证据视图')`, + 'graph page ready', + 30000, + 400, + ) + await wait(3000) + const afterGraph = await getJson(win, '/api/webui/memory/graph?limit=120') + fs.writeFileSync(path.join(OUTPUT_DIR, 'graph-after.json'), JSON.stringify(afterGraph, null, 2)) + await capture(win, '05-memory-graph-after-import.png') + + if (Array.isArray(afterGraph.nodes) && afterGraph.nodes.length > 0) { + await clickGraphElement(win, '.react-flow__node', 0) + await waitFor(win, `document.body && document.body.innerText.includes('实体详情')`, 'node detail dialog', 10000, 250) + await capture(win, '06-memory-node-detail.png') + try { + await clickButton(win, '切到证据视图') + await waitFor( + win, + `document.body && document.body.innerText.includes('证据视图') && document.querySelectorAll('.react-flow__node').length > 0`, + 'evidence graph after node click', + 10000, + 250, + ) + await capture(win, '07-memory-evidence-view.png') + } catch (error) { + fs.writeFileSync(path.join(OUTPUT_DIR, '07-memory-evidence-view-error.txt'), String(error?.stack || error)) + } + } + + if (Array.isArray(afterGraph.edges) && afterGraph.edges.length > 0) { + try { + await clickTab(win, '实体关系图') + await wait(800) + await clickGraphElement(win, '.react-flow__edge', 0) + await waitFor(win, `document.body && document.body.innerText.includes('关系详情')`, 'edge detail dialog', 10000, 250) + await capture(win, '08-memory-edge-detail.png') + } catch (error) { + fs.writeFileSync(path.join(OUTPUT_DIR, '08-memory-edge-detail-error.txt'), String(error?.stack || error)) + } + } + + const summary = { + before: { + nodes: beforeGraph.total_nodes, + edges: beforeGraph.total_edges, + }, + after: { + nodes: afterGraph.total_nodes, + edges: afterGraph.total_edges, + }, + taskId, + taskStatus, + source: sampleSource, + inputMode: detail?.task?.files?.[0]?.input_mode || null, + strategyType: detail?.task?.files?.[0]?.detected_strategy_type || null, + fileStatus: detail?.task?.files?.[0]?.status || null, + outputDir: OUTPUT_DIR, + } + fs.writeFileSync(path.join(OUTPUT_DIR, 'validation-summary.json'), JSON.stringify(summary, null, 2)) + console.log(JSON.stringify(summary, null, 2)) + + await win.close() + app.quit() +}).catch((error) => { + console.error(error) + app.exit(1) +}) diff --git a/dashboard/src/components/CodeEditor.tsx b/dashboard/src/components/CodeEditor.tsx index 00438655..ae928f90 100644 --- a/dashboard/src/components/CodeEditor.tsx +++ b/dashboard/src/components/CodeEditor.tsx @@ -2,6 +2,7 @@ import { useEffect, useState } from 'react' import CodeMirror from '@uiw/react-codemirror' import { css } from '@codemirror/lang-css' import { json, jsonParseLinter } from '@codemirror/lang-json' +import { linter } from '@codemirror/lint' import { python } from '@codemirror/lang-python' import { oneDark } from '@codemirror/theme-one-dark' import { EditorView } from '@codemirror/view' @@ -29,7 +30,7 @@ interface CodeEditorProps { // eslint-disable-next-line @typescript-eslint/no-explicit-any const languageExtensions: Record = { python: [python()], - json: [json(), jsonParseLinter()], + json: [json(), linter(jsonParseLinter())], toml: [StreamLanguage.define(tomlMode)], css: [css()], text: [], diff --git a/dashboard/src/components/memory/MemoryConfigEditor.tsx b/dashboard/src/components/memory/MemoryConfigEditor.tsx new file mode 100644 index 00000000..d9cec1c2 --- /dev/null +++ b/dashboard/src/components/memory/MemoryConfigEditor.tsx @@ -0,0 +1,311 @@ +import { useMemo, useState } from 'react' + +import { ListFieldEditor } from '@/components' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import { Label } from '@/components/ui/label' +import { Switch } from '@/components/ui/switch' +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' +import { Textarea } from '@/components/ui/textarea' +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select' +import type { ConfigFieldSchema, PluginConfigSchema } from '@/lib/plugin-api' + +interface MemoryConfigEditorProps { + schema: PluginConfigSchema + config: Record + onChange: (nextConfig: Record) => void + disabled?: boolean +} + +function getNestedRecord(config: Record, path: string): Record | undefined { + const parts = path.split('.').filter(Boolean) + let current: unknown = config + + for (const part of parts) { + if (!current || typeof current !== 'object' || Array.isArray(current)) { + return undefined + } + current = (current as Record)[part] + } + + if (!current || typeof current !== 'object' || Array.isArray(current)) { + return undefined + } + + return current as Record +} + +function setNestedField( + config: Record, + path: string, + fieldName: string, + value: unknown, +): Record { + const parts = path.split('.').filter(Boolean) + const nextConfig: Record = { ...config } + let target = nextConfig + let source: Record | undefined = config + + for (const part of parts) { + const sourceValue: unknown = source?.[part] + const nextValue = + sourceValue && typeof sourceValue === 'object' && !Array.isArray(sourceValue) + ? { ...(sourceValue as Record) } + : {} + target[part] = nextValue + target = nextValue + source = + sourceValue && typeof sourceValue === 'object' && !Array.isArray(sourceValue) + ? (sourceValue as Record) + : undefined + } + + target[fieldName] = value + return nextConfig +} + +function FieldRenderer({ + field, + value, + onChange, + disabled, +}: { + field: ConfigFieldSchema + value: unknown + onChange: (value: unknown) => void + disabled?: boolean +}) { + const [jsonDraft, setJsonDraft] = useState( + typeof value === 'string' ? String(value) : JSON.stringify(value ?? field.default ?? {}, null, 2), + ) + + switch (field.ui_type) { + case 'switch': + return ( +
+
+ + {field.hint &&

{field.hint}

} +
+ +
+ ) + + case 'number': + return ( +
+ + onChange(Number(event.target.value))} + min={field.min} + max={field.max} + step={field.step ?? 1} + disabled={disabled || field.disabled} + placeholder={field.placeholder} + /> + {field.hint &&

{field.hint}

} +
+ ) + + case 'select': + return ( +
+ + + {field.hint &&

{field.hint}

} +
+ ) + + case 'textarea': + return ( +
+ + +
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+ +
+
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+ +
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
将执行 staging 转换与切换,请确认输入目录正确。
+
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+ + +
+
+ +
+
+ +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+ + +
+
+
+ + + + +
+
+ +
+
+ + + +
+
任务队列轮询: 1000ms
+
+
+ + + +
+
运行中 / 准备中
+
+
排队中
+
+
最近完成
+
+
+
+ + +
+
+
任务详情
+
+
请选择任务查看详情
+ +
+
+ +
+
文件级状态
+
+ + + +
文件类型状态步骤进度统计
+
+
+ +
+
分块级状态
+
+ + + +
#类型状态步骤预览错误
+
+
0 / 0
+
+ + +
+
+
+
+
+ + + + + + diff --git a/src/A_memorix/web/index.html b/src/A_memorix/web/index.html new file mode 100644 index 00000000..36402286 --- /dev/null +++ b/src/A_memorix/web/index.html @@ -0,0 +1,3136 @@ + + + + + + A_Memorix | 知识全景图 + + + + + + + + + + + + +
+
+

正在同步全景知识图谱...

+

+ 初次使用?请在加载完成后点击操作指南了解基础操作 +

+
+ + +
+ + +
+ +
+
+ + +
+
+
🔄
+ 同步状态 +
+
+
📐
+ 重排布局 +
+
+
⏸️
+ 暂停模拟 +
+
+
📖
+ 内容字典 +
+
+
♻️
+ 回收站 +
+
+
+
+ 新增节点 +
+
+
📂
+ 记忆溯源 +
+
+
📥
+ 导入中心 +
+
+
🎯
+ 检索调优 +
+
+
👤
+ 人物画像 +
+
+
💾
+ 持久化 +
+
+
+
⚙️
+ 视图配置 +
+
+
+
+ 操作指南 +
+
+ + + + + +
+
+

属性信息

+
+
+
+ +
+
+ + +
-
+
{runtimeBadges.map((item) => ( @@ -790,13 +1759,23 @@ export function KnowledgeBasePage() { ))}
- - - 概览 - 配置 - 导入 - 调优 - 删除 + + + + 概览 + + + 配置 + + + 导入 + + + 调优 + + + 删除 + @@ -808,7 +1787,7 @@ export function KnowledgeBasePage() { 运行时自检 - 用于确认 embedding、向量库与运行时状态是否一致。 + 用于确认 embedding、向量库与运行时状态是否一致