diff --git a/.gitignore b/.gitignore index ace02c8b..36853b8a 100644 --- a/.gitignore +++ b/.gitignore @@ -339,6 +339,7 @@ run_pet.bat /plugins/* !/plugins +!/plugins/A_memorix !/plugins/hello_world_plugin !/plugins/emoji_manage_plugin !/plugins/take_picture_plugin diff --git a/plugins/A_memorix/__init__.py b/plugins/A_memorix/__init__.py new file mode 100644 index 00000000..d23a5bd5 --- /dev/null +++ b/plugins/A_memorix/__init__.py @@ -0,0 +1,12 @@ +""" +A_Memorix - 轻量级知识库插件 + +完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。 +""" + +__version__ = "2.0.0" +__author__ = "A_Dawn" + +from .plugin import AMemorixPlugin + +__all__ = ["AMemorixPlugin"] diff --git a/plugins/A_memorix/_manifest.json b/plugins/A_memorix/_manifest.json new file mode 100644 index 00000000..a45b2f73 --- /dev/null +++ b/plugins/A_memorix/_manifest.json @@ -0,0 +1,62 @@ +{ + "manifest_version": 1, + "name": "A_Memorix", + "version": "2.0.0", + "description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。", + "author": { + "name": "A_Dawn" + }, + "license": "AGPL-3.0", + "repository_url": "https://github.com/A-Dawn/A_memorix/", + "host_application": { + "min_version": "1.0.0" + }, + "keywords": [ + "memory", + "knowledge", + "retrieval", + "profile", + "episode" + ], + "categories": [ + "Memory", + "Data" + ], + "plugin_info": { + "is_built_in": false, + "plugin_type": "memory_provider", + "components": [ + { + "type": "tool", + "name": "search_memory", + "description": "搜索长期记忆" + }, + { + "type": "tool", + "name": "ingest_summary", + "description": "写入聊天摘要" + }, + { + "type": "tool", + "name": "ingest_text", + "description": "写入普通长期记忆文本" + }, + { + "type": "tool", + "name": "get_person_profile", + "description": "查询人物画像" + }, + { + "type": "tool", + "name": "maintain_memory", + "description": "维护记忆关系" + }, + { + "type": "tool", + "name": "memory_stats", + "description": "查询记忆统计" + } + ] + }, + "capabilities": [] +} diff --git a/plugins/A_memorix/core/__init__.py b/plugins/A_memorix/core/__init__.py new file mode 100644 index 00000000..3f87929c --- /dev/null +++ b/plugins/A_memorix/core/__init__.py @@ -0,0 +1,84 @@ +"""核心模块 - 存储、嵌入、检索引擎""" + +# 存储模块(已实现) +from .storage import ( + VectorStore, + GraphStore, + MetadataStore, + ImportStrategy, + KnowledgeType, + parse_import_strategy, + resolve_stored_knowledge_type, + detect_knowledge_type, + select_import_strategy, + should_extract_relations, + get_type_display_name, +) + +# 嵌入模块(使用主程序 API) +from .embedding import ( + EmbeddingAPIAdapter, + create_embedding_api_adapter, +) + +# 检索模块(已实现) +from .retrieval import ( + DualPathRetriever, + RetrievalStrategy, + RetrievalResult, + DualPathRetrieverConfig, + TemporalQueryOptions, + FusionConfig, + GraphRelationRecallConfig, + RelationIntentConfig, + PersonalizedPageRank, + PageRankConfig, + create_ppr_from_graph, + DynamicThresholdFilter, + ThresholdMethod, + ThresholdConfig, + SparseBM25Index, + SparseBM25Config, +) +from .utils import ( + RelationWriteService, + RelationWriteResult, +) + +__all__ = [ + # Storage + "VectorStore", + "GraphStore", + "MetadataStore", + "ImportStrategy", + "KnowledgeType", + "parse_import_strategy", + "resolve_stored_knowledge_type", + "detect_knowledge_type", + "select_import_strategy", + "should_extract_relations", + "get_type_display_name", + # Embedding + "EmbeddingAPIAdapter", + "create_embedding_api_adapter", + # Retrieval + "DualPathRetriever", + "RetrievalStrategy", + "RetrievalResult", + "DualPathRetrieverConfig", + "TemporalQueryOptions", + "FusionConfig", + "GraphRelationRecallConfig", + "RelationIntentConfig", + "PersonalizedPageRank", + "PageRankConfig", + "create_ppr_from_graph", + "DynamicThresholdFilter", + "ThresholdMethod", + "ThresholdConfig", + "SparseBM25Index", + "SparseBM25Config", + "RelationWriteService", + "RelationWriteResult", +] + diff --git a/plugins/A_memorix/core/embedding/__init__.py b/plugins/A_memorix/core/embedding/__init__.py new file mode 100644 index 00000000..11a52db9 --- /dev/null +++ b/plugins/A_memorix/core/embedding/__init__.py @@ -0,0 +1,18 @@ +"""嵌入模块 - 向量生成与量化""" + +# 新的 API 适配器(主程序嵌入 API) +from .api_adapter import ( + EmbeddingAPIAdapter, + create_embedding_api_adapter, +) + +from ..utils.quantization import QuantizationType + +__all__ = [ + # 新的 API 适配器(推荐使用) + "EmbeddingAPIAdapter", + "create_embedding_api_adapter", + # 量化 + "QuantizationType", +] + diff --git a/plugins/A_memorix/core/embedding/api_adapter.py b/plugins/A_memorix/core/embedding/api_adapter.py new file mode 100644 index 00000000..4262ddb9 --- /dev/null +++ b/plugins/A_memorix/core/embedding/api_adapter.py @@ -0,0 +1,174 @@ +""" +Hash-based embedding adapter used by the SDK runtime. + +The plugin runtime cannot import MaiBot host embedding internals from ``src.chat`` +or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in +Runner by generating deterministic dense vectors locally. +""" + +from __future__ import annotations + +import hashlib +import re +import time +from typing import List, Optional, Union + +import numpy as np + +from src.common.logger import get_logger + + +logger = get_logger("A_Memorix.EmbeddingAPIAdapter") + +_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}") + + +class EmbeddingAPIAdapter: + """Deterministic local embedding adapter.""" + + def __init__( + self, + batch_size: int = 32, + max_concurrent: int = 5, + default_dimension: int = 256, + enable_cache: bool = False, + model_name: str = "hash-v1", + retry_config: Optional[dict] = None, + ) -> None: + self.batch_size = max(1, int(batch_size)) + self.max_concurrent = max(1, int(max_concurrent)) + self.default_dimension = max(32, int(default_dimension)) + self.enable_cache = bool(enable_cache) + self.model_name = str(model_name or "hash-v1") + self.retry_config = retry_config or {} + + self._dimension: Optional[int] = None + self._dimension_detected = False + self._total_encoded = 0 + self._total_errors = 0 + self._total_time = 0.0 + + logger.info( + "EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s", + self.model_name, + self.batch_size, + self.default_dimension, + ) + + async def _detect_dimension(self) -> int: + if self._dimension_detected and self._dimension is not None: + return self._dimension + self._dimension = self.default_dimension + self._dimension_detected = True + return self._dimension + + @staticmethod + def _tokenize(text: str) -> List[str]: + clean = str(text or "").strip().lower() + if not clean: + return [] + return _TOKEN_PATTERN.findall(clean) + + @staticmethod + def _feature_weight(token: str) -> float: + digest = hashlib.sha256(token.encode("utf-8")).digest() + return 1.0 + (digest[10] / 255.0) * 0.5 + + def _encode_single(self, text: str, dimension: int) -> np.ndarray: + vector = np.zeros(dimension, dtype=np.float32) + content = str(text or "").strip() + tokens = self._tokenize(content) + if not tokens and content: + tokens = [content.lower()] + if not tokens: + vector[0] = 1.0 + return vector + + for token in tokens: + digest = hashlib.sha256(token.encode("utf-8")).digest() + bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension + sign = 1.0 if digest[8] % 2 == 0 else -1.0 + vector[bucket] += sign * self._feature_weight(token) + + second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension + if second_bucket != bucket: + vector[second_bucket] += (sign * 0.35) + + norm = float(np.linalg.norm(vector)) + if norm > 1e-8: + vector /= norm + else: + vector[0] = 1.0 + return vector + + async def encode( + self, + texts: Union[str, List[str]], + batch_size: Optional[int] = None, + show_progress: bool = False, + normalize: bool = True, + dimensions: Optional[int] = None, + ) -> np.ndarray: + _ = batch_size + _ = show_progress + _ = normalize + + started_at = time.time() + target_dimension = max(32, int(dimensions or await self._detect_dimension())) + + if isinstance(texts, str): + single_input = True + normalized_texts = [texts] + else: + single_input = False + normalized_texts = list(texts or []) + + if not normalized_texts: + empty = np.zeros((0, target_dimension), dtype=np.float32) + return empty[0] if single_input else empty + + try: + matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts]) + self._total_encoded += len(normalized_texts) + self._total_time += time.time() - started_at + except Exception: + self._total_errors += 1 + raise + + return matrix[0] if single_input else matrix + + def get_statistics(self) -> dict: + avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0 + return { + "model_name": self.model_name, + "dimension": self._dimension or self.default_dimension, + "total_encoded": self._total_encoded, + "total_errors": self._total_errors, + "total_time": self._total_time, + "avg_time_per_text": avg_time, + } + + def __repr__(self) -> str: + return ( + f"EmbeddingAPIAdapter(model_name={self.model_name}, " + f"dimension={self._dimension or self.default_dimension}, " + f"total_encoded={self._total_encoded})" + ) + + +def create_embedding_api_adapter( + batch_size: int = 32, + max_concurrent: int = 5, + default_dimension: int = 256, + enable_cache: bool = False, + model_name: str = "hash-v1", + retry_config: Optional[dict] = None, +) -> EmbeddingAPIAdapter: + return EmbeddingAPIAdapter( + batch_size=batch_size, + max_concurrent=max_concurrent, + default_dimension=default_dimension, + enable_cache=enable_cache, + model_name=model_name, + retry_config=retry_config, + ) diff --git a/plugins/A_memorix/core/embedding/manager.py b/plugins/A_memorix/core/embedding/manager.py new file mode 100644 index 00000000..d161e23b --- /dev/null +++ b/plugins/A_memorix/core/embedding/manager.py @@ -0,0 +1,510 @@ +""" +嵌入管理器 + +负责嵌入模型的加载、缓存和批量生成。 +""" + +import hashlib +import pickle +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path +from typing import Optional, Union, List, Dict, Any, Tuple + +import numpy as np + +try: + from sentence_transformers import SentenceTransformer + HAS_SENTENCE_TRANSFORMERS = True +except ImportError: + HAS_SENTENCE_TRANSFORMERS = False + +from src.common.logger import get_logger +from .presets import ( + EmbeddingModelConfig, + get_custom_config, + validate_config_compatibility, + are_models_compatible, +) +from ..utils.quantization import QuantizationType + +logger = get_logger("A_Memorix.EmbeddingManager") + + +class EmbeddingManager: + """ + 嵌入管理器 + + 功能: + - 模型加载与缓存 + - 批量生成嵌入 + - 多线程/多进程支持 + - 模型一致性检查 + - 智能分批 + + 参数: + config: 模型配置 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + """ + + def __init__( + self, + config: EmbeddingModelConfig, + cache_dir: Optional[Union[str, Path]] = None, + enable_cache: bool = True, + num_workers: int = 1, + ): + """ + 初始化嵌入管理器 + + Args: + config: 模型配置 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + """ + if not HAS_SENTENCE_TRANSFORMERS: + raise ImportError( + "sentence-transformers 未安装,请安装: " + "pip install sentence-transformers" + ) + + self.config = config + self.cache_dir = Path(cache_dir) if cache_dir else None + self.enable_cache = enable_cache + self.num_workers = max(1, num_workers) + + # 模型实例 + self._model: Optional[SentenceTransformer] = None + self._model_lock = threading.Lock() + + # 缓存 + self._embedding_cache: Dict[str, np.ndarray] = {} + self._cache_lock = threading.Lock() + + # 统计 + self._total_encoded = 0 + self._cache_hits = 0 + self._cache_misses = 0 + + logger.info( + f"EmbeddingManager 初始化: model={config.model_name}, " + f"dim={config.dimension}, workers={num_workers}" + ) + + def load_model(self) -> None: + """加载模型(懒加载)""" + if self._model is not None: + return + + with self._model_lock: + # 双重检查 + if self._model is not None: + return + + logger.info(f"正在加载模型: {self.config.model_name}") + + try: + # 构建模型参数 + model_kwargs = {} + if self.config.cache_dir: + model_kwargs["cache_folder"] = self.config.cache_dir + + # 加载模型 + self._model = SentenceTransformer( + self.config.model_path, + **model_kwargs, + ) + + logger.info(f"模型加载成功: {self.config.model_name}") + + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + def encode( + self, + texts: Union[str, List[str]], + batch_size: Optional[int] = None, + show_progress: bool = False, + normalize: bool = True, + ) -> np.ndarray: + """ + 生成文本嵌入 + + Args: + texts: 文本或文本列表 + batch_size: 批次大小(默认使用配置值) + show_progress: 是否显示进度条 + normalize: 是否归一化 + + Returns: + 嵌入向量 (N x D) + """ + # 确保模型已加载 + self.load_model() + + # 标准化输入 + if isinstance(texts, str): + texts = [texts] + single_input = True + else: + single_input = False + + if not texts: + return np.zeros((0, self.config.dimension), dtype=np.float32) + + # 使用配置的批次大小 + if batch_size is None: + batch_size = self.config.batch_size + + # 生成嵌入 + try: + embeddings = self._model.encode( + texts, + batch_size=batch_size, + show_progress_bar=show_progress, + normalize_embeddings=normalize and self.config.normalization, + convert_to_numpy=True, + ) + + # 确保是2D数组 + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) + + self._total_encoded += len(texts) + + # 如果是单个输入,返回1D数组 + if single_input: + return embeddings[0] + + return embeddings + + except Exception as e: + logger.error(f"生成嵌入失败: {e}") + raise + + def encode_batch( + self, + texts: List[str], + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + show_progress: bool = False, + ) -> np.ndarray: + """ + 批量生成嵌入(多线程优化) + + Args: + texts: 文本列表 + batch_size: 批次大小 + num_workers: 工作线程数(默认使用初始化时的值) + show_progress: 是否显示进度条 + + Returns: + 嵌入向量 (N x D) + """ + if not texts: + return np.zeros((0, self.config.dimension), dtype=np.float32) + + # 单线程模式 + num_workers = num_workers if num_workers is not None else self.num_workers + if num_workers == 1: + return self.encode(texts, batch_size=batch_size, show_progress=show_progress) + + # 多线程模式 + logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入") + + # 分批 + batch_size = batch_size or self.config.batch_size + batches = [ + texts[i:i + batch_size] + for i in range(0, len(texts), batch_size) + ] + + # 多线程生成 + all_embeddings = [] + with ThreadPoolExecutor(max_workers=num_workers) as executor: + # 提交任务 + future_to_batch = { + executor.submit( + self.encode, + batch, + batch_size, + False, # 不显示进度条(多线程时会混乱) + ): i + for i, batch in enumerate(batches) + } + + # 收集结果 + for future in as_completed(future_to_batch): + batch_idx = future_to_batch[future] + try: + embeddings = future.result() + all_embeddings.append((batch_idx, embeddings)) + except Exception as e: + logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}") + raise + + # 按顺序合并 + all_embeddings.sort(key=lambda x: x[0]) + final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0) + + return final_embeddings + + def encode_with_cache( + self, + texts: List[str], + batch_size: Optional[int] = None, + show_progress: bool = False, + ) -> np.ndarray: + """ + 生成嵌入(带缓存) + + Args: + texts: 文本列表 + batch_size: 批次大小 + show_progress: 是否显示进度条 + + Returns: + 嵌入向量 (N x D) + """ + if not self.enable_cache: + return self.encode(texts, batch_size, show_progress) + + # 分离缓存命中和未命中的文本 + cached_embeddings = [] + uncached_texts = [] + uncached_indices = [] + + for i, text in enumerate(texts): + cache_key = self._get_cache_key(text) + + with self._cache_lock: + if cache_key in self._embedding_cache: + cached_embeddings.append((i, self._embedding_cache[cache_key])) + self._cache_hits += 1 + else: + uncached_texts.append(text) + uncached_indices.append(i) + self._cache_misses += 1 + + # 生成未缓存的嵌入 + if uncached_texts: + new_embeddings = self.encode( + uncached_texts, + batch_size, + show_progress, + ) + + # 更新缓存 + with self._cache_lock: + for text, embedding in zip(uncached_texts, new_embeddings): + cache_key = self._get_cache_key(text) + self._embedding_cache[cache_key] = embedding.copy() + + # 合并结果 + for idx, embedding in zip(uncached_indices, new_embeddings): + cached_embeddings.append((idx, embedding)) + + # 按原始顺序排序 + cached_embeddings.sort(key=lambda x: x[0]) + final_embeddings = np.array([emb for _, emb in cached_embeddings]) + + return final_embeddings + + def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: + """ + 保存缓存到磁盘 + + Args: + cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) + """ + if cache_path is None: + if self.cache_dir is None: + raise ValueError("未指定缓存目录") + cache_path = self.cache_dir / "embeddings_cache.pkl" + + cache_path = Path(cache_path) + cache_path.parent.mkdir(parents=True, exist_ok=True) + + with self._cache_lock: + with open(cache_path, "wb") as f: + pickle.dump(self._embedding_cache, f) + + logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)") + + def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None: + """ + 从磁盘加载缓存 + + Args: + cache_path: 缓存文件路径(默认使用cache_dir/embeddings_cache.pkl) + """ + if cache_path is None: + if self.cache_dir is None: + raise ValueError("未指定缓存目录") + cache_path = self.cache_dir / "embeddings_cache.pkl" + + cache_path = Path(cache_path) + if not cache_path.exists(): + logger.warning(f"缓存文件不存在: {cache_path}") + return + + with self._cache_lock: + with open(cache_path, "rb") as f: + self._embedding_cache = pickle.load(f) + + logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)") + + def clear_cache(self) -> None: + """清空缓存""" + with self._cache_lock: + count = len(self._embedding_cache) + self._embedding_cache.clear() + logger.info(f"已清空缓存: {count} 条") + + def check_model_consistency( + self, + stored_embeddings: np.ndarray, + sample_texts: List[str] = None, + ) -> Tuple[bool, str]: + """ + 检查模型一致性 + + Args: + stored_embeddings: 存储的嵌入向量 + sample_texts: 样本文本(用于重新生成对比) + + Returns: + (是否一致, 详细信息) + """ + # 检查维度 + if stored_embeddings.shape[1] != self.config.dimension: + return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}" + + # 如果提供了样本文本,重新生成并比较 + if sample_texts: + try: + new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个 + + # 计算相似度 + similarities = np.dot( + stored_embeddings[:5], + new_embeddings.T, + ).diagonal() + + # 检查相似度 + if np.mean(similarities) < 0.95: + return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}" + + return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}" + + except Exception as e: + return False, f"一致性检查失败: {e}" + + return True, "维度匹配" + + def get_model_info(self) -> Dict[str, Any]: + """ + 获取模型信息 + + Returns: + 模型信息字典 + """ + return { + "model_name": self.config.model_name, + "dimension": self.config.dimension, + "max_seq_length": self.config.max_seq_length, + "batch_size": self.config.batch_size, + "normalization": self.config.normalization, + "pooling": self.config.pooling, + "model_loaded": self._model is not None, + "cache_enabled": self.enable_cache, + "cache_size": len(self._embedding_cache), + "total_encoded": self._total_encoded, + "cache_hits": self._cache_hits, + "cache_misses": self._cache_misses, + } + + def get_embedding_dimension(self) -> int: + """获取嵌入维度""" + return self.config.dimension + + def _get_cache_key(self, text: str) -> str: + """ + 生成缓存键 + + Args: + text: 文本内容 + + Returns: + 缓存键(SHA256哈希) + """ + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + @property + def is_model_loaded(self) -> bool: + """模型是否已加载""" + return self._model is not None + + @property + def cache_hit_rate(self) -> float: + """缓存命中率""" + total = self._cache_hits + self._cache_misses + if total == 0: + return 0.0 + return self._cache_hits / total + + def __repr__(self) -> str: + return ( + f"EmbeddingManager(model={self.config.model_name}, " + f"dim={self.config.dimension}, " + f"loaded={self.is_model_loaded}, " + f"cache={len(self._embedding_cache)})" + ) + + + + +def create_embedding_manager_from_config( + model_name: str, + model_path: str, + dimension: int, + cache_dir: Optional[Union[str, Path]] = None, + enable_cache: bool = True, + num_workers: int = 1, + **config_kwargs, +) -> EmbeddingManager: + """ + 从自定义配置创建嵌入管理器 + + Args: + model_name: 模型名称 + model_path: HuggingFace模型路径 + dimension: 输出维度 + cache_dir: 缓存目录 + enable_cache: 是否启用缓存 + num_workers: 工作线程数 + **config_kwargs: 其他配置参数 + + Returns: + 嵌入管理器实例 + """ + # 创建自定义配置 + config = get_custom_config( + model_name=model_name, + model_path=model_path, + dimension=dimension, + cache_dir=cache_dir, + **config_kwargs, + ) + + # 创建管理器 + return EmbeddingManager( + config=config, + cache_dir=cache_dir, + enable_cache=enable_cache, + num_workers=num_workers, + ) diff --git a/plugins/A_memorix/core/embedding/presets.py b/plugins/A_memorix/core/embedding/presets.py new file mode 100644 index 00000000..54e6f8b4 --- /dev/null +++ b/plugins/A_memorix/core/embedding/presets.py @@ -0,0 +1,72 @@ +""" +嵌入模型配置模块 +""" + +from dataclasses import dataclass +from typing import Optional, Dict, Any, Union +from pathlib import Path + + +@dataclass +class EmbeddingModelConfig: + """ + 嵌入模型配置 + + 属性: + model_name: 模型描述名称 + model_path: 实际加载路径(Local or HF) + dimension: 嵌入向量维度 + max_seq_length: 最大序列长度 + batch_size: 编码批次大小 + model_size_mb: 估计显存占用 + description: 模型说明 + normalization: 是否自动归一化 + pooling: 池化策略 (mean, cls, max) + cache_dir: 模型缓存目录 + """ + + model_name: str + model_path: str + dimension: int + max_seq_length: int = 512 + batch_size: int = 32 + model_size_mb: int = 100 + description: str = "" + normalization: bool = True + pooling: str = "mean" + cache_dir: Optional[Union[str, Path]] = None + + +def validate_config_compatibility( + config1: EmbeddingModelConfig, config2: EmbeddingModelConfig +) -> bool: + """检查两个配置是否兼容(主要看维度)""" + return config1.dimension == config2.dimension + + +def are_models_compatible( + config1: EmbeddingModelConfig, config2: EmbeddingModelConfig +) -> bool: + """检查模型是否完全相同(用于热切换判断)""" + return ( + config1.model_path == config2.model_path + and config1.dimension == config2.dimension + and config1.pooling == config2.pooling + ) + + +def get_custom_config( + model_name: str, + model_path: str, + dimension: int, + cache_dir: Optional[Union[str, Path]] = None, + **kwargs, +) -> EmbeddingModelConfig: + """创建自定义模型配置""" + return EmbeddingModelConfig( + model_name=model_name, + model_path=model_path, + dimension=dimension, + cache_dir=cache_dir, + **kwargs, + ) diff --git a/plugins/A_memorix/core/retrieval/__init__.py b/plugins/A_memorix/core/retrieval/__init__.py new file mode 100644 index 00000000..6efce7f6 --- /dev/null +++ b/plugins/A_memorix/core/retrieval/__init__.py @@ -0,0 +1,54 @@ +"""检索模块 - 双路检索与排序""" + +from .dual_path import ( + DualPathRetriever, + RetrievalStrategy, + RetrievalResult, + DualPathRetrieverConfig, + TemporalQueryOptions, + FusionConfig, + RelationIntentConfig, +) +from .pagerank import ( + PersonalizedPageRank, + PageRankConfig, + create_ppr_from_graph, +) +from .threshold import ( + DynamicThresholdFilter, + ThresholdMethod, + ThresholdConfig, +) +from .sparse_bm25 import ( + SparseBM25Index, + SparseBM25Config, +) +from .graph_relation_recall import ( + GraphRelationRecallConfig, + GraphRelationRecallService, +) + +__all__ = [ + # DualPathRetriever + "DualPathRetriever", + "RetrievalStrategy", + "RetrievalResult", + "DualPathRetrieverConfig", + "TemporalQueryOptions", + "FusionConfig", + "RelationIntentConfig", + # PersonalizedPageRank + "PersonalizedPageRank", + "PageRankConfig", + "create_ppr_from_graph", + # DynamicThresholdFilter + "DynamicThresholdFilter", + "ThresholdMethod", + "ThresholdConfig", + # Sparse BM25 + "SparseBM25Index", + "SparseBM25Config", + # Graph relation recall + "GraphRelationRecallConfig", + "GraphRelationRecallService", +] diff --git a/plugins/A_memorix/core/retrieval/dual_path.py b/plugins/A_memorix/core/retrieval/dual_path.py new file mode 100644 index 00000000..cfeb343c --- /dev/null +++ b/plugins/A_memorix/core/retrieval/dual_path.py @@ -0,0 +1,1796 @@ +""" +双路检索器 + +同时检索关系和段落,实现知识图谱增强的检索。 +""" + +import asyncio +import re +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Tuple, Union +from enum import Enum + +import numpy as np + +from src.common.logger import get_logger +from ..storage import VectorStore, GraphStore, MetadataStore +from ..embedding import EmbeddingAPIAdapter +from ..utils.matcher import AhoCorasick +from ..utils.time_parser import format_timestamp +from .graph_relation_recall import GraphRelationRecallConfig, GraphRelationRecallService +from .pagerank import PersonalizedPageRank, PageRankConfig +from .sparse_bm25 import SparseBM25Config, SparseBM25Index + +logger = get_logger("A_Memorix.DualPathRetriever") + + +class RetrievalStrategy(Enum): + """检索策略""" + + PARA_ONLY = "paragraph_only" # 仅段落检索 + REL_ONLY = "relation_only" # 仅关系检索 + DUAL_PATH = "dual_path" # 双路检索(推荐) + + +@dataclass +class RetrievalResult: + """ + 检索结果 + + 属性: + hash_value: 哈希值 + content: 内容(段落或关系) + score: 相似度分数 + result_type: 结果类型(paragraph/relation) + source: 来源(paragraph_search/relation_search/fusion) + metadata: 额外元数据 + """ + + hash_value: str + content: str + score: float + result_type: str # "paragraph" or "relation" + source: str # "paragraph_search", "relation_search", "fusion" + metadata: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "hash": self.hash_value, + "content": self.content, + "score": self.score, + "type": self.result_type, + "source": self.source, + "metadata": self.metadata, + } + + +@dataclass +class DualPathRetrieverConfig: + """ + 双路检索器配置 + + 属性: + top_k_paragraphs: 段落检索数量 + top_k_relations: 关系检索数量 + top_k_final: 最终返回数量 + alpha: 段落和关系的融合权重(0-1) + - 0: 仅使用关系分数 + - 1: 仅使用段落分数 + - 0.5: 平均融合 + enable_ppr: 是否启用PageRank重排序 + ppr_alpha: PageRank的alpha参数 + ppr_concurrency_limit: PPR计算的最大并发数 + enable_parallel: 是否并行检索 + retrieval_strategy: 检索策略 + debug: 是否启用调试模式(打印搜索结果原文) + """ + + top_k_paragraphs: int = 20 + top_k_relations: int = 10 + top_k_final: int = 10 + alpha: float = 0.5 # 融合权重 + enable_ppr: bool = True + ppr_alpha: float = 0.85 + ppr_timeout_seconds: float = 1.5 + ppr_concurrency_limit: int = 4 + enable_parallel: bool = True + retrieval_strategy: RetrievalStrategy = RetrievalStrategy.DUAL_PATH + debug: bool = False + sparse: SparseBM25Config = field(default_factory=SparseBM25Config) + fusion: "FusionConfig" = field(default_factory=lambda: FusionConfig()) + relation_intent: "RelationIntentConfig" = field(default_factory=lambda: RelationIntentConfig()) + graph_recall: GraphRelationRecallConfig = field(default_factory=GraphRelationRecallConfig) + + def __post_init__(self): + """验证配置""" + if isinstance(self.sparse, dict): + self.sparse = SparseBM25Config(**self.sparse) + if isinstance(self.fusion, dict): + self.fusion = FusionConfig(**self.fusion) + if isinstance(self.relation_intent, dict): + self.relation_intent = RelationIntentConfig(**self.relation_intent) + if isinstance(self.graph_recall, dict): + self.graph_recall = GraphRelationRecallConfig(**self.graph_recall) + + if not 0 <= self.alpha <= 1: + raise ValueError(f"alpha必须在[0, 1]之间: {self.alpha}") + + if self.top_k_paragraphs <= 0: + raise ValueError(f"top_k_paragraphs必须大于0: {self.top_k_paragraphs}") + + if self.top_k_relations <= 0: + raise ValueError(f"top_k_relations必须大于0: {self.top_k_relations}") + + if self.top_k_final <= 0: + raise ValueError(f"top_k_final必须大于0: {self.top_k_final}") + if self.ppr_timeout_seconds <= 0: + raise ValueError(f"ppr_timeout_seconds必须大于0: {self.ppr_timeout_seconds}") + + +@dataclass +class TemporalQueryOptions: + """时序查询选项。""" + + time_from: Optional[float] = None + time_to: Optional[float] = None + person: Optional[str] = None + source: Optional[str] = None + allow_created_fallback: bool = True + candidate_multiplier: int = 8 + max_scan: int = 1000 + + +@dataclass +class RelationIntentConfig: + """关系意图增强配置。""" + + enabled: bool = True + alpha_override: float = 0.35 + relation_candidate_multiplier: int = 4 + preserve_top_relations: int = 3 + force_relation_sparse: bool = True + pair_predicate_rerank_enabled: bool = True + pair_predicate_limit: int = 3 + + def __post_init__(self): + self.alpha_override = min(1.0, max(0.0, float(self.alpha_override))) + self.relation_candidate_multiplier = max(1, int(self.relation_candidate_multiplier)) + self.preserve_top_relations = max(0, int(self.preserve_top_relations)) + self.force_relation_sparse = bool(self.force_relation_sparse) + self.pair_predicate_rerank_enabled = bool(self.pair_predicate_rerank_enabled) + self.pair_predicate_limit = max(1, int(self.pair_predicate_limit)) + + +@dataclass +class FusionConfig: + """融合配置。""" + + method: str = "weighted_rrf" # weighted_rrf | alpha_legacy + rrf_k: int = 60 + vector_weight: float = 0.7 + bm25_weight: float = 0.3 + normalize_score: bool = True + normalize_method: str = "minmax" + + def __post_init__(self): + self.method = str(self.method or "weighted_rrf").strip().lower() + self.normalize_method = str(self.normalize_method or "minmax").strip().lower() + self.rrf_k = max(1, int(self.rrf_k)) + self.vector_weight = max(0.0, float(self.vector_weight)) + self.bm25_weight = max(0.0, float(self.bm25_weight)) + s = self.vector_weight + self.bm25_weight + if s <= 0: + self.vector_weight = 0.7 + self.bm25_weight = 0.3 + elif abs(s - 1.0) > 1e-8: + self.vector_weight /= s + self.bm25_weight /= s + + +class DualPathRetriever: + """ + 双路检索器 + + 功能: + - 并行检索段落和关系 + - 结果融合与排序 + - PageRank重排序 + - 实体识别与加权 + + 参数: + vector_store: 向量存储 + graph_store: 图存储 + metadata_store: 元数据存储 + embedding_manager: 嵌入管理器 + config: 检索配置 + """ + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + metadata_store: MetadataStore, + embedding_manager: EmbeddingAPIAdapter, + sparse_index: Optional[SparseBM25Index] = None, + config: Optional[DualPathRetrieverConfig] = None, + ): + """ + 初始化双路检索器 + + Args: + vector_store: 向量存储 + graph_store: 图存储 + metadata_store: 元数据存储 + embedding_manager: 嵌入管理器 + config: 检索配置 + """ + self.vector_store = vector_store + self.graph_store = graph_store + self.metadata_store = metadata_store + self.embedding_manager = embedding_manager + self.config = config or DualPathRetrieverConfig() + self.sparse_index = sparse_index + + # PageRank计算器 + ppr_config = PageRankConfig(alpha=self.config.ppr_alpha) + self._ppr = PersonalizedPageRank( + graph_store=graph_store, + config=ppr_config, + ) + self._ppr_semaphore = asyncio.Semaphore(self.config.ppr_concurrency_limit) + self._graph_relation_recall = GraphRelationRecallService( + graph_store=graph_store, + metadata_store=metadata_store, + config=self.config.graph_recall, + ) + + logger.info( + f"DualPathRetriever 初始化: " + f"strategy={self.config.retrieval_strategy.value}, " + f"top_k_para={self.config.top_k_paragraphs}, " + f"top_k_rel={self.config.top_k_relations}" + ) + + # 缓存 Aho-Corasick 匹配器 + self._ac_matcher: Optional[AhoCorasick] = None + self._ac_nodes_count = 0 + self._relation_intent_pattern = re.compile( + r"(什么关系|有哪些关系|和.+关系|关联|关系网|subject|predicate|object|" + r"relation|related|between.+and)", + re.IGNORECASE, + ) + + async def retrieve( + self, + query: str, + top_k: Optional[int] = None, + strategy: Optional[RetrievalStrategy] = None, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 执行检索(异步方法) + + Args: + query: 查询文本 + top_k: 返回结果数量(默认使用配置值) + strategy: 检索策略(默认使用配置值) + temporal: 时序查询选项(可选) + + Returns: + 检索结果列表 + """ + top_k = top_k or self.config.top_k_final + strategy = strategy or self.config.retrieval_strategy + relation_intent_ctx = self._build_relation_intent_context(query=query, top_k=top_k) + + logger.info( + "执行检索: query='%s...', strategy=%s, relation_intent=%s", + query[:50], + strategy.value, + relation_intent_ctx.get("enabled", False), + ) + + if temporal and not (query or "").strip(): + return self._retrieve_temporal_only(temporal, top_k) + + # 根据策略执行检索 + if strategy == RetrievalStrategy.PARA_ONLY: + results = await self._retrieve_paragraphs_only(query, top_k, temporal=temporal) + elif strategy == RetrievalStrategy.REL_ONLY: + results = await self._retrieve_relations_only(query, top_k, temporal=temporal) + else: # DUAL_PATH + results = await self._retrieve_dual_path( + query, + top_k, + temporal=temporal, + relation_intent=relation_intent_ctx, + ) + + logger.info(f"检索完成: 返回 {len(results)} 条结果") + + # 调试模式:打印结果原文 + if self.config.debug: + logger.info(f"[DEBUG] 检索结果内容原文:") + for i, res in enumerate(results): + logger.info(f" {i+1}. [{res.result_type}] (Score: {res.score:.4f}) {res.content}") + + return results + + def _is_relation_intent_query(self, query: str) -> bool: + q = str(query or "").strip() + if not q: + return False + if "|" in q or "->" in q: + return True + return self._relation_intent_pattern.search(q) is not None + + def _build_relation_intent_context(self, query: str, top_k: int) -> Dict[str, Any]: + cfg = self.config.relation_intent + enabled = bool(cfg.enabled) and self._is_relation_intent_query(query) + base_relation_k = max(1, int(self.config.top_k_relations)) + relation_top_k = max(base_relation_k, int(top_k)) + if enabled: + relation_top_k = max( + relation_top_k, + relation_top_k * int(cfg.relation_candidate_multiplier), + ) + return { + "enabled": enabled, + "alpha_override": float(cfg.alpha_override) if enabled else None, + "relation_top_k": int(relation_top_k), + "preserve_top_relations": int(cfg.preserve_top_relations) if enabled else 0, + "force_relation_sparse": bool(cfg.force_relation_sparse) if enabled else False, + "pair_predicate_rerank_enabled": bool(cfg.pair_predicate_rerank_enabled) if enabled else False, + "pair_predicate_limit": int(cfg.pair_predicate_limit) if enabled else 0, + } + + def _cap_temporal_scan_k( + self, + candidate_k: int, + temporal: Optional[TemporalQueryOptions], + ) -> int: + """对 temporal 模式候选召回数应用 max_scan 上限。""" + k = max(1, int(candidate_k)) + if temporal and temporal.max_scan and temporal.max_scan > 0: + k = min(k, int(temporal.max_scan)) + return max(1, k) + + def _is_valid_embedding(self, emb: Optional[np.ndarray]) -> bool: + if emb is None: + return False + arr = np.asarray(emb, dtype=np.float32) + if arr.ndim == 0 or arr.size == 0: + return False + return bool(np.all(np.isfinite(arr))) + + def _get_embedding_dim(self, emb: Optional[np.ndarray]) -> Optional[int]: + if emb is None: + return None + arr = np.asarray(emb) + if arr.ndim == 1: + return int(arr.shape[0]) if arr.size > 0 else None + if arr.ndim == 2: + if arr.shape[0] == 0: + return None + return int(arr.shape[1]) + return None + + def _is_embedding_dimension_compatible(self, emb: Optional[np.ndarray]) -> bool: + got_dim = self._get_embedding_dim(emb) + expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) + if got_dim is None or expected_dim <= 0: + return False + return got_dim == expected_dim + + def _is_embedding_ready_for_vector_search( + self, + emb: Optional[np.ndarray], + *, + stage: str, + ) -> bool: + if not self._is_valid_embedding(emb): + return False + if self._is_embedding_dimension_compatible(emb): + return True + + expected_dim = int(getattr(self.vector_store, "dimension", 0) or 0) + got_dim = self._get_embedding_dim(emb) + logger.warning( + "metric.embedding_dim_mismatch_fallback_count=1 " + f"stage={stage} expected_dim={expected_dim} got_dim={got_dim}" + ) + return False + + def _should_use_sparse( + self, + embedding_ok: bool, + vector_results: Optional[List[RetrievalResult]] = None, + ) -> bool: + if not self.config.sparse.enabled or self.sparse_index is None: + return False + + mode = self.config.sparse.mode + if mode == "hybrid": + return True + if mode == "fallback_only": + return not embedding_ok + # auto + if not embedding_ok: + return True + if not vector_results: + return True + best = max((float(r.score) for r in vector_results), default=0.0) + return best < 0.45 + + def _should_use_sparse_relations( + self, + embedding_ok: bool, + relation_results: Optional[List[RetrievalResult]] = None, + force_enable: bool = False, + ) -> bool: + if force_enable and self.config.sparse.enabled and self.sparse_index is not None: + return True + if not self.config.sparse.enable_relation_sparse_fallback: + return False + return self._should_use_sparse(embedding_ok, relation_results) + + def _normalize_scores_minmax(self, results: List[RetrievalResult]) -> None: + if not results: + return + vals = [float(r.score) for r in results] + lo = min(vals) + hi = max(vals) + if hi - lo < 1e-12: + for r in results: + r.score = 1.0 + return + for r in results: + r.score = (float(r.score) - lo) / (hi - lo) + + def _build_minmax_score_map(self, results: List[RetrievalResult]) -> Dict[str, float]: + if not results: + return {} + vals = [float(r.score) for r in results] + lo = min(vals) + hi = max(vals) + if hi - lo < 1e-12: + return {r.hash_value: 1.0 for r in results} + return { + r.hash_value: (float(r.score) - lo) / (hi - lo) + for r in results + } + + @staticmethod + def _clone_retrieval_result(item: RetrievalResult) -> RetrievalResult: + return RetrievalResult( + hash_value=item.hash_value, + content=item.content, + score=float(item.score), + result_type=item.result_type, + source=item.source, + metadata=dict(item.metadata or {}), + ) + + def _extract_graph_seed_entities(self, query: str, limit: int = 2) -> List[str]: + entities = self._extract_entities(query) + if not entities: + return [] + ranked = sorted( + entities.items(), + key=lambda x: (-float(x[1]), -len(str(x[0])), str(x[0]).lower()), + ) + return [str(name) for name, _ in ranked[: max(0, int(limit))]] + + def _search_relations_graph( + self, + query: str, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + service = getattr(self, "_graph_relation_recall", None) + if service is None or not bool(getattr(self.config.graph_recall, "enabled", True)): + return [] + + seed_entities = self._extract_graph_seed_entities(query, limit=2) + if not seed_entities: + return [] + + payloads = service.recall(seed_entities=seed_entities) + results: List[RetrievalResult] = [] + for payload in payloads: + meta = payload.to_payload() + results.append( + RetrievalResult( + hash_value=str(meta["hash"]), + content=str(meta["content"]), + score=0.0, + result_type="relation", + source="graph_relation_recall", + metadata={ + "subject": meta["subject"], + "predicate": meta["predicate"], + "object": meta["object"], + "confidence": float(meta["confidence"]), + "graph_seed_entities": list(meta["graph_seed_entities"]), + "graph_hops": int(meta["graph_hops"]), + "graph_candidate_type": str(meta["graph_candidate_type"]), + "supporting_paragraph_count": int(meta["supporting_paragraph_count"]), + }, + ) + ) + return self._apply_temporal_filter_to_relations(results, temporal) + + def _fuse_ranked_lists_weighted_rrf( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """按 weighted RRF 融合两路段落召回。""" + if not vector_results: + out = sparse_results[:] + if self.config.fusion.normalize_score: + self._normalize_scores_minmax(out) + return out + if not sparse_results: + out = vector_results[:] + if self.config.fusion.normalize_score: + self._normalize_scores_minmax(out) + return out + + k = self.config.fusion.rrf_k + w_vec = self.config.fusion.vector_weight + w_sparse = self.config.fusion.bm25_weight + merged: Dict[str, RetrievalResult] = {} + score_map: Dict[str, float] = {} + + for rank, item in enumerate(vector_results, start=1): + h = item.hash_value + if h not in merged: + merged[h] = item + merged[h].source = "fusion_rrf" + score_map[h] = score_map.get(h, 0.0) + w_vec * (1.0 / (k + rank)) + + for rank, item in enumerate(sparse_results, start=1): + h = item.hash_value + if h not in merged: + merged[h] = item + merged[h].source = "fusion_rrf" + score_map[h] = score_map.get(h, 0.0) + w_sparse * (1.0 / (k + rank)) + + out = list(merged.values()) + for item in out: + item.score = float(score_map.get(item.hash_value, 0.0)) + + out.sort(key=lambda x: x.score, reverse=True) + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(out) + return out + + def _search_paragraphs_sparse( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """BM25 段落召回。""" + if not self.sparse_index or not self.config.sparse.enabled: + return [] + + candidate_k = max(top_k, self.config.sparse.candidate_k) + candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) + sparse_rows = self.sparse_index.search(query=query, k=candidate_k) + results: List[RetrievalResult] = [] + for row in sparse_rows: + hash_value = row["hash"] + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + results.append( + RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(row.get("score", 0.0)), + result_type="paragraph", + source="sparse_bm25", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + "bm25_score": float(row.get("bm25_score", 0.0)), + }, + ) + ) + results = self._apply_temporal_filter_to_paragraphs(results, temporal) + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(results) + return results + + def _search_relations_sparse( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """关系 BM25 召回。""" + if not self.sparse_index or not self.config.sparse.enabled: + return [] + if not self.config.sparse.enable_relation_sparse_fallback: + return [] + + candidate_k = max(top_k, self.config.sparse.relation_candidate_k) + candidate_k = self._cap_temporal_scan_k(candidate_k, temporal) + rows = self.sparse_index.search_relations(query=query, k=candidate_k) + results: List[RetrievalResult] = [] + for row in rows: + hash_value = row["hash"] + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + continue + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) + if relation_time_meta is None: + continue + + content = f"{relation['subject']} {relation['predicate']} {relation['object']}" + results.append( + RetrievalResult( + hash_value=hash_value, + content=content, + score=float(row.get("score", 0.0)), + result_type="relation", + source="sparse_relation_bm25", + metadata={ + "subject": relation["subject"], + "predicate": relation["predicate"], + "object": relation["object"], + "confidence": relation.get("confidence", 1.0), + "time_meta": relation_time_meta, + "bm25_score": float(row.get("bm25_score", 0.0)), + }, + ) + ) + + if self.config.fusion.normalize_score and self.config.fusion.normalize_method == "minmax": + self._normalize_scores_minmax(results) + return self._apply_temporal_filter_to_relations(results, temporal) + + def _merge_relation_results( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """合并关系候选,按 hash 去重并保留更高分。""" + merged: Dict[str, RetrievalResult] = {} + for item in vector_results: + merged[item.hash_value] = item + for item in sparse_results: + old = merged.get(item.hash_value) + if old is None or float(item.score) > float(old.score): + merged[item.hash_value] = item + elif old is not None and old.source != item.source: + old.source = "relation_fusion" + out = list(merged.values()) + out.sort(key=lambda x: x.score, reverse=True) + return out + + def _merge_relation_results_graph_enhanced( + self, + vector_results: List[RetrievalResult], + sparse_results: List[RetrievalResult], + graph_results: List[RetrievalResult], + ) -> List[RetrievalResult]: + """Graph-aware relation fusion with semantic + graph + evidence scoring.""" + vector_norm = self._build_minmax_score_map(vector_results) + sparse_norm = self._build_minmax_score_map(sparse_results) + graph_score_map = { + "direct_pair": 1.0, + "one_hop_seed": 0.75, + "two_hop_pair": 0.55, + } + + merged: Dict[str, RetrievalResult] = {} + source_sets: Dict[str, set[str]] = {} + support_cache: Dict[str, int] = {} + + for group in (vector_results, sparse_results, graph_results): + for item in group: + existing = merged.get(item.hash_value) + if existing is None: + existing = self._clone_retrieval_result(item) + merged[item.hash_value] = existing + else: + for key, value in dict(item.metadata or {}).items(): + if key not in existing.metadata or existing.metadata.get(key) in (None, "", []): + existing.metadata[key] = value + source_sets.setdefault(item.hash_value, set()).add(str(item.source or "").strip() or "relation_search") + + out = list(merged.values()) + for item in out: + meta = item.metadata if isinstance(item.metadata, dict) else {} + semantic_norm = max( + float(vector_norm.get(item.hash_value, 0.0)), + float(sparse_norm.get(item.hash_value, 0.0)), + ) + graph_candidate_type = str(meta.get("graph_candidate_type", "") or "") + graph_score = float(graph_score_map.get(graph_candidate_type, 0.0)) + + if item.hash_value not in support_cache: + cached = meta.get("supporting_paragraph_count") + if cached is None: + support_cache[item.hash_value] = len( + self.metadata_store.get_paragraphs_by_relation(item.hash_value) + ) + else: + support_cache[item.hash_value] = max(0, int(cached)) + supporting_paragraph_count = support_cache[item.hash_value] + evidence_score = min(1.0, supporting_paragraph_count / 3.0) + + meta["supporting_paragraph_count"] = supporting_paragraph_count + meta["graph_seed_entities"] = list(meta.get("graph_seed_entities") or []) + if "graph_hops" in meta: + meta["graph_hops"] = int(meta.get("graph_hops") or 0) + item.score = 0.60 * semantic_norm + 0.30 * graph_score + 0.10 * evidence_score + + sources = source_sets.get(item.hash_value, set()) + if len(sources) > 1: + item.source = "relation_fusion" + elif sources: + item.source = next(iter(sources)) + + out.sort(key=lambda x: x.score, reverse=True) + return out + + async def _retrieve_paragraphs_only( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 仅检索段落(异步方法) + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 检索结果列表 + """ + query_emb = None + embedding_ok = False + vector_results: List[RetrievalResult] = [] + + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="paragraph_only", + ) + except Exception as e: + logger.warning(f"段落检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + if embedding_ok: + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * 2 * multiplier, temporal) + para_ids, para_scores = self.vector_store.search( + query_emb, # type: ignore[arg-type] + k=candidate_k, + ) + + for hash_value, score in zip(para_ids, para_scores): + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + time_meta = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + vector_results.append( + RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(score), + result_type="paragraph", + source="paragraph_search", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + }, + ) + ) + vector_results = self._apply_temporal_filter_to_paragraphs(vector_results, temporal) + + sparse_results: List[RetrievalResult] = [] + if self._should_use_sparse(embedding_ok, vector_results): + sparse_results = self._search_paragraphs_sparse(query, top_k, temporal=temporal) + + if self.config.fusion.method == "weighted_rrf" and (vector_results and sparse_results): + results = self._fuse_ranked_lists_weighted_rrf(vector_results, sparse_results) + elif vector_results and sparse_results: + results = vector_results + sparse_results + results.sort(key=lambda x: x.score, reverse=True) + else: + results = vector_results if vector_results else sparse_results + + return results[:top_k] + + async def _retrieve_relations_only( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 仅检索关系 (通过实体枢纽 Entity-Pivot) + + 策略: + 1. 检索向量库中的 Top-K 实体 (Entity) + 2. 通过图结构/元数据扩展出与实体关联的关系 (Relation) + 3. 以实体相似度作为基础分返回关系 + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 检索结果列表 + """ + query_emb = None + embedding_ok = False + vector_results: List[RetrievalResult] = [] + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="relation_only", + ) + except Exception as e: + logger.warning(f"关系检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + if embedding_ok: + # 1. 检索向量 (混合了段落和实体,所以扩大检索范围以召回足够多实体) + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * 3 * multiplier, temporal) + ids, scores = self.vector_store.search( + query_emb, # type: ignore[arg-type] + k=candidate_k, + ) + + seen_relations = set() + for hash_value, score in zip(ids, scores): + entity = self.metadata_store.get_entity(hash_value) + if not entity: + continue + entity_name = entity["name"] + + related_rels = [] + related_rels.extend(self.metadata_store.get_relations(subject=entity_name)) + related_rels.extend(self.metadata_store.get_relations(object=entity_name)) + + for rel in related_rels: + if rel["hash"] in seen_relations: + continue + seen_relations.add(rel["hash"]) + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(rel["hash"], temporal) + if relation_time_meta is None: + continue + + content = f"{rel['subject']} {rel['predicate']} {rel['object']}" + vector_results.append( + RetrievalResult( + hash_value=rel["hash"], + content=content, + score=float(score), + result_type="relation", + source="relation_search (via entity)", + metadata={ + "subject": rel["subject"], + "predicate": rel["predicate"], + "object": rel["object"], + "confidence": rel.get("confidence", 1.0), + "pivot_entity": entity_name, + "time_meta": relation_time_meta, + }, + ) + ) + + vector_results = self._apply_temporal_filter_to_relations(vector_results, temporal) + + sparse_results: List[RetrievalResult] = [] + if self._should_use_sparse_relations(embedding_ok, vector_results): + sparse_results = self._search_relations_sparse(query=query, top_k=top_k, temporal=temporal) + + graph_results = self._search_relations_graph(query=query, temporal=temporal) + if graph_results: + results = self._merge_relation_results_graph_enhanced( + vector_results, + sparse_results, + graph_results, + ) + elif vector_results and sparse_results: + results = self._merge_relation_results(vector_results, sparse_results) + else: + results = vector_results if vector_results else sparse_results + + return results[:top_k] + + async def _retrieve_dual_path( + self, + query: str, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + relation_intent: Optional[Dict[str, Any]] = None, + ) -> List[RetrievalResult]: + """ + 双路检索(段落+关系)(异步方法) + + Args: + query: 查询文本 + top_k: 返回数量 + + Returns: + 融合后的检索结果列表 + """ + query_emb = None + embedding_ok = False + relation_intent = relation_intent or {} + relation_top_k = max( + 1, + int(relation_intent.get("relation_top_k", self.config.top_k_relations)), + ) + force_relation_sparse = bool(relation_intent.get("force_relation_sparse", False)) + preserve_top_relations = max( + 0, + int(relation_intent.get("preserve_top_relations", 0)), + ) + pair_predicate_rerank_enabled = bool( + relation_intent.get("pair_predicate_rerank_enabled", False) + ) + pair_predicate_limit = max( + 1, + int( + relation_intent.get( + "pair_predicate_limit", + self.config.relation_intent.pair_predicate_limit, + ) + ), + ) + alpha_override = relation_intent.get("alpha_override") + try: + query_emb = await self.embedding_manager.encode(query) + embedding_ok = self._is_embedding_ready_for_vector_search( + query_emb, + stage="dual_path", + ) + except Exception as e: + logger.warning(f"双路检索 embedding 生成失败,将尝试 sparse 回退: {e}") + + para_results: List[RetrievalResult] = [] + rel_results: List[RetrievalResult] = [] + if embedding_ok: + # 并行检索(使用 asyncio) + if self.config.enable_parallel: + para_results, rel_results = await self._parallel_retrieve( + query_emb, + temporal=temporal, + relation_top_k=relation_top_k, + ) # type: ignore[arg-type] + else: + para_results, rel_results = self._sequential_retrieve( + query_emb, + temporal=temporal, + relation_top_k=relation_top_k, + ) # type: ignore[arg-type] + else: + logger.warning("embedding 不可用,跳过向量段落/关系召回") + + sparse_para_results: List[RetrievalResult] = [] + if self._should_use_sparse(embedding_ok, para_results): + sparse_para_results = self._search_paragraphs_sparse( + query=query, + top_k=max(top_k * 2, self.config.sparse.candidate_k), + temporal=temporal, + ) + sparse_rel_results: List[RetrievalResult] = [] + if self._should_use_sparse_relations( + embedding_ok, + rel_results, + force_enable=force_relation_sparse, + ): + sparse_rel_results = self._search_relations_sparse( + query=query, + top_k=max( + top_k, + self.config.sparse.relation_candidate_k, + relation_top_k, + ), + temporal=temporal, + ) + + graph_rel_results: List[RetrievalResult] = [] + if bool(relation_intent.get("enabled", False)): + graph_rel_results = self._search_relations_graph(query=query, temporal=temporal) + + if self.config.fusion.method == "weighted_rrf" and para_results and sparse_para_results: + para_results = self._fuse_ranked_lists_weighted_rrf(para_results, sparse_para_results) + elif para_results and sparse_para_results: + para_results = para_results + sparse_para_results + para_results.sort(key=lambda x: x.score, reverse=True) + elif sparse_para_results and (not para_results or not embedding_ok): + para_results = sparse_para_results + + if graph_rel_results: + rel_results = self._merge_relation_results_graph_enhanced( + rel_results, + sparse_rel_results, + graph_rel_results, + ) + elif rel_results and sparse_rel_results: + rel_results = self._merge_relation_results(rel_results, sparse_rel_results) + elif sparse_rel_results and (not rel_results or not embedding_ok): + rel_results = sparse_rel_results + + # 融合结果 + fused_results = self._fuse_results( + para_results, + rel_results, + query_emb, + alpha_override=alpha_override, + preserve_top_relations=preserve_top_relations, + ) + + # PageRank重排序 + if self.config.enable_ppr: + fused_results = await self._rerank_with_ppr( + fused_results, + query, + ) + + if temporal: + fused_results = self._sort_results_with_temporal(fused_results, temporal) + + fused_results = self._apply_relation_intent_pair_rerank( + fused_results, + enabled=bool(relation_intent.get("enabled", False)), + pair_rerank_enabled=pair_predicate_rerank_enabled, + pair_limit=pair_predicate_limit, + ) + + return fused_results[:top_k] + + async def _parallel_retrieve( + self, + query_emb: np.ndarray, + temporal: Optional[TemporalQueryOptions] = None, + relation_top_k: Optional[int] = None, + ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: + """ + 并行检索段落和关系(异步方法) + + Args: + query_emb: 查询嵌入 + + Returns: + (段落结果, 关系结果) + """ + # 使用 asyncio.gather 并发执行两个搜索任务 + # 由于 _search_paragraphs 和 _search_relations 是 CPU 密集型同步函数, + # 使用 asyncio.to_thread 在线程池中执行 + try: + para_task = asyncio.to_thread( + self._search_paragraphs, + query_emb, + self.config.top_k_paragraphs, + temporal, + ) + rel_task = asyncio.to_thread( + self._search_relations, + query_emb, + relation_top_k if relation_top_k is not None else self.config.top_k_relations, + temporal, + ) + + para_results, rel_results = await asyncio.gather( + para_task, rel_task, return_exceptions=True + ) + + # 处理异常 + if isinstance(para_results, Exception): + logger.error(f"段落检索失败: {para_results}") + para_results = [] + if isinstance(rel_results, Exception): + logger.error(f"关系检索失败: {rel_results}") + rel_results = [] + + return para_results, rel_results + + except Exception as e: + logger.error(f"并行检索失败: {e}") + return [], [] + + def _sequential_retrieve( + self, + query_emb: np.ndarray, + temporal: Optional[TemporalQueryOptions] = None, + relation_top_k: Optional[int] = None, + ) -> Tuple[List[RetrievalResult], List[RetrievalResult]]: + """ + 顺序检索段落和关系 + + Args: + query_emb: 查询嵌入 + + Returns: + (段落结果, 关系结果) + """ + para_results = self._search_paragraphs( + query_emb, + self.config.top_k_paragraphs, + temporal, + ) + + rel_results = self._search_relations( + query_emb, + relation_top_k if relation_top_k is not None else self.config.top_k_relations, + temporal, + ) + + return para_results, rel_results + + def _search_paragraphs( + self, + query_emb: np.ndarray, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 搜索段落 + + Args: + query_emb: 查询嵌入 + top_k: 返回数量 + + Returns: + 段落结果列表 + """ + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) + para_ids, para_scores = self.vector_store.search(query_emb, k=candidate_k) + + results = [] + for hash_value, score in zip(para_ids, para_scores): + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is None: + continue + + time_meta = self._build_time_meta_from_paragraph( + paragraph, + temporal=temporal, + ) + results.append(RetrievalResult( + hash_value=hash_value, + content=paragraph["content"], + score=float(score), + result_type="paragraph", + source="paragraph_search", + metadata={ + "word_count": paragraph.get("word_count", 0), + "time_meta": time_meta, + }, + )) + + return self._apply_temporal_filter_to_paragraphs(results, temporal) + + def _search_relations( + self, + query_emb: np.ndarray, + top_k: int, + temporal: Optional[TemporalQueryOptions] = None, + ) -> List[RetrievalResult]: + """ + 搜索关系 + + Args: + query_emb: 查询嵌入 + top_k: 返回数量 + + Returns: + 关系结果列表 + """ + multiplier = max(1, temporal.candidate_multiplier) if temporal else 1 + candidate_k = self._cap_temporal_scan_k(top_k * multiplier, temporal) + rel_ids, rel_scores = self.vector_store.search(query_emb, k=candidate_k) + + results = [] + for hash_value, score in zip(rel_ids, rel_scores): + relation = self.metadata_store.get_relation(hash_value) + if relation is None: + continue + + relation_time_meta = None + if temporal: + relation_time_meta = self._best_supporting_time_meta(hash_value, temporal) + if relation_time_meta is None: + continue + + content = f"{relation['subject']} {relation['predicate']} {relation['object']}" + + results.append(RetrievalResult( + hash_value=hash_value, + content=content, + score=float(score), + result_type="relation", + source="relation_search", + metadata={ + "subject": relation["subject"], + "predicate": relation["predicate"], + "object": relation["object"], + "confidence": relation.get("confidence", 1.0), + "time_meta": relation_time_meta, + }, + )) + + return self._apply_temporal_filter_to_relations(results, temporal) + + def _fuse_results( + self, + para_results: List[RetrievalResult], + rel_results: List[RetrievalResult], + query_emb: Optional[np.ndarray] = None, + alpha_override: Optional[float] = None, + preserve_top_relations: int = 0, + ) -> List[RetrievalResult]: + """ + 融合段落和关系结果 + + 融合策略: + 1. 计算加权分数 + 2. 去重(基于段落和关系的关联) + 3. 排序 + + Args: + para_results: 段落结果 + rel_results: 关系结果 + query_emb: 查询嵌入(兼容参数,当前未使用) + + Returns: + 融合后的结果列表 + """ + del query_emb # 参数保留用于兼容 + alpha = float(alpha_override) if alpha_override is not None else self.config.alpha + + # 为段落结果计算加权分数 + for result in para_results: + result.score = result.score * alpha + result.source = "fusion" + + # 为关系结果计算加权分数 + for result in rel_results: + result.score = result.score * (1 - alpha) + result.source = "fusion" + + preserve_top_relations = max(0, int(preserve_top_relations)) + preserved_relation_hashes = set() + if preserve_top_relations > 0 and rel_results: + rel_ranked = sorted(rel_results, key=lambda x: x.score, reverse=True) + preserved_relation_hashes = { + item.hash_value for item in rel_ranked[:preserve_top_relations] + } + + # 合并结果 + all_results = para_results + rel_results + all_results.sort(key=lambda x: x.score, reverse=True) + + # 去重:如果段落有关联的关系,只保留分数更高的 + seen_paragraphs = set() + seen_items = set() + deduplicated_results = [] + + for result in all_results: + if result.hash_value in seen_items: + continue + if result.result_type == "paragraph": + hash_val = result.hash_value + if hash_val not in seen_paragraphs: + seen_paragraphs.add(hash_val) + seen_items.add(hash_val) + deduplicated_results.append(result) + else: # relation + if result.hash_value in preserved_relation_hashes: + seen_items.add(result.hash_value) + deduplicated_results.append(result) + continue + # 检查关系关联的段落是否已存在 + relation = self.metadata_store.get_relation(result.hash_value) + if relation: + # 获取关联的段落 + para_rels = self.metadata_store.query(""" + SELECT paragraph_hash FROM paragraph_relations + WHERE relation_hash = ? + """, (result.hash_value,)) + + if para_rels: + # 检查段落是否已在结果中 + for para_rel in para_rels: + if para_rel["paragraph_hash"] in seen_paragraphs: + # 段落已存在,跳过此关系 + break + else: + # 所有段落都不存在,添加关系 + seen_items.add(result.hash_value) + deduplicated_results.append(result) + else: + # 没有关联段落,直接添加 + seen_items.add(result.hash_value) + deduplicated_results.append(result) + else: + seen_items.add(result.hash_value) + deduplicated_results.append(result) + + # 按分数排序 + deduplicated_results.sort(key=lambda x: x.score, reverse=True) + + return deduplicated_results + + def _apply_relation_intent_pair_rerank( + self, + results: List[RetrievalResult], + *, + enabled: bool, + pair_rerank_enabled: bool, + pair_limit: int, + ) -> List[RetrievalResult]: + """仅在 relation-intent 下对关系项执行同主客体多谓词重排。""" + if not enabled or not pair_rerank_enabled: + return results + return self._rerank_relation_items_by_pair(results, pair_limit=pair_limit) + + def _rerank_relation_items_by_pair( + self, + results: List[RetrievalResult], + pair_limit: int, + ) -> List[RetrievalResult]: + """ + 同主客体多谓词重排: + 1. 关系项按 (subject, object) 分组 + 2. 组内按分数降序 + 原始位置升序 + 3. 组间按组最高分降序 + 组最早位置升序 + 4. 先拼接每组前 N 条,再拼接每组 overflow 条目 + 5. 回填到原关系槽位,段落槽位不变 + """ + if len(results) <= 1: + return results + + relation_positions: List[int] = [] + relation_items: List[Tuple[int, RetrievalResult]] = [] + for idx, item in enumerate(results): + if item.result_type == "relation": + relation_positions.append(idx) + relation_items.append((idx, item)) + + if len(relation_items) <= 1: + return results + + pair_limit = max(1, int(pair_limit)) + + grouped: Dict[Tuple[str, str], List[Tuple[int, RetrievalResult]]] = {} + for original_idx, item in relation_items: + metadata = item.metadata if isinstance(item.metadata, dict) else {} + subject = str(metadata.get("subject", "")).strip().lower() + obj = str(metadata.get("object", "")).strip().lower() + if subject and obj: + key = (subject, obj) + else: + key = ("__missing__", item.hash_value) + grouped.setdefault(key, []).append((original_idx, item)) + + for grouped_items in grouped.values(): + grouped_items.sort(key=lambda x: (-float(x[1].score), x[0])) + + ordered_groups = sorted( + grouped.values(), + key=lambda grouped_items: ( + -float(grouped_items[0][1].score), + grouped_items[0][0], + ), + ) + + prioritized: List[RetrievalResult] = [] + overflow: List[RetrievalResult] = [] + for grouped_items in ordered_groups: + prioritized.extend([item for _, item in grouped_items[:pair_limit]]) + overflow.extend([item for _, item in grouped_items[pair_limit:]]) + + reordered_relations = prioritized + overflow + if len(reordered_relations) != len(relation_items): + return results + + logger.debug( + "relation_rerank_applied=1 relation_pair_groups=%s relation_pair_overflow_count=%s relation_pair_limit=%s", + len(ordered_groups), + len(overflow), + pair_limit, + ) + + rebuilt = list(results) + for slot_idx, relation_item in zip(relation_positions, reordered_relations): + rebuilt[slot_idx] = relation_item + return rebuilt + + async def _rerank_with_ppr( + self, + results: List[RetrievalResult], + query: str, + ) -> List[RetrievalResult]: + """ + 使用PageRank重排序结果 (异步 + 线程池) + + Args: + results: 检索结果 + query: 查询文本 + + Returns: + 重排序后的结果 + """ + # 从查询中提取实体 + entities = self._extract_entities(query) + + if not entities: + logger.debug("未识别到实体,跳过PPR重排序") + return results + + # 计算PPR分数 (放入线程池运行,避免阻塞主循环) + ppr_timeout_s = max(0.1, float(getattr(self.config, "ppr_timeout_seconds", 1.5) or 1.5)) + try: + async with self._ppr_semaphore: + ppr_scores = await asyncio.wait_for( + asyncio.to_thread( + self._ppr.compute, + personalization=entities, + normalize=True, + ), + timeout=ppr_timeout_s, + ) + except asyncio.TimeoutError: + logger.warning( + "metric.ppr_timeout_skip_count=1 timeout_s=%s entities=%s", + ppr_timeout_s, + len(entities), + ) + return results + except Exception as e: + logger.warning(f"PPR 重排序失败,回退原排序: {e}") + return results + + # 调整结果分数 + ppr_scores_by_name = { + str(name).strip().lower(): float(score) + for name, score in ppr_scores.items() + } + for result in results: + if result.result_type == "paragraph": + # 获取段落的实体 + para_entities = self.metadata_store.get_paragraph_entities( + result.hash_value + ) + + # 计算实体的平均PPR分数 + if para_entities: + entity_scores = [] + for ent in para_entities: + ent_name = str(ent.get("name", "")).strip().lower() + if ent_name in ppr_scores_by_name: + entity_scores.append(ppr_scores_by_name[ent_name]) + + if entity_scores: + avg_ppr = np.mean(entity_scores) + # 融合原始分数和PPR分数 + result.score = result.score * 0.7 + avg_ppr * 0.3 + + # 重新排序 + results.sort(key=lambda x: x.score, reverse=True) + + return results + + def _retrieve_temporal_only( + self, + temporal: TemporalQueryOptions, + top_k: int, + ) -> List[RetrievalResult]: + """无语义 query 时,直接走时序索引查询。""" + limit = self._cap_temporal_scan_k( + top_k * max(1, temporal.candidate_multiplier), + temporal, + ) + paragraphs = self.metadata_store.query_paragraphs_temporal( + start_ts=temporal.time_from, + end_ts=temporal.time_to, + person=temporal.person, + source=temporal.source, + limit=limit, + allow_created_fallback=temporal.allow_created_fallback, + ) + results: List[RetrievalResult] = [] + for para in paragraphs: + time_meta = self._build_time_meta_from_paragraph(para, temporal=temporal) + results.append( + RetrievalResult( + hash_value=para["hash"], + content=para["content"], + score=1.0, + result_type="paragraph", + source="temporal_scan", + metadata={ + "word_count": para.get("word_count", 0), + "time_meta": time_meta, + }, + ) + ) + + results = self._sort_results_with_temporal(results, temporal) + return results[:top_k] + + def _extract_effective_time( + self, + paragraph: Dict[str, Any], + temporal: Optional[TemporalQueryOptions] = None, + ) -> Tuple[Optional[float], Optional[float], Optional[str]]: + """提取段落有效时间区间与命中依据。""" + event_time = paragraph.get("event_time") + event_start = paragraph.get("event_time_start") + event_end = paragraph.get("event_time_end") + + if event_start is not None or event_end is not None: + effective_start = event_start if event_start is not None else ( + event_time if event_time is not None else event_end + ) + effective_end = event_end if event_end is not None else ( + event_time if event_time is not None else event_start + ) + return effective_start, effective_end, "event_time_range" + + if event_time is not None: + return event_time, event_time, "event_time" + + allow_fallback = True + if temporal is not None: + allow_fallback = temporal.allow_created_fallback + + created_at = paragraph.get("created_at") + if allow_fallback and created_at is not None: + return created_at, created_at, "created_at_fallback" + + return None, None, None + + def _build_time_meta_from_paragraph( + self, + paragraph: Dict[str, Any], + temporal: Optional[TemporalQueryOptions] = None, + ) -> Dict[str, Any]: + """构建统一 time_meta 结构。""" + effective_start, effective_end, match_basis = self._extract_effective_time( + paragraph, + temporal=temporal, + ) + return { + "event_time": paragraph.get("event_time"), + "event_time_start": paragraph.get("event_time_start"), + "event_time_end": paragraph.get("event_time_end"), + "ingest_time": paragraph.get("created_at"), + "time_granularity": paragraph.get("time_granularity"), + "time_confidence": paragraph.get("time_confidence", 1.0), + "effective_start": effective_start, + "effective_end": effective_end, + "effective_start_text": format_timestamp(effective_start), + "effective_end_text": format_timestamp(effective_end), + "match_basis": match_basis or "none", + } + + def _matches_person_filter(self, paragraph_hash: str, person: Optional[str]) -> bool: + if not person: + return True + target = person.strip().lower() + if not target: + return True + para_entities = self.metadata_store.get_paragraph_entities(paragraph_hash) + for ent in para_entities: + name = str(ent.get("name", "")).strip().lower() + if target in name: + return True + return False + + def _is_temporal_match( + self, + paragraph: Dict[str, Any], + temporal: TemporalQueryOptions, + ) -> bool: + """判断段落是否命中时序筛选。""" + if temporal.source and paragraph.get("source") != temporal.source: + return False + + if not self._matches_person_filter(paragraph.get("hash", ""), temporal.person): + return False + + effective_start, effective_end, _ = self._extract_effective_time(paragraph, temporal=temporal) + if effective_start is None or effective_end is None: + return False + + if temporal.time_from is not None and temporal.time_to is not None: + return effective_end >= temporal.time_from and effective_start <= temporal.time_to + if temporal.time_from is not None: + return effective_end >= temporal.time_from + if temporal.time_to is not None: + return effective_start <= temporal.time_to + return True + + def _apply_temporal_filter_to_paragraphs( + self, + results: List[RetrievalResult], + temporal: Optional[TemporalQueryOptions], + ) -> List[RetrievalResult]: + if not temporal: + return results + + filtered: List[RetrievalResult] = [] + for result in results: + paragraph = self.metadata_store.get_paragraph(result.hash_value) + if not paragraph: + continue + if not self._is_temporal_match(paragraph, temporal): + continue + result.metadata["time_meta"] = self._build_time_meta_from_paragraph(paragraph, temporal=temporal) + filtered.append(result) + + return self._sort_results_with_temporal(filtered, temporal) + + def _best_supporting_time_meta( + self, + relation_hash: str, + temporal: TemporalQueryOptions, + ) -> Optional[Dict[str, Any]]: + """获取关系在时序窗口内最优支撑段落的 time_meta。""" + supports = self.metadata_store.get_paragraphs_by_relation(relation_hash) + if not supports: + return None + + best_meta: Optional[Dict[str, Any]] = None + best_time = float("-inf") + for para in supports: + if not self._is_temporal_match(para, temporal): + continue + meta = self._build_time_meta_from_paragraph(para, temporal=temporal) + eff = meta.get("effective_end") + score = float(eff) if eff is not None else float("-inf") + if score >= best_time: + best_time = score + best_meta = meta + + return best_meta + + def _apply_temporal_filter_to_relations( + self, + results: List[RetrievalResult], + temporal: Optional[TemporalQueryOptions], + ) -> List[RetrievalResult]: + if not temporal: + return results + + filtered: List[RetrievalResult] = [] + for result in results: + meta = result.metadata.get("time_meta") + if meta is None: + meta = self._best_supporting_time_meta(result.hash_value, temporal) + if meta is None: + continue + result.metadata["time_meta"] = meta + filtered.append(result) + + return self._sort_results_with_temporal(filtered, temporal) + + def _sort_results_with_temporal( + self, + results: List[RetrievalResult], + temporal: TemporalQueryOptions, + ) -> List[RetrievalResult]: + """语义优先,时间次排序(新到旧)。""" + del temporal # temporal 保留给未来扩展,目前只使用结果内 time_meta + + def _temporal_key(item: RetrievalResult) -> float: + time_meta = item.metadata.get("time_meta", {}) + effective = time_meta.get("effective_end") + if effective is None: + effective = time_meta.get("effective_start") + if effective is None: + return float("-inf") + return float(effective) + + results.sort(key=lambda x: (x.score, _temporal_key(x)), reverse=True) + return results + + def _extract_entities(self, text: str) -> Dict[str, float]: + """ + 从文本中提取实体(简化版本) + + Args: + text: 输入文本 + + Returns: + 实体字典 {实体名: 权重} + """ + # 获取所有实体 + all_entities = self.graph_store.get_nodes() + if not all_entities: + return {} + + # 检查是否需要更新 Aho-Corasick 匹配器 + if self._ac_matcher is None or self._ac_nodes_count != len(all_entities): + self._ac_matcher = AhoCorasick() + for entity in all_entities: + self._ac_matcher.add_pattern(entity.lower()) + self._ac_matcher.build() + self._ac_nodes_count = len(all_entities) + + # 执行匹配 + text_lower = text.lower() + stats = self._ac_matcher.find_all(text_lower) + + # 映射回原始名称并使用出现次数作为权重 + node_map = {node.lower(): node for node in all_entities} + entities = {node_map[low_name]: float(count) for low_name, count in stats.items()} + + return entities + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取检索统计信息 + + Returns: + 统计信息字典 + """ + vector_size = getattr(self.vector_store, "size", None) + if vector_size is None: + vector_size = getattr(self.vector_store, "num_vectors", 0) + + return { + "config": { + "top_k_paragraphs": self.config.top_k_paragraphs, + "top_k_relations": self.config.top_k_relations, + "top_k_final": self.config.top_k_final, + "alpha": self.config.alpha, + "enable_ppr": self.config.enable_ppr, + "enable_parallel": self.config.enable_parallel, + "strategy": self.config.retrieval_strategy.value, + "sparse_mode": self.config.sparse.mode, + "fusion_method": self.config.fusion.method, + "relation_intent_enabled": self.config.relation_intent.enabled, + "relation_intent_alpha_override": self.config.relation_intent.alpha_override, + "relation_intent_candidate_multiplier": self.config.relation_intent.relation_candidate_multiplier, + "relation_intent_preserve_top_relations": self.config.relation_intent.preserve_top_relations, + "relation_intent_force_sparse": self.config.relation_intent.force_relation_sparse, + "relation_intent_pair_rerank_enabled": self.config.relation_intent.pair_predicate_rerank_enabled, + "relation_intent_pair_predicate_limit": self.config.relation_intent.pair_predicate_limit, + "graph_recall_enabled": self.config.graph_recall.enabled, + "graph_recall_candidate_k": self.config.graph_recall.candidate_k, + "graph_recall_allow_two_hop_pair": self.config.graph_recall.allow_two_hop_pair, + "graph_recall_max_paths": self.config.graph_recall.max_paths, + }, + "vector_store": { + "size": int(vector_size), + }, + "graph_store": { + "num_nodes": self.graph_store.num_nodes, + "num_edges": self.graph_store.num_edges, + }, + "metadata_store": self.metadata_store.get_statistics(), + "sparse": self.sparse_index.stats() if self.sparse_index else None, + } + + def __repr__(self) -> str: + return ( + f"DualPathRetriever(" + f"strategy={self.config.retrieval_strategy.value}, " + f"para_k={self.config.top_k_paragraphs}, " + f"rel_k={self.config.top_k_relations})" + ) diff --git a/plugins/A_memorix/core/retrieval/graph_relation_recall.py b/plugins/A_memorix/core/retrieval/graph_relation_recall.py new file mode 100644 index 00000000..3ce03b14 --- /dev/null +++ b/plugins/A_memorix/core/retrieval/graph_relation_recall.py @@ -0,0 +1,272 @@ +"""Graph-assisted relation candidate recall for relation-oriented queries.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Set + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.GraphRelationRecall") + + +@dataclass +class GraphRelationRecallConfig: + """Configuration for controlled graph relation recall.""" + + enabled: bool = True + candidate_k: int = 24 + max_hop: int = 1 + allow_two_hop_pair: bool = True + max_paths: int = 4 + + def __post_init__(self) -> None: + self.enabled = bool(self.enabled) + self.candidate_k = max(1, int(self.candidate_k)) + self.max_hop = max(1, int(self.max_hop)) + self.allow_two_hop_pair = bool(self.allow_two_hop_pair) + self.max_paths = max(1, int(self.max_paths)) + + +@dataclass +class GraphRelationCandidate: + """A graph-derived relation candidate before retriever-side fusion.""" + + hash_value: str + subject: str + predicate: str + object: str + confidence: float + graph_seed_entities: List[str] + graph_hops: int + graph_candidate_type: str + supporting_paragraph_count: int + + def to_payload(self) -> Dict[str, Any]: + content = f"{self.subject} {self.predicate} {self.object}" + return { + "hash": self.hash_value, + "content": content, + "subject": self.subject, + "predicate": self.predicate, + "object": self.object, + "confidence": self.confidence, + "graph_seed_entities": list(self.graph_seed_entities), + "graph_hops": int(self.graph_hops), + "graph_candidate_type": self.graph_candidate_type, + "supporting_paragraph_count": int(self.supporting_paragraph_count), + } + + +class GraphRelationRecallService: + """Collect relation candidates from the entity graph in a controlled way.""" + + def __init__( + self, + *, + graph_store: Any, + metadata_store: Any, + config: Optional[GraphRelationRecallConfig] = None, + ) -> None: + self.graph_store = graph_store + self.metadata_store = metadata_store + self.config = config or GraphRelationRecallConfig() + + def recall( + self, + *, + seed_entities: Sequence[str], + ) -> List[GraphRelationCandidate]: + if not self.config.enabled: + return [] + if self.graph_store is None or self.metadata_store is None: + return [] + + seeds = self._normalize_seed_entities(seed_entities) + if not seeds: + return [] + + seen_hashes: Set[str] = set() + candidates: List[GraphRelationCandidate] = [] + + if len(seeds) >= 2: + self._collect_direct_pair_candidates( + seed_a=seeds[0], + seed_b=seeds[1], + seen_hashes=seen_hashes, + out=candidates, + ) + if ( + len(candidates) < 3 + and self.config.allow_two_hop_pair + and len(candidates) < self.config.candidate_k + ): + self._collect_two_hop_pair_candidates( + seed_a=seeds[0], + seed_b=seeds[1], + seen_hashes=seen_hashes, + out=candidates, + ) + else: + self._collect_one_hop_seed_candidates( + seed=seeds[0], + seen_hashes=seen_hashes, + out=candidates, + ) + + return candidates[: self.config.candidate_k] + + def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]: + out: List[str] = [] + seen = set() + for raw in list(seed_entities)[:2]: + resolved = None + try: + resolved = self.graph_store.find_node(str(raw), ignore_case=True) + except Exception: + resolved = None + if not resolved: + continue + canon = str(resolved).strip().lower() + if not canon or canon in seen: + continue + seen.add(canon) + out.append(str(resolved)) + return out + + def _collect_direct_pair_candidates( + self, + *, + seed_a: str, + seed_b: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + relation_hashes = [] + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b)) + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a)) + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="direct_pair", + graph_hops=1, + graph_seed_entities=[seed_a, seed_b], + ) + + def _collect_two_hop_pair_candidates( + self, + *, + seed_a: str, + seed_b: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + try: + paths = self.graph_store.find_paths( + seed_a, + seed_b, + max_depth=2, + max_paths=self.config.max_paths, + ) + except Exception as e: + logger.debug("graph two-hop recall skipped: %s", e) + return + + for path_nodes in paths: + if len(out) >= self.config.candidate_k: + break + if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3: + continue + if len(path_nodes) != 3: + continue + for idx in range(len(path_nodes) - 1): + if len(out) >= self.config.candidate_k: + break + u = str(path_nodes[idx]) + v = str(path_nodes[idx + 1]) + relation_hashes = [] + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v)) + relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u)) + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="two_hop_pair", + graph_hops=2, + graph_seed_entities=[seed_a, seed_b], + ) + + def _collect_one_hop_seed_candidates( + self, + *, + seed: str, + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + ) -> None: + try: + relation_hashes = self.graph_store.get_incident_relation_hashes( + seed, + limit=self.config.candidate_k, + ) + except Exception as e: + logger.debug("graph one-hop recall skipped: %s", e) + return + self._append_relation_hashes( + relation_hashes=relation_hashes, + seen_hashes=seen_hashes, + out=out, + candidate_type="one_hop_seed", + graph_hops=min(1, self.config.max_hop), + graph_seed_entities=[seed], + ) + + def _append_relation_hashes( + self, + *, + relation_hashes: Sequence[str], + seen_hashes: Set[str], + out: List[GraphRelationCandidate], + candidate_type: str, + graph_hops: int, + graph_seed_entities: Sequence[str], + ) -> None: + for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}): + if len(out) >= self.config.candidate_k: + break + if relation_hash in seen_hashes: + continue + candidate = self._build_candidate( + relation_hash=relation_hash, + candidate_type=candidate_type, + graph_hops=graph_hops, + graph_seed_entities=graph_seed_entities, + ) + if candidate is None: + continue + seen_hashes.add(relation_hash) + out.append(candidate) + + def _build_candidate( + self, + *, + relation_hash: str, + candidate_type: str, + graph_hops: int, + graph_seed_entities: Sequence[str], + ) -> Optional[GraphRelationCandidate]: + relation = self.metadata_store.get_relation(relation_hash) + if relation is None: + return None + supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash) + return GraphRelationCandidate( + hash_value=relation_hash, + subject=str(relation.get("subject", "")), + predicate=str(relation.get("predicate", "")), + object=str(relation.get("object", "")), + confidence=float(relation.get("confidence", 1.0) or 1.0), + graph_seed_entities=[str(x) for x in graph_seed_entities], + graph_hops=int(graph_hops), + graph_candidate_type=str(candidate_type), + supporting_paragraph_count=len(supporting_paragraphs), + ) diff --git a/plugins/A_memorix/core/retrieval/pagerank.py b/plugins/A_memorix/core/retrieval/pagerank.py new file mode 100644 index 00000000..c8ee48bb --- /dev/null +++ b/plugins/A_memorix/core/retrieval/pagerank.py @@ -0,0 +1,482 @@ +""" +Personalized PageRank实现 + +提供个性化的图节点排序功能。 +""" + +from typing import Dict, List, Optional, Tuple, Union, Any +from dataclasses import dataclass +import numpy as np + +from src.common.logger import get_logger +from ..storage import GraphStore +from ..utils.matcher import AhoCorasick + +logger = get_logger("A_Memorix.PersonalizedPageRank") + + +@dataclass +class PageRankConfig: + """ + PageRank配置 + + 属性: + alpha: 阻尼系数(0-1之间) + max_iter: 最大迭代次数 + tol: 收敛阈值 + normalize: 是否归一化结果 + min_iterations: 最小迭代次数 + """ + + alpha: float = 0.85 + max_iter: int = 100 + tol: float = 1e-6 + normalize: bool = True + min_iterations: int = 20 + + def __post_init__(self): + """验证配置""" + if not 0 <= self.alpha < 1: + raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}") + + if self.max_iter <= 0: + raise ValueError(f"max_iter必须大于0: {self.max_iter}") + + if self.tol <= 0: + raise ValueError(f"tol必须大于0: {self.tol}") + + if self.min_iterations < 0: + raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}") + + if self.min_iterations >= self.max_iter: + raise ValueError(f"min_iterations必须小于max_iter") + + +class PersonalizedPageRank: + """ + Personalized PageRank计算器 + + 功能: + - 个性化向量支持 + - 快速收敛检测 + - 结果归一化 + - 批量计算 + - 统计信息 + + 参数: + graph_store: 图存储 + config: PageRank配置 + """ + + def __init__( + self, + graph_store: GraphStore, + config: Optional[PageRankConfig] = None, + ): + """ + 初始化PPR计算器 + + Args: + graph_store: 图存储 + config: PageRank配置 + """ + self.graph_store = graph_store + self.config = config or PageRankConfig() + + # 统计信息 + self._total_computations = 0 + self._total_iterations = 0 + self._convergence_history: List[int] = [] + + logger.info( + f"PersonalizedPageRank 初始化: " + f"alpha={self.config.alpha}, " + f"max_iter={self.config.max_iter}" + ) + + # 缓存 Aho-Corasick 匹配器 + self._ac_matcher: Optional[AhoCorasick] = None + self._ac_nodes_count = 0 + + def compute( + self, + personalization: Optional[Dict[str, float]] = None, + alpha: Optional[float] = None, + max_iter: Optional[int] = None, + normalize: Optional[bool] = None, + ) -> Dict[str, float]: + """ + 计算Personalized PageRank + + Args: + personalization: 个性化向量 {节点名: 权重} + alpha: 阻尼系数(覆盖配置值) + max_iter: 最大迭代次数(覆盖配置值) + normalize: 是否归一化(覆盖配置值) + + Returns: + 节点PageRank值字典 {节点名: 分数} + """ + # 使用覆盖值或配置值 + alpha = alpha if alpha is not None else self.config.alpha + max_iter = max_iter if max_iter is not None else self.config.max_iter + normalize = normalize if normalize is not None else self.config.normalize + + # 调用GraphStore的compute_pagerank + scores = self.graph_store.compute_pagerank( + personalization=personalization, + alpha=alpha, + max_iter=max_iter, + tol=self.config.tol, + ) + + # 归一化(如果需要) + if normalize and scores: + total = sum(scores.values()) + if total > 0: + scores = {node: score / total for node, score in scores.items()} + + # 更新统计 + self._total_computations += 1 + + logger.debug( + f"PPR计算完成: {len(scores)} 个节点, " + f"personalization_nodes={len(personalization) if personalization else 0}" + ) + + return scores + + def compute_batch( + self, + personalization_list: List[Dict[str, float]], + normalize: bool = True, + ) -> List[Dict[str, float]]: + """ + 批量计算PPR + + Args: + personalization_list: 个性化向量列表 + normalize: 是否归一化 + + Returns: + PageRank值字典列表 + """ + results = [] + + for i, personalization in enumerate(personalization_list): + logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR") + scores = self.compute(personalization=personalization, normalize=normalize) + results.append(scores) + + return results + + def compute_for_entities( + self, + entities: List[str], + weights: Optional[List[float]] = None, + normalize: bool = True, + ) -> Dict[str, float]: + """ + 为实体列表计算PPR + + Args: + entities: 实体列表 + weights: 权重列表(默认均匀权重) + normalize: 是否归一化 + + Returns: + PageRank值字典 + """ + if not entities: + logger.warning("实体列表为空,返回均匀PPR") + return self.compute(personalization=None, normalize=normalize) + + # 构建个性化向量 + if weights is None: + weights = [1.0] * len(entities) + + if len(weights) != len(entities): + raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}") + + personalization = {entity: weight for entity, weight in zip(entities, weights)} + + return self.compute(personalization=personalization, normalize=normalize) + + def compute_for_query( + self, + query: str, + entity_extractor: Optional[callable] = None, + normalize: bool = True, + ) -> Dict[str, float]: + """ + 为查询计算PPR + + Args: + query: 查询文本 + entity_extractor: 实体提取函数(可选) + normalize: 是否归一化 + + Returns: + PageRank值字典 + """ + # 提取实体 + if entity_extractor is not None: + entities = entity_extractor(query) + else: + # 简单实现:基于图中的节点匹配 + entities = self._extract_entities_from_query(query) + + if not entities: + logger.debug(f"未从查询中提取到实体: '{query}'") + return self.compute(personalization=None, normalize=normalize) + + # 计算PPR + return self.compute_for_entities(entities, normalize=normalize) + + def rank_nodes( + self, + scores: Dict[str, float], + top_k: Optional[int] = None, + min_score: float = 0.0, + ) -> List[Tuple[str, float]]: + """ + 对节点排序 + + Args: + scores: PageRank分数字典 + top_k: 返回前k个节点(None表示全部) + min_score: 最小分数阈值 + + Returns: + 排序后的节点列表 [(节点名, 分数), ...] + """ + # 过滤低分节点 + filtered = [(node, score) for node, score in scores.items() if score >= min_score] + + # 按分数降序排序 + sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True) + + # 返回top_k + if top_k is not None: + sorted_nodes = sorted_nodes[:top_k] + + return sorted_nodes + + def get_personalization_vector( + self, + nodes: List[str], + method: str = "uniform", + ) -> Dict[str, float]: + """ + 生成个性化向量 + + Args: + nodes: 节点列表 + method: 生成方法 + - "uniform": 均匀权重 + - "degree": 按度数加权 + - "inverse_degree": 按度数反比加权 + + Returns: + 个性化向量 {节点名: 权重} + """ + if not nodes: + return {} + + if method == "uniform": + # 均匀权重 + weight = 1.0 / len(nodes) + return {node: weight for node in nodes} + + elif method == "degree": + # 按度数加权 + node_degrees = {} + for node in nodes: + neighbors = self.graph_store.get_neighbors(node) + node_degrees[node] = len(neighbors) + + total_degree = sum(node_degrees.values()) + if total_degree > 0: + return {node: degree / total_degree for node, degree in node_degrees.items()} + else: + return {node: 1.0 / len(nodes) for node in nodes} + + elif method == "inverse_degree": + # 按度数反比加权 + node_degrees = {} + for node in nodes: + neighbors = self.graph_store.get_neighbors(node) + node_degrees[node] = len(neighbors) + + # 反度数 + inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()} + total_inv = sum(inv_degrees.values()) + + if total_inv > 0: + return {node: inv / total_inv for node, inv in inv_degrees.items()} + else: + return {node: 1.0 / len(nodes) for node in nodes} + + else: + raise ValueError(f"不支持的个性化向量生成方法: {method}") + + def compare_scores( + self, + scores1: Dict[str, float], + scores2: Dict[str, float], + ) -> Dict[str, Dict[str, float]]: + """ + 比较两组PPR分数 + + Args: + scores1: 第一组分数 + scores2: 第二组分数 + + Returns: + 比较结果 { + "common_nodes": {节点: (score1, score2)}, + "only_in_1": {节点: score1}, + "only_in_2": {节点: score2}, + } + """ + common_nodes = {} + only_in_1 = {} + only_in_2 = {} + + all_nodes = set(scores1.keys()) | set(scores2.keys()) + + for node in all_nodes: + if node in scores1 and node in scores2: + common_nodes[node] = (scores1[node], scores2[node]) + elif node in scores1: + only_in_1[node] = scores1[node] + else: + only_in_2[node] = scores2[node] + + return { + "common_nodes": common_nodes, + "only_in_1": only_in_1, + "only_in_2": only_in_2, + } + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + avg_iterations = ( + self._total_iterations / self._total_computations + if self._total_computations > 0 + else 0 + ) + + return { + "config": { + "alpha": self.config.alpha, + "max_iter": self.config.max_iter, + "tol": self.config.tol, + "normalize": self.config.normalize, + "min_iterations": self.config.min_iterations, + }, + "statistics": { + "total_computations": self._total_computations, + "total_iterations": self._total_iterations, + "avg_iterations": avg_iterations, + "convergence_history": self._convergence_history.copy(), + }, + "graph": { + "num_nodes": self.graph_store.num_nodes, + "num_edges": self.graph_store.num_edges, + }, + } + + def reset_statistics(self) -> None: + """重置统计信息""" + self._total_computations = 0 + self._total_iterations = 0 + self._convergence_history.clear() + logger.info("统计信息已重置") + + def _extract_entities_from_query(self, query: str) -> List[str]: + """ + 从查询中提取实体(简化实现) + + Args: + query: 查询文本 + + Returns: + 实体列表 + """ + # 获取所有节点 + all_nodes = self.graph_store.get_nodes() + if not all_nodes: + return [] + + # 检查是否需要更新 Aho-Corasick 匹配器 + if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes): + self._ac_matcher = AhoCorasick() + for node in all_nodes: + # 统一转为小写进行不区分大小写匹配 + self._ac_matcher.add_pattern(node.lower()) + self._ac_matcher.build() + self._ac_nodes_count = len(all_nodes) + + # 执行匹配 + query_lower = query.lower() + stats = self._ac_matcher.find_all(query_lower) + + # 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值) + # 为了简单,AC 中 add_pattern 存的是小写 + # 我们需要一个映射:小写 -> 原始 + node_map = {node.lower(): node for node in all_nodes} + entities = [node_map[low_name] for low_name in stats.keys()] + + return entities + + @property + def num_computations(self) -> int: + """计算次数""" + return self._total_computations + + @property + def avg_iterations(self) -> float: + """平均迭代次数""" + if self._total_computations == 0: + return 0.0 + return self._total_iterations / self._total_computations + + def __repr__(self) -> str: + return ( + f"PersonalizedPageRank(" + f"alpha={self.config.alpha}, " + f"computations={self._total_computations})" + ) + + +def create_ppr_from_graph( + graph_store: GraphStore, + alpha: float = 0.85, + max_iter: int = 100, +) -> PersonalizedPageRank: + """ + 从图存储创建PPR计算器 + + Args: + graph_store: 图存储 + alpha: 阻尼系数 + max_iter: 最大迭代次数 + + Returns: + PPR计算器实例 + """ + config = PageRankConfig( + alpha=alpha, + max_iter=max_iter, + ) + + return PersonalizedPageRank( + graph_store=graph_store, + config=config, + ) diff --git a/plugins/A_memorix/core/retrieval/sparse_bm25.py b/plugins/A_memorix/core/retrieval/sparse_bm25.py new file mode 100644 index 00000000..3b6f075d --- /dev/null +++ b/plugins/A_memorix/core/retrieval/sparse_bm25.py @@ -0,0 +1,402 @@ +""" +稀疏检索组件(FTS5 + BM25) + +支持: +- 懒加载索引连接 +- jieba / char n-gram 分词 +- 可卸载并收缩 SQLite 内存缓存 +""" + +from __future__ import annotations + +import re +import sqlite3 +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger +from ..storage import MetadataStore + +logger = get_logger("A_Memorix.SparseBM25") + +try: + import jieba # type: ignore + + HAS_JIEBA = True +except Exception: + HAS_JIEBA = False + jieba = None + + +@dataclass +class SparseBM25Config: + """BM25 稀疏检索配置。""" + + enabled: bool = True + backend: str = "fts5" + lazy_load: bool = True + mode: str = "auto" # auto | fallback_only | hybrid + tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram + jieba_user_dict: str = "" + char_ngram_n: int = 2 + candidate_k: int = 80 + max_doc_len: int = 2000 + enable_ngram_fallback_index: bool = True + enable_like_fallback: bool = False + enable_relation_sparse_fallback: bool = True + relation_candidate_k: int = 60 + relation_max_doc_len: int = 512 + unload_on_disable: bool = True + shrink_memory_on_unload: bool = True + + def __post_init__(self) -> None: + self.backend = str(self.backend or "fts5").strip().lower() + self.mode = str(self.mode or "auto").strip().lower() + self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower() + self.char_ngram_n = max(1, int(self.char_ngram_n)) + self.candidate_k = max(1, int(self.candidate_k)) + self.max_doc_len = max(0, int(self.max_doc_len)) + self.relation_candidate_k = max(1, int(self.relation_candidate_k)) + self.relation_max_doc_len = max(0, int(self.relation_max_doc_len)) + if self.backend != "fts5": + raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}") + if self.mode not in {"auto", "fallback_only", "hybrid"}: + raise ValueError(f"sparse.mode 非法: {self.mode}") + if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}: + raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}") + + +class SparseBM25Index: + """ + 基于 SQLite FTS5 的 BM25 检索适配层。 + """ + + def __init__( + self, + metadata_store: MetadataStore, + config: Optional[SparseBM25Config] = None, + ): + self.metadata_store = metadata_store + self.config = config or SparseBM25Config() + self._conn: Optional[sqlite3.Connection] = None + self._loaded: bool = False + self._jieba_dict_loaded: bool = False + + @property + def loaded(self) -> bool: + return self._loaded and self._conn is not None + + def ensure_loaded(self) -> bool: + """按需加载 FTS 连接与索引。""" + if not self.config.enabled: + return False + if self.loaded: + return True + + db_path = self.metadata_store.get_db_path() + conn = sqlite3.connect( + str(db_path), + check_same_thread=False, + timeout=30.0, + ) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + conn.execute("PRAGMA temp_store=MEMORY") + + if not self.metadata_store.ensure_fts_schema(conn=conn): + conn.close() + return False + self.metadata_store.ensure_fts_backfilled(conn=conn) + # 关系稀疏检索按独立开关加载,避免不必要的初始化开销。 + if self.config.enable_relation_sparse_fallback: + self.metadata_store.ensure_relations_fts_schema(conn=conn) + self.metadata_store.ensure_relations_fts_backfilled(conn=conn) + if self.config.enable_ngram_fallback_index: + self.metadata_store.ensure_paragraph_ngram_schema(conn=conn) + self.metadata_store.ensure_paragraph_ngram_backfilled( + n=self.config.char_ngram_n, + conn=conn, + ) + + self._conn = conn + self._loaded = True + self._prepare_tokenizer() + logger.info( + "SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s", + self.config.tokenizer_mode, + self.config.mode, + ) + return True + + def _prepare_tokenizer(self) -> None: + if self._jieba_dict_loaded: + return + if self.config.tokenizer_mode not in {"jieba", "mixed"}: + return + if not HAS_JIEBA: + logger.warning("jieba 不可用,tokenizer 将退化为 char n-gram") + return + user_dict = str(self.config.jieba_user_dict or "").strip() + if user_dict: + try: + jieba.load_userdict(user_dict) # type: ignore[union-attr] + logger.info("已加载 jieba 用户词典: %s", user_dict) + except Exception as e: + logger.warning("加载 jieba 用户词典失败: %s", e) + self._jieba_dict_loaded = True + + def _tokenize_jieba(self, text: str) -> List[str]: + if not HAS_JIEBA: + return [] + try: + tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr] + return [t.strip().lower() for t in tokens if t and t.strip()] + except Exception: + return [] + + def _tokenize_char_ngram(self, text: str, n: int) -> List[str]: + compact = re.sub(r"\s+", "", text.lower()) + if not compact: + return [] + if len(compact) < n: + return [compact] + return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] + + def _tokenize(self, text: str) -> List[str]: + text = str(text or "").strip() + if not text: + return [] + + mode = self.config.tokenizer_mode + if mode == "jieba": + tokens = self._tokenize_jieba(text) + if tokens: + return list(dict.fromkeys(tokens)) + return self._tokenize_char_ngram(text, self.config.char_ngram_n) + + if mode == "mixed": + toks = self._tokenize_jieba(text) + toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n)) + return list(dict.fromkeys([t for t in toks if t])) + + return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n))) + + def _build_match_query(self, tokens: List[str]) -> str: + safe_tokens: List[str] = [] + for token in tokens: + t = token.replace('"', '""').strip() + if not t: + continue + safe_tokens.append(f'"{t}"') + if not safe_tokens: + return "" + # 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。 + return " OR ".join(safe_tokens[:64]) + + def _fallback_substring_search( + self, + tokens: List[str], + limit: int, + ) -> List[Dict[str, Any]]: + """ + 当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。 + + 说明: + - FTS 索引当前采用 unicode61 tokenizer。 + - 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。 + - 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。 + """ + if not tokens: + return [] + + # 去重并裁剪 token 数量,避免生成超长 SQL。 + uniq_tokens = [t for t in dict.fromkeys(tokens) if t] + uniq_tokens = uniq_tokens[:32] + if not uniq_tokens: + return [] + + if self.config.enable_ngram_fallback_index: + try: + # 允许运行时切换开关后按需补齐 schema/回填。 + self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn) + self.metadata_store.ensure_paragraph_ngram_backfilled( + n=self.config.char_ngram_n, + conn=self._conn, + ) + rows = self.metadata_store.ngram_search_paragraphs( + tokens=uniq_tokens, + limit=limit, + max_doc_len=self.config.max_doc_len, + conn=self._conn, + ) + if rows: + return rows + except Exception as e: + logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}") + + if not self.config.enable_like_fallback: + return [] + + conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens)) + params: List[Any] = [f"%{tok}%" for tok in uniq_tokens] + scan_limit = max(int(limit) * 8, 200) + params.append(scan_limit) + + sql = f""" + SELECT p.hash, p.content + FROM paragraphs p + WHERE (p.is_deleted IS NULL OR p.is_deleted = 0) + AND ({conditions}) + LIMIT ? + """ + rows = self.metadata_store.query(sql, tuple(params)) + if not rows: + return [] + + scored: List[Dict[str, Any]] = [] + token_count = max(1, len(uniq_tokens)) + for row in rows: + content = str(row.get("content") or "") + content_low = content.lower() + matched = [tok for tok in uniq_tokens if tok in content_low] + if not matched: + continue + coverage = len(matched) / token_count + length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low)) + # 兜底路径使用相对分,保持与上层接口兼容。 + fallback_score = coverage * 0.8 + length_bonus * 0.2 + scored.append( + { + "hash": row["hash"], + "content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content, + "bm25_score": -float(fallback_score), + "fallback_score": float(fallback_score), + } + ) + + scored.sort(key=lambda x: x["fallback_score"], reverse=True) + return scored[:limit] + + def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]: + """执行 BM25 检索。""" + if not self.config.enabled: + return [] + if self.config.lazy_load and not self.loaded: + if not self.ensure_loaded(): + return [] + if not self.loaded: + return [] + # 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。 + self.metadata_store.ensure_relations_fts_schema(conn=self._conn) + self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn) + + tokens = self._tokenize(query) + match_query = self._build_match_query(tokens) + if not match_query: + return [] + + limit = max(1, int(k)) + rows = self.metadata_store.fts_search_bm25( + match_query=match_query, + limit=limit, + max_doc_len=self.config.max_doc_len, + conn=self._conn, + ) + if not rows: + rows = self._fallback_substring_search(tokens=tokens, limit=limit) + + results: List[Dict[str, Any]] = [] + for rank, row in enumerate(rows, start=1): + bm25_score = float(row.get("bm25_score", 0.0)) + results.append( + { + "hash": row["hash"], + "content": row["content"], + "rank": rank, + "bm25_score": bm25_score, + "score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数 + } + ) + return results + + def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]: + """执行关系稀疏检索(FTS5 + BM25)。""" + if not self.config.enabled or not self.config.enable_relation_sparse_fallback: + return [] + if self.config.lazy_load and not self.loaded: + if not self.ensure_loaded(): + return [] + if not self.loaded: + return [] + + tokens = self._tokenize(query) + match_query = self._build_match_query(tokens) + if not match_query: + return [] + + rows = self.metadata_store.fts_search_relations_bm25( + match_query=match_query, + limit=max(1, int(k)), + max_doc_len=self.config.relation_max_doc_len, + conn=self._conn, + ) + out: List[Dict[str, Any]] = [] + for rank, row in enumerate(rows, start=1): + bm25_score = float(row.get("bm25_score", 0.0)) + out.append( + { + "hash": row["hash"], + "subject": row["subject"], + "predicate": row["predicate"], + "object": row["object"], + "content": row["content"], + "rank": rank, + "bm25_score": bm25_score, + "score": -bm25_score, + } + ) + return out + + def upsert_paragraph(self, paragraph_hash: str) -> bool: + if not self.loaded: + return False + return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn) + + def delete_paragraph(self, paragraph_hash: str) -> bool: + if not self.loaded: + return False + return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn) + + def unload(self) -> None: + """卸载 BM25 连接并尽量释放内存。""" + if self._conn is not None: + try: + if self.config.shrink_memory_on_unload: + self.metadata_store.shrink_memory(conn=self._conn) + except Exception: + pass + try: + self._conn.close() + except Exception: + pass + self._conn = None + self._loaded = False + logger.info("SparseBM25Index unloaded") + + def stats(self) -> Dict[str, Any]: + doc_count = 0 + if self.loaded: + doc_count = self.metadata_store.fts_doc_count(conn=self._conn) + return { + "enabled": self.config.enabled, + "backend": self.config.backend, + "mode": self.config.mode, + "tokenizer_mode": self.config.tokenizer_mode, + "enable_ngram_fallback_index": self.config.enable_ngram_fallback_index, + "enable_like_fallback": self.config.enable_like_fallback, + "enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback, + "loaded": self.loaded, + "has_jieba": HAS_JIEBA, + "doc_count": doc_count, + } diff --git a/plugins/A_memorix/core/retrieval/threshold.py b/plugins/A_memorix/core/retrieval/threshold.py new file mode 100644 index 00000000..87a0094b --- /dev/null +++ b/plugins/A_memorix/core/retrieval/threshold.py @@ -0,0 +1,450 @@ +""" +动态阈值过滤器 + +根据检索结果的分布特征自适应调整过滤阈值。 +""" + +import numpy as np +from typing import List, Dict, Any, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum + +from src.common.logger import get_logger +from .dual_path import RetrievalResult + +logger = get_logger("A_Memorix.DynamicThresholdFilter") + + +class ThresholdMethod(Enum): + """阈值计算方法""" + + PERCENTILE = "percentile" # 百分位数 + STD_DEV = "std_dev" # 标准差 + GAP_DETECTION = "gap_detection" # 跳变检测 + ADAPTIVE = "adaptive" # 自适应(综合多种方法) + + +@dataclass +class ThresholdConfig: + """ + 阈值配置 + + 属性: + method: 阈值计算方法 + min_threshold: 最小阈值(绝对值) + max_threshold: 最大阈值(绝对值) + percentile: 百分位数(用于percentile方法) + std_multiplier: 标准差倍数(用于std_dev方法) + min_results: 最少保留结果数 + enable_auto_adjust: 是否自动调整参数 + """ + + method: ThresholdMethod = ThresholdMethod.ADAPTIVE + min_threshold: float = 0.3 + max_threshold: float = 0.95 + percentile: float = 75.0 # 百分位数 + std_multiplier: float = 1.5 # 标准差倍数 + min_results: int = 3 # 最少保留结果数 + enable_auto_adjust: bool = True + + def __post_init__(self): + """验证配置""" + if not 0 <= self.min_threshold <= 1: + raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}") + + if not 0 <= self.max_threshold <= 1: + raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}") + + if self.min_threshold >= self.max_threshold: + raise ValueError(f"min_threshold必须小于max_threshold") + + if not 0 <= self.percentile <= 100: + raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}") + + if self.std_multiplier <= 0: + raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}") + + if self.min_results < 0: + raise ValueError(f"min_results必须大于等于0: {self.min_results}") + + +class DynamicThresholdFilter: + """ + 动态阈值过滤器 + + 功能: + - 基于结果分布自适应计算阈值 + - 多种阈值计算方法 + - 自动参数调整 + - 统计信息收集 + + 参数: + config: 阈值配置 + """ + + def __init__( + self, + config: Optional[ThresholdConfig] = None, + ): + """ + 初始化动态阈值过滤器 + + Args: + config: 阈值配置 + """ + self.config = config or ThresholdConfig() + + # 统计信息 + self._total_filtered = 0 + self._total_processed = 0 + self._threshold_history: List[float] = [] + + logger.info( + f"DynamicThresholdFilter 初始化: " + f"method={self.config.method.value}, " + f"min_threshold={self.config.min_threshold}" + ) + + def filter( + self, + results: List[RetrievalResult], + return_threshold: bool = False, + ) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]: + """ + 过滤检索结果 + + Args: + results: 检索结果列表 + return_threshold: 是否返回使用的阈值 + + Returns: + 过滤后的结果列表,或 (结果列表, 阈值) 元组 + """ + if not results: + logger.debug("结果列表为空,无需过滤") + return ([], 0.0) if return_threshold else [] + + self._total_processed += len(results) + + # 提取分数 + scores = np.array([r.score for r in results]) + + # 计算阈值 + threshold = self._compute_threshold(scores, results) + + # 记录阈值 + self._threshold_history.append(threshold) + + # 应用阈值过滤 + filtered_results = [ + r for r in results + if r.score >= threshold + ] + + # 确保至少保留min_results个结果 + if len(filtered_results) < self.config.min_results: + # 按分数排序,取前min_results个 + sorted_results = sorted(results, key=lambda x: x.score, reverse=True) + filtered_results = sorted_results[:self.config.min_results] + threshold = filtered_results[-1].score if filtered_results else 0.0 + + self._total_filtered += len(results) - len(filtered_results) + + logger.info( + f"过滤完成: {len(results)} -> {len(filtered_results)} " + f"(threshold={threshold:.3f})" + ) + + if return_threshold: + return filtered_results, threshold + return filtered_results + + def _compute_threshold( + self, + scores: np.ndarray, + results: List[RetrievalResult], + ) -> float: + """ + 计算阈值 + + Args: + scores: 分数数组 + results: 检索结果列表 + + Returns: + 阈值 + """ + if self.config.method == ThresholdMethod.PERCENTILE: + threshold = self._percentile_threshold(scores) + elif self.config.method == ThresholdMethod.STD_DEV: + threshold = self._std_dev_threshold(scores) + elif self.config.method == ThresholdMethod.GAP_DETECTION: + threshold = self._gap_detection_threshold(scores) + else: # ADAPTIVE + # 自适应方法:综合多种方法 + thresholds = [ + self._percentile_threshold(scores), + self._std_dev_threshold(scores), + self._gap_detection_threshold(scores), + ] + # 使用中位数作为最终阈值 + threshold = float(np.median(thresholds)) + + # 限制在[min_threshold, max_threshold]范围内 + threshold = np.clip( + threshold, + self.config.min_threshold, + self.config.max_threshold, + ) + + # 自动调整 + if self.config.enable_auto_adjust: + threshold = self._auto_adjust_threshold(threshold, scores) + + return float(threshold) + + def _percentile_threshold(self, scores: np.ndarray) -> float: + """ + 基于百分位数计算阈值 + + Args: + scores: 分数数组 + + Returns: + 阈值 + """ + percentile = self.config.percentile + threshold = float(np.percentile(scores, percentile)) + + logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})") + return threshold + + def _std_dev_threshold(self, scores: np.ndarray) -> float: + """ + 基于标准差计算阈值 + + threshold = mean - std_multiplier * std + + Args: + scores: 分数数组 + + Returns: + 阈值 + """ + mean = float(np.mean(scores)) + std = float(np.std(scores)) + multiplier = self.config.std_multiplier + + threshold = mean - multiplier * std + + logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})") + return threshold + + def _gap_detection_threshold(self, scores: np.ndarray) -> float: + """ + 基于跳变检测计算阈值 + + 找到分数分布中最大的"跳变"位置,以此为阈值 + + Args: + scores: 分数数组(降序排列) + + Returns: + 阈值 + """ + # 降序排列 + sorted_scores = np.sort(scores)[::-1] + + if len(sorted_scores) < 2: + return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0 + + # 计算相邻分数的差值 + gaps = np.diff(sorted_scores) + + # 找到最大的跳变位置 + max_gap_idx = int(np.argmax(gaps)) + + # 阈值为跳变后的分数 + threshold = float(sorted_scores[max_gap_idx + 1]) + + logger.debug( + f"跳变检测阈值: {threshold:.3f} " + f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})" + ) + return threshold + + def _auto_adjust_threshold( + self, + threshold: float, + scores: np.ndarray, + ) -> float: + """ + 自动调整阈值 + + 基于历史阈值和当前分数分布调整 + + Args: + threshold: 当前阈值 + scores: 分数数组 + + Returns: + 调整后的阈值 + """ + if not self._threshold_history: + return threshold + + # 计算历史阈值的移动平均 + recent_thresholds = self._threshold_history[-10:] # 最近10次 + avg_threshold = float(np.mean(recent_thresholds)) + + # 当前阈值与历史平均的差异 + diff = threshold - avg_threshold + + # 如果差异过大(>0.2),向历史平均靠拢 + if abs(diff) > 0.2: + adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50% + logger.debug( + f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} " + f"(历史平均={avg_threshold:.3f})" + ) + return adjusted_threshold + + return threshold + + def filter_by_confidence( + self, + results: List[RetrievalResult], + min_confidence: float = 0.5, + ) -> List[RetrievalResult]: + """ + 基于置信度过滤结果 + + Args: + results: 检索结果列表 + min_confidence: 最小置信度 + + Returns: + 过滤后的结果列表 + """ + filtered = [] + for result in results: + # 对于关系结果,使用confidence字段 + if result.result_type == "relation": + confidence = result.metadata.get("confidence", 1.0) + if confidence >= min_confidence: + filtered.append(result) + else: + # 对于段落结果,直接使用分数 + if result.score >= min_confidence: + filtered.append(result) + + logger.info( + f"置信度过滤: {len(results)} -> {len(filtered)} " + f"(min_confidence={min_confidence})" + ) + + return filtered + + def filter_by_diversity( + self, + results: List[RetrievalResult], + similarity_threshold: float = 0.9, + top_k: int = 10, + ) -> List[RetrievalResult]: + """ + 基于多样性过滤结果(去除重复) + + Args: + results: 检索结果列表 + similarity_threshold: 相似度阈值(高于此值视为重复) + top_k: 最多保留结果数 + + Returns: + 过滤后的结果列表 + """ + if not results: + return [] + + # 按分数排序 + sorted_results = sorted(results, key=lambda x: x.score, reverse=True) + + # 贪心选择:选择与已选结果相似度低的结果 + selected = [] + selected_hashes = [] + + for result in sorted_results: + if len(selected) >= top_k: + break + + # 检查与已选结果的相似度 + is_duplicate = False + for selected_hash in selected_hashes: + # 简单判断:基于hash的前缀 + if result.hash_value[:8] == selected_hash[:8]: + is_duplicate = True + break + + if not is_duplicate: + selected.append(result) + selected_hashes.append(result.hash_value) + + logger.info( + f"多样性过滤: {len(results)} -> {len(selected)} " + f"(similarity_threshold={similarity_threshold})" + ) + + return selected + + def get_statistics(self) -> Dict[str, Any]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + filter_rate = ( + self._total_filtered / self._total_processed + if self._total_processed > 0 + else 0.0 + ) + + stats = { + "config": { + "method": self.config.method.value, + "min_threshold": self.config.min_threshold, + "max_threshold": self.config.max_threshold, + "percentile": self.config.percentile, + "std_multiplier": self.config.std_multiplier, + "min_results": self.config.min_results, + "enable_auto_adjust": self.config.enable_auto_adjust, + }, + "statistics": { + "total_processed": self._total_processed, + "total_filtered": self._total_filtered, + "filter_rate": filter_rate, + "avg_threshold": float(np.mean(self._threshold_history)) + if self._threshold_history else 0.0, + "threshold_count": len(self._threshold_history), + }, + } + + if self._threshold_history: + stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history)) + stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history)) + + return stats + + def reset_statistics(self) -> None: + """重置统计信息""" + self._total_filtered = 0 + self._total_processed = 0 + self._threshold_history.clear() + logger.info("统计信息已重置") + + def __repr__(self) -> str: + return ( + f"DynamicThresholdFilter(" + f"method={self.config.method.value}, " + f"min_threshold={self.config.min_threshold}, " + f"filtered={self._total_filtered}/{self._total_processed})" + ) diff --git a/plugins/A_memorix/core/runtime/__init__.py b/plugins/A_memorix/core/runtime/__init__.py new file mode 100644 index 00000000..fa222715 --- /dev/null +++ b/plugins/A_memorix/core/runtime/__init__.py @@ -0,0 +1,8 @@ +"""SDK runtime exports for A_Memorix.""" + +from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel + +__all__ = [ + "KernelSearchRequest", + "SDKMemoryKernel", +] diff --git a/plugins/A_memorix/core/runtime/sdk_memory_kernel.py b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py new file mode 100644 index 00000000..7c8f9213 --- /dev/null +++ b/plugins/A_memorix/core/runtime/sdk_memory_kernel.py @@ -0,0 +1,579 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from src.common.logger import get_logger + +from ..embedding import create_embedding_api_adapter +from ..retrieval import ( + DualPathRetriever, + DualPathRetrieverConfig, + RetrievalResult, + SparseBM25Config, + SparseBM25Index, + TemporalQueryOptions, +) +from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore +from ..utils.aggregate_query_service import AggregateQueryService +from ..utils.episode_retrieval_service import EpisodeRetrievalService +from ..utils.hash import normalize_text +from ..utils.relation_write_service import RelationWriteService + +logger = get_logger("A_Memorix.SDKMemoryKernel") + + +@dataclass +class KernelSearchRequest: + query: str = "" + limit: int = 5 + mode: str = "hybrid" + chat_id: str = "" + person_id: str = "" + time_start: Optional[float] = None + time_end: Optional[float] = None + + +class SDKMemoryKernel: + def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None: + self.plugin_root = Path(plugin_root).resolve() + self.config = config or {} + storage_cfg = self._cfg("storage", {}) or {} + data_dir = str(storage_cfg.get("data_dir", "./data") or "./data") + self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir) + self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256))) + self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False)) + + self.embedding_manager = None + self.vector_store: Optional[VectorStore] = None + self.graph_store: Optional[GraphStore] = None + self.metadata_store: Optional[MetadataStore] = None + self.relation_write_service: Optional[RelationWriteService] = None + self.sparse_index = None + self.retriever: Optional[DualPathRetriever] = None + self.episode_retriever: Optional[EpisodeRetrievalService] = None + self.aggregate_query_service: Optional[AggregateQueryService] = None + self._initialized = False + self._last_maintenance_at: Optional[float] = None + + def _cfg(self, key: str, default: Any = None) -> Any: + current: Any = self.config + if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict): + return current.get(key, default) + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + async def initialize(self) -> None: + if self._initialized: + return + self.data_dir.mkdir(parents=True, exist_ok=True) + self.embedding_manager = create_embedding_api_adapter( + batch_size=int(self._cfg("embedding.batch_size", 32)), + max_concurrent=int(self._cfg("embedding.max_concurrent", 5)), + default_dimension=self.embedding_dimension, + model_name=str(self._cfg("embedding.model_name", "hash-v1")), + retry_config=self._cfg("embedding.retry", {}) or {}, + ) + self.embedding_dimension = int(await self.embedding_manager._detect_dimension()) + self.vector_store = VectorStore( + dimension=self.embedding_dimension, + quantization_type=QuantizationType.INT8, + data_dir=self.data_dir / "vectors", + ) + self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph") + self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata") + self.metadata_store.connect() + if self.vector_store.has_data(): + self.vector_store.load() + self.vector_store.warmup_index(force_train=True) + if self.graph_store.has_data(): + self.graph_store.load() + + sparse_cfg = self._cfg("retrieval.sparse", {}) or {} + self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg)) + if getattr(self.sparse_index.config, "enabled", False): + self.sparse_index.ensure_loaded() + + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + self.retriever = DualPathRetriever( + vector_store=self.vector_store, + graph_store=self.graph_store, + metadata_store=self.metadata_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + config=DualPathRetrieverConfig( + top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)), + top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)), + top_k_final=int(self._cfg("retrieval.top_k_final", 10)), + alpha=float(self._cfg("retrieval.alpha", 0.5)), + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), + ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), + enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), + sparse=sparse_cfg, + fusion=self._cfg("retrieval.fusion", {}) or {}, + graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {}, + relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {}, + ), + ) + self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever) + self.aggregate_query_service = AggregateQueryService(plugin_config=self.config) + self._initialized = True + + def close(self) -> None: + if self.vector_store is not None: + self.vector_store.save() + if self.graph_store is not None: + self.graph_store.save() + if self.metadata_store is not None: + self.metadata_store.close() + self._initialized = False + + async def ingest_summary( + self, + *, + external_id: str, + chat_id: str, + text: str, + participants: Optional[Sequence[str]] = None, + time_start: Optional[float] = None, + time_end: Optional[float] = None, + tags: Optional[Sequence[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + summary_meta = dict(metadata or {}) + summary_meta.setdefault("kind", "chat_summary") + return await self.ingest_text( + external_id=external_id, + source_type="chat_summary", + text=text, + chat_id=chat_id, + participants=participants, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=summary_meta, + ) + + async def ingest_text( + self, + *, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[Sequence[str]] = None, + participants: Optional[Sequence[str]] = None, + timestamp: Optional[float] = None, + time_start: Optional[float] = None, + time_end: Optional[float] = None, + tags: Optional[Sequence[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + entities: Optional[Sequence[str]] = None, + relations: Optional[Sequence[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager + assert self.relation_write_service + content = normalize_text(text) + if not content: + return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"} + if ref := self.metadata_store.get_external_memory_ref(external_id): + return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"} + + person_tokens = self._tokens(person_ids) + participant_tokens = self._tokens(participants) + entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens) + source = self._build_source(source_type, chat_id, person_tokens) + paragraph_meta = dict(metadata or {}) + paragraph_meta.update( + { + "external_id": external_id, + "source_type": str(source_type or "").strip(), + "chat_id": str(chat_id or "").strip(), + "person_ids": person_tokens, + "participants": participant_tokens, + "tags": self._tokens(tags), + } + ) + paragraph_hash = self.metadata_store.add_paragraph( + content=content, + source=source, + metadata=paragraph_meta, + knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed", + time_meta=self._time_meta(timestamp, time_start, time_end), + ) + embedding = await self.embedding_manager.encode(content) + self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash]) + for name in entity_tokens: + self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash) + + stored_relations: List[str] = [] + for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]: + s = str(row.get("subject", "") or "").strip() + p = str(row.get("predicate", "") or "").strip() + o = str(row.get("object", "") or "").strip() + if not (s and p and o): + continue + result = await self.relation_write_service.upsert_relation_with_vector( + subject=s, + predicate=p, + obj=o, + confidence=float(row.get("confidence", 1.0) or 1.0), + source_paragraph=paragraph_hash, + metadata={"external_id": external_id, "source_type": source_type}, + write_vector=self.relation_vectors_enabled, + ) + self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value) + stored_relations.append(result.hash_value) + + self.metadata_store.upsert_external_memory_ref( + external_id=external_id, + paragraph_hash=paragraph_hash, + source_type=source_type, + metadata={"chat_id": chat_id, "person_ids": person_tokens}, + ) + self._persist() + self.rebuild_episodes_for_sources([source]) + for person_id in person_tokens: + await self.refresh_person_profile(person_id) + return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []} + + async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]: + await self.initialize() + assert self.retriever and self.episode_retriever and self.aggregate_query_service + mode = str(request.mode or "hybrid").strip().lower() or "hybrid" + clean_query = str(request.query or "").strip() + limit = max(1, int(request.limit or 5)) + temporal = self._temporal(request) + if mode == "episode": + rows = await self.episode_retriever.query( + query=clean_query, + top_k=limit, + time_from=request.time_start, + time_to=request.time_end, + source=self._chat_source(request.chat_id), + ) + hits = [self._episode_hit(row) for row in rows] + return {"summary": self._summary(hits), "hits": hits} + if mode == "aggregate": + payload = await self.aggregate_query_service.execute( + query=clean_query, + top_k=limit, + mix=True, + mix_top_k=limit, + time_from=str(request.time_start) if request.time_start is not None else None, + time_to=str(request.time_end) if request.time_end is not None else None, + search_runner=lambda: self._aggregate_search(clean_query, limit, temporal), + time_runner=lambda: self._aggregate_time(clean_query, limit, temporal), + episode_runner=lambda: self._aggregate_episode(clean_query, limit, request), + ) + hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)] + for item in hits: + item.setdefault("metadata", {}) + return {"summary": self._summary(hits), "hits": hits} + results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal) + hits = [self._retrieval_hit(item) for item in results] + return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)} + + async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]: + _ = chat_id + await self.initialize() + assert self.metadata_store + snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit) + evidence = [] + for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]: + paragraph = self.metadata_store.get_paragraph(hash_value) + if paragraph is not None: + evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}}) + text = str(snapshot.get("profile_text", "") or "").strip() + traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8] + return {"summary": text, "traits": traits, "evidence": evidence} + + async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]: + await self.initialize() + assert self.metadata_store + rows = self.metadata_store.query( + """ + SELECT DISTINCT p.* + FROM paragraphs p + JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash + JOIN entities e ON e.hash = pe.entity_hash + WHERE e.name = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC + LIMIT ? + """, + (person_id, max(1, int(limit)) * 3), + ) + evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] + vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]] + relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]] + if relation_edges: + profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6]) + elif vector_evidence: + profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6]) + else: + profile_text = "暂无稳定画像证据。" + return self.metadata_store.upsert_person_profile_snapshot( + person_id=person_id, + profile_text=profile_text, + aliases=[person_id], + relation_edges=relation_edges, + vector_evidence=vector_evidence, + evidence_ids=evidence_ids[: max(1, int(limit))], + expires_at=time.time() + 6 * 3600, + source_note="sdk_memory_kernel", + ) + + async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]: + _ = reason + await self.initialize() + assert self.metadata_store + hashes = self._resolve_relation_hashes(target) + if not hashes: + return {"success": False, "detail": "未命中可维护关系"} + act = str(action or "").strip().lower() + if act == "reinforce": + self.metadata_store.reinforce_relations(hashes) + elif act == "protect": + ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0 + self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0) + elif act == "restore": + restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value)) + if restored <= 0: + return {"success": False, "detail": "未恢复任何关系"} + else: + return {"success": False, "detail": f"不支持的维护动作: {act}"} + self._last_maintenance_at = time.time() + self._persist() + return {"success": True, "detail": f"{act} {len(hashes)} 条关系"} + + def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int: + assert self.metadata_store + rebuilt = 0 + for source in self._tokens(sources): + rows = self.metadata_store.query( + """ + SELECT * FROM paragraphs + WHERE source = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC + """, + (source,), + ) + if not rows: + continue + paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()] + payload = self.metadata_store.upsert_episode( + { + "source": source, + "title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80], + "summary": ";".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。", + "participants": self._episode_participants(rows), + "keywords": self._episode_keywords(rows), + "evidence_ids": paragraph_hashes, + "paragraph_count": len(paragraph_hashes), + "event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False), + "event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True), + "time_granularity": "day", + "time_confidence": 0.7, + "llm_confidence": 0.0, + "segmentation_model": "rule_based_sdk", + "segmentation_version": "1", + } + ) + self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes) + rebuilt += 1 + return rebuilt + + def memory_stats(self) -> Dict[str, Any]: + assert self.metadata_store + stats = self.metadata_store.get_statistics() + episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"] + profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"] + return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at} + + async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: + assert self.retriever + hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] + return {"success": True, "results": hits, "count": len(hits), "query_type": "search"} + + async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]: + if temporal is None: + return {"success": False, "error": "missing temporal window", "results": []} + assert self.retriever + hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)] + return {"success": True, "results": hits, "count": len(hits), "query_type": "time"} + + async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]: + assert self.episode_retriever + rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + hits = [self._episode_hit(row) for row in rows] + return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"} + + def _persist(self) -> None: + if self.vector_store is not None: + self.vector_store.save() + if self.graph_store is not None: + self.graph_store.save() + if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False): + self.sparse_index.ensure_loaded() + + @staticmethod + def _tokens(values: Optional[Iterable[Any]]) -> List[str]: + result: List[str] = [] + seen = set() + for item in values or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result + + @classmethod + def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]: + merged: List[str] = [] + seen = set() + for group in groups: + for item in cls._tokens(group): + if item in seen: + continue + seen.add(item) + merged.append(item) + return merged + + @staticmethod + def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str: + clean_type = str(source_type or "").strip() or "memory" + if clean_type == "chat_summary" and chat_id: + return f"chat_summary:{chat_id}" + if clean_type == "person_fact" and person_ids: + return f"person_fact:{person_ids[0]}" + return f"{clean_type}:{chat_id}" if chat_id else clean_type + + @staticmethod + def _chat_source(chat_id: str) -> Optional[str]: + clean = str(chat_id or "").strip() + return f"chat_summary:{clean}" if clean else None + + @staticmethod + def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + if timestamp is not None: + payload["event_time"] = float(timestamp) + if time_start is not None: + payload["event_time_start"] = float(time_start) + if time_end is not None: + payload["event_time_end"] = float(time_end) + if payload: + payload["time_granularity"] = "minute" + payload["time_confidence"] = 0.95 + return payload + + def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]: + if request.time_start is None and request.time_end is None and not request.chat_id: + return None + return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id)) + + @staticmethod + def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]: + payload = item.to_dict() + return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}} + + @staticmethod + def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]: + return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}} + + @staticmethod + def _summary(hits: Sequence[Dict[str, Any]]) -> str: + if not hits: + return "" + lines = [] + for index, item in enumerate(hits[:5], start=1): + content = str(item.get("content", "") or "").strip().replace("\n", " ") + lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}") + return "\n".join(lines) + + @staticmethod + def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]: + if not person_id: + return hits + filtered = [] + for item in hits: + metadata = item.get("metadata", {}) or {} + if person_id in (metadata.get("person_ids", []) or []): + filtered.append(item) + continue + if person_id and person_id in str(item.get("content", "") or ""): + filtered.append(item) + return filtered or hits + + @staticmethod + def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]: + seen = set() + result: List[str] = [] + for row in rows: + meta = row.get("metadata", {}) or {} + for key in ("participants", "person_ids"): + for item in meta.get(key, []) or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result[:16] + + @staticmethod + def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]: + seen = set() + result: List[str] = [] + for row in rows: + meta = row.get("metadata", {}) or {} + for item in meta.get("tags", []) or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + result.append(token) + return result[:12] + + @staticmethod + def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]: + values: List[float] = [] + for row in rows: + for key in (primary, fallback): + value = row.get(key) + try: + if value is not None: + values.append(float(value)) + break + except Exception: + continue + if not values: + return None + return max(values) if reverse else min(values) + + def _resolve_relation_hashes(self, target: str) -> List[str]: + assert self.metadata_store + token = str(target or "").strip() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()): + return [token] + hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10) + if hashes: + return hashes + return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()] diff --git a/plugins/A_memorix/core/storage/__init__.py b/plugins/A_memorix/core/storage/__init__.py new file mode 100644 index 00000000..d878b8e7 --- /dev/null +++ b/plugins/A_memorix/core/storage/__init__.py @@ -0,0 +1,53 @@ +"""存储层""" + +from .vector_store import VectorStore, QuantizationType +from .graph_store import GraphStore, SparseMatrixFormat +from .metadata_store import MetadataStore +from .knowledge_types import ( + ImportStrategy, + KnowledgeType, + allowed_import_strategy_values, + allowed_knowledge_type_values, + get_knowledge_type_from_string, + get_import_strategy_from_string, + parse_import_strategy, + resolve_stored_knowledge_type, + should_extract_relations, + get_default_chunk_size, + get_type_display_name, + validate_stored_knowledge_type, +) +from .type_detection import ( + detect_knowledge_type, + get_type_from_user_input, + looks_like_factual_text, + looks_like_quote_text, + looks_like_structured_text, + select_import_strategy, +) + +__all__ = [ + "VectorStore", + "GraphStore", + "MetadataStore", + "QuantizationType", + "SparseMatrixFormat", + "ImportStrategy", + "KnowledgeType", + "allowed_import_strategy_values", + "allowed_knowledge_type_values", + "get_knowledge_type_from_string", + "get_import_strategy_from_string", + "parse_import_strategy", + "resolve_stored_knowledge_type", + "should_extract_relations", + "get_default_chunk_size", + "get_type_display_name", + "validate_stored_knowledge_type", + "detect_knowledge_type", + "get_type_from_user_input", + "looks_like_factual_text", + "looks_like_quote_text", + "looks_like_structured_text", + "select_import_strategy", +] diff --git a/plugins/A_memorix/core/storage/graph_store.py b/plugins/A_memorix/core/storage/graph_store.py new file mode 100644 index 00000000..8a075864 --- /dev/null +++ b/plugins/A_memorix/core/storage/graph_store.py @@ -0,0 +1,1434 @@ +""" +图存储模块 + +基于SciPy稀疏矩阵的知识图谱存储与计算。 +""" + +import pickle +from enum import Enum +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Set, Any +from collections import defaultdict +import threading +import asyncio + +import numpy as np + +class SparseMatrixFormat(Enum): + """稀疏矩阵格式""" + CSR = "csr" + CSC = "csc" + +try: + from scipy.sparse import csr_matrix, csc_matrix, triu, save_npz, load_npz, bmat, lil_matrix + from scipy.sparse.linalg import norm + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + +import contextlib +from src.common.logger import get_logger +from ..utils.hash import compute_hash +from ..utils.io import atomic_write + +logger = get_logger("A_Memorix.GraphStore") + + +class GraphModificationMode(Enum): + """图修改模式""" + BATCH = "batch" # 批量模式 (默认, 适合一次性加载) + INCREMENTAL = "incremental" # 增量模式 (适合频繁随机写入, 使用LIL) + READ_ONLY = "read_only" # 只读模式 (适合计算, CSR/CSC) + + +class GraphStore: + """ + 图存储类 + + 功能: + - CSR稀疏矩阵存储图结构 + - 节点和边的CRUD操作 + - Personalized PageRank计算 + - 同义词自动连接 + - 图持久化 + + 参数: + matrix_format: 稀疏矩阵格式(csr/csc) + data_dir: 数据目录 + """ + + def __init__( + self, + matrix_format: str = "csr", + data_dir: Optional[Union[str, Path]] = None, + ): + """ + 初始化图存储 + + Args: + matrix_format: 稀疏矩阵格式(csr/csc) + data_dir: 数据目录 + """ + if not HAS_SCIPY: + raise ImportError("SciPy 未安装,请安装: pip install scipy") + + if isinstance(matrix_format, SparseMatrixFormat): + self.matrix_format = matrix_format.value + else: + self.matrix_format = str(matrix_format).lower() + self.data_dir = Path(data_dir) if data_dir else None + + # 节点管理 + self._nodes: List[str] = [] # 节点列表 + self._node_to_idx: Dict[str, int] = {} # 节点名到索引的映射 + self._node_attrs: Dict[str, Dict[str, Any]] = {} # 节点属性 + + # 边管理(邻接矩阵) + self._adjacency: Optional[Union[csr_matrix, csc_matrix]] = None + + # 统计信息 + self._total_nodes_added = 0 + self._total_edges_added = 0 + self._total_nodes_deleted = 0 + self._total_edges_deleted = 0 + + # 状态管理 + self._modification_mode = GraphModificationMode.BATCH + + # 状态管理 + self._adjacency_T: Optional[Union[csr_matrix, csc_matrix]] = None + self._adjacency_dirty: bool = True + self._saliency_cache: Optional[Dict[str, float]] = None + + # V5: 多关系映射 (src_idx, dst_idx) -> Set[relation_hash] + self._edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) + # V5: 简单的异步锁 (实际上 asyncio 环境下单线程主循环可能不需要,但为了安全保留) + self._lock = asyncio.Lock() + + logger.info(f"GraphStore 初始化: format={matrix_format}") + + def _canonicalize(self, node: str) -> str: + """规范化节点名称 (用于去重和内部索引)""" + if not node: + return "" + return str(node).strip().lower() + + @contextlib.contextmanager + def batch_update(self): + """ + 批量更新上下文管理器 + + 进入时切换到 LIL 格式以优化随机/增量更新 + 退出时恢复到 CSR/CSC 格式以优化存储和计算 + """ + original_mode = self._modification_mode + self._switch_mode(GraphModificationMode.INCREMENTAL) + try: + yield + finally: + self._switch_mode(original_mode) + + def _switch_mode(self, new_mode: GraphModificationMode): + """切换修改模式并转换矩阵格式""" + if new_mode == self._modification_mode: + return + + if self._adjacency is None: + self._modification_mode = new_mode + return + + logger.debug(f"切换图模式: {self._modification_mode.value} -> {new_mode.value}") + + # 转换逻辑 + if new_mode == GraphModificationMode.INCREMENTAL: + # 转换为 LIL 格式 + if not isinstance(self._adjacency, lil_matrix): # 粗略检查是否非 lil + try: + self._adjacency = self._adjacency.tolil() + logger.debug("已转换为 LIL 格式") + except Exception as e: + logger.warning(f"转换为 LIL 失败: {e}") + + elif new_mode in [GraphModificationMode.BATCH, GraphModificationMode.READ_ONLY]: + # 转换回配置的格式 (CSR/CSC) + if self.matrix_format == "csr": + self._adjacency = self._adjacency.tocsr() + elif self.matrix_format == "csc": + self._adjacency = self._adjacency.tocsc() + logger.debug(f"已恢复为 {self.matrix_format.upper()} 格式") + + self._modification_mode = new_mode + + def add_nodes( + self, + nodes: List[str], + attributes: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> int: + """ + 添加节点 + + Args: + nodes: 节点名称列表 + attributes: 节点属性字典 {node_name: {attr: value}} + + Returns: + 成功添加的节点数量 + """ + added = 0 + for node in nodes: + canon = self._canonicalize(node) + if canon in self._node_to_idx: + logger.debug(f"节点已存在,跳过: {node}") + continue + + # 添加到节点列表 + idx = len(self._nodes) + self._nodes.append(node) # 存储原始节点名 + self._node_to_idx[canon] = idx # 映射规范化节点名到索引 + self._adjacency_dirty = True + self._saliency_cache = None + + # 添加属性 + if attributes and node in attributes: + self._node_attrs[canon] = attributes[node] + else: + self._node_attrs[canon] = {} + + added += 1 + self._total_nodes_added += 1 + + # 扩展邻接矩阵 + if added > 0: + self._expand_adjacency_matrix(added) + + logger.debug(f"添加 {added} 个节点") + return added + + def add_edges( + self, + edges: List[Tuple[str, str]], + weights: Optional[List[float]] = None, + relation_hashes: Optional[List[str]] = None, # V5: 支持关系哈希映射 (Relation Hash Mapping) + ) -> int: + """ + 添加边 + + Args: + edges: 边列表 [(source, target), ...] + weights: 边权重列表(默认为1.0) + + Returns: + 成功添加的边数量 + """ + if not edges: + return 0 + + # 确保所有节点存在 + nodes_to_add = set() + for src, tgt in edges: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon not in self._node_to_idx: + nodes_to_add.add(src) + if tgt_canon not in self._node_to_idx: + nodes_to_add.add(tgt) + + if nodes_to_add: + self.add_nodes(list(nodes_to_add)) + + # 处理权重 + if weights is None: + weights = [1.0] * len(edges) + + if len(weights) != len(edges): + raise ValueError(f"边数量与权重数量不匹配: {len(edges)} vs {len(weights)}") + + # 如果仅仅是添加边且处于增量模式 (LIL),直接更新 + if self._modification_mode == GraphModificationMode.INCREMENTAL: + if self._adjacency is None: + # 初始化为空 LIL + n = len(self._nodes) + from scipy.sparse import lil_matrix + self._adjacency = lil_matrix((n, n), dtype=np.float32) + + # 尝试直接使用 LIL 索引更新 + try: + # 批量获取索引 + rows = [self._node_to_idx[self._canonicalize(src)] for src, _ in edges] + cols = [self._node_to_idx[self._canonicalize(tgt)] for _, tgt in edges] + + # 确保矩阵足够大 (如果 add_nodes 没有扩展它) - 通常 add_nodes 会处理 + # 这里直接赋值 + self._adjacency[rows, cols] = weights + + self._total_edges_added += len(edges) + + # V5: Update edge hash map + if relation_hashes: + for (src, tgt), r_hash in zip(edges, relation_hashes): + if r_hash: + s_idx = self._node_to_idx[self._canonicalize(src)] + t_idx = self._node_to_idx[self._canonicalize(tgt)] + self._edge_hash_map[(s_idx, t_idx)].add(r_hash) + + logger.debug(f"增量添加 {len(edges)} 条边 (LIL)") + return len(edges) + except Exception as e: + logger.warning(f"LIL 增量更新失败,回退到通用方法: {e}") + # Fallback to general method below + + # 通用方法 (构建 COO 然后合并) + # 构建边的三元组 + row_indices = [] + col_indices = [] + data_values = [] + + for (src, tgt), weight in zip(edges, weights): + src_idx = self._node_to_idx[self._canonicalize(src)] + tgt_idx = self._node_to_idx[self._canonicalize(tgt)] + + row_indices.append(src_idx) + col_indices.append(tgt_idx) + data_values.append(weight) + + # 创建新的边的矩阵 + n = len(self._nodes) + new_edges = csr_matrix( + (data_values, (row_indices, col_indices)), + shape=(n, n), + ) + + # 合并到邻接矩阵 + if self._adjacency is None: + self._adjacency = new_edges + else: + self._adjacency = self._adjacency + new_edges + + # 转换为指定格式 + if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): + self._adjacency = self._adjacency.tocsc() + elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): + self._adjacency = self._adjacency.tocsr() + + self._total_edges_added += len(edges) + self._adjacency_dirty = True # 标记脏位 + + # V5: 更新边哈希映射 (Edge Hash Map) + if relation_hashes: + for (src, tgt), r_hash in zip(edges, relation_hashes): + if r_hash: + try: + s_idx = self._node_to_idx[self._canonicalize(src)] + t_idx = self._node_to_idx[self._canonicalize(tgt)] + self._edge_hash_map[(s_idx, t_idx)].add(r_hash) + except KeyError: + pass # 正常情况下节点已在上方添加,此处仅作防错处理 + + logger.debug(f"添加 {len(edges)} 条边") + return len(edges) + + def update_edge_weight( + self, + source: str, + target: str, + delta: float, + min_weight: float = 0.1, + max_weight: float = 10.0, + ) -> float: + """ + 更新边权重 (增量/强化/弱化) + + Args: + source: 源节点 + target: 目标节点 + delta: 权重变化量 (+/-) + min_weight: 最小权重限制 + max_weight: 最大权重限制 + + Returns: + 更新后的权重 + """ + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + logger.warning(f"节点不存在,无法更新权重: {source} -> {target}") + return 0.0 + + current_weight = self.get_edge_weight(source, target) + if current_weight == 0.0 and delta <= 0: + # 边不存在且试图减少权重,忽略 + return 0.0 + + # 如果边不存在但 delta > 0,相当于添加新边 (默认基础权重0 + delta) + # 但为了逻辑清晰,我们假设只更新存在的边,或者确实添加 + + new_weight = current_weight + delta + new_weight = max(min_weight, min(max_weight, new_weight)) + + # 使用 batch_update 上下文自动处理格式转换 + # 这里我们临时切换到 incremental 模式进行单次更新 + with self.batch_update(): + # add_edges 会覆盖或添加,我们需要覆盖 + self.add_edges([(source, target)], [new_weight]) + + logger.debug(f"更新权重 {source}->{target}: {current_weight:.2f} -> {new_weight:.2f}") + return new_weight + + def delete_nodes(self, nodes: List[str]) -> int: + """ + 删除节点(及相关的边) + + Args: + nodes: 要删除的节点列表 + + Returns: + 成功删除的节点数量 + """ + if not nodes: + return 0 + + # 检查哪些节点存在 + existing_nodes = [node for node in nodes if self._canonicalize(node) in self._node_to_idx] + if not existing_nodes: + logger.warning("所有节点都不存在,无法删除") + return 0 + + # 获取要删除的索引 + indices_to_delete = {self._node_to_idx[self._canonicalize(node)] for node in existing_nodes} + indices_to_keep = [ + i for i in range(len(self._nodes)) + if i not in indices_to_delete + ] + + # 创建索引映射 + old_to_new = {old_idx: new_idx for new_idx, old_idx in enumerate(indices_to_keep)} + + # 重建节点列表 (存储原始节点名) + self._nodes = [self._nodes[i] for i in indices_to_keep] + # 重建规范化节点名到索引的映射 + self._node_to_idx = {self._canonicalize(self._nodes[new_idx]): new_idx for new_idx in range(len(self._nodes))} + + # 删除并重构节点属性 + new_node_attrs = {} + for idx, node_name in enumerate(self._nodes): + canon = self._canonicalize(node_name) + if canon in self._node_attrs: + new_node_attrs[canon] = self._node_attrs[canon] + self._node_attrs = new_node_attrs + + # 重建邻接矩阵 + if self._adjacency is not None: + # 转换为COO格式以进行切片,然后转换回原始格式 + self._adjacency = self._adjacency.tocoo() + mask_rows = np.isin(self._adjacency.row, list(indices_to_keep)) + mask_cols = np.isin(self._adjacency.col, list(indices_to_keep)) + + # 筛选出保留的行和列 + new_rows = self._adjacency.row[mask_rows & mask_cols] + new_cols = self._adjacency.col[mask_rows & mask_cols] + new_data = self._adjacency.data[mask_rows & mask_cols] + + # 更新索引 + new_rows = np.array([old_to_new[r] for r in new_rows]) + new_cols = np.array([old_to_new[c] for c in new_cols]) + + n = len(self._nodes) + if self.matrix_format == "csr": + self._adjacency = csr_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) + else: # csc + self._adjacency = csc_matrix((new_data, (new_rows, new_cols)), shape=(n, n)) + + # 重建关系哈希映射,移除涉及已删除节点的记录并重映射索引。 + if self._edge_hash_map: + new_edge_hash_map: Dict[Tuple[int, int], Set[str]] = defaultdict(set) + for (old_src, old_tgt), hashes in self._edge_hash_map.items(): + if old_src in indices_to_delete or old_tgt in indices_to_delete: + continue + if old_src in old_to_new and old_tgt in old_to_new and hashes: + new_edge_hash_map[(old_to_new[old_src], old_to_new[old_tgt])] = set(hashes) + self._edge_hash_map = new_edge_hash_map + + deleted_count = len(existing_nodes) + self._total_nodes_deleted += deleted_count + self._adjacency_dirty = True + self._saliency_cache = None + + logger.info(f"删除 {deleted_count} 个节点") + return deleted_count + + def remove_nodes(self, nodes: List[str]) -> int: + """兼容性别名:删除节点""" + return self.delete_nodes(nodes) + + def delete_edges( + self, + edges: List[Tuple[str, str]], + ) -> int: + """ + 删除边 + + Args: + edges: 要删除的边列表 [(source, target), ...] + + Returns: + 成功删除的边数量 + """ + if not edges: + return 0 + + deleted = 0 + # 构建要删除的边的索引集合 + edges_to_delete = set() + for src, tgt in edges: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + edges_to_delete.add((src_idx, tgt_idx)) + + if self._adjacency is not None and edges_to_delete: + # 转换为COO格式便于修改 + adj_coo = self._adjacency.tocoo() + + # 过滤要删除的边 + new_row = [] + new_col = [] + new_data = [] + + for i, j, val in zip(adj_coo.row, adj_coo.col, adj_coo.data): + if (i, j) not in edges_to_delete: + new_row.append(i) + new_col.append(j) + new_data.append(val) + else: + deleted += 1 + + # 重建邻接矩阵 + n = len(self._nodes) + self._adjacency = csr_matrix((new_data, (new_row, new_col)), shape=(n, n)) + + # 转换回指定格式 + if self.matrix_format == "csc": + self._adjacency = self._adjacency.tocsc() + + # delete_edges 是“物理删除”语义,必须同步清理 edge_hash_map。 + if edges_to_delete and self._edge_hash_map: + for key in edges_to_delete: + self._edge_hash_map.pop(key, None) + + self._total_edges_deleted += deleted + self._adjacency_dirty = True + self._saliency_cache = None + logger.info(f"删除 {deleted} 条边") + return deleted + + def remove_edges(self, edges: List[Tuple[str, str]]) -> int: + """兼容性别名:删除边""" + return self.delete_edges(edges) + + def get_nodes(self) -> List[str]: + """ + 获取所有节点 + + Returns: + 节点列表 + """ + return self._nodes.copy() + + def has_node(self, node: str) -> bool: + """ + 检查节点是否存在 + + Args: + node: 节点名称 + """ + return self._canonicalize(node) in self._node_to_idx + + def find_node(self, node: str, ignore_case: bool = False) -> Optional[str]: + """ + 查找节点 (由于底层已统一规范化,ignore_case 始终有效) + + Args: + node: 节点名称 + ignore_case: 是否忽略大小写 (已默认忽略) + + Returns: + 真实节点名称 (如果存在),否则 None + """ + canon = self._canonicalize(node) + if canon in self._node_to_idx: + return self._nodes[self._node_to_idx[canon]] + return None + + def get_node_attributes(self, node: str) -> Optional[Dict[str, Any]]: + """ + 获取节点属性 + + Args: + node: 节点名称 + + Returns: + 节点属性字典,不存在则返回None + """ + canon = self._canonicalize(node) + return self._node_attrs.get(canon) + + def get_neighbors(self, node: str) -> List[str]: + """ + 获取节点的出邻居 + + Args: + node: 节点名称 + + Returns: + 出邻居节点列表 + """ + canon = self._canonicalize(node) + if canon not in self._node_to_idx or self._adjacency is None: + return [] + + idx = self._node_to_idx[canon] + neighbor_indices = self._row_neighbor_indices(self._adjacency, idx) + return [self._nodes[int(i)] for i in neighbor_indices] + + def get_in_neighbors(self, node: str) -> List[str]: + """ + 获取节点的入邻居 + + Args: + node: 节点名称 + + Returns: + 入邻居节点列表 + """ + canon = self._canonicalize(node) + if canon not in self._node_to_idx or self._adjacency is None: + return [] + + self._ensure_adjacency_T() + if self._adjacency_T is None: + return [] + + idx = self._node_to_idx[canon] + neighbor_indices = self._row_neighbor_indices(self._adjacency_T, idx) + return [self._nodes[int(i)] for i in neighbor_indices] + + def get_edge_weight(self, source: str, target: str) -> float: + """ + 获取边的权重 + """ + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + return 0.0 + + if self._adjacency is None: + return 0.0 + + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + + return float(self._adjacency[src_idx, tgt_idx]) + + def canonicalize_node(self, node: str) -> str: + """公开节点规范化接口,避免外部访问私有方法。""" + return self._canonicalize(node) + + def has_edge_hash_map(self) -> bool: + """是否存在 relation-hash 映射。""" + return bool(self._edge_hash_map) + + def get_relation_hashes_for_edge(self, source: str, target: str) -> Set[str]: + """获取边 (source -> target) 关联的关系哈希集合。""" + src_canon = self._canonicalize(source) + tgt_canon = self._canonicalize(target) + if src_canon not in self._node_to_idx or tgt_canon not in self._node_to_idx: + return set() + src_idx = self._node_to_idx[src_canon] + tgt_idx = self._node_to_idx[tgt_canon] + return set(self._edge_hash_map.get((src_idx, tgt_idx), set())) + + def get_incident_relation_hashes(self, node: str, limit: Optional[int] = None) -> List[str]: + """获取与指定节点关联的关系哈希列表(入边 + 出边)。""" + canon = self._canonicalize(node) + if canon not in self._node_to_idx or not self._edge_hash_map: + return [] + + idx = self._node_to_idx[canon] + limit_val = max(1, int(limit)) if limit is not None else None + collected: List[Tuple[str, str, str]] = [] + idx_to_node = self._nodes + + for (src_idx, tgt_idx), hashes in self._edge_hash_map.items(): + if idx not in {src_idx, tgt_idx}: + continue + src_name = idx_to_node[src_idx] if 0 <= src_idx < len(idx_to_node) else "" + tgt_name = idx_to_node[tgt_idx] if 0 <= tgt_idx < len(idx_to_node) else "" + for hash_value in hashes: + hash_text = str(hash_value).strip() + if not hash_text: + continue + collected.append((src_name, tgt_name, hash_text)) + + collected.sort(key=lambda x: (x[0].lower(), x[1].lower(), x[2])) + out: List[str] = [] + seen = set() + for _, _, hash_value in collected: + if hash_value in seen: + continue + seen.add(hash_value) + out.append(hash_value) + if limit_val is not None and len(out) >= limit_val: + break + return out + + def edge_contains_relation_hash(self, source: str, target: str, hash_value: str) -> bool: + """判断边是否包含指定关系哈希。""" + if not hash_value: + return False + return str(hash_value) in self.get_relation_hashes_for_edge(source, target) + + def iter_edge_hash_entries(self) -> List[Tuple[str, str, Set[str]]]: + """以节点名形式遍历 edge-hash-map。""" + out: List[Tuple[str, str, Set[str]]] = [] + if not self._edge_hash_map: + return out + idx_to_node = self._nodes + for (s_idx, t_idx), hashes in self._edge_hash_map.items(): + if not hashes: + continue + if s_idx < 0 or t_idx < 0: + continue + if s_idx >= len(idx_to_node) or t_idx >= len(idx_to_node): + continue + out.append((idx_to_node[s_idx], idx_to_node[t_idx], set(hashes))) + return out + + def deactivate_edges(self, edges: List[Tuple[str, str]]) -> int: + """ + 冻结边 (将权重设为0.0,使其在计算意义上消失,但保留在Map中) + + Args: + edges: [(s1, t1), (s2, t2)...] + """ + if not edges or self._adjacency is None: + return 0 + + deactivated_count = 0 + with self.batch_update(): + # 我们需要 explicit set to 0. + # 使用增量更新模式覆盖 + for s, t in edges: + s_canon = self._canonicalize(s) + t_canon = self._canonicalize(t) + if s_canon in self._node_to_idx and t_canon in self._node_to_idx: + idx_s = self._node_to_idx[s_canon] + idx_t = self._node_to_idx[t_canon] + self._adjacency[idx_s, idx_t] = 0.0 + deactivated_count += 1 + + self._adjacency_dirty = True + return deactivated_count + + def _ensure_adjacency_T(self): + """确保转置邻接矩阵是最新的""" + if self._adjacency is None: + self._adjacency_T = None + return + + if self._adjacency_dirty or self._adjacency_T is None: + # 只有在确实需要时才计算转置 + # find_paths 以“按行读取邻居”为主,因此统一缓存为 CSR,避免 + # CSR->CSC 转置后按行切片读出错误的索引视图。 + self._adjacency_T = self._adjacency.transpose().tocsr() + + self._adjacency_dirty = False + # logger.debug("重建转置邻接矩阵缓存") + + @staticmethod + def _row_neighbor_indices( + matrix: Optional[Union[csr_matrix, csc_matrix]], + row_idx: int, + ) -> np.ndarray: + """返回指定行的非零列索引。""" + if matrix is None: + return np.asarray([], dtype=np.int32) + + if isinstance(matrix, csr_matrix): + return matrix.indices[matrix.indptr[row_idx]:matrix.indptr[row_idx + 1]] + + row = matrix[row_idx, :] + _, indices = row.nonzero() + return np.asarray(indices, dtype=np.int32) + + def find_paths( + self, + start_node: str, + end_node: str, + max_depth: int = 3, + max_paths: int = 5, + max_expansions: int = 20000 + ) -> List[List[str]]: + """ + 查找两个节点之间的路径 (BFS) + 支持有向和无向 (视作双向) 探索 + + Args: + start_node: 起始节点 + end_node: 目标节点 + max_depth: 最大深度 + max_paths: 最大路径数 (找到这么多就停止) + max_expansions: 最大扩展次数 (防止爆炸) + + Returns: + 路径列表 [[n1, n2, n3], ...] + """ + start_canon = self._canonicalize(start_node) + end_canon = self._canonicalize(end_node) + + if start_canon not in self._node_to_idx or end_canon not in self._node_to_idx: + return [] + + if self._adjacency is None: + return [] + + # 确保转置矩阵可用 (用于查找入边) + self._ensure_adjacency_T() + + start_idx = self._node_to_idx[start_canon] + end_idx = self._node_to_idx[end_canon] + + # 队列: (current_idx, path_indices) + queue = [(start_idx, [start_idx])] + found_paths = [] + expansions = 0 + + unique_paths = set() + + while queue: + curr, path = queue.pop(0) + + if len(path) > max_depth + 1: + continue + + if curr == end_idx: + # 找到路径 + # 转换回节点名 + path_names = [self._nodes[i] for i in path] + path_tuple = tuple(path_names) + if path_tuple not in unique_paths: + found_paths.append(path_names) + unique_paths.add(path_tuple) + + if len(found_paths) >= max_paths: + break + continue + + if expansions >= max_expansions: + break + + expansions += 1 + + # 获取邻居 (出边 + 入边) + out_indices = self._row_neighbor_indices(self._adjacency, curr) + + # 2. 入边 (使用转置矩阵) + if self._adjacency_T is not None: + in_indices = self._row_neighbor_indices(self._adjacency_T, curr) + neighbors = np.concatenate((out_indices, in_indices)) + else: + neighbors = out_indices + + # 去重并过滤已在路径中的节点 (防止环) + # 注意: 这里简单去重,可能包含重复的邻居(如果既是出又是入) + seen_in_path = set(path) + queued_neighbors = set() + + for neighbor_idx in neighbors: + neighbor = int(neighbor_idx) + if neighbor not in seen_in_path and neighbor not in queued_neighbors: + # 只有未访问过的才加入 + queue.append((neighbor, path + [neighbor])) + queued_neighbors.add(neighbor) + + return found_paths + + def compute_pagerank( + self, + personalization: Optional[Dict[str, float]] = None, + alpha: float = 0.85, + max_iter: int = 100, + tol: float = 1e-6, + ) -> Dict[str, float]: + """ + 计算Personalized PageRank + + Args: + personalization: 个性化向量 {node: weight},默认为均匀分布 + alpha: 阻尼系数(0-1之间) + max_iter: 最大迭代次数 + tol: 收敛阈值 + + Returns: + 节点PageRank值字典 {node: score} + """ + if self._adjacency is None or len(self._nodes) == 0: + logger.warning("图为空,无法计算PageRank") + return {} + + n = len(self._nodes) + + # 构建列归一化的转移矩阵 + adj = self._adjacency.astype(np.float32) + + # 计算出度 + out_degrees = np.array(adj.sum(axis=1)).flatten() + + # 处理悬挂节点(出度为0) + dangling = (out_degrees == 0) + out_degrees_inv = np.zeros_like(out_degrees) + out_degrees_inv[~dangling] = 1.0 / out_degrees[~dangling] + + # 归一化 (使用稀疏对角阵避免内存溢出) + from scipy.sparse import diags + D_inv = diags(out_degrees_inv) + M = adj.T @ D_inv # 转移矩阵 + + # 初始化个性化向量 + if personalization is None: + # 均匀分布 + p = np.ones(n) / n + else: + # 使用指定的个性化向量 + p = np.zeros(n) + total_weight = sum(personalization.values()) + for node, weight in personalization.items(): + if node in self._node_to_idx: + idx = self._node_to_idx[node] + p[idx] = weight / total_weight + + # 确保和为1 + if p.sum() == 0: + p = np.ones(n) / n + else: + p = p / p.sum() + + # 幂迭代法 + p_orig = p.copy() + for i in range(max_iter): + # p_new = alpha * M * p + (1-alpha) * personalization + p_new = alpha * (M @ p) + (1 - alpha) * p_orig + + # 处理因为悬挂节点导致的概率流失 + current_sum = p_new.sum() + if current_sum < 1.0: + p_new += (1.0 - current_sum) * p_orig + + # 检查收敛 + diff = np.linalg.norm(p_new - p, 1) + if diff < tol: + logger.debug(f"PageRank在 {i+1} 次迭代后收敛") + p = p_new + break + p = p_new + else: + logger.warning(f"PageRank未在 {max_iter} 次迭代内收敛") + + # 转换为真实节点名称字典 + return {self._nodes[idx]: float(val) for idx, val in enumerate(p)} + + def get_saliency_scores(self) -> Dict[str, float]: + """ + 获取节点显著性得分 (带有缓存机制) + """ + if self._saliency_cache is not None and not self._adjacency_dirty: + return self._saliency_cache + + logger.debug("正在计算节点显著性得分 (PageRank)...") + scores = self.compute_pagerank() + self._saliency_cache = scores + # 注意:这里我们不把 _adjacency_dirty 设为 False,因为其它逻辑(如_adjacency_T)也依赖它 + return scores + + def connect_synonyms( + self, + similarity_matrix: np.ndarray, + node_list: List[str], + threshold: float = 0.85, + ) -> int: + """ + 连接相似节点(同义词) + + Args: + similarity_matrix: 相似度矩阵 (N x N) + node_list: 对应的节点列表(长度为N) + threshold: 相似度阈值 + + Returns: + 添加的边数量 + """ + if len(node_list) != similarity_matrix.shape[0]: + raise ValueError( + f"节点列表长度与相似度矩阵维度不匹配: " + f"{len(node_list)} vs {similarity_matrix.shape[0]}" + ) + + # 找到相似的节点对(上三角,排除对角线) + similar_pairs = np.argwhere( + (triu(similarity_matrix, k=1) >= threshold) & + (triu(similarity_matrix, k=1) < 1.0) # 排除完全相同的 + ) + + # 添加边 + edges = [] + for i, j in similar_pairs: + if i < len(node_list) and j < len(node_list): + src = node_list[i] + tgt = node_list[j] + # 使用相似度作为权重 + weight = float(similarity_matrix[i, j]) + edges.append((src, tgt, weight)) + + if edges: + edge_pairs = [(src, tgt) for src, tgt, _ in edges] + weights = [w for _, _, w in edges] + count = self.add_edges(edge_pairs, weights) + logger.info(f"连接 {count} 对相似节点(阈值={threshold})") + return count + return 0 + + + # ========================================================================= + # V5 Memory System Methods (Graph Level) + # ========================================================================= + + def decay(self, factor: float, min_active_weight: float = 0.0) -> None: + """ + 全图衰减 (Atomic Decay) + + Args: + factor: 衰减因子 (0.0 < factor < 1.0) + min_active_weight: 最小活跃权重 (低于此值可能被视为无效,但在物理修剪前仍保留) + """ + if self._adjacency is None or factor >= 1.0 or factor <= 0.0: + return + + logger.debug(f"正在执行全图衰减,因子: {factor}") + + # 直接矩阵乘法,SciPy CSR/CSC 非常高效 + self._adjacency *= factor + + # 如果需要处理极小值 (可选,防止下溢,但通常浮点数足够小) + # if min_active_weight > 0: + # ... (复杂操作,暂不需要,由 prune 逻辑处理) + + self._adjacency_dirty = True + + def prune_relation_hashes(self, operations: List[Tuple[str, str, str]]) -> None: + """ + 修剪特定关系哈希 (从 _edge_hash_map 移除; 如果边变空则从矩阵移除) + + Args: + operations: List[(src, tgt, relation_hash)] + """ + if not operations: + return + + edges_to_check_removal = set() + + # 1. 更新映射 (Update Map) + for src, tgt, h in operations: + src_canon = self._canonicalize(src) + tgt_canon = self._canonicalize(tgt) + if src_canon in self._node_to_idx and tgt_canon in self._node_to_idx: + s_idx = self._node_to_idx[src_canon] + t_idx = self._node_to_idx[tgt_canon] + + key = (s_idx, t_idx) + if key in self._edge_hash_map: + if h in self._edge_hash_map[key]: + self._edge_hash_map[key].remove(h) + + if not self._edge_hash_map[key]: + del self._edge_hash_map[key] + edges_to_check_removal.add((src, tgt)) + + # 2. 从矩阵中移除空边 (Remove Empty Edges from Matrix) + if edges_to_check_removal: + self.deactivate_edges(list(edges_to_check_removal)) + self._total_edges_deleted += len(edges_to_check_removal) + + def get_low_weight_edges(self, threshold: float) -> List[Tuple[str, str]]: + """ + 获取低于阈值的边 (candidates for pruning/freezing) + + Args: + threshold: 权重阈值 + + Returns: + List[(src, tgt)]: 边列表 + """ + if self._adjacency is None: + return [] + + # 获取所有非零元素 + rows, cols = self._adjacency.nonzero() + data = self._adjacency.data + + low_weight_indices = np.where(data < threshold)[0] + + results = [] + for idx in low_weight_indices: + r = rows[idx] + c = cols[idx] + src = self._nodes[r] + tgt = self._nodes[c] + results.append((src, tgt)) + + return results + + def get_isolated_nodes(self, include_inactive: bool = True) -> List[str]: + """ + 获取孤儿节点 (Active Degree = 0) + + Args: + include_inactive: 是否包含参与了inactive边(冻结边)的节点。 + 如果 True (默认推荐): 排除掉虽然active degree=0但存在于_edge_hash_map(冻结边)中的节点。 + 如果 False: 只要在 active matrix 里度为0就返回 (可能会误删冻结节点)。 + + Returns: + 孤儿节点名称列表 + """ + if self._adjacency is None: + # 如果全空,则所有节点都是孤儿 + return self._nodes.copy() + + n = len(self._nodes) + + # 计算 Active Degree (In + Out) + # 用 sum(axis) 会得到 dense matrix/array + active_adj = self._adjacency + out_degrees = np.array(active_adj.sum(axis=1)).flatten() + in_degrees = np.array(active_adj.sum(axis=0)).flatten() + + # 处理悬挂节点 (dangling node check not really needed here, just sum) + total_degrees = out_degrees + in_degrees + + # 找到 active degree = 0 的索引 + isolated_indices = np.where(total_degrees == 0)[0] + + if len(isolated_indices) == 0: + return [] + + isolated_nodes_set = {self._nodes[i] for i in isolated_indices} + + # 如果需要排除 Inactive 参与者 + if include_inactive and self._edge_hash_map: + # 收集所有在冻结边中的 unique 节点索引 + frozen_participant_indices = set() + for (u_idx, v_idx), hashes in self._edge_hash_map.items(): + if hashes: # 只要有 hash 存在 (哪怕 inactive) + frozen_participant_indices.add(u_idx) + frozen_participant_indices.add(v_idx) + + # 过滤 + final_isolated = [] + for idx in isolated_indices: + if idx not in frozen_participant_indices: + final_isolated.append(self._nodes[idx]) + return final_isolated + + else: + return list(isolated_nodes_set) + + def clear(self) -> None: + """清空所有数据""" + self._nodes.clear() + self._node_to_idx.clear() + self._node_attrs.clear() + self._adjacency = None + self._edge_hash_map.clear() + self._adjacency_T = None + self._adjacency_dirty = True + self._total_nodes_added = 0 + self._total_edges_added = 0 + self._total_nodes_deleted = 0 + self._total_edges_deleted = 0 + logger.info("图存储已清空") + + def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: + """ + 保存到磁盘 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + # 保存邻接矩阵 + if self._adjacency is not None: + matrix_path = data_dir / "graph_adjacency.npz" + with atomic_write(matrix_path, "wb") as f: + save_npz(f, self._adjacency) + logger.debug(f"保存邻接矩阵: {matrix_path}") + + # 保存元数据 + metadata = { + "nodes": self._nodes, + "node_to_idx": self._node_to_idx, + "node_attrs": self._node_attrs, + "matrix_format": self.matrix_format, + "total_nodes_added": self._total_nodes_added, + "total_edges_added": self._total_edges_added, + "total_nodes_deleted": self._total_nodes_deleted, + "total_edges_deleted": self._total_edges_deleted, + "edge_hash_map": dict(self._edge_hash_map), # 持久化 V5 映射 (将 defaultdict 转换为普通 dict) + } + + metadata_path = data_dir / "graph_metadata.pkl" + with atomic_write(metadata_path, "wb") as f: + pickle.dump(metadata, f) + logger.debug(f"保存元数据: {metadata_path}") + + logger.info(f"图存储已保存到: {data_dir}") + + def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: + """ + 从磁盘加载 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + if not data_dir.exists(): + raise FileNotFoundError(f"数据目录不存在: {data_dir}") + + # 加载元数据 + metadata_path = data_dir / "graph_metadata.pkl" + if not metadata_path.exists(): + raise FileNotFoundError(f"元数据文件不存在: {metadata_path}") + + with open(metadata_path, "rb") as f: + metadata = pickle.load(f) + + # 恢复状态,并通过规范化处理旧数据中的重复项 + self._nodes = metadata["nodes"] + self._node_attrs = {} # 重新构建以确保键名 (Key) 规范化 + self._node_to_idx = {} # 重新构建以确保键名 (Key) 规范化 + + # 重新构建映射,处理旧数据中的碰撞 + for idx, node_name in enumerate(self._nodes): + canon = self._canonicalize(node_name) + if canon not in self._node_to_idx: + self._node_to_idx[canon] = idx + + # 处理属性 (优先保留已有的) + orig_attrs = metadata.get("node_attrs", {}) + if node_name in orig_attrs and canon not in self._node_attrs: + self._node_attrs[canon] = orig_attrs[node_name] + + self.matrix_format = metadata["matrix_format"] + self._total_nodes_added = metadata["total_nodes_added"] + self._total_edges_added = metadata["total_edges_added"] + self._total_nodes_deleted = metadata["total_nodes_deleted"] + self._total_edges_deleted = metadata["total_edges_deleted"] + + # 恢复 V5 边哈希映射 (Restore V5 edge hash map) + edge_map_data = metadata.get("edge_hash_map", {}) + # 重新初始化为 defaultdict(set) + self._edge_hash_map = defaultdict(set) + if edge_map_data: + for k, v in edge_map_data.items(): + self._edge_hash_map[k] = set(v) # 确保类型为 set + + # 加载邻接矩阵 + matrix_path = data_dir / "graph_adjacency.npz" + if matrix_path.exists(): + self._adjacency = load_npz(str(matrix_path)) + + # 确保格式正确 + if self.matrix_format == "csc" and isinstance(self._adjacency, csr_matrix): + self._adjacency = self._adjacency.tocsc() + elif self.matrix_format == "csr" and isinstance(self._adjacency, csc_matrix): + self._adjacency = self._adjacency.tocsr() + + logger.debug(f"加载邻接矩阵: {matrix_path}, shape={self._adjacency.shape}") + + # 检查维度不匹配并修复 + if self._adjacency is not None: + adj_n = self._adjacency.shape[0] + current_n = len(self._nodes) + if current_n > adj_n: + logger.warning(f"检测到图存储维度不匹配: 节点数={current_n}, 矩阵大小={adj_n}. 正在自动修复...") + self._expand_adjacency_matrix(current_n - adj_n) + + self._adjacency_dirty = True + logger.info( + f"图存储已加载: {len(self._nodes)} 个节点, " + f"{self._adjacency.nnz if self._adjacency is not None else 0} 条边" + ) + + def _expand_adjacency_matrix(self, added_nodes: int) -> None: + """ + 扩展邻接矩阵以容纳新节点 + + Args: + added_nodes: 新增节点数量 + """ + if self._adjacency is None: + n = len(self._nodes) + # 根据模式初始化 + + if self._modification_mode == GraphModificationMode.INCREMENTAL: + self._adjacency = lil_matrix((n, n), dtype=np.float32) + else: + self._adjacency = csr_matrix((n, n), dtype=np.float32) + return + + old_n = self._adjacency.shape[0] + new_n = old_n + added_nodes + + # 优化:根据模式选择不同的扩容策略 + if self._modification_mode == GraphModificationMode.INCREMENTAL: + # LIL 格式可以直接 resize,非常高效 + try: + if not isinstance(self._adjacency, lil_matrix): + self._adjacency = self._adjacency.tolil() + + self._adjacency.resize((new_n, new_n)) + # logger.debug(f"扩展 LIL 矩阵: {old_n} -> {new_n}") + except Exception as e: + logger.warning(f"LIL resize 失败,回退到通用方法: {e}") + self._expand_generic(new_n, old_n) + + else: + # CSR/CSC 格式使用 bmat 避免结构破坏警告 + try: + # bmat 需要明确的形状,不能全部依赖 None + added = new_n - old_n + # 创建零矩阵块 + # 注意: 这里统一创建 CSR 零矩阵,bmat 会处理合并 + z_tr = csr_matrix((old_n, added), dtype=np.float32) + z_bl = csr_matrix((added, old_n), dtype=np.float32) + z_br = csr_matrix((added, added), dtype=np.float32) + + self._adjacency = bmat( + [[self._adjacency, z_tr], [z_bl, z_br]], + format=self.matrix_format, + dtype=np.float32 + ) + # logger.debug(f"扩展矩阵 (bmat): {old_n} -> {new_n}") + except Exception as e: + logger.warning(f"bmat 扩展失败: {e}") + self._expand_generic(new_n, old_n) + + def _expand_generic(self, new_n: int, old_n: int): + """通用扩展方法(回退方案)""" + if self.matrix_format == "csr": + new_adjacency = csr_matrix((new_n, new_n), dtype=np.float32) + new_adjacency[:old_n, :old_n] = self._adjacency + else: + new_adjacency = csc_matrix((new_n, new_n), dtype=np.float32) + new_adjacency[:old_n, :old_n] = self._adjacency + self._adjacency = new_adjacency + self._adjacency_dirty = True + + # 如果都在增量模式,确保是LIL + if self._modification_mode == GraphModificationMode.INCREMENTAL: + try: + self._adjacency = self._adjacency.tolil() + except: + pass + + @property + def num_nodes(self) -> int: + """节点数量""" + return len(self._nodes) + + @property + def num_edges(self) -> int: + """边数量""" + if self._adjacency is None: + return 0 + return int(self._adjacency.nnz) + + @property + def density(self) -> float: + """ + 图密度(实际边数 / 可能的最大边数) + + 有向图: E / (V * (V - 1)) + 无向图: 2E / (V * (V - 1)) + + 这里按有向图计算 + """ + if self.num_nodes < 2: + return 0.0 + + max_edges = self.num_nodes * (self.num_nodes - 1) + return self.num_edges / max_edges if max_edges > 0 else 0.0 + + def __len__(self) -> int: + """节点数量""" + return self.num_nodes + + def has_data(self) -> bool: + """检查磁盘上是否存在现有数据""" + if self.data_dir is None: + return False + return (self.data_dir / "graph_metadata.pkl").exists() + + def __repr__(self) -> str: + return ( + f"GraphStore(nodes={self.num_nodes}, edges={self.num_edges}, " + f"density={self.density:.4f}, format={self.matrix_format})" + ) + + def rebuild_edge_hash_map(self, triples: List[Tuple[str, str, str, str]]) -> int: + """ + 从元数据重建 V5 边哈希映射 (Migration Tool) + + Args: + triples: List of (s, p, o, hash) + + Returns: + count of mapped hashes + """ + count = 0 + self._edge_hash_map = defaultdict(set) + + for s, p, o, h in triples: + if not h: continue + + s_canon = self._canonicalize(s) + o_canon = self._canonicalize(o) + + if s_canon in self._node_to_idx and o_canon in self._node_to_idx: + u = self._node_to_idx[s_canon] + v = self._node_to_idx[o_canon] + + # 如果是双向的,通常在元数据中存储为有向,而 GraphStore 也通常是有向的。 + # 映射键对应特定的边方向。 + self._edge_hash_map[(u, v)].add(h) + count += 1 + + self._adjacency_dirty = True + logger.info(f"已从 {count} 条哈希重建边哈希映射,覆盖 {len(self._edge_hash_map)} 条边") + return count + diff --git a/plugins/A_memorix/core/storage/knowledge_types.py b/plugins/A_memorix/core/storage/knowledge_types.py new file mode 100644 index 00000000..4ab91218 --- /dev/null +++ b/plugins/A_memorix/core/storage/knowledge_types.py @@ -0,0 +1,183 @@ +"""Knowledge type and import strategy helpers.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, Optional + + +class KnowledgeType(str, Enum): + """持久化到 paragraphs.knowledge_type 的合法类型。""" + + STRUCTURED = "structured" + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + MIXED = "mixed" + + +class ImportStrategy(str, Enum): + """文本导入阶段的策略选择。""" + + AUTO = "auto" + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + + +def allowed_knowledge_type_values() -> tuple[str, ...]: + return tuple(item.value for item in KnowledgeType) + + +def allowed_import_strategy_values() -> tuple[str, ...]: + return tuple(item.value for item in ImportStrategy) + + +def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]: + """从字符串解析合法的落库知识类型。""" + + if not isinstance(type_str, str): + return None + normalized = type_str.lower().strip() + try: + return KnowledgeType(normalized) + except ValueError: + return None + + +def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]: + """从字符串解析文本导入策略。""" + + if not isinstance(value, str): + return None + normalized = value.lower().strip() + try: + return ImportStrategy(normalized) + except ValueError: + return None + + +def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy: + """解析 import strategy;非法值直接报错。""" + + if value is None: + return default + if isinstance(value, ImportStrategy): + return value + + normalized = str(value or "").strip().lower() + if not normalized: + return default + + strategy = get_import_strategy_from_string(normalized) + if strategy is None: + allowed = "/".join(allowed_import_strategy_values()) + raise ValueError(f"strategy_override 必须为 {allowed}") + return strategy + + +def validate_stored_knowledge_type(value: Any) -> KnowledgeType: + """校验写库 knowledge_type,仅允许合法落库类型。""" + + if isinstance(value, KnowledgeType): + return value + + resolved = get_knowledge_type_from_string(value) + if resolved is None: + allowed = "/".join(allowed_knowledge_type_values()) + raise ValueError(f"knowledge_type 必须为 {allowed}") + return resolved + + +def resolve_stored_knowledge_type( + value: Any, + *, + content: str = "", + allow_legacy: bool = False, + unknown_fallback: Optional[KnowledgeType] = None, +) -> KnowledgeType: + """ + 将策略/字符串/旧值解析为合法落库类型。 + + `allow_legacy=True` 仅供迁移使用。 + """ + + if isinstance(value, KnowledgeType): + return value + + if isinstance(value, ImportStrategy): + if value == ImportStrategy.AUTO: + if not str(content or "").strip(): + raise ValueError("knowledge_type=auto 需要 content 才能推断") + from .type_detection import detect_knowledge_type + + return detect_knowledge_type(content) + return KnowledgeType(value.value) + + raw = str(value or "").strip() + if not raw: + if str(content or "").strip(): + from .type_detection import detect_knowledge_type + + return detect_knowledge_type(content) + raise ValueError("knowledge_type 不能为空") + + direct = get_knowledge_type_from_string(raw) + if direct is not None: + return direct + + strategy = get_import_strategy_from_string(raw) + if strategy is not None: + return resolve_stored_knowledge_type(strategy, content=content) + + if allow_legacy: + normalized = raw.lower() + if normalized == "imported": + return KnowledgeType.FACTUAL + if str(content or "").strip(): + from .type_detection import detect_knowledge_type + + detected = detect_knowledge_type(content) + if detected is not None: + return detected + if unknown_fallback is not None: + return unknown_fallback + + allowed = "/".join(allowed_knowledge_type_values()) + raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed})") + + +def should_extract_relations(knowledge_type: KnowledgeType) -> bool: + """判断是否应该做关系抽取。""" + + return knowledge_type in [ + KnowledgeType.STRUCTURED, + KnowledgeType.FACTUAL, + KnowledgeType.MIXED, + ] + + +def get_default_chunk_size(knowledge_type: KnowledgeType) -> int: + """获取默认分块大小。""" + + chunk_sizes = { + KnowledgeType.STRUCTURED: 300, + KnowledgeType.NARRATIVE: 800, + KnowledgeType.FACTUAL: 500, + KnowledgeType.QUOTE: 400, + KnowledgeType.MIXED: 500, + } + return chunk_sizes.get(knowledge_type, 500) + + +def get_type_display_name(knowledge_type: KnowledgeType) -> str: + """获取知识类型中文名称。""" + + display_names = { + KnowledgeType.STRUCTURED: "结构化知识", + KnowledgeType.NARRATIVE: "叙事性文本", + KnowledgeType.FACTUAL: "事实陈述", + KnowledgeType.QUOTE: "引用文本", + KnowledgeType.MIXED: "混合类型", + } + return display_names.get(knowledge_type, "未知类型") diff --git a/plugins/A_memorix/core/storage/metadata_store.py b/plugins/A_memorix/core/storage/metadata_store.py new file mode 100644 index 00000000..e94610f0 --- /dev/null +++ b/plugins/A_memorix/core/storage/metadata_store.py @@ -0,0 +1,5225 @@ +""" +元数据存储模块 + +基于SQLite的元数据管理,存储段落、实体、关系等信息。 +""" + +import sqlite3 +import pickle +import json +from datetime import datetime +from pathlib import Path +from typing import Optional, Union, List, Dict, Any, Tuple + +from src.common.logger import get_logger +from ..utils.hash import compute_hash, normalize_text +from ..utils.time_parser import normalize_time_meta +from .knowledge_types import ( + KnowledgeType, + allowed_knowledge_type_values, + resolve_stored_knowledge_type, + validate_stored_knowledge_type, +) + +logger = get_logger("A_Memorix.MetadataStore") + + +SCHEMA_VERSION = 7 + + +class MetadataStore: + """ + 元数据存储类 + + 功能: + - SQLite数据库管理 + - 段落/实体/关系元数据存储 + - 增删改查操作 + - 事务支持 + - 索引优化 + + 参数: + data_dir: 数据目录 + db_name: 数据库文件名(默认metadata.db) + """ + + def __init__( + self, + data_dir: Optional[Union[str, Path]] = None, + db_name: str = "metadata.db", + ): + """ + 初始化元数据存储 + + Args: + data_dir: 数据目录 + db_name: 数据库文件名 + """ + self.data_dir = Path(data_dir) if data_dir else None + self.db_name = db_name + self._conn: Optional[sqlite3.Connection] = None + self._is_initialized = False + self._db_path: Optional[Path] = None + + logger.info(f"MetadataStore 初始化: db={db_name}") + + def connect( + self, + data_dir: Optional[Union[str, Path]] = None, + *, + enforce_schema: bool = True, + ) -> None: + """ + 连接到数据库 + + Args: + data_dir: 数据目录(默认使用初始化时的目录) + """ + if data_dir is None: + data_dir = self.data_dir + + if data_dir is None: + raise ValueError("未指定数据目录") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + db_path = data_dir / self.db_name + db_existed = db_path.exists() + self._db_path = db_path + + # 连接数据库 + self._conn = sqlite3.connect( + str(db_path), + check_same_thread=False, + timeout=30.0, + ) + self._conn.row_factory = sqlite3.Row # 使用字典式访问 + + # 优化性能 + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA cache_size=-64000") # 64MB缓存 + self._conn.execute("PRAGMA temp_store=MEMORY") + self._conn.execute("PRAGMA foreign_keys = ON") # 开启外键约束支持级联删除 + + logger.info(f"连接到数据库: {db_path}") + + # 初始化或校验 schema + if not self._is_initialized: + if not db_existed: + self._initialize_tables() + if enforce_schema: + self._assert_schema_compatible(db_existed=db_existed) + self._is_initialized = True + + # 初始化 FTS schema(幂等) + try: + self.ensure_fts_schema() + except Exception as e: + logger.warning(f"初始化 FTS schema 失败,将跳过 BM25 检索: {e}") + + def _assert_schema_compatible(self, db_existed: bool) -> None: + """vNext 运行时只做 schema 版本校验,不做隐式迁移。""" + cursor = self._conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" + ) + has_version_table = cursor.fetchone() is not None + if not has_version_table: + if db_existed: + raise RuntimeError( + "检测到旧版 metadata schema(缺少 schema_migrations)。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + return + + cursor.execute("SELECT MAX(version) FROM schema_migrations") + row = cursor.fetchone() + version = int(row[0]) if row and row[0] is not None else 0 + if version != SCHEMA_VERSION: + raise RuntimeError( + f"metadata schema 版本不匹配: current={version}, expected={SCHEMA_VERSION}。" + " 请执行 scripts/release_vnext_migrate.py migrate。" + ) + + def close(self) -> None: + """关闭数据库连接""" + if self._conn: + self._conn.close() + self._conn = None + logger.info("数据库连接已关闭") + + def _initialize_tables(self) -> None: + """初始化数据库表结构""" + cursor = self._conn.cursor() + + # 段落表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraphs ( + hash TEXT PRIMARY KEY, + content TEXT NOT NULL, + vector_index INTEGER, + created_at REAL, + updated_at REAL, + metadata TEXT, + source TEXT, + word_count INTEGER, + event_time REAL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + knowledge_type TEXT DEFAULT 'mixed', + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_deleted INTEGER DEFAULT 0, + deleted_at REAL + ) + """) + + # 实体表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS entities ( + hash TEXT PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + vector_index INTEGER, + appearance_count INTEGER DEFAULT 1, + created_at REAL, + metadata TEXT, + is_deleted INTEGER DEFAULT 0, + deleted_at REAL + ) + """) + + # 关系表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + UNIQUE(subject, predicate, object) + ) + """) + + # 回收站关系表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS deleted_relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + deleted_at REAL + ) + """) + + # 32位哈希别名映射(用于 vNext 唯一解析) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + + # Schema 版本 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at REAL NOT NULL + ) + """) + + # 三元组与段落的关联表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_relations ( + paragraph_hash TEXT NOT NULL, + relation_hash TEXT NOT NULL, + PRIMARY KEY (paragraph_hash, relation_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (relation_hash) REFERENCES relations(hash) ON DELETE CASCADE + ) + """) + + # 实体与段落的关联表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_entities ( + paragraph_hash TEXT NOT NULL, + entity_hash TEXT NOT NULL, + mention_count INTEGER DEFAULT 1, + PRIMARY KEY (paragraph_hash, entity_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE, + FOREIGN KEY (entity_hash) REFERENCES entities(hash) ON DELETE CASCADE + ) + """) + + # 创建索引 + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_vector + ON paragraphs(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_vector + ON entities(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_vector + ON relations(vector_index) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_subject + ON relations(subject) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_object + ON relations(object) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_name + ON entities(name) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_source + ON paragraphs(source) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraphs_deleted + ON paragraphs(is_deleted, deleted_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_entities_deleted + ON entities(is_deleted, deleted_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_inactive + ON relations(is_inactive, inactive_since) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_relations_protected + ON relations(is_pinned, protected_until) + """) + + # 人物画像开关表(按 stream_id + user_id 维度) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_switches ( + stream_id TEXT NOT NULL, + user_id TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 0, + updated_at REAL NOT NULL, + PRIMARY KEY (stream_id, user_id) + ) + """) + + # 人物画像快照表(版本化) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_snapshots ( + snapshot_id INTEGER PRIMARY KEY AUTOINCREMENT, + person_id TEXT NOT NULL, + profile_version INTEGER NOT NULL, + profile_text TEXT NOT NULL, + aliases_json TEXT, + relation_edges_json TEXT, + vector_evidence_json TEXT, + evidence_ids_json TEXT, + updated_at REAL NOT NULL, + expires_at REAL, + source_note TEXT, + UNIQUE(person_id, profile_version) + ) + """) + + # 已开启范围内的活跃人物集合 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_active_persons ( + stream_id TEXT NOT NULL, + user_id TEXT NOT NULL, + person_id TEXT NOT NULL, + last_seen_at REAL NOT NULL, + PRIMARY KEY (stream_id, user_id, person_id) + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS person_profile_overrides ( + person_id TEXT PRIMARY KEY, + override_text TEXT NOT NULL, + updated_at REAL NOT NULL, + updated_by TEXT, + source TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_switches_enabled + ON person_profile_switches(enabled) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_snapshots_person + ON person_profile_snapshots(person_id, updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_active_seen + ON person_profile_active_persons(last_seen_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_person_profile_overrides_updated + ON person_profile_overrides(updated_at DESC) + """) + + # Episode 情景记忆表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + episode_id TEXT PRIMARY KEY, + source TEXT, + title TEXT NOT NULL, + summary TEXT NOT NULL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + participants_json TEXT, + keywords_json TEXT, + evidence_ids_json TEXT, + paragraph_count INTEGER DEFAULT 0, + llm_confidence REAL DEFAULT 0.0, + segmentation_model TEXT, + segmentation_version TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + + # Episode -> Paragraph 映射 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_paragraphs ( + episode_id TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + position INTEGER DEFAULT 0, + PRIMARY KEY (episode_id, paragraph_hash), + FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + + # Episode 生成队列(异步) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( + paragraph_hash TEXT PRIMARY KEY, + source TEXT, + created_at REAL, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( + source TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + reason TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end + ON episodes(source, event_time_end DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_updated_at + ON episodes(updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph + ON episode_paragraphs(paragraph_hash) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated + ON episode_pending_paragraphs(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created + ON episode_pending_paragraphs(source, created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated + ON episode_rebuild_sources(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at + ON episode_rebuild_sources(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS external_memory_refs ( + external_id TEXT PRIMARY KEY, + paragraph_hash TEXT NOT NULL, + source_type TEXT, + created_at REAL NOT NULL, + metadata_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph + ON external_memory_refs(paragraph_hash) + """) + # 新版 schema 包含完整字段,直接写入版本信息 + cursor.execute("INSERT OR IGNORE INTO schema_migrations(version, applied_at) VALUES (?, ?)", (SCHEMA_VERSION, datetime.now().timestamp())) + self._conn.commit() + logger.debug("数据库表结构初始化完成") + + def _migrate_schema(self) -> None: + """执行数据库schema迁移""" + cursor = self._conn.cursor() + + # vNext 关键表兜底:历史库可能缺失,需在迁移阶段主动补齐。 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at REAL NOT NULL + ) + """) + + # Episode MVP 表结构补齐 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episodes ( + episode_id TEXT PRIMARY KEY, + source TEXT, + title TEXT NOT NULL, + summary TEXT NOT NULL, + event_time_start REAL, + event_time_end REAL, + time_granularity TEXT, + time_confidence REAL DEFAULT 1.0, + participants_json TEXT, + keywords_json TEXT, + evidence_ids_json TEXT, + paragraph_count INTEGER DEFAULT 0, + llm_confidence REAL DEFAULT 0.0, + segmentation_model TEXT, + segmentation_version TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_paragraphs ( + episode_id TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + position INTEGER DEFAULT 0, + PRIMARY KEY (episode_id, paragraph_hash), + FOREIGN KEY (episode_id) REFERENCES episodes(episode_id) ON DELETE CASCADE, + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_pending_paragraphs ( + paragraph_hash TEXT PRIMARY KEY, + source TEXT, + created_at REAL, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS episode_rebuild_sources ( + source TEXT PRIMARY KEY, + status TEXT DEFAULT 'pending', + retry_count INTEGER DEFAULT 0, + last_error TEXT, + reason TEXT, + requested_at REAL NOT NULL, + updated_at REAL NOT NULL + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_source_time_end + ON episodes(source, event_time_end DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episodes_updated_at + ON episodes(updated_at DESC) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_paragraphs_paragraph + ON episode_paragraphs(paragraph_hash) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_status_updated + ON episode_pending_paragraphs(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_pending_source_created + ON episode_pending_paragraphs(source, created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_status_updated + ON episode_rebuild_sources(status, updated_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_episode_rebuild_updated_at + ON episode_rebuild_sources(updated_at DESC) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS external_memory_refs ( + external_id TEXT PRIMARY KEY, + paragraph_hash TEXT NOT NULL, + source_type TEXT, + created_at REAL NOT NULL, + metadata_json TEXT + ) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_external_memory_refs_paragraph + ON external_memory_refs(paragraph_hash) + """) + + # 检查paragraphs表是否有knowledge_type列 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + + if "knowledge_type" not in columns: + logger.info("检测到旧版schema,正在迁移添加knowledge_type字段...") + try: + cursor.execute(""" + ALTER TABLE paragraphs + ADD COLUMN knowledge_type TEXT DEFAULT 'mixed' + """) + self._conn.commit() + logger.info("Schema迁移完成:已添加knowledge_type字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败(可能已存在): {e}") + + # 问题2: 时序字段迁移 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + temporal_columns = { + "event_time": "ALTER TABLE paragraphs ADD COLUMN event_time REAL", + "event_time_start": "ALTER TABLE paragraphs ADD COLUMN event_time_start REAL", + "event_time_end": "ALTER TABLE paragraphs ADD COLUMN event_time_end REAL", + "time_granularity": "ALTER TABLE paragraphs ADD COLUMN time_granularity TEXT", + "time_confidence": "ALTER TABLE paragraphs ADD COLUMN time_confidence REAL DEFAULT 1.0", + } + for col, sql in temporal_columns.items(): + if col not in columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败({col}): {e}") + + # 时序索引(仅在列存在时创建,兼容旧库迁移) + self._create_temporal_indexes_if_ready() + self._conn.commit() + + # 检查paragraphs表是否有is_permanent列 + cursor.execute("PRAGMA table_info(paragraphs)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_permanent" not in columns: + logger.info("正在迁移: 添加记忆动态字段...") + try: + # 段落表 + cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_permanent BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN last_accessed REAL") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN access_count INTEGER DEFAULT 0") + + # 关系表 + cursor.execute("ALTER TABLE relations ADD COLUMN is_permanent BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN last_accessed REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN access_count INTEGER DEFAULT 0") + + self._conn.commit() + logger.info("Schema迁移完成:已添加记忆动态字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败: {e}") + + # 检查relations表是否有is_inactive列 (V5 Memory System) + cursor.execute("PRAGMA table_info(relations)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_inactive" not in columns: + logger.info("正在迁移: 添加V5记忆动态字段 (inactive, protected)...") + try: + # 关系表 V5 新增字段 + cursor.execute("ALTER TABLE relations ADD COLUMN is_inactive BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN inactive_since REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN is_pinned BOOLEAN DEFAULT 0") + cursor.execute("ALTER TABLE relations ADD COLUMN protected_until REAL") + cursor.execute("ALTER TABLE relations ADD COLUMN last_reinforced REAL") + + # 为回收站创建 deleted_relations 表 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS deleted_relations ( + hash TEXT PRIMARY KEY, + subject TEXT NOT NULL, + predicate TEXT NOT NULL, + object TEXT NOT NULL, + vector_index INTEGER, + confidence REAL DEFAULT 1.0, + vector_state TEXT DEFAULT 'none', + vector_updated_at REAL, + vector_error TEXT, + vector_retry_count INTEGER DEFAULT 0, + created_at REAL, + source_paragraph TEXT, + metadata TEXT, + is_permanent BOOLEAN DEFAULT 0, + last_accessed REAL, + access_count INTEGER DEFAULT 0, + is_inactive BOOLEAN DEFAULT 0, + inactive_since REAL, + is_pinned BOOLEAN DEFAULT 0, + protected_until REAL, + last_reinforced REAL, + deleted_at REAL -- 用于记录删除时间的额外列 + ) + """) + + self._conn.commit() + logger.info("Schema迁移完成:已添加V5记忆动态字段及回收站表") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (V5): {e}") + + # 关系向量状态字段迁移 + cursor.execute("PRAGMA table_info(relations)") + relation_columns = {row[1] for row in cursor.fetchall()} + relation_vector_columns = { + "vector_state": "ALTER TABLE relations ADD COLUMN vector_state TEXT DEFAULT 'none'", + "vector_updated_at": "ALTER TABLE relations ADD COLUMN vector_updated_at REAL", + "vector_error": "ALTER TABLE relations ADD COLUMN vector_error TEXT", + "vector_retry_count": "ALTER TABLE relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", + } + for col, sql in relation_vector_columns.items(): + if col not in relation_columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (relations.{col}): {e}") + + # 回收站同步字段迁移(用于 restore 保留向量状态) + cursor.execute("PRAGMA table_info(deleted_relations)") + deleted_relation_columns = {row[1] for row in cursor.fetchall()} + deleted_relation_vector_columns = { + "vector_state": "ALTER TABLE deleted_relations ADD COLUMN vector_state TEXT DEFAULT 'none'", + "vector_updated_at": "ALTER TABLE deleted_relations ADD COLUMN vector_updated_at REAL", + "vector_error": "ALTER TABLE deleted_relations ADD COLUMN vector_error TEXT", + "vector_retry_count": "ALTER TABLE deleted_relations ADD COLUMN vector_retry_count INTEGER DEFAULT 0", + } + for col, sql in deleted_relation_vector_columns.items(): + if col not in deleted_relation_columns: + try: + cursor.execute(sql) + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (deleted_relations.{col}): {e}") + + # 检查 entities 表是否有 is_deleted 列 (Soft Delete System) + cursor.execute("PRAGMA table_info(entities)") + columns = [row[1] for row in cursor.fetchall()] + + if "is_deleted" not in columns: + logger.info("正在迁移: 添加软删除字段 (Soft Delete)...") + try: + # 实体表 + cursor.execute("ALTER TABLE entities ADD COLUMN is_deleted INTEGER DEFAULT 0") + cursor.execute("ALTER TABLE entities ADD COLUMN deleted_at REAL") + + # 段落表 + cursor.execute("ALTER TABLE paragraphs ADD COLUMN is_deleted INTEGER DEFAULT 0") + cursor.execute("ALTER TABLE paragraphs ADD COLUMN deleted_at REAL") + + self._conn.commit() + logger.info("Schema迁移完成:已添加软删除字段") + except sqlite3.OperationalError as e: + logger.warning(f"Schema迁移失败 (Soft Delete): {e}") + + # 数据修复: 检查是否存在 source/vector_index 列错位的情况 + # 症状: vector_index (本应是int) 变成了文件名字符串, source (本应是文件名) 变成了类型字符串 + try: + cursor.execute(""" + SELECT count(*) FROM paragraphs + WHERE typeof(vector_index) = 'text' + AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') + """) + count = cursor.fetchone()[0] + if count > 0: + logger.warning(f"检测到 {count} 条数据存在列错位(文件名误存入vector_index),正在自动修复...") + cursor.execute(""" + UPDATE paragraphs + SET + knowledge_type = source, + source = vector_index, + vector_index = NULL + WHERE typeof(vector_index) = 'text' + AND source IN ('mixed', 'factual', 'narrative', 'structured', 'auto') + """) + self._conn.commit() + logger.info(f"自动修复完成: 已校正 {cursor.rowcount} 条数据") + except Exception as e: + logger.error(f"数据自动修复失败: {e}") + + def _create_temporal_indexes_if_ready(self) -> None: + """ + 仅当时序列已存在时创建索引。 + + 旧库升级时,_initialize_tables 不能提前对不存在的列建索引; + 因此统一在迁移阶段按列存在性安全创建。 + """ + cursor = self._conn.cursor() + cursor.execute("PRAGMA table_info(paragraphs)") + columns = {row[1] for row in cursor.fetchall()} + + if "event_time" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_time ON paragraphs(event_time)" + ) + if "event_time_start" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_start ON paragraphs(event_time_start)" + ) + if "event_time_end" in columns: + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_paragraphs_event_end ON paragraphs(event_time_end)" + ) + + def run_legacy_migration_for_vnext(self) -> Dict[str, Any]: + """ + 离线迁移入口: + - 复用旧迁移逻辑补齐历史库字段 + - 重建 relation 32位别名 + - 归一化历史 knowledge_type + - 写入 vNext schema 版本 + """ + self._migrate_schema() + alias_result = self.rebuild_relation_hash_aliases() + knowledge_type_result = self.normalize_paragraph_knowledge_types() + self.set_schema_version(SCHEMA_VERSION) + return { + "schema_version": SCHEMA_VERSION, + "alias_result": alias_result, + "knowledge_type_result": knowledge_type_result, + } + + def list_invalid_paragraph_knowledge_types(self) -> List[str]: + """列出当前库中不合法的段落 knowledge_type。""" + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT DISTINCT knowledge_type + FROM paragraphs + WHERE knowledge_type IS NULL + OR TRIM(COALESCE(knowledge_type, '')) = '' + OR LOWER(TRIM(knowledge_type)) NOT IN ({placeholders}) + ORDER BY knowledge_type + """.format(placeholders=", ".join("?" for _ in allowed_knowledge_type_values())), + tuple(allowed_knowledge_type_values()), + ) + invalid: List[str] = [] + for row in cursor.fetchall(): + raw = row[0] + invalid.append(str(raw) if raw is not None else "") + return invalid + + def normalize_paragraph_knowledge_types(self) -> Dict[str, Any]: + """将历史非法 knowledge_type 归一化为合法值。""" + + cursor = self._conn.cursor() + cursor.execute("SELECT hash, content, knowledge_type FROM paragraphs") + rows = cursor.fetchall() + + normalized_count = 0 + normalized_map: Dict[str, int] = {} + invalid_before: List[str] = [] + invalid_seen = set() + + for row in rows: + paragraph_hash = str(row["hash"]) + content = str(row["content"] or "") + raw_value = row["knowledge_type"] + try: + validate_stored_knowledge_type(raw_value) + continue + except ValueError: + raw_text = str(raw_value) if raw_value is not None else "" + if raw_text not in invalid_seen: + invalid_seen.add(raw_text) + invalid_before.append(raw_text) + + normalized_type = resolve_stored_knowledge_type( + raw_value, + content=content, + allow_legacy=True, + unknown_fallback=KnowledgeType.MIXED, + ) + cursor.execute( + "UPDATE paragraphs SET knowledge_type = ? WHERE hash = ?", + (normalized_type.value, paragraph_hash), + ) + normalized_count += 1 + normalized_map[normalized_type.value] = normalized_map.get(normalized_type.value, 0) + 1 + + self._conn.commit() + return { + "normalized": normalized_count, + "invalid_before": sorted(invalid_before), + "normalized_to": normalized_map, + } + + def _resolve_conn(self, conn: Optional[sqlite3.Connection] = None) -> sqlite3.Connection: + """解析可用连接。""" + resolved = conn or self._conn + if resolved is None: + raise RuntimeError("MetadataStore 未连接数据库") + return resolved + + def get_db_path(self) -> Path: + """获取 SQLite 数据库文件路径。""" + if self._db_path is not None: + return self._db_path + if self.data_dir is None: + raise RuntimeError("MetadataStore 未配置 data_dir") + return Path(self.data_dir) / self.db_name + + def ensure_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保 FTS5 schema 存在(幂等)。 + + 采用 external-content 方式,不在 FTS 表重复存储正文。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS paragraphs_fts + USING fts5( + content, + content='paragraphs', + content_rowid='rowid', + tokenize='unicode61' + ) + """) + + # insert trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_ai + AFTER INSERT ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(rowid, content) + VALUES (new.rowid, new.content); + END + """) + + # delete trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_ad + AFTER DELETE ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) + VALUES ('delete', old.rowid, old.content); + END + """) + + # update trigger + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS paragraphs_au + AFTER UPDATE OF content ON paragraphs + BEGIN + INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) + VALUES ('delete', old.rowid, old.content); + INSERT INTO paragraphs_fts(rowid, content) + VALUES (new.rowid, new.content); + END + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS5 schema 创建失败(可能不支持 FTS5): {e}") + c.rollback() + return False + + def ensure_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保 FTS 索引已回填。 + + 当历史数据存在但 FTS 表为空/不一致时执行 rebuild。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) AS n FROM paragraphs") + para_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(1) AS n FROM paragraphs_fts") + fts_count = int(cur.fetchone()[0]) + + if para_count > 0 and fts_count != para_count: + cur.execute("INSERT INTO paragraphs_fts(paragraphs_fts) VALUES ('rebuild')") + c.commit() + logger.info(f"FTS 回填完成: paragraphs={para_count}, fts={para_count}") + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS 回填失败: {e}") + c.rollback() + return False + + def ensure_relations_fts_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """ + 确保关系 FTS5 schema 存在(幂等)。 + + 注意:relations 表没有 content 列,因此使用独立 FTS 表并通过触发器同步。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS relations_fts + USING fts5( + relation_hash UNINDEXED, + content, + tokenize='unicode61' + ) + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_ai + AFTER INSERT ON relations + BEGIN + INSERT INTO relations_fts(relation_hash, content) + VALUES ( + new.hash, + COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') + ); + END + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_ad + AFTER DELETE ON relations + BEGIN + DELETE FROM relations_fts WHERE relation_hash = old.hash; + END + """) + + cur.execute(""" + CREATE TRIGGER IF NOT EXISTS relations_au + AFTER UPDATE OF subject, predicate, object ON relations + BEGIN + DELETE FROM relations_fts WHERE relation_hash = new.hash; + INSERT INTO relations_fts(relation_hash, content) + VALUES ( + new.hash, + COALESCE(new.subject, '') || ' ' || COALESCE(new.predicate, '') || ' ' || COALESCE(new.object, '') + ); + END + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS5 schema 创建失败(可能不支持 FTS5): {e}") + c.rollback() + return False + + def ensure_relations_fts_backfilled(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """确保关系 FTS 索引已回填。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) AS n FROM relations") + rel_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(1) AS n FROM relations_fts") + fts_count = int(cur.fetchone()[0]) + + if rel_count != fts_count: + cur.execute("DELETE FROM relations_fts") + cur.execute(""" + INSERT INTO relations_fts(relation_hash, content) + SELECT + r.hash, + COALESCE(r.subject, '') || ' ' || COALESCE(r.predicate, '') || ' ' || COALESCE(r.object, '') + FROM relations r + """) + c.commit() + logger.info(f"relations FTS 回填完成: relations={rel_count}, fts={rel_count}") + return True + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS 回填失败: {e}") + c.rollback() + return False + + def ensure_paragraph_ngram_schema(self, conn: Optional[sqlite3.Connection] = None) -> bool: + """确保段落 ngram 倒排表存在。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_ngrams ( + term TEXT NOT NULL, + paragraph_hash TEXT NOT NULL, + PRIMARY KEY (term, paragraph_hash), + FOREIGN KEY (paragraph_hash) REFERENCES paragraphs(hash) ON DELETE CASCADE + ) + """) + cur.execute(""" + CREATE INDEX IF NOT EXISTS idx_paragraph_ngrams_hash + ON paragraph_ngrams(paragraph_hash) + """) + cur.execute(""" + CREATE TABLE IF NOT EXISTS paragraph_ngram_meta ( + key TEXT PRIMARY KEY, + value TEXT + ) + """) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"paragraph ngram schema 创建失败: {e}") + c.rollback() + return False + + @staticmethod + def _char_ngrams(text: str, n: int) -> List[str]: + compact = "".join(str(text or "").lower().split()) + if not compact: + return [] + if len(compact) < n: + return [compact] + return [compact[i : i + n] for i in range(0, len(compact) - n + 1)] + + def ensure_paragraph_ngram_backfilled( + self, + n: int = 2, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 确保段落 ngram 倒排索引已回填。 + + 仅在 n 变化或文档数量变化时重建,避免每次加载都全量重建。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + n = max(1, int(n)) + try: + cur.execute("SELECT value FROM paragraph_ngram_meta WHERE key='ngram_n'") + row = cur.fetchone() + current_n = int(row[0]) if row and row[0] is not None else None + + cur.execute("SELECT COUNT(1) FROM paragraphs WHERE is_deleted IS NULL OR is_deleted = 0") + para_count = int(cur.fetchone()[0]) + cur.execute("SELECT COUNT(DISTINCT paragraph_hash) FROM paragraph_ngrams") + indexed_docs = int(cur.fetchone()[0]) + + need_rebuild = (current_n != n) or (para_count != indexed_docs) + if not need_rebuild: + return True + + cur.execute("DELETE FROM paragraph_ngrams") + cur.execute(""" + SELECT hash, content + FROM paragraphs + WHERE is_deleted IS NULL OR is_deleted = 0 + """) + rows = cur.fetchall() + + batch: List[Tuple[str, str]] = [] + batch_size = 2000 + for row in rows: + p_hash = str(row["hash"]) + terms = list(dict.fromkeys(self._char_ngrams(str(row["content"] or ""), n))) + for term in terms: + batch.append((term, p_hash)) + if len(batch) >= batch_size: + cur.executemany( + "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", + batch, + ) + batch.clear() + if batch: + cur.executemany( + "INSERT OR IGNORE INTO paragraph_ngrams(term, paragraph_hash) VALUES (?, ?)", + batch, + ) + + cur.execute(""" + INSERT INTO paragraph_ngram_meta(key, value) VALUES('ngram_n', ?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value + """, (str(n),)) + cur.execute(""" + INSERT INTO paragraph_ngram_meta(key, value) VALUES('paragraph_count', ?) + ON CONFLICT(key) DO UPDATE SET value=excluded.value + """, (str(para_count),)) + c.commit() + logger.info(f"paragraph ngram 回填完成: n={n}, paragraphs={para_count}") + return True + except Exception as e: + logger.warning(f"paragraph ngram 回填失败: {e}") + c.rollback() + return False + + def fts_upsert_paragraph( + self, + paragraph_hash: str, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 将段落写入(或覆盖)到 FTS 索引。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + "SELECT rowid, content FROM paragraphs WHERE hash = ?", + (paragraph_hash,), + ) + row = cur.fetchone() + if not row: + return False + rowid = int(row[0]) + content = str(row[1] or "") + cur.execute( + "INSERT OR REPLACE INTO paragraphs_fts(rowid, content) VALUES (?, ?)", + (rowid, content), + ) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS upsert 失败: {e}") + c.rollback() + return False + + def fts_delete_paragraph( + self, + paragraph_hash: str, + conn: Optional[sqlite3.Connection] = None, + ) -> bool: + """ + 从 FTS 索引删除段落。 + """ + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + "SELECT rowid, content FROM paragraphs WHERE hash = ?", + (paragraph_hash,), + ) + row = cur.fetchone() + if not row: + return False + rowid = int(row[0]) + content = str(row[1] or "") + cur.execute( + "INSERT INTO paragraphs_fts(paragraphs_fts, rowid, content) VALUES ('delete', ?, ?)", + (rowid, content), + ) + c.commit() + return True + except sqlite3.OperationalError as e: + logger.warning(f"FTS delete 失败: {e}") + c.rollback() + return False + + def fts_search_bm25( + self, + match_query: str, + limit: int = 20, + max_doc_len: int = 2000, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """ + 使用 FTS5 + bm25 执行全文检索。 + """ + if not match_query.strip(): + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + """ + SELECT p.hash, p.content, bm25(paragraphs_fts) AS bm25_score + FROM paragraphs_fts + JOIN paragraphs p ON p.rowid = paragraphs_fts.rowid + WHERE paragraphs_fts MATCH ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY bm25_score ASC + LIMIT ? + """, + (match_query, max(1, int(limit))), + ) + rows = cur.fetchall() + results: List[Dict[str, Any]] = [] + for row in rows: + content = str(row["content"] or "") + if max_doc_len > 0: + content = content[:max_doc_len] + results.append( + { + "hash": row["hash"], + "content": content, + "bm25_score": float(row["bm25_score"]), + } + ) + return results + except sqlite3.OperationalError as e: + logger.warning(f"FTS 查询失败: {e}") + return [] + + def fts_search_relations_bm25( + self, + match_query: str, + limit: int = 20, + max_doc_len: int = 512, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """使用 FTS5 + bm25 执行关系全文检索。""" + if not match_query.strip(): + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute( + """ + SELECT + r.hash, + r.subject, + r.predicate, + r.object, + bm25(relations_fts) AS bm25_score + FROM relations_fts + JOIN relations r ON r.hash = relations_fts.relation_hash + WHERE relations_fts MATCH ? + ORDER BY bm25_score ASC + LIMIT ? + """, + (match_query, max(1, int(limit))), + ) + rows = cur.fetchall() + out: List[Dict[str, Any]] = [] + for row in rows: + content = f"{row['subject']} {row['predicate']} {row['object']}" + if max_doc_len > 0: + content = content[:max_doc_len] + out.append( + { + "hash": row["hash"], + "subject": row["subject"], + "predicate": row["predicate"], + "object": row["object"], + "content": content, + "bm25_score": float(row["bm25_score"]), + } + ) + return out + except sqlite3.OperationalError as e: + logger.warning(f"relations FTS 查询失败: {e}") + return [] + + def ngram_search_paragraphs( + self, + tokens: List[str], + limit: int = 20, + max_doc_len: int = 2000, + conn: Optional[sqlite3.Connection] = None, + ) -> List[Dict[str, Any]]: + """按 ngram 倒排索引检索段落,避免 LIKE 全表扫描。""" + uniq = [t for t in dict.fromkeys([str(x).strip().lower() for x in tokens]) if t] + if not uniq: + return [] + + c = self._resolve_conn(conn) + cur = c.cursor() + placeholders = ",".join(["?"] * len(uniq)) + try: + cur.execute( + f""" + SELECT + p.hash, + p.content, + COUNT(*) AS hit_terms + FROM paragraph_ngrams ng + JOIN paragraphs p ON p.hash = ng.paragraph_hash + WHERE ng.term IN ({placeholders}) + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + GROUP BY p.hash, p.content + ORDER BY hit_terms DESC + LIMIT ? + """, + tuple(uniq + [max(1, int(limit))]), + ) + rows = cur.fetchall() + out: List[Dict[str, Any]] = [] + token_count = max(1, len(uniq)) + for row in rows: + hit_terms = int(row["hit_terms"]) + score = float(hit_terms / token_count) + content = str(row["content"] or "") + if max_doc_len > 0: + content = content[:max_doc_len] + out.append( + { + "hash": row["hash"], + "content": content, + "bm25_score": -score, + "fallback_score": score, + } + ) + return out + except sqlite3.OperationalError as e: + logger.warning(f"ngram 倒排查询失败: {e}") + return [] + + def fts_doc_count(self, conn: Optional[sqlite3.Connection] = None) -> int: + """获取 FTS 文档数量。""" + c = self._resolve_conn(conn) + cur = c.cursor() + try: + cur.execute("SELECT COUNT(1) FROM paragraphs_fts") + return int(cur.fetchone()[0]) + except sqlite3.OperationalError: + return 0 + + def shrink_memory(self, conn: Optional[sqlite3.Connection] = None) -> None: + """请求 SQLite 收缩当前连接缓存。""" + c = self._resolve_conn(conn) + try: + c.execute("PRAGMA shrink_memory") + except sqlite3.OperationalError: + pass + + @staticmethod + def _normalize_episode_source(source: Any) -> str: + return str(source or "").strip() + + def _dedupe_episode_sources(self, sources: List[Any]) -> List[str]: + normalized: List[str] = [] + seen = set() + for item in sources or []: + token = self._normalize_episode_source(item) + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + return normalized + + def _get_sources_for_paragraph_hashes( + self, + hashes: List[str], + *, + include_deleted: bool = True, + ) -> List[str]: + normalized_hashes = [ + str(item or "").strip() + for item in (hashes or []) + if str(item or "").strip() + ] + if not normalized_hashes: + return [] + + placeholders = ",".join(["?"] * len(normalized_hashes)) + conditions = ["hash IN ({})".format(placeholders), "TRIM(COALESCE(source, '')) != ''"] + if not include_deleted: + conditions.append("(is_deleted IS NULL OR is_deleted = 0)") + + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT DISTINCT TRIM(source) AS source + FROM paragraphs + WHERE {' AND '.join(conditions)} + """, + tuple(normalized_hashes), + ) + return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) + + def _enqueue_episode_source_rebuilds(self, sources: List[Any], reason: str = "") -> int: + normalized_sources = self._dedupe_episode_sources(sources) + if not normalized_sources: + return 0 + + now = datetime.now().timestamp() + reason_text = str(reason or "").strip()[:200] or None + cursor = self._conn.cursor() + cursor.executemany( + """ + INSERT INTO episode_rebuild_sources ( + source, status, retry_count, last_error, reason, requested_at, updated_at + ) VALUES (?, 'pending', 0, NULL, ?, ?, ?) + ON CONFLICT(source) DO UPDATE SET + status = 'pending', + last_error = NULL, + reason = excluded.reason, + requested_at = excluded.requested_at, + updated_at = excluded.updated_at + """, + [ + (source, reason_text, now, now) + for source in normalized_sources + ], + ) + self._conn.commit() + return len(normalized_sources) + + def add_paragraph( + self, + content: str, + vector_index: Optional[int] = None, + source: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + knowledge_type: str = "mixed", + time_meta: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加段落 + + Args: + content: 段落内容 + vector_index: 向量索引 + source: 来源 + metadata: 额外元数据 + knowledge_type: 知识类型 (narrative/factual/quote/structured/mixed) + time_meta: 时间元信息 (event_time/event_time_start/event_time_end/...) + + Returns: + 段落哈希值 + """ + content_normalized = normalize_text(content) + hash_value = compute_hash(content_normalized) + resolved_knowledge_type = validate_stored_knowledge_type(knowledge_type) + + now = datetime.now().timestamp() + word_count = len(content_normalized.split()) + normalized_time = normalize_time_meta(time_meta) + + cursor = self._conn.cursor() + try: + cursor.execute(""" + INSERT INTO paragraphs + ( + hash, content, vector_index, created_at, updated_at, metadata, source, word_count, + event_time, event_time_start, event_time_end, time_granularity, time_confidence, + knowledge_type + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + hash_value, + content, + vector_index, + now, + now, + pickle.dumps(metadata or {}), + source, + word_count, + normalized_time.get("event_time"), + normalized_time.get("event_time_start"), + normalized_time.get("event_time_end"), + normalized_time.get("time_granularity"), + normalized_time.get("time_confidence", 1.0), + resolved_knowledge_type.value, + )) + self._conn.commit() + try: + self.enqueue_episode_source_rebuild( + source=source, + reason="paragraph_added", + ) + except Exception as e: + logger.warning(f"Episode source 重建入队失败: hash={hash_value[:16]}..., err={e}") + logger.debug( + f"添加段落: hash={hash_value[:16]}..., words={word_count}, type={resolved_knowledge_type.value}" + ) + return hash_value + except sqlite3.IntegrityError: + logger.debug(f"段落已存在: {hash_value[:16]}...") + # 尝试复活 + self.revive_if_deleted(paragraph_hashes=[hash_value]) + return hash_value + + def _canonicalize_name(self, name: str) -> str: + """ + 规范化名称 (统一小写并去除首尾空格) + + Args: + name: 原始名称 + + Returns: + 规范化后的名称 + """ + if not name: + return "" + return name.strip().lower() + + def add_entity( + self, + name: str, + vector_index: Optional[int] = None, + source_paragraph: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加实体 + + Args: + name: 实体名称 + vector_index: 向量索引 + source_paragraph: 来源段落哈希 (如果提供,将建立关联) + metadata: 额外元数据 + + Returns: + 实体哈希值 + """ + # 1. 规范化名称 + name_normalized = self._canonicalize_name(name) + if not name_normalized: + raise ValueError("Entity name cannot be empty") + + hash_value = compute_hash(name_normalized) + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + + # 2. 插入实体 (INSERT OR IGNORE) + # 注意:这里我们保留原有的 name 字段存储,可以是 display name, + # 但 hash 必须由 canonical name 生成。 + # 如果实体已存在,我们其实不一定要更新 name (保留第一次的 display name 往往更好) + # 或者我们也可以选择不作为唯一键冲突,而是逻辑判断。 + # 考虑到 entities.hash 是主键,entities.name 是 UNIQUE。 + # 如果 name 大小写不同但 hash 相同 (冲突),或者 name 不同但 canonical name 相同? + # 由于 hash 是由 canonical name 算出来的,所以 hash 相同意味着 canonical name 相同。 + # 如果 db 中已存在的 name 是 "Apple",新来的 name 是 "apple",它们 canonical name 都是 "apple",hash 一样。 + # 此时 INSERT OR IGNORE 会忽略。 + + try: + cursor.execute(""" + INSERT INTO entities + (hash, name, vector_index, appearance_count, created_at, metadata) + VALUES (?, ?, ?, 1, ?, ?) + """, ( + hash_value, + name, + vector_index, + now, + pickle.dumps(metadata or {}), + )) + + logger.debug(f"添加实体: {name} ({hash_value[:8]})") + self._conn.commit() + + # 3. 建立来源关联 + if source_paragraph: + self.link_paragraph_entity(source_paragraph, hash_value) + + return hash_value + + except sqlite3.IntegrityError: + # 实体已存在 + # 1. 尝试复活 (自动复活) + self.revive_if_deleted(entity_hashes=[hash_value]) + + # 2. 更新计数 + cursor.execute(""" + UPDATE entities + SET appearance_count = appearance_count + 1 + WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + logger.debug(f"实体已存在(复活/计数+1): {name}") + + # 3. 建立来源关联 + if source_paragraph: + self.link_paragraph_entity(source_paragraph, hash_value) + + return hash_value + + def add_relation( + self, + subject: str, + predicate: str, + obj: str, + vector_index: Optional[int] = None, + confidence: float = 1.0, + source_paragraph: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """ + 添加关系 + + Args: + subject: 主语 + predicate: 谓语 + obj: 宾语 + vector_index: 向量索引 + confidence: 置信度 + source_paragraph: 来源段落哈希 + metadata: 额外元数据 + + Returns: + 关系哈希值 + """ + # 1. 规范化输入 + s_canon = self._canonicalize_name(subject) + p_canon = self._canonicalize_name(predicate) + o_canon = self._canonicalize_name(obj) + + if not all([s_canon, p_canon, o_canon]): + raise ValueError("Relation components cannot be empty") + + # 2. 计算组合哈希 + # 公式: md5(s|p|o) + relation_key = f"{s_canon}|{p_canon}|{o_canon}" + hash_value = compute_hash(relation_key) + + now = datetime.now().timestamp() + + # 记录原始 display name 到 metadata (如果需要的话,或者直接存到 DB 字段) + # 这里我们直接存入 subject, predicate, object 字段, + # 注意:如果 DB 里已存在该关系 (hash 相同),则不会更新这些字段,保留第一次的拼写。 + + cursor = self._conn.cursor() + try: + cursor.execute(""" + INSERT OR IGNORE INTO relations + (hash, subject, predicate, object, vector_index, confidence, created_at, source_paragraph, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + hash_value, + subject, # 原始拼写 + predicate, + obj, + vector_index, + confidence, + now, + source_paragraph, # 这里的 source_paragraph 仅作为 "首次发现地" 记录,也可留空 + pickle.dumps(metadata or {}), + )) + self._conn.commit() + + if cursor.rowcount > 0: + logger.debug(f"添加关系: {subject} -{predicate}-> {obj}") + else: + logger.debug(f"关系已存在: {subject} -{predicate}-> {obj}") + + # 3. 建立来源关联 (幂等) + # 无论关系是新创建的还是已存在的,只要提供了 source_paragraph,都要建立连接 + if source_paragraph: + self.link_paragraph_relation(source_paragraph, hash_value) + + return hash_value + + except sqlite3.IntegrityError as e: + logger.warning(f"添加关系异常: {e}") + return hash_value + + def link_paragraph_relation( + self, + paragraph_hash: str, + relation_hash: str, + ) -> bool: + """ + 关联段落和关系 (幂等) + """ + cursor = self._conn.cursor() + try: + # 使用 INSERT OR IGNORE 避免重复报错 + cursor.execute(""" + INSERT OR IGNORE INTO paragraph_relations + (paragraph_hash, relation_hash) + VALUES (?, ?) + """, (paragraph_hash, relation_hash)) + self._conn.commit() + self._enqueue_episode_source_rebuilds( + self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), + reason="paragraph_relation_linked", + ) + return True + except sqlite3.IntegrityError: + return False + + def link_paragraph_entity( + self, + paragraph_hash: str, + entity_hash: str, + mention_count: int = 1, + ) -> bool: + """ + 关联段落和实体 (幂等) + """ + cursor = self._conn.cursor() + try: + # 首先尝试插入 + cursor.execute(""" + INSERT OR IGNORE INTO paragraph_entities + (paragraph_hash, entity_hash, mention_count) + VALUES (?, ?, ?) + """, (paragraph_hash, entity_hash, mention_count)) + + if cursor.rowcount == 0: + # 如果已存在 (IGNORE生效),则更新计数 + cursor.execute(""" + UPDATE paragraph_entities + SET mention_count = mention_count + ? + WHERE paragraph_hash = ? AND entity_hash = ? + """, (mention_count, paragraph_hash, entity_hash)) + + self._conn.commit() + self._enqueue_episode_source_rebuilds( + self._get_sources_for_paragraph_hashes([paragraph_hash], include_deleted=True), + reason="paragraph_entity_linked", + ) + return True + except sqlite3.IntegrityError: + return False + + def get_paragraph(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取段落 + + Args: + hash_value: 段落哈希 + + Returns: + 段落信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM paragraphs WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "paragraph") + return None + + def update_paragraph_time_meta( + self, + paragraph_hash: str, + time_meta: Dict[str, Any], + ) -> bool: + """ + 更新段落时间元信息。 + """ + normalized = normalize_time_meta(time_meta) + if not normalized: + return False + source_to_rebuild = self._get_sources_for_paragraph_hashes( + [paragraph_hash], + include_deleted=True, + ) + + updates: List[str] = [] + params: List[Any] = [] + for key in [ + "event_time", + "event_time_start", + "event_time_end", + "time_granularity", + "time_confidence", + ]: + if key in normalized: + updates.append(f"{key} = ?") + params.append(normalized[key]) + + if not updates: + return False + + updates.append("updated_at = ?") + params.append(datetime.now().timestamp()) + params.append(paragraph_hash) + + cursor = self._conn.cursor() + cursor.execute( + f"UPDATE paragraphs SET {', '.join(updates)} WHERE hash = ?", + tuple(params), + ) + self._conn.commit() + changed = cursor.rowcount > 0 + if changed: + self._enqueue_episode_source_rebuilds( + source_to_rebuild, + reason="paragraph_time_updated", + ) + return changed + + def query_paragraphs_temporal( + self, + start_ts: Optional[float] = None, + end_ts: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + limit: int = 100, + allow_created_fallback: bool = True, + ) -> List[Dict[str, Any]]: + """ + 查询时序命中的段落(区间相交语义)。 + """ + if limit <= 0: + return [] + + effective_start = "COALESCE(p.event_time_start, p.event_time, p.event_time_end" + effective_end = "COALESCE(p.event_time_end, p.event_time, p.event_time_start" + if allow_created_fallback: + effective_start += ", p.created_at)" + effective_end += ", p.created_at)" + else: + effective_start += ")" + effective_end += ")" + + conditions = ["(p.is_deleted IS NULL OR p.is_deleted = 0)"] + params: List[Any] = [] + + if source: + conditions.append("p.source = ?") + params.append(source) + + if person: + conditions.append( + """ + EXISTS ( + SELECT 1 + FROM paragraph_entities pe + JOIN entities e ON e.hash = pe.entity_hash + WHERE pe.paragraph_hash = p.hash + AND LOWER(e.name) LIKE ? + ) + """ + ) + params.append(f"%{str(person).strip().lower()}%") + + if start_ts is not None and end_ts is not None: + conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") + params.extend([start_ts, end_ts]) + elif start_ts is not None: + conditions.append(f"({effective_end} >= ?)") + params.append(start_ts) + elif end_ts is not None: + conditions.append(f"({effective_start} <= ?)") + params.append(end_ts) + + where_sql = " AND ".join(conditions) + sql = f""" + SELECT p.* + FROM paragraphs p + WHERE {where_sql} + ORDER BY {effective_end} DESC, p.updated_at DESC + LIMIT ? + """ + params.append(limit) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_entity(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取实体 + + Args: + hash_value: 实体哈希 + + Returns: + 实体信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM entities WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "entity") + return None + + def get_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 获取关系 + + Args: + hash_value: 关系哈希 + + Returns: + 关系信息字典,不存在则返回None + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM relations WHERE hash = ? + """, (hash_value,)) + row = cursor.fetchone() + + if row: + return self._row_to_dict(row, "relation") + return None + + def get_paragraph_relations(self, paragraph_hash: str) -> List[Dict[str, Any]]: + """ + 获取段落的所有关系 + + Args: + paragraph_hash: 段落哈希 + + Returns: + 关系列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT r.* FROM relations r + JOIN paragraph_relations pr ON r.hash = pr.relation_hash + WHERE pr.paragraph_hash = ? + """, (paragraph_hash,)) + + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def get_paragraph_entities(self, paragraph_hash: str) -> List[Dict[str, Any]]: + """ + 获取段落的所有实体 + + Args: + paragraph_hash: 段落哈希 + + Returns: + 实体列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT e.*, pe.mention_count + FROM entities e + JOIN paragraph_entities pe ON e.hash = pe.entity_hash + WHERE pe.paragraph_hash = ? + """, (paragraph_hash,)) + + return [self._row_to_dict(row, "entity") for row in cursor.fetchall()] + + def get_paragraphs_by_entity(self, entity_name: str) -> List[Dict[str, Any]]: + """ + 获取包含指定实体的所有段落 (自动处理规范化) + + Args: + entity_name: 实体名称 (支持任意大小写) + + Returns: + 段落列表 + """ + # 1. 计算规范化 Hash + name_canon = self._canonicalize_name(entity_name) + if not name_canon: + return [] + + entity_hash = compute_hash(name_canon) + + cursor = self._conn.cursor() + # 2. 直接使用 Hash 查询中间表,完全避开 Name 匹配 + cursor.execute(""" + SELECT p.* + FROM paragraphs p + JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash + WHERE pe.entity_hash = ? + """, (entity_hash,)) + + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_relations( + self, + subject: Optional[str] = None, + predicate: Optional[str] = None, + object: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + 查询关系(大小写不敏感) + + Args: + subject: 主语(可选) + predicate: 谓语(可选) + object: 宾语(可选) + + Returns: + 关系列表 + """ + # 构建查询条件 + conditions = [] + params = [] + + if subject: + conditions.append("LOWER(subject) = ?") + params.append(self._canonicalize_name(subject)) + if predicate: + conditions.append("LOWER(predicate) = ?") + params.append(self._canonicalize_name(predicate)) + if object: + conditions.append("LOWER(object) = ?") + params.append(self._canonicalize_name(object)) + + sql = "SELECT * FROM relations" + if conditions: + sql += " WHERE " + " AND ".join(conditions) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def get_all_triples(self) -> List[Tuple[str, str, str, str]]: + """ + 高效获取所有三元组 (subject, predicate, object, hash) + 直接返回元组,跳过字典转换和pickle反序列化,用于构建 V5 Map 缓存。 + """ + cursor = self._conn.cursor() + cursor.execute("SELECT subject, predicate, object, hash FROM relations") + return list(cursor.fetchall()) + + def get_paragraphs_by_relation(self, relation_hash: str) -> List[Dict[str, Any]]: + """ + 获取支持指定关系的所有段落 + + Args: + relation_hash: 关系哈希 + + Returns: + 段落列表 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT p.* + FROM paragraphs p + JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash + WHERE pr.relation_hash = ? + """, (relation_hash,)) + + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def get_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: + """ + 按来源获取段落 + + Args: + source: 来源标识符 + + Returns: + 段落列表 + """ + return self.query("SELECT * FROM paragraphs WHERE source = ?", (source,)) + + def get_all_sources(self) -> List[Dict[str, Any]]: + """ + 获取所有来源文件统计信息 + + Returns: + 来源列表 [{'source': 'name', 'count': int, 'last_updated': timestamp}] + """ + cursor = self._conn.cursor() + # 排除 source 为 NULL 或空的记录 + cursor.execute(""" + SELECT source, COUNT(*) as count, MAX(created_at) as last_updated + FROM paragraphs + WHERE source IS NOT NULL AND source != '' + GROUP BY source + ORDER BY last_updated DESC + """) + + results = [] + for row in cursor.fetchall(): + results.append({ + "source": row[0], + "count": row[1], + "last_updated": row[2] + }) + return results + + + def search_paragraphs_by_content(self, content_query: str) -> List[Dict[str, Any]]: + """按内容模糊搜索段落""" + cursor = self._conn.cursor() + cursor.execute(""" + SELECT * FROM paragraphs WHERE content LIKE ? + """, (f"%{content_query}%",)) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def delete_paragraph(self, hash_value: str) -> bool: + """ + 删除段落(级联删除相关关联) + + Args: + hash_value: 段落哈希 + + Returns: + 是否成功删除 + """ + cursor = self._conn.cursor() + cursor.execute(""" + DELETE FROM paragraphs WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + deleted = cursor.rowcount > 0 + if deleted: + logger.info(f"删除段落: {hash_value[:16]}...") + + return deleted + + def delete_entity(self, hash_or_name: str) -> bool: + """ + 删除实体(级联删除相关关联) + 支持通过哈希值或名称删除 + + 注意:会同时删除所有引用该实体(作为主语或宾语)的关系 + """ + cursor = self._conn.cursor() + + # 1. 解析实体信息 (获取 Name 和 Hash) + entity_name = None + entity_hash = None + + # 尝试作为 Hash 查询 + cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (hash_or_name,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + else: + # 尝试作为 Name 查询 (原始匹配) + cursor.execute("SELECT name, hash FROM entities WHERE name = ?", (hash_or_name,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + else: + # 最后的最后:尝试规范化名称 (Canonical) 查询,解决大小写或 WebUI 手动输入导致的不匹配 + name_canon = self._canonicalize_name(hash_or_name) + canon_hash = compute_hash(name_canon) + cursor.execute("SELECT name, hash FROM entities WHERE hash = ?", (canon_hash,)) + row = cursor.fetchone() + if row: + entity_name = row[0] + entity_hash = row[1] + + if not entity_name or not entity_hash: + logger.debug(f"删除实体请求跳过:未在元数据记录中找到 {hash_or_name}") + return False + + logger.info(f"开始删除实体: {entity_name} (Hash: {entity_hash[:8]}...)") + + try: + # 2. 查找相关关系 (Subject 或 Object 为该实体) + cursor.execute(""" + SELECT hash FROM relations + WHERE subject = ? OR object = ? + """, (entity_name, entity_name)) + + relation_hashes = [r[0] for r in cursor.fetchall()] + + if relation_hashes: + logger.info(f"发现 {len(relation_hashes)} 个相关关系,准备级联删除") + + # 3. 删除这些关系与段落的关联 + # SQLite 不支持直接 DELETE ... WHERE ... IN (...) 的列表参数,需要拼接占位符 + placeholders = ','.join(['?'] * len(relation_hashes)) + + cursor.execute(f""" + DELETE FROM paragraph_relations + WHERE relation_hash IN ({placeholders}) + """, relation_hashes) + + # 4. 删除关系本体 + cursor.execute(f""" + DELETE FROM relations + WHERE hash IN ({placeholders}) + """, relation_hashes) + + logger.info("相关关系已级联删除") + + # 5. 删除实体与段落的关联 + cursor.execute("DELETE FROM paragraph_entities WHERE entity_hash = ?", (entity_hash,)) + + # 6. 删除实体本体 + cursor.execute("DELETE FROM entities WHERE hash = ?", (entity_hash,)) + + self._conn.commit() + logger.info("实体删除完成") + return True + + except Exception as e: + logger.error(f"删除实体时发生错误: {e}") + self._conn.rollback() + return False + + def delete_relation(self, hash_value: str) -> bool: + """ + 删除关系(级联删除相关关联) + + Args: + hash_value: 关系哈希 + + Returns: + 是否成功删除 + """ + cursor = self._conn.cursor() + cursor.execute(""" + DELETE FROM relations WHERE hash = ? + """, (hash_value,)) + self._conn.commit() + + deleted = cursor.rowcount > 0 + if deleted: + logger.info(f"删除关系: {hash_value[:16]}...") + + return deleted + + def set_relation_vector_state( + self, + hash_value: str, + state: str, + error: Optional[str] = None, + bump_retry: bool = False, + ) -> bool: + """ + 更新关系向量状态。 + """ + state_norm = str(state or "").strip().lower() + if state_norm not in {"none", "pending", "ready", "failed"}: + raise ValueError(f"无效 vector_state: {state}") + + now = datetime.now().timestamp() + err_text = (str(error).strip() if error is not None else None) + if err_text: + err_text = err_text[:500] + clear_error = state_norm in {"none", "pending", "ready"} + + cursor = self._conn.cursor() + if bump_retry: + cursor.execute( + """ + UPDATE relations + SET vector_state = ?, + vector_updated_at = ?, + vector_error = ?, + vector_retry_count = COALESCE(vector_retry_count, 0) + 1 + WHERE hash = ? + """, + (state_norm, now, None if clear_error else err_text, hash_value), + ) + else: + cursor.execute( + """ + UPDATE relations + SET vector_state = ?, + vector_updated_at = ?, + vector_error = ? + WHERE hash = ? + """, + (state_norm, now, None if clear_error else err_text, hash_value), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_relations_by_vector_state( + self, + states: List[str], + limit: int = 200, + max_retry: Optional[int] = None, + ) -> List[Dict[str, Any]]: + """ + 根据向量状态列出关系,用于回填任务。 + """ + normalized_states = [ + str(s or "").strip().lower() + for s in (states or []) + if str(s or "").strip() + ] + normalized_states = [ + s for s in normalized_states + if s in {"none", "pending", "ready", "failed"} + ] + if not normalized_states: + return [] + + placeholders = ",".join(["?"] * len(normalized_states)) + params: List[Any] = list(normalized_states) + sql = f""" + SELECT hash, subject, predicate, object, confidence, source_paragraph, + vector_state, vector_updated_at, vector_error, vector_retry_count, created_at + FROM relations + WHERE vector_state IN ({placeholders}) + """ + if max_retry is not None: + sql += " AND COALESCE(vector_retry_count, 0) < ?" + params.append(int(max_retry)) + sql += " ORDER BY COALESCE(vector_updated_at, created_at, 0) ASC LIMIT ?" + params.append(max(1, int(limit))) + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params)) + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def count_relations_by_vector_state(self) -> Dict[str, int]: + """ + 统计关系向量状态分布。 + """ + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT COALESCE(vector_state, 'none') AS state, COUNT(*) AS cnt + FROM relations + GROUP BY COALESCE(vector_state, 'none') + """ + ) + result: Dict[str, int] = {"none": 0, "pending": 0, "ready": 0, "failed": 0} + total = 0 + for row in cursor.fetchall(): + state = str(row["state"] or "none").lower() + count = int(row["cnt"] or 0) + if state not in result: + result[state] = 0 + result[state] += count + total += count + result["total"] = total + return result + + def update_vector_index( + self, + item_type: str, + hash_value: str, + vector_index: int, + ) -> bool: + """ + 更新向量索引 + + Args: + item_type: 类型(paragraph/entity/relation) + hash_value: 哈希值 + vector_index: 向量索引 + + Returns: + 是否成功更新 + """ + valid_types = ["paragraph", "entity", "relation"] + if item_type not in valid_types: + raise ValueError(f"无效的类型: {item_type}") + + table_map = { + "paragraph": "paragraphs", + "entity": "entities", + "relation": "relations", + } + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET vector_index = ? + WHERE hash = ? + """, (vector_index, hash_value)) + self._conn.commit() + + return cursor.rowcount > 0 + + def set_permanence(self, hash_value: str, item_type: str, is_permanent: bool) -> bool: + """设置永久记忆标记""" + table_map = { + "paragraph": "paragraphs", + "relation": "relations", + } + if item_type not in table_map: + raise ValueError(f"类型 {item_type} 不支持设置永久性") + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET is_permanent = ? + WHERE hash = ? + """, (1 if is_permanent else 0, hash_value)) + self._conn.commit() + + if cursor.rowcount > 0: + logger.debug(f"设置永久记忆: {item_type}/{hash_value[:8]} -> {is_permanent}") + return True + return False + + def record_access(self, hash_value: str, item_type: str) -> bool: + """记录访问(更新时间和次数)""" + table_map = { + "paragraph": "paragraphs", + "relation": "relations", + } + if item_type not in table_map: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table_map[item_type]} + SET last_accessed = ?, access_count = access_count + 1 + WHERE hash = ? + """, (now, hash_value)) + self._conn.commit() + return cursor.rowcount > 0 + + def query( + self, + sql: str, + params: Optional[Tuple] = None, + ) -> List[Dict[str, Any]]: + """ + 执行自定义查询 + + Args: + sql: SQL语句 + params: 参数 + + Returns: + 查询结果列表 + """ + cursor = self._conn.cursor() + if params: + cursor.execute(sql, params) + else: + cursor.execute(sql) + + return [dict(row) for row in cursor.fetchall()] + + def get_external_memory_ref(self, external_id: str) -> Optional[Dict[str, Any]]: + """按 external_id 查询外部记忆映射。""" + token = str(external_id or "").strip() + if not token: + return None + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT external_id, paragraph_hash, source_type, created_at, metadata_json + FROM external_memory_refs + WHERE external_id = ? + LIMIT 1 + """, + (token,), + ) + row = cursor.fetchone() + if row is None: + return None + + payload = dict(row) + raw_metadata = payload.get("metadata_json") + if raw_metadata: + try: + payload["metadata"] = json.loads(raw_metadata) + except Exception: + payload["metadata"] = {} + else: + payload["metadata"] = {} + return payload + + def upsert_external_memory_ref( + self, + *, + external_id: str, + paragraph_hash: str, + source_type: str = "", + metadata: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """注册 external_id 到段落哈希的幂等映射。""" + external_token = str(external_id or "").strip() + paragraph_token = str(paragraph_hash or "").strip() + if not external_token: + raise ValueError("external_id 不能为空") + if not paragraph_token: + raise ValueError("paragraph_hash 不能为空") + + now = datetime.now().timestamp() + metadata_json = json.dumps(metadata or {}, ensure_ascii=False) + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO external_memory_refs ( + external_id, paragraph_hash, source_type, created_at, metadata_json + ) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(external_id) DO UPDATE SET + paragraph_hash = excluded.paragraph_hash, + source_type = excluded.source_type, + metadata_json = excluded.metadata_json + """, + ( + external_token, + paragraph_token, + str(source_type or "").strip() or None, + now, + metadata_json, + ), + ) + self._conn.commit() + return self.get_external_memory_ref(external_token) or { + "external_id": external_token, + "paragraph_hash": paragraph_token, + "source_type": str(source_type or "").strip(), + "created_at": now, + "metadata": metadata or {}, + } + + def get_statistics(self) -> Dict[str, int]: + """ + 获取统计信息 + + Returns: + 统计信息字典 + """ + cursor = self._conn.cursor() + + stats = {} + + # 段落数量 + cursor.execute("SELECT COUNT(*) FROM paragraphs") + stats["paragraph_count"] = cursor.fetchone()[0] + + # 实体数量 + cursor.execute("SELECT COUNT(*) FROM entities") + stats["entity_count"] = cursor.fetchone()[0] + + # 关系数量 + cursor.execute("SELECT COUNT(*) FROM relations") + stats["relation_count"] = cursor.fetchone()[0] + + # 总词数 + cursor.execute("SELECT SUM(word_count) FROM paragraphs") + result = cursor.fetchone()[0] + stats["total_words"] = result if result else 0 + + return stats + + def count_paragraphs(self, include_deleted: bool = False, only_deleted: bool = False) -> int: + """ + 获取段落数量 + """ + cursor = self._conn.cursor() + if only_deleted: + cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 1") + return cursor.fetchone()[0] + if include_deleted: + cursor.execute("SELECT COUNT(*) FROM paragraphs") + return cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM paragraphs WHERE is_deleted = 0") + return cursor.fetchone()[0] + + def count_relations(self, include_deleted: bool = False, only_deleted: bool = False) -> int: + """ + 获取关系数量 + """ + cursor = self._conn.cursor() + if only_deleted: + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + return cursor.fetchone()[0] + cursor.execute("SELECT COUNT(*) FROM relations") + active_count = cursor.fetchone()[0] + if not include_deleted: + return active_count + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + deleted_count = cursor.fetchone()[0] + return int(active_count) + int(deleted_count) + + def count_entities(self) -> int: + """ + 获取实体数量 + + Returns: + 实体数量 + """ + cursor = self._conn.cursor() + cursor.execute("SELECT COUNT(*) FROM entities") + return cursor.fetchone()[0] + + def get_knowledge_type_distribution(self) -> Dict[str, int]: + """获取段落知识类型分布。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT knowledge_type, COUNT(*) as count + FROM paragraphs + WHERE is_deleted = 0 + GROUP BY knowledge_type + """ + ) + result: Dict[str, int] = {} + for row in cursor.fetchall(): + type_name = row[0] if row[0] else "未分类" + result[str(type_name)] = int(row[1] or 0) + return result + + def get_memory_status_summary(self, now_ts: Optional[float] = None) -> Dict[str, int]: + """聚合 memory status 统计。""" + now_ts = float(now_ts) if now_ts is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 0") + active_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_inactive = 1") + inactive_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM deleted_relations") + deleted_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE is_pinned = 1") + pinned_count = int(cursor.fetchone()[0] or 0) + cursor.execute("SELECT COUNT(*) FROM relations WHERE protected_until > ?", (now_ts,)) + ttl_count = int(cursor.fetchone()[0] or 0) + return { + "active_count": active_count, + "inactive_count": inactive_count, + "deleted_count": deleted_count, + "pinned_count": pinned_count, + "temp_protected_count": ttl_count, + } + + def get_relations_subject_object_map(self, hashes: List[str]) -> Dict[str, Tuple[str, str]]: + """批量获取关系 hash 对应的 (subject, object)。""" + if not hashes: + return {} + cursor = self._conn.cursor() + placeholders = ",".join(["?"] * len(hashes)) + cursor.execute( + f"SELECT hash, subject, object FROM relations WHERE hash IN ({placeholders})", + hashes, + ) + return {str(row[0]): (str(row[1]), str(row[2])) for row in cursor.fetchall()} + + def get_connection(self) -> sqlite3.Connection: + """公开连接访问(用于离线脚本),替代外部访问私有字段。""" + return self._resolve_conn() + + def get_relation_db_snapshot(self) -> Tuple[int, float, str]: + """返回关系快照:(relation_count, max_created_at, max_hash)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT + COUNT(*) AS relation_count, + COALESCE(MAX(created_at), 0) AS max_created_at, + COALESCE(MAX(hash), '') AS max_hash + FROM relations + """ + ) + row = cursor.fetchone() + if not row: + return (0, 0.0, "") + return ( + int(row[0] or 0), + float(row[1] or 0.0), + str(row[2] or ""), + ) + + def is_entity_still_referenced(self, entity_hash: str, entity_name: str = "") -> bool: + """ + 判断实体是否仍被引用: + 1) 被 paragraph_entities 引用 + 2) 在 relations.subject/object 中出现 + """ + token_hash = str(entity_hash or "").strip() + if token_hash: + cursor = self._conn.cursor() + cursor.execute( + "SELECT 1 FROM paragraph_entities WHERE entity_hash = ? LIMIT 1", + (token_hash,), + ) + if cursor.fetchone() is not None: + return True + + canon_name = self._canonicalize_name(entity_name) + if canon_name: + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM relations + WHERE LOWER(TRIM(subject)) = ? OR LOWER(TRIM(object)) = ? + LIMIT 1 + """, + (canon_name, canon_name), + ) + if cursor.fetchone() is not None: + return True + return False + + def search_relations_by_subject_or_object( + self, + query: str, + *, + limit: int = 5, + include_deleted: bool = False, + ) -> List[Dict[str, Any]]: + """按 subject/object 模糊查询关系。""" + q = str(query or "").strip() + if not q: + return [] + max_limit = int(max(1, limit)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM relations + WHERE subject LIKE ? OR object LIKE ? + LIMIT ? + """, + (f"%{q}%", f"%{q}%", max_limit), + ) + rows = [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + if rows or not include_deleted: + return rows + + cursor.execute( + """ + SELECT * + FROM deleted_relations + WHERE subject LIKE ? OR object LIKE ? + LIMIT ? + """, + (f"%{q}%", f"%{q}%", max_limit), + ) + return [self._row_to_dict(row, "relation") for row in cursor.fetchall()] + + def list_hashes(self, table: str) -> List[str]: + """安全枚举指定表的 hash 列。""" + allowed = {"paragraphs", "entities", "relations", "deleted_relations"} + token = str(table or "").strip().lower() + if token not in allowed: + raise ValueError(f"unsupported table for list_hashes: {table}") + cursor = self._conn.cursor() + cursor.execute(f"SELECT hash FROM {token}") + return [str(row[0]) for row in cursor.fetchall()] + + def get_orphan_deleted_relation_hashes(self, limit: int = 200) -> List[str]: + """获取 deleted_relations 中已不在 relations 的孤儿 hash。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT d.hash + FROM deleted_relations d + LEFT JOIN relations r ON r.hash = d.hash + WHERE r.hash IS NULL + LIMIT ? + """, + (int(max(1, limit)),), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def resolve_relation_hash_alias( + self, + value: str, + *, + include_deleted: bool = False, + ) -> List[str]: + """ + 解析关系哈希输入: + - 64位:直接校验存在性 + - 32位:通过 relation_hash_aliases 唯一映射 + """ + token = str(value or "").strip().lower() + if not token: + return [] + if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token): + cursor = self._conn.cursor() + cursor.execute("SELECT 1 FROM relations WHERE hash = ? LIMIT 1", (token,)) + if cursor.fetchone(): + return [token] + if include_deleted: + cursor.execute("SELECT 1 FROM deleted_relations WHERE hash = ? LIMIT 1", (token,)) + if cursor.fetchone(): + return [token] + return [] + + if len(token) != 32 or not all(ch in "0123456789abcdef" for ch in token): + return [] + + cursor = self._conn.cursor() + cursor.execute("SELECT hash FROM relation_hash_aliases WHERE alias32 = ?", (token,)) + row = cursor.fetchone() + if not row: + return [] + resolved = str(row[0]) + return [resolved] + + def rebuild_relation_hash_aliases(self) -> Dict[str, Any]: + """重建 32 位 relation hash 别名映射。""" + cursor = self._conn.cursor() + # 历史库兜底:缺表时先创建,避免迁移过程直接中断。 + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relation_hash_aliases ( + alias32 TEXT PRIMARY KEY, + hash TEXT NOT NULL + ) + """) + cursor.execute("DELETE FROM relation_hash_aliases") + + cursor.execute("SELECT hash FROM relations") + hashes = [str(r[0]) for r in cursor.fetchall()] + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='deleted_relations'" + ) + has_deleted_relations = cursor.fetchone() is not None + if has_deleted_relations: + cursor.execute("SELECT hash FROM deleted_relations") + hashes.extend(str(r[0]) for r in cursor.fetchall()) + + alias_map: Dict[str, str] = {} + conflicts: Dict[str, set[str]] = {} + for h in hashes: + if len(h) != 64: + continue + alias = h[:32] + old = alias_map.get(alias) + if old is None: + alias_map[alias] = h + elif old != h: + conflicts.setdefault(alias, set()).update({old, h}) + + for alias, full_hash in alias_map.items(): + if alias in conflicts: + continue + cursor.execute( + "INSERT INTO relation_hash_aliases(alias32, hash) VALUES (?, ?)", + (alias, full_hash), + ) + self._conn.commit() + return { + "inserted": len(alias_map) - len(conflicts), + "conflict_count": len(conflicts), + "conflicts": sorted(conflicts.keys()), + } + + def search_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: + """按 relation 内容模糊查询 hash。""" + q = str(query or "").strip() + if not q: + return [] + cursor = self._conn.cursor() + cursor.execute( + "SELECT hash FROM relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", + (f"%{q}%", f"%{q}%", int(max(1, limit))), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def search_deleted_relation_hashes_by_text(self, query: str, limit: int = 5) -> List[str]: + """按 deleted_relations 内容模糊查询 hash。""" + q = str(query or "").strip() + if not q: + return [] + cursor = self._conn.cursor() + cursor.execute( + "SELECT hash FROM deleted_relations WHERE subject LIKE ? OR object LIKE ? LIMIT ?", + (f"%{q}%", f"%{q}%", int(max(1, limit))), + ) + return [str(row[0]) for row in cursor.fetchall()] + + def restore_entity_by_hash(self, entity_hash: str) -> bool: + """恢复软删除实体。""" + cursor = self._conn.cursor() + cursor.execute( + "UPDATE entities SET is_deleted=0, deleted_at=NULL WHERE hash=?", + (str(entity_hash),), + ) + changed = cursor.rowcount > 0 + if changed: + self._conn.commit() + return changed + + def backfill_temporal_metadata_from_created_at( + self, + *, + limit: int = 100000, + dry_run: bool = False, + no_created_fallback: bool = False, + ) -> Dict[str, int]: + """回填段落 event_time 字段(created_at 兜底)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT hash, created_at, source + FROM paragraphs + WHERE (event_time IS NULL AND event_time_start IS NULL AND event_time_end IS NULL) + ORDER BY created_at DESC + LIMIT ? + """, + (int(max(1, limit)),), + ) + rows = cursor.fetchall() + candidates = len(rows) + if dry_run: + return {"candidates": candidates, "updated": 0} + if no_created_fallback: + return {"candidates": candidates, "updated": 0} + + updated = 0 + touched_sources: List[str] = [] + for row in rows: + created_at = row["created_at"] + if created_at is None: + continue + cursor.execute( + """ + UPDATE paragraphs + SET event_time = ?, time_granularity = ?, time_confidence = ?, updated_at = ? + WHERE hash = ? + """, + (float(created_at), "day", 0.2, float(created_at), row["hash"]), + ) + if cursor.rowcount > 0: + updated += 1 + touched_sources.append(row["source"]) + self._conn.commit() + if updated > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_time_backfill", + ) + return {"candidates": candidates, "updated": updated} + + def get_schema_version(self) -> int: + cursor = self._conn.cursor() + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'" + ) + if cursor.fetchone() is None: + return 0 + cursor.execute("SELECT MAX(version) FROM schema_migrations") + row = cursor.fetchone() + return int(row[0]) if row and row[0] is not None else 0 + + def set_schema_version(self, version: int = SCHEMA_VERSION) -> None: + cursor = self._conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, applied_at REAL NOT NULL)" + ) + cursor.execute( + "INSERT OR REPLACE INTO schema_migrations(version, applied_at) VALUES (?, ?)", + (int(version), datetime.now().timestamp()), + ) + self._conn.commit() + + def delete_paragraph_atomic(self, paragraph_hash: str) -> Dict[str, Any]: + """ + 两阶段删除段落:DB 事务内计算 + 提交后执行清理 + + Args: + paragraph_hash: 段落哈希 + + Returns: + cleanup_plan: 包含需要后续从 Vector/GraphStore 中移除的 ID 列表 + """ + cleanup_plan = { + "paragraph_hash": paragraph_hash, + "vector_id_to_remove": None, + "edges_to_remove": [], # (src, tgt) 元组列表 (fallback) + "relation_prune_ops": [], # (subject, object, relation_hash) 精准裁剪 + "episode_sources_to_rebuild": [], + } + + cursor = self._conn.cursor() + try: + # === Phase 1: DB Transaction (可回滚) === + # 使用 IMMEDIATE 模式,一旦开启事务立即锁定 DB (防止其他写操作插队导致幻读) + cursor.execute("BEGIN IMMEDIATE") + + # 1. [快照] 获取候选关系 + cursor.execute("SELECT relation_hash FROM paragraph_relations WHERE paragraph_hash = ?", (paragraph_hash,)) + candidate_relations = [row[0] for row in cursor.fetchall()] + + # 2. [快照] 确认该段落存在并记录 ID 用于向量删除 + cursor.execute("SELECT hash, source FROM paragraphs WHERE hash = ?", (paragraph_hash,)) + paragraph_row = cursor.fetchone() + if paragraph_row: + cleanup_plan["vector_id_to_remove"] = paragraph_hash + cleanup_plan["episode_sources_to_rebuild"] = self._dedupe_episode_sources( + [paragraph_row["source"]] + ) + + # 3. [主删除] 删除段落 (触发 CASCADE 删 paragraph_relations) + cursor.execute("DELETE FROM paragraphs WHERE hash = ?", (paragraph_hash,)) + + # 4. [计算孤儿] + orphaned_hashes = [] + for rel_hash in candidate_relations: + count = cursor.execute( + "SELECT count(*) FROM paragraph_relations WHERE relation_hash = ?", + (rel_hash,) + ).fetchone()[0] + + if count == 0: + # 是孤儿:记录边信息以便后续删 Graph + cursor.execute("SELECT subject, object FROM relations WHERE hash = ?", (rel_hash,)) + rel_info = cursor.fetchone() + if rel_info: + s_val, o_val = rel_info[0], rel_info[1] + cleanup_plan["relation_prune_ops"].append((s_val, o_val, rel_hash)) + + # 仅当 (subject, object) 不再有任何关系时,才计划删整条边(兼容旧实现)。 + sibling_count = cursor.execute( + """ + SELECT count(*) FROM relations + WHERE LOWER(TRIM(subject)) = LOWER(TRIM(?)) + AND LOWER(TRIM(object)) = LOWER(TRIM(?)) + AND hash != ? + """, + (s_val, o_val, rel_hash) + ).fetchone()[0] + if sibling_count == 0: + cleanup_plan["edges_to_remove"].append((s_val, o_val)) + + orphaned_hashes.append(rel_hash) + + # 5. [DB清理] 删除孤儿关系记录 + if orphaned_hashes: + placeholders = ','.join(['?'] * len(orphaned_hashes)) + cursor.execute(f"DELETE FROM relations WHERE hash IN ({placeholders})", orphaned_hashes) + + self._conn.commit() + if cleanup_plan["episode_sources_to_rebuild"]: + self._enqueue_episode_source_rebuilds( + cleanup_plan["episode_sources_to_rebuild"], + reason="paragraph_deleted", + ) + if cleanup_plan["vector_id_to_remove"]: + logger.debug(f"原子删除段落成功: {paragraph_hash}, 计划清理 {len(orphaned_hashes)} 个孤儿关系") + return cleanup_plan + + except Exception as e: + self._conn.rollback() + logger.error(f"DB Transaction failed: {e}") + raise e + + + def clear_all(self) -> None: + """清空所有表数据""" + cursor = self._conn.cursor() + tables = [ + "paragraphs", "entities", "relations", + "paragraph_relations", "paragraph_entities", + "episodes", "episode_paragraphs", + "episode_rebuild_sources", "episode_pending_paragraphs", + ] + for table in tables: + cursor.execute(f"DELETE FROM {table}") + self._conn.commit() + logger.info("元数据存储所有表已清空") + + + + def update_relation_timestamp(self, hash_value: str, access_count_delta: int = 1) -> None: + """更新关系的访问时间和计数""" + now = datetime.now().timestamp() + + # 同时更新 last_accessed (旧) 和 last_reinforced (V5) + + cursor = self._conn.cursor() + cursor.execute(""" + UPDATE relations + SET last_accessed = ?, + access_count = access_count + ? + WHERE hash = ? + """, (now, access_count_delta, hash_value)) + self._conn.commit() + + # ========================================================================= + # V5 Memory System Methods + # ========================================================================= + + def get_relation_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: + """ + 批量获取关系状态 (V5) + + Args: + hashes: 关系哈希列表 + + Returns: + Dict[hash, status_dict] + status_dict 包含: is_inactive, weight(confidence), is_pinned, protected_until, last_reinforced, inactive_since + """ + if not hashes: + return {} + + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT hash, is_inactive, confidence, is_pinned, protected_until, last_reinforced, inactive_since + FROM relations + WHERE hash IN ({placeholders}) + """, hashes) + + result = {} + for row in cursor.fetchall(): + result[row["hash"]] = { + "is_inactive": bool(row["is_inactive"]), + "weight": row["confidence"], + "is_pinned": bool(row["is_pinned"]), + "protected_until": row["protected_until"], + "last_reinforced": row["last_reinforced"], + "inactive_since": row["inactive_since"] + } + return result + + def mark_relations_active(self, hashes: List[str], boost_weight: Optional[float] = None) -> None: + """ + 批量标记关系为活跃 (Active/Revive) + + Args: + hashes: 关系哈希列表 + boost_weight: 如果提供,将设置 confidence = max(confidence, boost_weight) + """ + if not hashes: + return + + placeholders = ",".join(["?"] * len(hashes)) + cursor = self._conn.cursor() + + if boost_weight is not None: + cursor.execute(f""" + UPDATE relations + SET is_inactive = 0, + inactive_since = NULL, + confidence = MAX(confidence, ?) + WHERE hash IN ({placeholders}) + """, (boost_weight, *hashes)) + else: + cursor.execute(f""" + UPDATE relations + SET is_inactive = 0, + inactive_since = NULL + WHERE hash IN ({placeholders}) + """, hashes) + + self._conn.commit() + + def update_relations_protection( + self, + hashes: List[str], + protected_until: Optional[float] = None, + is_pinned: Optional[bool] = None, + last_reinforced: Optional[float] = None + ) -> None: + """ + 批量更新关系保护状态 + """ + if not hashes: + return + + updates = [] + params = [] + + if protected_until is not None: + updates.append("protected_until = ?") + params.append(protected_until) + if is_pinned is not None: + updates.append("is_pinned = ?") + params.append(1 if is_pinned else 0) + if last_reinforced is not None: + updates.append("last_reinforced = ?") + params.append(last_reinforced) + + if not updates: + return + + sql_set = ", ".join(updates) + placeholders = ",".join(["?"] * len(hashes)) + + params.extend(hashes) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE relations + SET {sql_set} + WHERE hash IN ({placeholders}) + """, params) + self._conn.commit() + + def get_prune_candidates(self, cutoff_time: float, limit: int = 1000) -> List[str]: + """ + 获取待修剪候选 (已过冷冻保留期) + + Args: + cutoff_time: 截止时间 (now - 冷冻时长) + limit: 限制数量 + """ + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash FROM relations + WHERE is_inactive = 1 + AND inactive_since < ? + LIMIT ? + """, (cutoff_time, limit)) + return [row[0] for row in cursor.fetchall()] + + def backup_and_delete_relations(self, hashes: List[str]) -> int: + """ + 备份并删除关系 (Prune) + + Returns: + 删除的数量 + """ + if not hashes: + return 0 + + placeholders = ",".join(["?"] * len(hashes)) + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + try: + # 1. 备份 + cursor.execute(f""" + INSERT OR REPLACE INTO deleted_relations + (hash, subject, predicate, object, vector_index, confidence, created_at, + vector_state, vector_updated_at, vector_error, vector_retry_count, + source_paragraph, metadata, is_permanent, last_accessed, access_count, + is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, deleted_at) + SELECT + hash, subject, predicate, object, vector_index, confidence, created_at, + vector_state, vector_updated_at, vector_error, vector_retry_count, + source_paragraph, metadata, is_permanent, last_accessed, access_count, + is_inactive, inactive_since, is_pinned, protected_until, last_reinforced, ? + FROM relations + WHERE hash IN ({placeholders}) + """, (now, *hashes)) + + # 2. 删除 (级联删除会自动处理 paragraph_relations 关联) + cursor.execute(f""" + DELETE FROM relations + WHERE hash IN ({placeholders}) + """, hashes) + + deleted_count = cursor.rowcount + self._conn.commit() + return deleted_count + + except Exception as e: + logger.error(f"备份删除失败: {e}") + self._conn.rollback() + return 0 + + def restore_relation_metadata(self, hash_value: str) -> Optional[Dict[str, Any]]: + """ + 从回收站恢复关系元数据 + + Returns: + 恢复后的关系数据 (字典),失败返回 None + """ + cursor = self._conn.cursor() + try: + # 1. 查询备份数据 + cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) + row = cursor.fetchone() + if not row: + return None + + data = dict(row) + # 移除 deleted_at 字段 + if "deleted_at" in data: + del data["deleted_at"] + + # 2. 插入回 relations 表 + # 动态构建 SQL 以适应字段变化 + columns = list(data.keys()) + placeholders = ",".join(["?"] * len(columns)) + cols_str = ",".join(columns) + values = list(data.values()) + + cursor.execute(f""" + INSERT OR REPLACE INTO relations ({cols_str}) + VALUES ({placeholders}) + """, values) + + # 3. 从备份表删除 + cursor.execute("DELETE FROM deleted_relations WHERE hash = ?", (hash_value,)) + + self._conn.commit() + return self._row_to_dict(row, "relation") # 使用助手函数将原始行转换为字典 + + except Exception as e: + logger.error(f"恢复关系失败: {hash_value} - {e}") + self._conn.rollback() + return None + + def restore_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """兼容旧调用名:恢复关系。""" + return self.restore_relation_metadata(hash_value) + + def get_protected_relations_hashes(self) -> List[str]: + """获取所有受保护关系的哈希 (Pinned 或 Protected Until > Now)""" + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash FROM relations + WHERE is_pinned = 1 OR protected_until > ? + """, (now,)) + + return [row[0] for row in cursor.fetchall()] + + + + def get_deleted_relations(self, limit: int = 50) -> List[Dict[str, Any]]: + """获取回收站中的关系记录""" + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM deleted_relations ORDER BY deleted_at DESC LIMIT ?", (limit,)) + data = [] + for row in cursor.fetchall(): + d = dict(row) + # 是否需要解码元数据?是的,与普通行相同 + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + data.append(d) + return data + + def get_deleted_relation(self, hash_value: str) -> Optional[Dict[str, Any]]: + """获取单条回收站记录""" + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM deleted_relations WHERE hash = ?", (hash_value,)) + row = cursor.fetchone() + if not row: return None + + d = dict(row) + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + return d + + def reinforce_relations(self, hashes: List[str]) -> None: + """强化关系 (更新 last_reinforced, is_inactive=0)""" + if not hashes: return + now = datetime.now().timestamp() + + cursor = self._conn.cursor() + # Batch update? chunking + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + sql = f""" + UPDATE relations + SET last_reinforced = ?, is_inactive = 0, inactive_since = NULL + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [now] + chunk) + + self._conn.commit() + + def mark_relations_inactive(self, hashes: List[str], inactive_since: Optional[float] = None) -> None: + """标记关系为非活跃 (Freeze)。兼容显式 inactive_since 或默认当前时间。""" + if not hashes: + return + mark_time = inactive_since if inactive_since is not None else datetime.now().timestamp() + + cursor = self._conn.cursor() + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + sql = f""" + UPDATE relations + SET is_inactive = 1, inactive_since = ? + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [mark_time] + chunk) + + self._conn.commit() + + def protect_relations( + self, + hashes: List[str], + is_pinned: bool = False, + ttl_seconds: float = 0 + ) -> None: + """ + 设置保护状态 + """ + if not hashes: return + now = datetime.now().timestamp() + protected_until = (now + ttl_seconds) if ttl_seconds > 0 else 0 + + cursor = self._conn.cursor() + chunk_size = 500 + for i in range(0, len(hashes), chunk_size): + chunk = hashes[i:i+chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + + # 由于 is_pinned 和 protected_until 是分开的,如果请求固定(pin),我们会同时更新这两项, + # 但通常用户要么切换固定状态,要么设置 TTL。 + # 如果 is_pinned=True,TTL 通常就不重要了。 + # 但目前的逻辑是正交处理它们的。 + + # 如果用户取消固定 (is_pinned=False),我们是否应该尊重已设置的 TTL? + # 当前的 API 会同时设置这两项。 + + sql = f""" + UPDATE relations + SET is_pinned = ?, protected_until = ? + WHERE hash IN ({placeholders}) + """ + cursor.execute(sql, [is_pinned, protected_until] + chunk) + + self._conn.commit() + + def vacuum(self) -> None: + """优化数据库""" + cursor = self._conn.cursor() + cursor.execute("VACUUM") + self._conn.commit() + logger.info("数据库优化完成") + + def _row_to_dict(self, row: sqlite3.Row, row_type: str) -> Dict[str, Any]: + """ + 将数据库行转换为字典 + + Args: + row: 数据库行 + row_type: 行类型 + + Returns: + 字典 + """ + d = dict(row) + + # 解码pickle字段 + if "metadata" in d and d["metadata"]: + try: + d["metadata"] = pickle.loads(d["metadata"]) + except Exception: + d["metadata"] = {} + + return d + + @property + def is_connected(self) -> bool: + """是否已连接""" + return self._conn is not None + + def __enter__(self): + """上下文管理器入口""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + self.close() + + # ========================================================================= + # V5 Soft Delete & Garbage Collection + # ========================================================================= + + def get_entity_gc_candidates(self, isolated_hashes: List[str], retention_seconds: float) -> List[str]: + """ + 获取实体 GC 候选列表 (Soft Delete Candidates) + 条件: + 1. 在 isolated_hashes 列表中 (由 GraphStore 提供;通常是实体名称) + 2. is_deleted = 0 (未被标记) + 3. created_at < now - retention (过了新手保护期) + 4. 不被任何 active paragraph 引用 (paragraph_entities check) + + Args: + isolated_hashes: 孤儿实体名称列表(兼容传入 hash) + retention_seconds: 保留时间 (秒) + """ + if not isolated_hashes: + return [] + + # GraphStore.get_isolated_nodes 返回节点名,这里做 canonicalize -> entity hash 映射。 + # 同时兼容历史调用直接传 hash。 + normalized_hashes: List[str] = [] + for item in isolated_hashes: + if not item: + continue + v = str(item).strip() + if len(v) == 64 and all(c in "0123456789abcdefABCDEF" for c in v): + normalized_hashes.append(v.lower()) + else: + canon = self._canonicalize_name(v) + if canon: + normalized_hashes.append(compute_hash(canon)) + + normalized_hashes = list(dict.fromkeys(normalized_hashes)) + if not normalized_hashes: + return [] + + now = datetime.now().timestamp() + cutoff = now - retention_seconds + + candidates = [] + batch_size = 900 + + # 分批处理 IN 查询 + for i in range(0, len(normalized_hashes), batch_size): + batch = normalized_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + # 使用 NOT EXISTS 子查询检查引用 + # 注意: paragraph_entities 中引用的 paragraph 如果被软删了,是否算引用? + # 这里的语义: 只要有 rows 存在于 paragraph_entities 且该 row 对应的 paragraph 没被彻底物理删除,就算引用。 + # 更严格: ... OR (EXISTS ... AND entity_hash=... AND is_deleted=0) + # 但 paragraph_entities 表没有 is_deleted 字段(它是关联表). 我们检查关联是否存在。 + # 如果 paragraph 本身 soft deleted, 它的引用应该失效吗? + # 策略: 只有当 paragraph 也是 active 时,引用才有效。 + # JOIN paragraphs p ON pe.paragraph_hash = p.hash WHERE p.is_deleted = 0 + + query = f""" + SELECT e.hash FROM entities e + WHERE e.hash IN ({placeholders}) + AND e.is_deleted = 0 + AND (e.created_at IS NULL OR e.created_at < ?) + AND NOT EXISTS ( + SELECT 1 FROM paragraph_entities pe + JOIN paragraphs p ON pe.paragraph_hash = p.hash + WHERE pe.entity_hash = e.hash + AND p.is_deleted = 0 + ) + """ + + cursor = self._conn.cursor() + cursor.execute(query, [*batch, cutoff]) + candidates.extend([row[0] for row in cursor.fetchall()]) + + return candidates + + def get_paragraph_gc_candidates(self, retention_seconds: float) -> List[str]: + """ + 获取段落 GC 候选列表 + 条件: + 1. is_deleted = 0 + 2. created_at < cutoff + 3. 没有 Relations (paragraph_relations empty) + 4. 没有 Entities 引用 (paragraph_entities empty) + OR 引用的 Entities 全是软删状态? (太复杂,简单点: 无引用) + + Refined Strategy: + 段落孤儿判定 = + (Left Join paragraph_relations -> NULL) AND + (Left Join paragraph_entities -> NULL) + """ + now = datetime.now().timestamp() + cutoff = now - retention_seconds + + query = """ + SELECT p.hash FROM paragraphs p + LEFT JOIN paragraph_relations pr ON p.hash = pr.paragraph_hash + LEFT JOIN paragraph_entities pe ON p.hash = pe.paragraph_hash + WHERE p.is_deleted = 0 + AND (p.created_at IS NULL OR p.created_at < ?) + AND pr.relation_hash IS NULL + AND pe.entity_hash IS NULL + """ + + cursor = self._conn.cursor() + cursor.execute(query, (cutoff,)) + return [row[0] for row in cursor.fetchall()] + + def mark_as_deleted(self, hashes: List[str], type_: str) -> int: + """ + 标记为软删除 (Mark Phase) + + Args: + hashes: Hash 列表 + type_: 'entity' | 'paragraph' + """ + if not hashes: + return 0 + + table = "entities" if type_ == "entity" else "paragraphs" + now = datetime.now().timestamp() + touched_sources: List[str] = [] + if type_ == "paragraph": + touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + # 幂等更新: 只更那些 is_deleted=0 的 + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE {table} + SET is_deleted = 1, deleted_at = ? + WHERE is_deleted = 0 AND hash IN ({placeholders}) + """, [now] + batch) + count += cursor.rowcount + + self._conn.commit() + if type_ == "paragraph" and count > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_soft_deleted", + ) + if count > 0: + logger.info(f"软删除标记 ({table}): {count} 项") + return count + + def sweep_deleted_items(self, type_: str, grace_period_seconds: float) -> List[Tuple[str, str]]: + """ + 扫描可物理清理的项目 (Sweep Phase - Selection) + + Args: + type_: 'entity' | 'paragraph' + grace_period_seconds: 宽限期 + + Returns: + List[(hash, name)]: 待删除项列表 (paragraph name为空) + """ + table = "entities" if type_ == "entity" else "paragraphs" + now = datetime.now().timestamp() + cutoff = now - grace_period_seconds + + cols = "hash, name" if type_ == "entity" else "hash, '' as name" + + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT {cols} FROM {table} + WHERE is_deleted = 1 + AND deleted_at < ? + """, (cutoff,)) + + return [(row[0], row[1]) for row in cursor.fetchall()] + + def physically_delete_entities(self, hashes: List[str]) -> int: + """物理删除实体 (批量)""" + if not hashes: return 0 + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f"DELETE FROM entities WHERE hash IN ({placeholders})", batch) + count += cursor.rowcount + + self._conn.commit() + return count + + def physically_delete_paragraphs(self, hashes: List[str]) -> int: + """物理删除段落 (批量)""" + if not hashes: return 0 + touched_sources = self._get_sources_for_paragraph_hashes(hashes, include_deleted=True) + + count = 0 + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f"DELETE FROM paragraphs WHERE hash IN ({placeholders})", batch) + count += cursor.rowcount + + self._conn.commit() + if count > 0: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_physically_deleted", + ) + return count + + def revive_if_deleted(self, entity_hashes: List[str] = None, paragraph_hashes: List[str] = None) -> int: + """ + 复活已软删的项目 (Auto Revival) + 当数据被再次访问、引用或导入时调用。 + """ + count = 0 + + if entity_hashes: + batch_size = 900 + for i in range(0, len(entity_hashes), batch_size): + batch = entity_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE entities + SET is_deleted = 0, deleted_at = NULL + WHERE is_deleted = 1 AND hash IN ({placeholders}) + """, batch) + count += cursor.rowcount + + if paragraph_hashes: + touched_sources = self._get_sources_for_paragraph_hashes(paragraph_hashes, include_deleted=True) + batch_size = 900 + for i in range(0, len(paragraph_hashes), batch_size): + batch = paragraph_hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + UPDATE paragraphs + SET is_deleted = 0, deleted_at = NULL + WHERE is_deleted = 1 AND hash IN ({placeholders}) + """, batch) + count += cursor.rowcount + else: + touched_sources = [] + + if count > 0: + self._conn.commit() + if touched_sources: + self._enqueue_episode_source_rebuilds( + touched_sources, + reason="paragraph_revived", + ) + logger.info(f"自动复活: {count} 项 (Soft Delete Revived)") + + return count + + def revive_entities_by_names(self, names: List[str]) -> int: + """ + 根据名称复活实体 (Convenience wrapper) + """ + if not names: return 0 + + # 使用内部方法计算哈希 + hashes = [compute_hash(self._canonicalize_name(n)) for n in names] + return self.revive_if_deleted(entity_hashes=hashes) + + def get_entity_status_batch(self, hashes: List[str]) -> Dict[str, Dict[str, Any]]: + """批量获取实体状态 (WebUI用)""" + if not hashes: return {} + + result = {} + batch_size = 900 + for i in range(0, len(hashes), batch_size): + batch = hashes[i:i+batch_size] + placeholders = ",".join(["?"] * len(batch)) + + cursor = self._conn.cursor() + cursor.execute(f""" + SELECT hash, is_deleted, deleted_at + FROM entities + WHERE hash IN ({placeholders}) + """, batch) + + for row in cursor.fetchall(): + result[row[0]] = { + "is_deleted": bool(row[1]), + "deleted_at": row[2] + } + return result + + # ========================================================================= + # Person Profile (问题3) - Switches / Active Set / Snapshots + # ========================================================================= + + def set_person_profile_switch( + self, + stream_id: str, + user_id: str, + enabled: bool, + updated_at: Optional[float] = None, + ) -> None: + """设置人物画像自动注入开关(按 stream_id + user_id)。""" + if not stream_id or not user_id: + raise ValueError("stream_id 和 user_id 不能为空") + + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_switches (stream_id, user_id, enabled, updated_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(stream_id, user_id) DO UPDATE SET + enabled = excluded.enabled, + updated_at = excluded.updated_at + """, + (str(stream_id), str(user_id), 1 if enabled else 0, ts), + ) + self._conn.commit() + + def get_person_profile_switch(self, stream_id: str, user_id: str, default: bool = False) -> bool: + """读取人物画像自动注入开关。""" + if not stream_id or not user_id: + return bool(default) + + cursor = self._conn.cursor() + cursor.execute( + "SELECT enabled FROM person_profile_switches WHERE stream_id = ? AND user_id = ?", + (str(stream_id), str(user_id)), + ) + row = cursor.fetchone() + if not row: + return bool(default) + return bool(row[0]) + + def get_enabled_person_profile_switches(self, limit: int = 1000) -> List[Dict[str, Any]]: + """获取已开启人物画像注入开关的会话范围。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT stream_id, user_id, enabled, updated_at + FROM person_profile_switches + WHERE enabled = 1 + ORDER BY updated_at DESC + LIMIT ? + """, + (int(max(1, limit)),), + ) + return [ + { + "stream_id": row[0], + "user_id": row[1], + "enabled": bool(row[2]), + "updated_at": row[3], + } + for row in cursor.fetchall() + ] + + def mark_person_profile_active( + self, + stream_id: str, + user_id: str, + person_id: str, + seen_at: Optional[float] = None, + ) -> None: + """记录活跃人物(用于定时按需刷新)。""" + if not stream_id or not user_id or not person_id: + return + ts = float(seen_at) if seen_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_active_persons (stream_id, user_id, person_id, last_seen_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(stream_id, user_id, person_id) DO UPDATE SET + last_seen_at = excluded.last_seen_at + """, + (str(stream_id), str(user_id), str(person_id), ts), + ) + self._conn.commit() + + def get_active_person_ids_for_enabled_switches( + self, + active_after: Optional[float] = None, + limit: int = 200, + ) -> List[str]: + """获取“已开启开关范围内”的活跃人物集合。""" + cursor = self._conn.cursor() + sql = """ + SELECT a.person_id, MAX(a.last_seen_at) AS last_seen + FROM person_profile_active_persons a + JOIN person_profile_switches s + ON a.stream_id = s.stream_id AND a.user_id = s.user_id + WHERE s.enabled = 1 + """ + params: List[Any] = [] + if active_after is not None: + sql += " AND a.last_seen_at >= ?" + params.append(float(active_after)) + sql += """ + GROUP BY a.person_id + ORDER BY last_seen DESC + LIMIT ? + """ + params.append(int(max(1, limit))) + cursor.execute(sql, tuple(params)) + return [str(row[0]) for row in cursor.fetchall() if row and row[0]] + + def get_latest_person_profile_snapshot(self, person_id: str) -> Optional[Dict[str, Any]]: + """获取人物最新画像快照。""" + if not person_id: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT + snapshot_id, person_id, profile_version, profile_text, + aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, + updated_at, expires_at, source_note + FROM person_profile_snapshots + WHERE person_id = ? + ORDER BY profile_version DESC + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + if not row: + return None + + def _load_list(raw: Any) -> List[Any]: + if not raw: + return [] + try: + data = json.loads(raw) + return data if isinstance(data, list) else [] + except Exception: + return [] + + return { + "snapshot_id": row[0], + "person_id": row[1], + "profile_version": int(row[2]), + "profile_text": row[3] or "", + "aliases": _load_list(row[4]), + "relation_edges": _load_list(row[5]), + "vector_evidence": _load_list(row[6]), + "evidence_ids": _load_list(row[7]), + "updated_at": row[8], + "expires_at": row[9], + "source_note": row[10] or "", + } + + def upsert_person_profile_snapshot( + self, + person_id: str, + profile_text: str, + aliases: Optional[List[str]] = None, + relation_edges: Optional[List[Dict[str, Any]]] = None, + vector_evidence: Optional[List[Dict[str, Any]]] = None, + evidence_ids: Optional[List[str]] = None, + expires_at: Optional[float] = None, + source_note: str = "", + updated_at: Optional[float] = None, + ) -> Dict[str, Any]: + """写入人物画像快照(按 person_id 自动递增版本)。""" + if not person_id: + raise ValueError("person_id 不能为空") + + aliases = aliases or [] + relation_edges = relation_edges or [] + vector_evidence = vector_evidence or [] + evidence_ids = evidence_ids or [] + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT profile_version + FROM person_profile_snapshots + WHERE person_id = ? + ORDER BY profile_version DESC + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + next_version = int(row[0]) + 1 if row else 1 + + cursor.execute( + """ + INSERT INTO person_profile_snapshots ( + person_id, profile_version, profile_text, + aliases_json, relation_edges_json, vector_evidence_json, evidence_ids_json, + updated_at, expires_at, source_note + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + str(person_id), + next_version, + str(profile_text or ""), + json.dumps(aliases, ensure_ascii=False), + json.dumps(relation_edges, ensure_ascii=False), + json.dumps(vector_evidence, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + ts, + float(expires_at) if expires_at is not None else None, + str(source_note or ""), + ), + ) + self._conn.commit() + latest = self.get_latest_person_profile_snapshot(person_id) + return latest or { + "person_id": person_id, + "profile_version": next_version, + "profile_text": str(profile_text or ""), + "aliases": aliases, + "relation_edges": relation_edges, + "vector_evidence": vector_evidence, + "evidence_ids": evidence_ids, + "updated_at": ts, + "expires_at": expires_at, + "source_note": source_note, + } + + def get_person_profile_override(self, person_id: str) -> Optional[Dict[str, Any]]: + """获取人物画像手工覆盖内容。""" + if not person_id: + return None + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT person_id, override_text, updated_at, updated_by, source + FROM person_profile_overrides + WHERE person_id = ? + LIMIT 1 + """, + (str(person_id),), + ) + row = cursor.fetchone() + if not row: + return None + return { + "person_id": str(row[0]), + "override_text": str(row[1] or ""), + "updated_at": row[2], + "updated_by": str(row[3] or ""), + "source": str(row[4] or ""), + } + + def set_person_profile_override( + self, + person_id: str, + override_text: str, + updated_by: str = "", + source: str = "webui", + updated_at: Optional[float] = None, + ) -> Dict[str, Any]: + """写入人物画像手工覆盖;空文本等价于清除覆盖。""" + if not person_id: + raise ValueError("person_id 不能为空") + + text = str(override_text or "").strip() + if not text: + self.delete_person_profile_override(person_id) + return { + "person_id": str(person_id), + "override_text": "", + "updated_at": None, + "updated_by": str(updated_by or ""), + "source": str(source or ""), + } + + ts = float(updated_at) if updated_at is not None else datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO person_profile_overrides ( + person_id, override_text, updated_at, updated_by, source + ) VALUES (?, ?, ?, ?, ?) + ON CONFLICT(person_id) DO UPDATE SET + override_text = excluded.override_text, + updated_at = excluded.updated_at, + updated_by = excluded.updated_by, + source = excluded.source + """, + ( + str(person_id), + text, + ts, + str(updated_by or ""), + str(source or ""), + ), + ) + self._conn.commit() + return self.get_person_profile_override(person_id) or { + "person_id": str(person_id), + "override_text": text, + "updated_at": ts, + "updated_by": str(updated_by or ""), + "source": str(source or ""), + } + + def delete_person_profile_override(self, person_id: str) -> bool: + """删除人物画像手工覆盖。""" + if not person_id: + return False + cursor = self._conn.cursor() + cursor.execute( + "DELETE FROM person_profile_overrides WHERE person_id = ?", + (str(person_id),), + ) + self._conn.commit() + return cursor.rowcount > 0 + + # ========================================================================= + # Episode MVP + # ========================================================================= + + def enqueue_episode_source_rebuild(self, source: str, reason: str = "") -> bool: + """将 source 入队到 episode 重建队列。""" + return bool(self._enqueue_episode_source_rebuilds([source], reason=reason)) + + def fetch_episode_source_rebuild_batch( + self, + limit: int = 20, + max_retry: int = 3, + ) -> List[Dict[str, Any]]: + """获取待处理的 source 重建任务。""" + safe_limit = max(1, int(limit)) + safe_retry = max(0, int(max_retry)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT source, status, retry_count, last_error, reason, requested_at, updated_at + FROM episode_rebuild_sources + WHERE status = 'pending' + OR (status = 'failed' AND retry_count < ?) + ORDER BY requested_at ASC, updated_at ASC + LIMIT ? + """, + (safe_retry, safe_limit), + ) + return [dict(row) for row in cursor.fetchall()] + + def mark_episode_source_running( + self, + source: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + """将 source 标记为 running。""" + token = self._normalize_episode_source(source) + if not token: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + params: List[Any] = [now, token] + sql = """ + UPDATE episode_rebuild_sources + SET status = 'running', + updated_at = ? + WHERE source = ? + AND status IN ('pending', 'failed') + """ + if requested_at is not None: + sql += " AND requested_at = ?" + params.append(float(requested_at)) + cursor.execute(sql, tuple(params)) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_episode_source_done( + self, + source: str, + *, + requested_at: Optional[float] = None, + ) -> bool: + """将 source 标记为 done;若运行期间发生新写入,则保持 pending。""" + token = self._normalize_episode_source(source) + if not token: + return False + + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = 'done', + last_error = NULL, + updated_at = ? + WHERE source = ? + """, + (now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'done' + END, + last_error = NULL, + updated_at = ? + WHERE source = ? + """, + (req_ts, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def mark_episode_source_failed( + self, + source: str, + error: str = "", + *, + requested_at: Optional[float] = None, + ) -> bool: + """标记 source 失败;若运行期间发生新写入,则重新回到 pending。""" + token = self._normalize_episode_source(source) + if not token: + return False + + err_text = str(error or "").strip()[:500] + now = datetime.now().timestamp() + cursor = self._conn.cursor() + if requested_at is None: + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = 'failed', + retry_count = COALESCE(retry_count, 0) + 1, + last_error = ?, + updated_at = ? + WHERE source = ? + """, + (err_text, now, token), + ) + else: + req_ts = float(requested_at) + cursor.execute( + """ + UPDATE episode_rebuild_sources + SET status = CASE + WHEN requested_at > ? THEN 'pending' + ELSE 'failed' + END, + retry_count = CASE + WHEN requested_at > ? THEN COALESCE(retry_count, 0) + ELSE COALESCE(retry_count, 0) + 1 + END, + last_error = CASE + WHEN requested_at > ? THEN NULL + ELSE ? + END, + updated_at = ? + WHERE source = ? + """, + (req_ts, req_ts, req_ts, err_text, now, token), + ) + self._conn.commit() + return cursor.rowcount > 0 + + def list_episode_source_rebuilds( + self, + *, + statuses: Optional[List[str]] = None, + limit: int = 100, + ) -> List[Dict[str, Any]]: + """列出 source 重建状态。""" + safe_limit = max(1, int(limit)) + params: List[Any] = [] + conditions: List[str] = [] + normalized_statuses = [ + str(item or "").strip().lower() + for item in (statuses or []) + if str(item or "").strip().lower() in {"pending", "running", "done", "failed"} + ] + if normalized_statuses: + placeholders = ",".join(["?"] * len(normalized_statuses)) + conditions.append(f"status IN ({placeholders})") + params.extend(normalized_statuses) + + where_sql = f"WHERE {' AND '.join(conditions)}" if conditions else "" + params.append(safe_limit) + cursor = self._conn.cursor() + cursor.execute( + f""" + SELECT source, status, retry_count, last_error, reason, requested_at, updated_at + FROM episode_rebuild_sources + {where_sql} + ORDER BY updated_at DESC, source ASC + LIMIT ? + """, + tuple(params), + ) + return [dict(row) for row in cursor.fetchall()] + + def get_episode_source_rebuild_summary(self, failed_limit: int = 20) -> Dict[str, Any]: + """汇总 source 重建队列状态。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT status, COUNT(*) AS cnt + FROM episode_rebuild_sources + GROUP BY status + """ + ) + counts = {"pending": 0, "running": 0, "done": 0, "failed": 0, "total": 0} + for row in cursor.fetchall(): + status = str(row["status"] or "").strip().lower() + cnt = int(row["cnt"] or 0) + counts[status] = counts.get(status, 0) + cnt + counts["total"] += cnt + + running = self.list_episode_source_rebuilds(statuses=["running"], limit=20) + failed = self.list_episode_source_rebuilds( + statuses=["failed"], + limit=max(1, int(failed_limit)), + ) + return { + "counts": counts, + "running": running, + "failed": failed, + } + + def get_live_paragraphs_by_source(self, source: str) -> List[Dict[str, Any]]: + """获取指定 source 下所有 live paragraphs。""" + token = self._normalize_episode_source(source) + if not token: + return [] + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT * + FROM paragraphs + WHERE TRIM(COALESCE(source, '')) = ? + AND (is_deleted IS NULL OR is_deleted = 0) + ORDER BY created_at ASC, hash ASC + """, + (token,), + ) + return [self._row_to_dict(row, "paragraph") for row in cursor.fetchall()] + + def list_episode_sources_for_rebuild(self) -> List[str]: + """列出全量重建涉及的 source(live paragraphs + stale episodes)。""" + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT DISTINCT source + FROM ( + SELECT TRIM(source) AS source + FROM paragraphs + WHERE TRIM(COALESCE(source, '')) != '' + AND (is_deleted IS NULL OR is_deleted = 0) + UNION + SELECT TRIM(source) AS source + FROM episodes + WHERE TRIM(COALESCE(source, '')) != '' + ) + WHERE TRIM(COALESCE(source, '')) != '' + ORDER BY source ASC + """ + ) + return self._dedupe_episode_sources([row["source"] for row in cursor.fetchall()]) + + def is_episode_source_query_blocked(self, source: str) -> bool: + """判断 source 是否处于重建中或失败状态。""" + token = self._normalize_episode_source(source) + if not token: + return False + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT 1 + FROM episode_rebuild_sources + WHERE source = ? + AND status IN ('pending', 'running', 'failed') + LIMIT 1 + """, + (token,), + ) + return cursor.fetchone() is not None + + def replace_episodes_for_source( + self, + source: str, + episodes_payloads: List[Dict[str, Any]], + ) -> Dict[str, Any]: + """按 source 全量替换 episode 结果。""" + token = self._normalize_episode_source(source) + if not token: + return {"source": "", "episode_count": 0} + + payloads = [dict(item) for item in (episodes_payloads or []) if isinstance(item, dict)] + now = datetime.now().timestamp() + cursor = self._conn.cursor() + + try: + cursor.execute("BEGIN IMMEDIATE") + cursor.execute( + """ + SELECT episode_id, created_at + FROM episodes + WHERE TRIM(COALESCE(source, '')) = ? + """, + (token,), + ) + existing_created_at = { + str(row["episode_id"]): self._as_optional_float(row["created_at"]) + for row in cursor.fetchall() + } + + cursor.execute( + "DELETE FROM episodes WHERE TRIM(COALESCE(source, '')) = ?", + (token,), + ) + + inserted_count = 0 + for raw_payload in payloads: + title = str(raw_payload.get("title", "") or "").strip() + summary = str(raw_payload.get("summary", "") or "").strip() + evidence_ids = [ + str(item).strip() + for item in (raw_payload.get("evidence_ids") or []) + if str(item).strip() + ] + evidence_ids = list(dict.fromkeys(evidence_ids)) + if not title or not summary or not evidence_ids: + continue + + episode_id = str(raw_payload.get("episode_id", "") or "").strip() + if not episode_id: + seed = json.dumps( + { + "source": token, + "title": title, + "summary": summary, + "event_time_start": raw_payload.get("event_time_start"), + "event_time_end": raw_payload.get("event_time_end"), + "evidence_ids": evidence_ids, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + participants = [ + str(item).strip() + for item in (raw_payload.get("participants") or []) + if str(item).strip() + ][:16] + keywords = [ + str(item).strip() + for item in (raw_payload.get("keywords") or []) + if str(item).strip() + ][:20] + paragraph_count = raw_payload.get("paragraph_count", len(evidence_ids)) + try: + paragraph_count = max(0, int(paragraph_count)) + except Exception: + paragraph_count = len(evidence_ids) + if paragraph_count <= 0: + paragraph_count = len(evidence_ids) + if paragraph_count <= 0: + continue + + time_confidence = raw_payload.get("time_confidence", 1.0) + llm_confidence = raw_payload.get("llm_confidence", 0.0) + try: + time_confidence = float(time_confidence) + except Exception: + time_confidence = 1.0 + try: + llm_confidence = float(llm_confidence) + except Exception: + llm_confidence = 0.0 + + created_at = existing_created_at.get(episode_id) + created_ts = created_at if created_at is not None else now + updated_ts = self._as_optional_float(raw_payload.get("updated_at")) or now + + cursor.execute( + """ + INSERT INTO episodes ( + episode_id, source, title, summary, + event_time_start, event_time_end, time_granularity, time_confidence, + participants_json, keywords_json, evidence_ids_json, + paragraph_count, llm_confidence, segmentation_model, segmentation_version, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + episode_id, + token, + title[:120], + summary[:2000], + self._as_optional_float(raw_payload.get("event_time_start")), + self._as_optional_float(raw_payload.get("event_time_end")), + str(raw_payload.get("time_granularity", "") or "").strip() or None, + time_confidence, + json.dumps(participants, ensure_ascii=False), + json.dumps(keywords, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + paragraph_count, + llm_confidence, + str(raw_payload.get("segmentation_model", "") or "").strip() or None, + str(raw_payload.get("segmentation_version", "") or "").strip() or None, + created_ts, + updated_ts, + ), + ) + cursor.executemany( + """ + INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) + VALUES (?, ?, ?) + """, + [(episode_id, hash_value, idx) for idx, hash_value in enumerate(evidence_ids)], + ) + inserted_count += 1 + + self._conn.commit() + return {"source": token, "episode_count": inserted_count} + except Exception: + self._conn.rollback() + raise + + def enqueue_episode_pending( + self, + paragraph_hash: str, + source: Optional[str] = None, + created_at: Optional[float] = None, + ) -> None: + """将段落入队到 episode 异步生成队列。""" + token = str(paragraph_hash or "").strip() + if not token: + return + now = datetime.now().timestamp() + created_ts = float(created_at) if created_at is not None else now + src = str(source or "").strip() or None + + cursor = self._conn.cursor() + cursor.execute( + """ + INSERT INTO episode_pending_paragraphs ( + paragraph_hash, source, created_at, status, retry_count, last_error, updated_at + ) VALUES (?, ?, ?, 'pending', 0, NULL, ?) + ON CONFLICT(paragraph_hash) DO UPDATE SET + source = excluded.source, + created_at = COALESCE(episode_pending_paragraphs.created_at, excluded.created_at), + status = CASE + WHEN episode_pending_paragraphs.status = 'done' THEN 'done' + ELSE 'pending' + END, + last_error = CASE + WHEN episode_pending_paragraphs.status = 'done' THEN episode_pending_paragraphs.last_error + ELSE NULL + END, + updated_at = excluded.updated_at + """, + (token, src, created_ts, now), + ) + self._conn.commit() + + def fetch_episode_pending_batch(self, limit: int = 20, max_retry: int = 3) -> List[Dict[str, Any]]: + """获取待处理 episode 队列批次。""" + safe_limit = max(1, int(limit)) + safe_retry = max(0, int(max_retry)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT paragraph_hash, source, created_at, status, retry_count, last_error, updated_at + FROM episode_pending_paragraphs + WHERE status = 'pending' + OR (status = 'failed' AND retry_count < ?) + ORDER BY updated_at ASC + LIMIT ? + """, + (safe_retry, safe_limit), + ) + return [dict(row) for row in cursor.fetchall()] + + def mark_episode_pending_running(self, hashes: List[str]) -> None: + """批量标记队列项为 running。""" + if not hashes: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + chunk_size = 500 + uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) + for i in range(0, len(uniq), chunk_size): + chunk = uniq[i:i + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE episode_pending_paragraphs + SET status = 'running', updated_at = ? + WHERE paragraph_hash IN ({placeholders}) + AND status IN ('pending', 'failed') + """, + [now] + chunk, + ) + self._conn.commit() + + def mark_episode_pending_done(self, hashes: List[str]) -> None: + """批量标记队列项为 done。""" + if not hashes: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + chunk_size = 500 + uniq = list(dict.fromkeys([str(h).strip() for h in hashes if str(h).strip()])) + for i in range(0, len(uniq), chunk_size): + chunk = uniq[i:i + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cursor.execute( + f""" + UPDATE episode_pending_paragraphs + SET status = 'done', + last_error = NULL, + updated_at = ? + WHERE paragraph_hash IN ({placeholders}) + """, + [now] + chunk, + ) + self._conn.commit() + + def mark_episode_pending_failed(self, hash_value: str, error: str = "") -> None: + """标记单条队列项失败并累加重试次数。""" + token = str(hash_value or "").strip() + if not token: + return + now = datetime.now().timestamp() + cursor = self._conn.cursor() + cursor.execute( + """ + UPDATE episode_pending_paragraphs + SET status = 'failed', + retry_count = COALESCE(retry_count, 0) + 1, + last_error = ?, + updated_at = ? + WHERE paragraph_hash = ? + """, + (str(error or ""), now, token), + ) + self._conn.commit() + + def _episode_row_to_dict(self, row: sqlite3.Row) -> Dict[str, Any]: + data = dict(row) + + def _load_list(raw: Any) -> List[Any]: + if not raw: + return [] + try: + val = json.loads(raw) + return val if isinstance(val, list) else [] + except Exception: + return [] + + data["participants"] = _load_list(data.pop("participants_json", None)) + data["keywords"] = _load_list(data.pop("keywords_json", None)) + data["evidence_ids"] = _load_list(data.pop("evidence_ids_json", None)) + return data + + @staticmethod + def _as_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except Exception: + return None + + def upsert_episode(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """写入或更新 episode。""" + if not isinstance(payload, dict): + raise ValueError("payload 必须是字典") + + title = str(payload.get("title", "") or "").strip() + summary = str(payload.get("summary", "") or "").strip() + if not title: + raise ValueError("episode.title 不能为空") + if not summary: + raise ValueError("episode.summary 不能为空") + + source = str(payload.get("source", "") or "").strip() or None + participants_raw = payload.get("participants", []) or [] + keywords_raw = payload.get("keywords", []) or [] + evidence_ids_raw = payload.get("evidence_ids", []) or [] + participants = [str(x).strip() for x in participants_raw if str(x).strip()] + keywords = [str(x).strip() for x in keywords_raw if str(x).strip()] + evidence_ids = [str(x).strip() for x in evidence_ids_raw if str(x).strip()] + + now = datetime.now().timestamp() + created_at = self._as_optional_float(payload.get("created_at")) + updated_at = self._as_optional_float(payload.get("updated_at")) + created_ts = created_at if created_at is not None else now + updated_ts = updated_at if updated_at is not None else now + + episode_id = str(payload.get("episode_id", "") or "").strip() + if not episode_id: + seed = json.dumps( + { + "source": source, + "title": title, + "summary": summary, + "event_time_start": payload.get("event_time_start"), + "event_time_end": payload.get("event_time_end"), + "evidence_ids": evidence_ids, + }, + ensure_ascii=False, + sort_keys=True, + ) + episode_id = compute_hash(seed) + + paragraph_count = payload.get("paragraph_count") + if paragraph_count is None: + paragraph_count = len(evidence_ids) + try: + paragraph_count = int(paragraph_count) + except Exception: + paragraph_count = len(evidence_ids) + + time_conf = payload.get("time_confidence", 1.0) + llm_conf = payload.get("llm_confidence", 0.0) + try: + time_conf = float(time_conf) + except Exception: + time_conf = 1.0 + try: + llm_conf = float(llm_conf) + except Exception: + llm_conf = 0.0 + + cursor = self._conn.cursor() + cursor.execute( + "SELECT created_at FROM episodes WHERE episode_id = ? LIMIT 1", + (episode_id,), + ) + existed = cursor.fetchone() + if existed and existed[0] is not None: + created_ts = float(existed[0]) + + cursor.execute( + """ + INSERT INTO episodes ( + episode_id, source, title, summary, + event_time_start, event_time_end, time_granularity, time_confidence, + participants_json, keywords_json, evidence_ids_json, + paragraph_count, llm_confidence, segmentation_model, segmentation_version, + created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(episode_id) DO UPDATE SET + source = excluded.source, + title = excluded.title, + summary = excluded.summary, + event_time_start = excluded.event_time_start, + event_time_end = excluded.event_time_end, + time_granularity = excluded.time_granularity, + time_confidence = excluded.time_confidence, + participants_json = excluded.participants_json, + keywords_json = excluded.keywords_json, + evidence_ids_json = excluded.evidence_ids_json, + paragraph_count = excluded.paragraph_count, + llm_confidence = excluded.llm_confidence, + segmentation_model = excluded.segmentation_model, + segmentation_version = excluded.segmentation_version, + updated_at = excluded.updated_at + """, + ( + episode_id, + source, + title, + summary, + self._as_optional_float(payload.get("event_time_start")), + self._as_optional_float(payload.get("event_time_end")), + str(payload.get("time_granularity", "") or "").strip() or None, + time_conf, + json.dumps(participants, ensure_ascii=False), + json.dumps(keywords, ensure_ascii=False), + json.dumps(evidence_ids, ensure_ascii=False), + max(0, paragraph_count), + llm_conf, + str(payload.get("segmentation_model", "") or "").strip() or None, + str(payload.get("segmentation_version", "") or "").strip() or None, + created_ts, + updated_ts, + ), + ) + self._conn.commit() + return self.get_episode_by_id(episode_id) or {"episode_id": episode_id} + + def bind_episode_paragraphs(self, episode_id: str, paragraph_hashes_ordered: List[str]) -> int: + """重建 episode 与段落映射。""" + token = str(episode_id or "").strip() + if not token: + raise ValueError("episode_id 不能为空") + + normalized: List[str] = [] + seen = set() + for item in paragraph_hashes_ordered or []: + h = str(item or "").strip() + if not h or h in seen: + continue + seen.add(h) + normalized.append(h) + + cursor = self._conn.cursor() + cursor.execute("DELETE FROM episode_paragraphs WHERE episode_id = ?", (token,)) + + if normalized: + cursor.executemany( + """ + INSERT OR IGNORE INTO episode_paragraphs (episode_id, paragraph_hash, position) + VALUES (?, ?, ?) + """, + [(token, h, idx) for idx, h in enumerate(normalized)], + ) + + now = datetime.now().timestamp() + cursor.execute( + """ + UPDATE episodes + SET paragraph_count = ?, updated_at = ? + WHERE episode_id = ? + """, + (len(normalized), now, token), + ) + self._conn.commit() + return len(normalized) + + def _build_episode_query_components( + self, + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> Tuple[str, str, str, List[str], List[Any]]: + source_expr = "TRIM(COALESCE(e.source, ''))" + effective_start = "COALESCE(e.event_time_start, e.event_time_end, e.updated_at)" + effective_end = "COALESCE(e.event_time_end, e.event_time_start, e.updated_at)" + conditions: List[str] = [] + params: List[Any] = [] + + conditions.append(f"{source_expr} != ''") + conditions.append("COALESCE(e.paragraph_count, 0) > 0") + conditions.append( + """ + NOT EXISTS ( + SELECT 1 + FROM episode_rebuild_sources ers + WHERE ers.source = TRIM(COALESCE(e.source, '')) + AND ers.status IN ('pending', 'running', 'failed') + ) + """ + ) + + if source: + token = self._normalize_episode_source(source) + if not token: + return source_expr, effective_start, effective_end, ["1 = 0"], [] + conditions.append(f"{source_expr} = ?") + params.append(token) + + p = str(person or "").strip().lower() + if p: + like_person = f"%{p}%" + conditions.append( + """ + ( + LOWER(COALESCE(e.participants_json, '')) LIKE ? + OR EXISTS ( + SELECT 1 + FROM episode_paragraphs ep_person + JOIN paragraph_entities pe ON pe.paragraph_hash = ep_person.paragraph_hash + JOIN entities en ON en.hash = pe.entity_hash + WHERE ep_person.episode_id = e.episode_id + AND LOWER(en.name) LIKE ? + ) + ) + """ + ) + params.extend([like_person, like_person]) + + if time_from is not None and time_to is not None: + conditions.append(f"({effective_end} >= ? AND {effective_start} <= ?)") + params.extend([float(time_from), float(time_to)]) + elif time_from is not None: + conditions.append(f"({effective_end} >= ?)") + params.append(float(time_from)) + elif time_to is not None: + conditions.append(f"({effective_start} <= ?)") + params.append(float(time_to)) + + return source_expr, effective_start, effective_end, conditions, params + + def get_episode_rows_by_paragraph_hashes( + self, + paragraph_hashes: List[str], + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> List[Dict[str, Any]]: + normalized: List[str] = [] + seen = set() + for item in paragraph_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return [] + + _, _, _, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + placeholders = ",".join(["?"] * len(normalized)) + conditions.append(f"ep.paragraph_hash IN ({placeholders})") + conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") + where_sql = "WHERE " + " AND ".join(conditions) + + sql = f""" + SELECT e.*, ep.paragraph_hash AS matched_paragraph_hash + FROM episodes e + JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id + JOIN paragraphs p ON p.hash = ep.paragraph_hash + {where_sql} + ORDER BY e.updated_at DESC + """ + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params + normalized)) + + grouped: Dict[str, Dict[str, Any]] = {} + for row in cursor.fetchall(): + episode_id = str(row["episode_id"] or "").strip() + if not episode_id: + continue + payload = grouped.get(episode_id) + if payload is None: + payload = self._episode_row_to_dict(row) + payload["matched_paragraph_hashes"] = [] + grouped[episode_id] = payload + matched_hash = str(row["matched_paragraph_hash"] or "").strip() + if matched_hash and matched_hash not in payload["matched_paragraph_hashes"]: + payload["matched_paragraph_hashes"].append(matched_hash) + + out = list(grouped.values()) + for item in out: + item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) + return out + + def get_episode_rows_by_relation_hashes( + self, + relation_hashes: List[str], + *, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + ) -> List[Dict[str, Any]]: + normalized: List[str] = [] + seen = set() + for item in relation_hashes or []: + token = str(item or "").strip() + if not token or token in seen: + continue + seen.add(token) + normalized.append(token) + if not normalized: + return [] + + _, _, _, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + placeholders = ",".join(["?"] * len(normalized)) + conditions.append(f"pr.relation_hash IN ({placeholders})") + conditions.append("(p.is_deleted IS NULL OR p.is_deleted = 0)") + where_sql = "WHERE " + " AND ".join(conditions) + + sql = f""" + SELECT + e.*, + p.hash AS matched_paragraph_hash, + pr.relation_hash AS matched_relation_hash + FROM episodes e + JOIN episode_paragraphs ep ON ep.episode_id = e.episode_id + JOIN paragraphs p ON p.hash = ep.paragraph_hash + JOIN paragraph_relations pr ON pr.paragraph_hash = p.hash + {where_sql} + ORDER BY e.updated_at DESC + """ + cursor = self._conn.cursor() + cursor.execute(sql, tuple(params + normalized)) + + grouped: Dict[str, Dict[str, Any]] = {} + for row in cursor.fetchall(): + episode_id = str(row["episode_id"] or "").strip() + if not episode_id: + continue + payload = grouped.get(episode_id) + if payload is None: + payload = self._episode_row_to_dict(row) + payload["matched_paragraph_hashes"] = [] + payload["matched_relation_hashes"] = [] + grouped[episode_id] = payload + matched_paragraph = str(row["matched_paragraph_hash"] or "").strip() + matched_relation = str(row["matched_relation_hash"] or "").strip() + if matched_paragraph and matched_paragraph not in payload["matched_paragraph_hashes"]: + payload["matched_paragraph_hashes"].append(matched_paragraph) + if matched_relation and matched_relation not in payload["matched_relation_hashes"]: + payload["matched_relation_hashes"].append(matched_relation) + + out = list(grouped.values()) + for item in out: + item["matched_paragraph_count"] = len(item.get("matched_paragraph_hashes", [])) + item["matched_relation_count"] = len(item.get("matched_relation_hashes", [])) + return out + + def query_episodes( + self, + query: str = "", + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + limit: int = 20, + ) -> List[Dict[str, Any]]: + """查询 episode 列表。""" + safe_limit = max(1, int(limit)) + _, effective_start, effective_end, conditions, params = self._build_episode_query_components( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + + q = str(query or "").strip().lower() + select_score_sql = "0.0 AS lexical_score" + order_sql = f"{effective_end} DESC, e.updated_at DESC" + select_params: List[Any] = [] + query_params: List[Any] = [] + if q: + like = f"%{q}%" + title_expr = "LOWER(COALESCE(e.title, '')) LIKE ?" + summary_expr = "LOWER(COALESCE(e.summary, '')) LIKE ?" + keywords_expr = "LOWER(COALESCE(e.keywords_json, '')) LIKE ?" + participants_expr = "LOWER(COALESCE(e.participants_json, '')) LIKE ?" + conditions.append( + f"({title_expr} OR {summary_expr} OR {keywords_expr} OR {participants_expr})" + ) + select_score_sql = ( + f"(CASE WHEN {title_expr} THEN 4.0 ELSE 0.0 END + " + f"CASE WHEN {keywords_expr} THEN 3.0 ELSE 0.0 END + " + f"CASE WHEN {summary_expr} THEN 2.0 ELSE 0.0 END + " + f"CASE WHEN {participants_expr} THEN 1.0 ELSE 0.0 END) AS lexical_score" + ) + select_params.extend([like, like, like, like]) + query_params.extend([like, like, like, like]) + order_sql = f"lexical_score DESC, {effective_end} DESC, e.updated_at DESC" + + where_sql = ("WHERE " + " AND ".join(conditions)) if conditions else "" + sql = f""" + SELECT e.*, {select_score_sql} + FROM episodes e + {where_sql} + ORDER BY {order_sql} + LIMIT ? + """ + final_params = list(select_params) + list(params) + list(query_params) + [safe_limit] + + cursor = self._conn.cursor() + cursor.execute(sql, tuple(final_params)) + return [self._episode_row_to_dict(row) for row in cursor.fetchall()] + + def get_episode_by_id(self, episode_id: str) -> Optional[Dict[str, Any]]: + """获取单条 episode。""" + token = str(episode_id or "").strip() + if not token: + return None + cursor = self._conn.cursor() + cursor.execute( + "SELECT * FROM episodes WHERE episode_id = ? LIMIT 1", + (token,), + ) + row = cursor.fetchone() + if not row: + return None + return self._episode_row_to_dict(row) + + def get_episode_paragraphs(self, episode_id: str, limit: int = 100) -> List[Dict[str, Any]]: + """获取 episode 关联段落(按 position 排序)。""" + token = str(episode_id or "").strip() + if not token: + return [] + safe_limit = max(1, int(limit)) + cursor = self._conn.cursor() + cursor.execute( + """ + SELECT p.*, ep.position + FROM episode_paragraphs ep + JOIN paragraphs p ON p.hash = ep.paragraph_hash + WHERE ep.episode_id = ? + AND (p.is_deleted IS NULL OR p.is_deleted = 0) + ORDER BY ep.position ASC + LIMIT ? + """, + (token, safe_limit), + ) + items = [] + for row in cursor.fetchall(): + payload = self._row_to_dict(row, "paragraph") + payload["position"] = row["position"] + items.append(payload) + return items + + def has_table(self, table_name: str) -> bool: + """检查数据库是否存在指定表。""" + if not self._conn: + return False + cursor = self._conn.cursor() + cursor.execute( + "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? LIMIT 1", + (table_name,), + ) + return cursor.fetchone() is not None + + def get_deleted_entities(self, limit: int = 50) -> List[Dict[str, Any]]: + """获取已软删除的实体 (回收站用)""" + if not self.has_table("entities"): return [] + + cursor = self._conn.cursor() + cursor.execute(""" + SELECT hash, name, deleted_at + FROM entities + WHERE is_deleted = 1 + ORDER BY deleted_at DESC + LIMIT ? + """, (limit,)) + + items = [] + for row in cursor.fetchall(): + items.append({ + "hash": row[0], + "name": row[1], + "type": "entity", # 标记为实体 + "deleted_at": row[2] + }) + return items + + def __repr__(self) -> str: + stats = self.get_statistics() if self.is_connected else {} + return ( + f"MetadataStore(paragraphs={stats.get('paragraph_count', 0)}, " + f"entities={stats.get('entity_count', 0)}, " + f"relations={stats.get('relation_count', 0)})" + ) + + def has_data(self) -> bool: + """检查磁盘上是否存在现有数据""" + if self.data_dir is None: + return False + return (self.data_dir / self.db_name).exists() diff --git a/plugins/A_memorix/core/storage/type_detection.py b/plugins/A_memorix/core/storage/type_detection.py new file mode 100644 index 00000000..c20d2cb4 --- /dev/null +++ b/plugins/A_memorix/core/storage/type_detection.py @@ -0,0 +1,137 @@ +"""Heuristic detection for import strategies and stored knowledge types.""" + +from __future__ import annotations + +import re +from typing import Optional + +from .knowledge_types import ( + ImportStrategy, + KnowledgeType, + parse_import_strategy, + resolve_stored_knowledge_type, +) + + +_NARRATIVE_MARKERS = [ + r"然后", + r"接着", + r"于是", + r"后来", + r"最后", + r"突然", + r"一天", + r"曾经", + r"有一次", + r"从前", + r"说道", + r"问道", + r"想着", + r"觉得", +] +_FACTUAL_MARKERS = [ + r"是", + r"有", + r"在", + r"为", + r"属于", + r"位于", + r"包含", + r"拥有", + r"成立于", + r"出生于", +] + + +def _non_empty_lines(content: str) -> list[str]: + return [line for line in str(content or "").splitlines() if line.strip()] + + +def looks_like_structured_text(content: str) -> bool: + text = str(content or "").strip() + if "|" not in text or text.count("|") < 2: + return False + parts = text.split("|") + return len(parts) == 3 and all(part.strip() for part in parts) + + +def looks_like_quote_text(content: str) -> bool: + lines = _non_empty_lines(content) + if len(lines) < 5: + return False + avg_len = sum(len(line) for line in lines) / len(lines) + return avg_len < 20 + + +def looks_like_narrative_text(content: str) -> bool: + text = str(content or "").strip() + if not text: + return False + + narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text)) + has_dialogue = bool(re.search(r'["「『].*?["」』]', text)) + has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###")) + return has_chapter or has_dialogue or narrative_score >= 2 + + +def looks_like_factual_text(content: str) -> bool: + text = str(content or "").strip() + if not text: + return False + if looks_like_structured_text(text) or looks_like_quote_text(text): + return False + + factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text)) + if factual_score <= 0: + return False + + if len(text) <= 240: + return True + return factual_score >= 2 and not looks_like_narrative_text(text) + + +def select_import_strategy( + content: str, + *, + override: Optional[str | ImportStrategy] = None, + chat_log: bool = False, +) -> ImportStrategy: + """文本导入策略选择:override > quote > factual > narrative。""" + + if chat_log: + return ImportStrategy.NARRATIVE + + strategy = parse_import_strategy(override, default=ImportStrategy.AUTO) + if strategy != ImportStrategy.AUTO: + return strategy + + if looks_like_quote_text(content): + return ImportStrategy.QUOTE + if looks_like_factual_text(content): + return ImportStrategy.FACTUAL + return ImportStrategy.NARRATIVE + + +def detect_knowledge_type(content: str) -> KnowledgeType: + """自动检测落库 knowledge_type;无法可靠判断时回退 mixed。""" + + text = str(content or "").strip() + if not text: + return KnowledgeType.MIXED + if looks_like_structured_text(text): + return KnowledgeType.STRUCTURED + if looks_like_quote_text(text): + return KnowledgeType.QUOTE + if looks_like_factual_text(text): + return KnowledgeType.FACTUAL + if looks_like_narrative_text(text): + return KnowledgeType.NARRATIVE + return KnowledgeType.MIXED + + +def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType: + """优先使用显式 type_hint,否则自动检测。""" + + if type_hint: + return resolve_stored_knowledge_type(type_hint, content=content) + return detect_knowledge_type(content) diff --git a/plugins/A_memorix/core/storage/vector_store.py b/plugins/A_memorix/core/storage/vector_store.py new file mode 100644 index 00000000..97a9144c --- /dev/null +++ b/plugins/A_memorix/core/storage/vector_store.py @@ -0,0 +1,776 @@ +""" +向量存储模块 + +基于Faiss的高效向量存储与检索,支持SQ8量化、Append-Only磁盘存储和内存映射。 +""" + +import os +import pickle +import hashlib +import shutil +import time +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Set, Any +import random +import threading # Added threading import + +import numpy as np + +try: + import faiss + HAS_FAISS = True +except ImportError: + HAS_FAISS = False + +from src.common.logger import get_logger +from ..utils.quantization import QuantizationType +from ..utils.io import atomic_write, atomic_save_path + +logger = get_logger("A_Memorix.VectorStore") + + +class VectorStore: + """ + 向量存储类 (SQ8 + Append-Only Disk) + + 特性: + - 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit)) + - 存储: float16 on-disk binary (vectors.bin) + - 内存: 仅索引常驻 RAM (<512MB for 100k vectors) + - ID: SHA1-based stable int64 IDs + - 一致性: 强制 L2 Normalization (IP == Cosine) + """ + + # 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化) + DEFAULT_MIN_TRAIN = 40 + # 强制训练样本量 + TRAIN_SIZE = 10000 + # 储水池采样上限 (流式处理前 50k 数据) + RESERVOIR_CAPACITY = 10000 + RESERVOIR_SAMPLE_SCOPE = 50000 + + def __init__( + self, + dimension: int, + quantization_type: QuantizationType = QuantizationType.INT8, + index_type: str = "sq8", + data_dir: Optional[Union[str, Path]] = None, + use_mmap: bool = True, + buffer_size: int = 1024, + ): + if not HAS_FAISS: + raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu") + + self.dimension = dimension + self.data_dir = Path(data_dir) if data_dir else None + if self.data_dir: + self.data_dir.mkdir(parents=True, exist_ok=True) + if quantization_type != QuantizationType.INT8: + raise ValueError( + "vNext 仅支持 quantization_type=int8(SQ8)。" + " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" + ) + normalized_index_type = str(index_type or "sq8").strip().lower() + if normalized_index_type not in {"sq8", "int8"}: + raise ValueError( + "vNext 仅支持 index_type=sq8。" + " 请更新配置并执行 scripts/release_vnext_migrate.py migrate。" + ) + self.quantization_type = QuantizationType.INT8 + self.index_type = "sq8" + self.buffer_size = buffer_size + + self._index: Optional[faiss.IndexIDMap2] = None + self._init_index() + + self._is_trained = False + self._vector_norm = "l2" + + # Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力 + # 必须使用 IndexIDMap2 以保证 ID 与主索引一致 + self._fallback_index: Optional[faiss.IndexIDMap2] = None + self._init_fallback_index() + + self._known_hashes: Set[str] = set() + self._deleted_ids: Set[int] = set() + + self._reservoir_buffer: List[np.ndarray] = [] + self._seen_count_for_reservoir = 0 + + self._write_buffer_vecs: List[np.ndarray] = [] + self._write_buffer_ids: List[int] = [] + + self._total_added = 0 + self._total_deleted = 0 + self._bin_count = 0 + + # Thread safety lock + self._lock = threading.RLock() + + logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage") + + def _init_index(self): + """初始化空的 Faiss 索引""" + quantizer = faiss.IndexScalarQuantizer( + self.dimension, + faiss.ScalarQuantizer.QT_8bit, + faiss.METRIC_INNER_PRODUCT + ) + self._index = faiss.IndexIDMap2(quantizer) + self._is_trained = False + + def _init_fallback_index(self): + """初始化 Flat 回退索引""" + flat_index = faiss.IndexFlatIP(self.dimension) + self._fallback_index = faiss.IndexIDMap2(flat_index) + logger.debug("Fallback index (Flat) initialized.") + + @staticmethod + def _generate_id(key: str) -> int: + """生成稳定的 int64 ID (SHA1 截断)""" + h = hashlib.sha1(key.encode("utf-8")).digest() + val = int.from_bytes(h[:8], byteorder="big", signed=False) + return val & 0x7FFFFFFFFFFFFFFF + + @property + def _bin_path(self) -> Path: + return self.data_dir / "vectors.bin" + + @property + def _ids_bin_path(self) -> Path: + return self.data_dir / "vectors_ids.bin" + + @property + def _int_to_str_map(self) -> Dict[int, str]: + """Lazy build volatile map from known hashes""" + # Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently + # But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache + if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes): + with self._lock: # Protect cache rebuild + self._cached_map = {self._generate_id(k): k for k in self._known_hashes} + return self._cached_map + + def add(self, vectors: np.ndarray, ids: List[str]) -> int: + with self._lock: + if vectors.shape[1] != self.dimension: + raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}") + + vectors = np.ascontiguousarray(vectors, dtype=np.float32) + faiss.normalize_L2(vectors) + + processed_vecs = [] + processed_int_ids = [] + + for i, str_id in enumerate(ids): + if str_id in self._known_hashes: + continue + + int_id = self._generate_id(str_id) + self._known_hashes.add(str_id) + + processed_vecs.append(vectors[i]) + processed_int_ids.append(int_id) + + if not processed_vecs: + return 0 + + batch_vecs = np.array(processed_vecs, dtype=np.float32) + batch_ids = np.array(processed_int_ids, dtype=np.int64) + + self._write_buffer_vecs.append(batch_vecs) + self._write_buffer_ids.extend(processed_int_ids) + + if len(self._write_buffer_ids) >= self.buffer_size: + self._flush_write_buffer_unlocked() + + if not self._is_trained: + # 双写到回退索引 + self._fallback_index.add_with_ids(batch_vecs, batch_ids) + + self._update_reservoir(batch_vecs) + # 这里的 TRAIN_SIZE 取默认 10k,或者根据当前数据量动态判断 + if len(self._reservoir_buffer) >= 10000: + logger.info(f"训练样本达到上限,开始训练...") + self._train_and_replay_unlocked() + + self._total_added += len(batch_ids) + return len(batch_ids) + + def _flush_write_buffer(self): + with self._lock: + self._flush_write_buffer_unlocked() + + def _flush_write_buffer_unlocked(self): + if not self._write_buffer_vecs: + return + + batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0) + batch_ids = np.array(self._write_buffer_ids, dtype=np.int64) + + vecs_fp16 = batch_vecs.astype(np.float16) + + with open(self._bin_path, "ab") as f: + f.write(vecs_fp16.tobytes()) + + ids_bytes = batch_ids.astype('>i8').tobytes() + with open(self._ids_bin_path, "ab") as f: + f.write(ids_bytes) + + self._bin_count += len(batch_ids) + + if self._is_trained and self._index.is_trained: + self._index.add_with_ids(batch_vecs, batch_ids) + else: + # 即使在 flush 时,如果未训练,也要同步到 fallback + self._fallback_index.add_with_ids(batch_vecs, batch_ids) + + self._write_buffer_vecs.clear() + self._write_buffer_ids.clear() + + def _update_reservoir(self, vectors: np.ndarray): + for vec in vectors: + self._seen_count_for_reservoir += 1 + if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY: + self._reservoir_buffer.append(vec) + else: + if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE: + r = random.randint(0, self._seen_count_for_reservoir - 1) + if r < self.RESERVOIR_CAPACITY: + self._reservoir_buffer[r] = vec + + def _train_and_replay(self): + with self._lock: + self._train_and_replay_unlocked() + + def _train_and_replay_unlocked(self): + if not self._reservoir_buffer: + logger.warning("No training data available.") + return + + train_data = np.array(self._reservoir_buffer, dtype=np.float32) + logger.info(f"Training Index with {len(train_data)} samples...") + + try: + self._index.train(train_data) + except Exception as e: + logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.") + return + + self._is_trained = True + self._reservoir_buffer = [] + + logger.info("Replaying data from disk to populate index...") + try: + replay_count = self._replay_vectors_to_index() + # 只有当 replay 成功且数据量一致时,才释放回退索引 + if self._index.ntotal >= self._bin_count: + logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.") + self._fallback_index.reset() + else: + logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.") + except Exception as e: + logger.error(f"Replay failed: {e}. Keeping fallback index as backup.") + + def _replay_vectors_to_index(self) -> int: + """从 vectors.bin 读取并添加到 index""" + if not self._bin_path.exists() or not self._ids_bin_path.exists(): + return 0 + + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + + if not vec_data: + break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_fp32 = batch_fp16.astype(np.float32) + faiss.normalize_L2(batch_fp32) + + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + if not all(valid_mask): + batch_fp32 = batch_fp32[valid_mask] + batch_ids = batch_ids[valid_mask] + + if len(batch_ids) > 0: + self._index.add_with_ids(batch_fp32, batch_ids) + + def search( + self, + query: np.ndarray, + k: int = 10, + filter_deleted: bool = True, + ) -> Tuple[List[str], List[float]]: + query_local = np.array(query, dtype=np.float32, order="C", copy=True) + if query_local.ndim == 1: + got_dim = int(query_local.shape[0]) + query_local = query_local.reshape(1, -1) + elif query_local.ndim == 2: + if query_local.shape[0] != 1: + raise ValueError( + f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" + ) + got_dim = int(query_local.shape[1]) + else: + raise ValueError( + f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}" + ) + + if got_dim != self.dimension: + raise ValueError( + f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}" + ) + if not np.all(np.isfinite(query_local)): + raise ValueError("query embedding contains non-finite values") + + faiss.normalize_L2(query_local) + + # 查询路径仅负责检索,不在此触发训练/回放。 + # 训练/回放前置到 warmup_index(),并由插件启动阶段触发。 + # Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。 + with self._lock: + self._flush_write_buffer_unlocked() + search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index + if search_index.ntotal == 0: + logger.warning("Indices are empty. No data to search.") + return [], [] + # 执行检索 + dists, ids = search_index.search(query_local, k * 2) + + # Faiss search 返回的是 (1, K) 的数组,取第一行 + dists = dists[0] + ids = ids[0] + + results = [] + for id_val, score in zip(ids, dists): + if id_val == -1: continue + if filter_deleted and id_val in self._deleted_ids: + continue + + str_id = self._int_to_str_map.get(id_val) + if str_id: + results.append((str_id, float(score))) + + # Sort and trim just in case filtering reduced count + results.sort(key=lambda x: x[1], reverse=True) + results = results[:k] + + if not results: + return [], [] + + return [r[0] for r in results], [r[1] for r in results] + + def warmup_index(self, force_train: bool = True) -> Dict[str, Any]: + """ + 预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。 + + Args: + force_train: 是否在满足阈值时强制训练 SQ8 索引 + + Returns: + 预热状态摘要 + """ + started = time.perf_counter() + logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}") + + try: + with self._lock: + self._flush_write_buffer() + + if self._bin_path.exists(): + self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2) + else: + self._bin_count = 0 + + needs_fallback_bootstrap = ( + self._bin_count > 0 + and self._fallback_index.ntotal == 0 + and (not self._is_trained or self._index.ntotal == 0) + ) + if needs_fallback_bootstrap: + self._bootstrap_fallback_from_disk() + + min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN))) + needs_train = ( + bool(force_train) + and self._bin_count >= min_train + and not self._is_trained + ) + if needs_train: + self._force_train_small_data() + + duration_ms = (time.perf_counter() - started) * 1000.0 + summary = { + "ok": True, + "trained": bool(self._is_trained), + "index_ntotal": int(self._index.ntotal), + "fallback_ntotal": int(self._fallback_index.ntotal), + "bin_count": int(self._bin_count), + "duration_ms": duration_ms, + "error": None, + } + except Exception as e: + duration_ms = (time.perf_counter() - started) * 1000.0 + summary = { + "ok": False, + "trained": bool(self._is_trained), + "index_ntotal": int(self._index.ntotal) if self._index is not None else 0, + "fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0, + "bin_count": int(getattr(self, "_bin_count", 0)), + "duration_ms": duration_ms, + "error": str(e), + } + logger.error( + "metric.vector_index_prewarm_fail=1 " + f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} " + f"error={e}" + ) + return summary + + logger.info( + "metric.vector_index_prewarm_success=1 " + f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} " + f"trained={summary['trained']} " + f"index_ntotal={summary['index_ntotal']} " + f"fallback_ntotal={summary['fallback_ntotal']} " + f"bin_count={summary['bin_count']}" + ) + return summary + + def _bootstrap_fallback_from_disk(self): + with self._lock: + self._bootstrap_fallback_from_disk_unlocked() + + def _bootstrap_fallback_from_disk_unlocked(self): + """重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引""" + if not self._bin_path.exists() or not self._ids_bin_path.exists(): + return + + logger.info("Replaying all disk vectors to fallback index...") + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + if not vec_data: break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_fp32 = batch_fp16.astype(np.float32) + faiss.normalize_L2(batch_fp32) + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + if any(valid_mask): + self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask]) + + logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.") + + def _force_train_small_data(self): + with self._lock: + self._force_train_small_data_unlocked() + + def _force_train_small_data_unlocked(self): + logger.info("Forcing training on small dataset...") + self._reservoir_buffer = [] + + chunk_size = 10000 + vec_item_size = self.dimension * 2 + + with open(self._bin_path, "rb") as f: + while len(self._reservoir_buffer) < self.TRAIN_SIZE: + data = f.read(chunk_size * vec_item_size) + if not data: break + fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension) + fp32 = fp16.astype(np.float32) + faiss.normalize_L2(fp32) + + for vec in fp32: + self._reservoir_buffer.append(vec) + if len(self._reservoir_buffer) >= self.TRAIN_SIZE: + break + + self._train_and_replay_unlocked() + + def delete(self, ids: List[str]) -> int: + with self._lock: + count = 0 + for str_id in ids: + if str_id not in self._known_hashes: + continue + int_id = self._generate_id(str_id) + if int_id not in self._deleted_ids: + self._deleted_ids.add(int_id) + if self._index.is_trained: + self._index.remove_ids(np.array([int_id], dtype=np.int64)) + # 同步从 fallback 移除 + if self._fallback_index.ntotal > 0: + self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64)) + count += 1 + self._total_deleted += count + + # Check GC + self._check_rebuild_needed() + return count + + def _check_rebuild_needed(self): + """GC Excution Check""" + if self._bin_count == 0: return + ratio = len(self._deleted_ids) / self._bin_count + if ratio > 0.3 and len(self._deleted_ids) > 1000: + logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})") + self.rebuild_index() + + def rebuild_index(self): + """GC: 重建索引,压缩 bin 文件""" + with self._lock: + self._rebuild_index_locked() + + def _rebuild_index_locked(self): + """实际 GC 重建逻辑。""" + logger.info("Starting Compaction (GC)...") + + tmp_bin = self.data_dir / "vectors.bin.tmp" + tmp_ids = self.data_dir / "vectors_ids.bin.tmp" + + vec_item_size = self.dimension * 2 + id_item_size = 8 + chunk_size = 10000 + + new_count = 0 + + # 1. Compact Files + with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \ + open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id: + while True: + vec_data = f_vec.read(chunk_size * vec_item_size) + id_data = f_id.read(chunk_size * id_item_size) + if not vec_data: break + + batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension) + batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64) + + keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids] + + if any(keep_mask): + keep_vecs = batch_fp16[keep_mask] + keep_ids = batch_ids[keep_mask] + + w_vec.write(keep_vecs.tobytes()) + w_id.write(keep_ids.astype('>i8').tobytes()) + new_count += len(keep_ids) + + # 2. Reset State & Atomic Swap + self._bin_count = new_count + + # Close current index + self._index.reset() + if self._fallback_index: self._fallback_index.reset() # Also clear fallback + self._is_trained = False + + # Swap files + shutil.move(str(tmp_bin), str(self._bin_path)) + shutil.move(str(tmp_ids), str(self._ids_bin_path)) + + # Reset Tombstones (Critical) + self._deleted_ids.clear() + + # 3. Reload/Rebuild Index (Fresh Train) + # We need to re-train because data distribution might have changed significantly after deletion + self._init_index() + self._init_fallback_index() # Re-init fallback too + self._force_train_small_data() # This will train and replay from the NEW compact file + + logger.info("Compaction Complete.") + + def save(self, data_dir: Optional[Union[str, Path]] = None) -> None: + with self._lock: + if not data_dir: + data_dir = self.data_dir + if not data_dir: + raise ValueError("No data_dir") + + data_dir = Path(data_dir) + data_dir.mkdir(parents=True, exist_ok=True) + + self._flush_write_buffer_unlocked() + + if self._is_trained: + index_path = data_dir / "vectors.index" + with atomic_save_path(index_path) as tmp: + faiss.write_index(self._index, tmp) + + meta = { + "dimension": self.dimension, + "quantization_type": self.quantization_type.value, + "is_trained": self._is_trained, + "vector_norm": self._vector_norm, + "deleted_ids": list(self._deleted_ids), + "known_hashes": list(self._known_hashes), + } + + with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f: + pickle.dump(meta, f) + + logger.info("VectorStore saved.") + + def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]: + """ + 离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。 + """ + with self._lock: + target_dir = Path(data_dir) if data_dir else self.data_dir + if target_dir is None: + raise ValueError("No data_dir") + target_dir = Path(target_dir) + npy_path = target_dir / "vectors.npy" + idx_path = target_dir / "vectors.index" + bin_path = target_dir / "vectors.bin" + ids_bin_path = target_dir / "vectors_ids.bin" + meta_path = target_dir / "vectors_metadata.pkl" + + if not npy_path.exists(): + return {"migrated": False, "reason": "npy_missing"} + if not meta_path.exists(): + raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl") + if bin_path.exists() and ids_bin_path.exists(): + return {"migrated": False, "reason": "bin_exists"} + + # Reset in-memory state to avoid appending to stale runtime buffers. + self._known_hashes.clear() + self._deleted_ids.clear() + self._write_buffer_vecs.clear() + self._write_buffer_ids.clear() + self._init_index() + self._init_fallback_index() + self._is_trained = False + self._bin_count = 0 + + self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir) + self.save(target_dir) + return {"migrated": True, "reason": "ok"} + + def load(self, data_dir: Optional[Union[str, Path]] = None) -> None: + with self._lock: + if not data_dir: data_dir = self.data_dir + data_dir = Path(data_dir) + + npy_path = data_dir / "vectors.npy" + idx_path = data_dir / "vectors.index" + bin_path = data_dir / "vectors.bin" + + if npy_path.exists() and not bin_path.exists(): + raise RuntimeError( + "检测到 legacy vectors.npy,vNext 不再支持运行时自动迁移。" + " 请先执行 scripts/release_vnext_migrate.py migrate。" + ) + + meta_path = data_dir / "vectors_metadata.pkl" + if not meta_path.exists(): + logger.warning("No metadata found, initialized empty.") + return + + with open(meta_path, "rb") as f: + meta = pickle.load(f) + + if meta.get("vector_norm") != "l2": + logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...") + self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", [])) + self._deleted_ids = set(meta.get("deleted_ids", [])) + self._init_index() + self._force_train_small_data() + return + + self._is_trained = meta.get("is_trained", False) + self._vector_norm = meta.get("vector_norm", "l2") + self._deleted_ids = set(meta.get("deleted_ids", [])) + self._known_hashes = set(meta.get("known_hashes", [])) + + if self._is_trained: + if idx_path.exists(): + try: + self._index = faiss.read_index(str(idx_path)) + if not isinstance(self._index, faiss.IndexIDMap2): + logger.warning("Loaded index type mismatch. Rebuilding...") + self._init_index() + self._force_train_small_data() + except Exception as e: + logger.error(f"Failed to load index: {e}. Rebuilding...") + self._init_index() + self._force_train_small_data() + else: + logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...") + self._init_index() + self._force_train_small_data() + + if bin_path.exists(): + self._bin_count = bin_path.stat().st_size // (self.dimension * 2) + + def _migrate_from_npy(self, npy_path, idx_path, data_dir): + with self._lock: + self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir) + + def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir): + try: + arr = np.load(npy_path, mmap_mode="r") + except Exception: + arr = np.load(npy_path) + + meta_path = data_dir / "vectors_metadata.pkl" + old_ids = [] + if meta_path.exists(): + with open(meta_path, "rb") as f: + m = pickle.load(f) + old_ids = m.get("ids", []) + + if len(arr) != len(old_ids): + logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}") + return + + logger.info(f"Migrating {len(arr)} vectors...") + + chunk = 1000 + for i in range(0, len(arr), chunk): + sub_arr = arr[i : i+chunk] + sub_ids = old_ids[i : i+chunk] + self.add(sub_arr, sub_ids) + + if not self._is_trained: + self._force_train_small_data() + + shutil.move(str(npy_path), str(npy_path) + ".bak") + if idx_path.exists(): + shutil.move(str(idx_path), str(idx_path) + ".bak") + + logger.info("Migration complete.") + + def clear(self) -> None: + with self._lock: + self._ids_bin_path.unlink(missing_ok=True) + self._bin_path.unlink(missing_ok=True) + self._init_index() + self._known_hashes.clear() + self._deleted_ids.clear() + self._bin_count = 0 + logger.info("VectorStore cleared.") + + def has_data(self) -> bool: + return (self.data_dir / "vectors_metadata.pkl").exists() + + @property + def num_vectors(self) -> int: + return len(self._known_hashes) - len(self._deleted_ids) + + def __contains__(self, hash_value: str) -> bool: + """Check if a hash exists in the store""" + return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids + diff --git a/plugins/A_memorix/core/strategies/__init__.py b/plugins/A_memorix/core/strategies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/A_memorix/core/strategies/base.py b/plugins/A_memorix/core/strategies/base.py new file mode 100644 index 00000000..ff250cdf --- /dev/null +++ b/plugins/A_memorix/core/strategies/base.py @@ -0,0 +1,89 @@ +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union +from dataclasses import dataclass, field +from enum import Enum +import hashlib + +class KnowledgeType(str, Enum): + NARRATIVE = "narrative" + FACTUAL = "factual" + QUOTE = "quote" + MIXED = "mixed" + +@dataclass +class SourceInfo: + file: str + offset_start: int + offset_end: int + checksum: str = "" + +@dataclass +class ChunkContext: + chunk_id: str + index: int + context: Dict[str, Any] = field(default_factory=dict) + text: str = "" + +@dataclass +class ChunkFlags: + verbatim: bool = False + requires_llm: bool = True + +@dataclass +class ProcessedChunk: + type: KnowledgeType + source: SourceInfo + chunk: ChunkContext + data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities + flags: ChunkFlags = field(default_factory=ChunkFlags) + + def to_dict(self) -> Dict: + return { + "type": self.type.value, + "source": { + "file": self.source.file, + "offset_start": self.source.offset_start, + "offset_end": self.source.offset_end, + "checksum": self.source.checksum + }, + "chunk": { + "text": self.chunk.text, + "chunk_id": self.chunk.chunk_id, + "index": self.chunk.index, + "context": self.chunk.context + }, + "data": self.data, + "flags": { + "verbatim": self.flags.verbatim, + "requires_llm": self.flags.requires_llm + } + } + +class BaseStrategy(ABC): + def __init__(self, filename: str): + self.filename = filename + + @abstractmethod + def split(self, text: str) -> List[ProcessedChunk]: + """按策略将文本切分为块。""" + pass + + @abstractmethod + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + """从文本块中抽取结构化信息。""" + pass + + def calculate_checksum(self, text: str) -> str: + return hashlib.sha256(text.encode("utf-8")).hexdigest() + + def build_language_guard(self, text: str) -> str: + """ + 构建统一的输出语言约束。 + 不区分语言类型,仅要求抽取值保持原文语言,不做翻译。 + """ + _ = text # 预留参数,便于后续按需扩展 + return ( + "Focus on the original source language. Keep extracted events, entities, predicates " + "and objects in the same language as the source text, preserve names/terms as-is, " + "and do not translate." + ) diff --git a/plugins/A_memorix/core/strategies/factual.py b/plugins/A_memorix/core/strategies/factual.py new file mode 100644 index 00000000..4b7d6e56 --- /dev/null +++ b/plugins/A_memorix/core/strategies/factual.py @@ -0,0 +1,98 @@ +import re +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext + +class FactualStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + # 结构感知切分 + lines = text.split('\n') + chunks = [] + current_chunk_lines = [] + current_len = 0 + target_size = 600 + + for i, line in enumerate(lines): + # 判断是否应当切分 + # 若当前行为列表项/定义/表格行,则尽量不切分 + is_structure = self._is_structural_line(line) + + current_len += len(line) + 1 + current_chunk_lines.append(line) + + # 达到目标长度且不在紧凑结构块内时切分(过长时强制切分) + if current_len >= target_size and not is_structure: + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + current_chunk_lines = [] + current_len = 0 + elif current_len >= target_size * 2: # 超长时强制切分 + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + current_chunk_lines = [] + current_len = 0 + + if current_chunk_lines: + chunks.append(self._create_chunk(current_chunk_lines, len(chunks))) + + return chunks + + def _is_structural_line(self, line: str) -> bool: + line = line.strip() + if not line: return False + # 列表项 + if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line): + return True + # 定义项(术语: 定义) + if re.match(r'^[^::]+[::].+', line): + return True + # 表格行(按 markdown 语法假设) + if line.startswith('|') and line.endswith('|'): + return True + return False + + def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk: + text = "\n".join(lines) + return ProcessedChunk( + type=KnowledgeType.FACTUAL, + source=SourceInfo( + file=self.filename, + offset_start=0, # 简化处理:真实偏移跟踪需要额外状态 + offset_end=0, + checksum=self.calculate_checksum(text) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{index}", + index=index, + text=text + ) + ) + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + if not llm_func: + raise ValueError("LLM function required for Factual extraction") + + language_guard = self.build_language_guard(chunk.chunk.text) + prompt = f"""You are a factual knowledge extraction engine. +Extract factual triples and entities from the text. +Preserve lists and definitions accurately. + +Language constraints: +- {language_guard} +- Preserve original names and domain terms exactly when possible. +- JSON keys must stay exactly as: triples, entities, subject, predicate, object. + +Text: +{chunk.chunk.text} + +Return ONLY valid JSON: +{{ + "triples": [ + {{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}} + ], + "entities": ["Entity1", "Entity2"] +}} +""" + result = await llm_func(prompt) + + # 结果保持原样存入 data,后续统一归一化流程会处理 + # vector_store 侧期望关系字段为 subject/predicate/object 映射形式 + chunk.data = result + return chunk diff --git a/plugins/A_memorix/core/strategies/narrative.py b/plugins/A_memorix/core/strategies/narrative.py new file mode 100644 index 00000000..731414f7 --- /dev/null +++ b/plugins/A_memorix/core/strategies/narrative.py @@ -0,0 +1,126 @@ +import re +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext + +class NarrativeStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + scenes = self._split_into_scenes(text) + chunks = [] + + for scene_idx, (scene_text, scene_title) in enumerate(scenes): + scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx) + chunks.extend(scene_chunks) + + return chunks + + def _split_into_scenes(self, text: str) -> List[tuple[str, str]]: + """按标题或分隔符把文本切分为场景。""" + # 简单启发式:按 markdown 标题或特定分隔符切分 + # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 + # 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行 + scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)' + + # 保留分隔符,以便识别场景起点 + parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE) + + scenes = [] + current_scene_title = "Start" + current_scene_content = [] + + if parts and parts[0].strip() == "": + parts = parts[1:] + + for part in parts: + if re.match(scene_pattern_str, part, re.MULTILINE): + # 先保存上一段场景 + if current_scene_content: + scenes.append(("".join(current_scene_content), current_scene_title)) + current_scene_content = [] + current_scene_title = part.strip() + else: + current_scene_content.append(part) + + if current_scene_content: + scenes.append(("".join(current_scene_content), current_scene_title)) + + # 若未识别到场景,则把全文视作单一场景 + if not scenes: + scenes = [(text, "Whole Text")] + + return scenes + + def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]: + chunks = [] + if len(text) <= window_size: + chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0)) + return chunks + + stride = window_size - overlap + start = 0 + local_idx = 0 + while start < len(text): + end = min(start + window_size, len(text)) + chunk_text = text[start:end] + + # 尽量对齐到最近换行,避免生硬截断句子 + # 仅在未到文本尾部时进行回退 + if end < len(text): + last_newline = chunk_text.rfind('\n') + if last_newline > window_size // 2: # 仅在回退距离可接受时启用 + end = start + last_newline + 1 + chunk_text = text[start:end] + + chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start)) + + start += len(chunk_text) - overlap if end < len(text) else len(chunk_text) + local_idx += 1 + + return chunks + + def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk: + return ProcessedChunk( + type=KnowledgeType.NARRATIVE, + source=SourceInfo( + file=self.filename, + offset_start=offset, + offset_end=offset + len(text), + checksum=self.calculate_checksum(text) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{scene_idx}_{local_idx}", + index=local_idx, + text=text, + context={"scene_id": scene_id} + ) + ) + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + if not llm_func: + raise ValueError("LLM function required for Narrative extraction") + + language_guard = self.build_language_guard(chunk.chunk.text) + prompt = f"""You are a narrative knowledge extraction engine. +Extract key events and character relations from the scene text. + +Language constraints: +- {language_guard} +- Preserve original names and terms exactly when possible. +- JSON keys must stay exactly as: events, relations, subject, predicate, object. + +Scene: +{chunk.chunk.context.get('scene_id')} + +Text: +{chunk.chunk.text} + +Return ONLY valid JSON: +{{ + "events": ["event description 1", "event description 2"], + "relations": [ + {{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}} + ] +}} +""" + result = await llm_func(prompt) + chunk.data = result + return chunk diff --git a/plugins/A_memorix/core/strategies/quote.py b/plugins/A_memorix/core/strategies/quote.py new file mode 100644 index 00000000..10733d64 --- /dev/null +++ b/plugins/A_memorix/core/strategies/quote.py @@ -0,0 +1,52 @@ +from typing import List, Dict, Any +from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags + +class QuoteStrategy(BaseStrategy): + def split(self, text: str) -> List[ProcessedChunk]: + # Split by double newlines (stanzas) + stanzas = text.split("\n\n") + chunks = [] + offset = 0 + + for idx, stanza in enumerate(stanzas): + if not stanza.strip(): + offset += len(stanza) + 2 + continue + + chunk = ProcessedChunk( + type=KnowledgeType.QUOTE, + source=SourceInfo( + file=self.filename, + offset_start=offset, + offset_end=offset + len(stanza), + checksum=self.calculate_checksum(stanza) + ), + chunk=ChunkContext( + chunk_id=f"{self.filename}_{idx}", + index=idx, + text=stanza + ), + flags=ChunkFlags( + verbatim=True, + requires_llm=False # Default to no LLM, but can be overridden + ) + ) + chunks.append(chunk) + offset += len(stanza) + 2 # +2 for \n\n + + return chunks + + async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk: + # For quotes, the text itself is the entity/knowledge + # We might use LLM to extract headers/metadata if requested, but core logic is pass-through + + # Treat the whole chunk text as a verbatim entity + chunk.data = { + "verbatim_entities": [chunk.chunk.text] + } + + if llm_func and chunk.flags.requires_llm: + # Optional: Extract metadata + pass + + return chunk diff --git a/plugins/A_memorix/core/utils/__init__.py b/plugins/A_memorix/core/utils/__init__.py new file mode 100644 index 00000000..e0d763cf --- /dev/null +++ b/plugins/A_memorix/core/utils/__init__.py @@ -0,0 +1,33 @@ +"""工具模块 - 哈希、监控等辅助功能""" + +from .hash import compute_hash, normalize_text +from .monitor import MemoryMonitor +from .quantization import quantize_vector, dequantize_vector +from .time_parser import ( + parse_query_datetime_to_timestamp, + parse_query_time_range, + parse_ingest_datetime_to_timestamp, + normalize_time_meta, + format_timestamp, +) +from .relation_write_service import RelationWriteService, RelationWriteResult +from .relation_query import RelationQuerySpec, parse_relation_query_spec +from .plugin_id_policy import PluginIdPolicy + +__all__ = [ + "compute_hash", + "normalize_text", + "MemoryMonitor", + "quantize_vector", + "dequantize_vector", + "parse_query_datetime_to_timestamp", + "parse_query_time_range", + "parse_ingest_datetime_to_timestamp", + "normalize_time_meta", + "format_timestamp", + "RelationWriteService", + "RelationWriteResult", + "RelationQuerySpec", + "parse_relation_query_spec", + "PluginIdPolicy", +] diff --git a/plugins/A_memorix/core/utils/aggregate_query_service.py b/plugins/A_memorix/core/utils/aggregate_query_service.py new file mode 100644 index 00000000..a87a4913 --- /dev/null +++ b/plugins/A_memorix/core/utils/aggregate_query_service.py @@ -0,0 +1,360 @@ +""" +聚合查询服务: +- 并发执行 search/time/episode 分支 +- 统一分支结果结构 +- 可选混合排序(Weighted RRF) +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.AggregateQueryService") + +BranchRunner = Callable[[], Awaitable[Dict[str, Any]]] + + +class AggregateQueryService: + """聚合查询执行服务(search/time/episode)。""" + + def __init__(self, plugin_config: Optional[Any] = None): + self.plugin_config = plugin_config or {} + + def _cfg(self, key: str, default: Any = None) -> Any: + getter = getattr(self.plugin_config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + @staticmethod + def _as_float(value: Any, default: float = 0.0) -> float: + try: + return float(value) + except Exception: + return float(default) + + @staticmethod + def _as_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + def _rrf_k(self) -> float: + raw = self._cfg("retrieval.aggregate.rrf_k", 60.0) + value = self._as_float(raw, 60.0) + return max(1.0, value) + + def _weights(self) -> Dict[str, float]: + defaults = {"search": 1.0, "time": 1.0, "episode": 1.0} + raw = self._cfg("retrieval.aggregate.weights", {}) + if not isinstance(raw, dict): + return defaults + + out = dict(defaults) + for key in ("search", "time", "episode"): + if key in raw: + out[key] = max(0.0, self._as_float(raw.get(key), defaults[key])) + return out + + @staticmethod + def _normalize_branch_payload( + name: str, + payload: Optional[Dict[str, Any]], + ) -> Dict[str, Any]: + data = payload if isinstance(payload, dict) else {} + results_raw = data.get("results", []) + results = results_raw if isinstance(results_raw, list) else [] + count = data.get("count") + if count is None: + count = len(results) + return { + "name": name, + "success": bool(data.get("success", False)), + "skipped": bool(data.get("skipped", False)), + "skip_reason": str(data.get("skip_reason", "") or "").strip(), + "error": str(data.get("error", "") or "").strip(), + "results": results, + "count": max(0, int(count)), + "elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)), + "content": str(data.get("content", "") or ""), + "query_type": str(data.get("query_type", "") or name), + } + + @staticmethod + def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str: + item_type = str(item.get("type", "") or "").strip().lower() + if item_type == "episode": + episode_id = str(item.get("episode_id", "") or "").strip() + if episode_id: + return f"episode:{episode_id}" + + item_hash = str(item.get("hash", "") or "").strip() + if item_hash: + return f"{item_type}:{item_hash}" + + return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}" + + def _build_mixed_results( + self, + *, + branches: Dict[str, Dict[str, Any]], + top_k: int, + ) -> List[Dict[str, Any]]: + rrf_k = self._rrf_k() + weights = self._weights() + bucket: Dict[str, Dict[str, Any]] = {} + + for branch_name, branch in branches.items(): + if not branch.get("success", False): + continue + results = branch.get("results", []) + if not isinstance(results, list): + continue + + weight = max(0.0, float(weights.get(branch_name, 1.0))) + for idx, item in enumerate(results, start=1): + if not isinstance(item, dict): + continue + key = self._mix_key(item, branch_name, idx) + score = weight / (rrf_k + float(idx)) + if key not in bucket: + merged = dict(item) + merged["fusion_score"] = 0.0 + merged["_source_branches"] = set() + bucket[key] = merged + + target = bucket[key] + target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score + target["_source_branches"].add(branch_name) + + mixed = list(bucket.values()) + mixed.sort( + key=lambda x: ( + -float(x.get("fusion_score", 0.0)), + str(x.get("type", "") or ""), + str(x.get("hash", "") or x.get("episode_id", "") or ""), + ) + ) + + out: List[Dict[str, Any]] = [] + for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1): + merged = dict(item) + branches_set = merged.pop("_source_branches", set()) + merged["source_branches"] = sorted(list(branches_set)) + merged["rank"] = rank + out.append(merged) + return out + + @staticmethod + def _status(branch: Dict[str, Any]) -> str: + if branch.get("skipped", False): + return "skipped" + if branch.get("success", False): + return "success" + return "failed" + + def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: + summary: Dict[str, Dict[str, Any]] = {} + for name, branch in branches.items(): + status = self._status(branch) + summary[name] = { + "status": status, + "count": int(branch.get("count", 0) or 0), + } + if status == "skipped": + summary[name]["reason"] = str(branch.get("skip_reason", "") or "") + if status == "failed": + summary[name]["error"] = str(branch.get("error", "") or "") + return summary + + def _build_content( + self, + *, + query: str, + branches: Dict[str, Dict[str, Any]], + errors: List[Dict[str, str]], + mixed_results: Optional[List[Dict[str, Any]]], + ) -> str: + lines: List[str] = [ + f"🔀 聚合查询结果(query='{query or 'N/A'}')", + "", + "分支状态:", + ] + for name in ("search", "time", "episode"): + branch = branches.get(name, {}) + status = self._status(branch) + count = int(branch.get("count", 0) or 0) + line = f"- {name}: {status}, count={count}" + reason = str(branch.get("skip_reason", "") or "").strip() + err = str(branch.get("error", "") or "").strip() + if status == "skipped" and reason: + line += f" ({reason})" + if status == "failed" and err: + line += f" ({err})" + lines.append(line) + + if errors: + lines.append("") + lines.append("错误:") + for item in errors[:6]: + lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}") + + if mixed_results is not None: + lines.append("") + lines.append(f"🧩 混合结果({len(mixed_results)} 条):") + for idx, item in enumerate(mixed_results[:5], start=1): + src = ",".join(item.get("source_branches", []) or []) + if str(item.get("type", "") or "") == "episode": + title = str(item.get("title", "") or "Untitled") + lines.append(f"{idx}. 🧠 {title} [{src}]") + else: + text = str(item.get("content", "") or "") + if len(text) > 80: + text = text[:80] + "..." + lines.append(f"{idx}. {text} [{src}]") + + return "\n".join(lines) + + async def execute( + self, + *, + query: str, + top_k: int, + mix: bool, + mix_top_k: Optional[int], + time_from: Optional[str], + time_to: Optional[str], + search_runner: Optional[BranchRunner], + time_runner: Optional[BranchRunner], + episode_runner: Optional[BranchRunner], + ) -> Dict[str, Any]: + clean_query = str(query or "").strip() + safe_top_k = max(1, int(top_k)) + safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k)) + + branches: Dict[str, Dict[str, Any]] = {} + errors: List[Dict[str, str]] = [] + scheduled: List[Tuple[str, asyncio.Task]] = [] + + if clean_query: + if search_runner is not None: + scheduled.append(("search", asyncio.create_task(search_runner()))) + else: + branches["search"] = self._normalize_branch_payload( + "search", + {"success": False, "error": "search runner unavailable", "results": []}, + ) + else: + branches["search"] = self._normalize_branch_payload( + "search", + { + "success": False, + "skipped": True, + "skip_reason": "missing_query", + "results": [], + "count": 0, + }, + ) + + if time_from or time_to: + if time_runner is not None: + scheduled.append(("time", asyncio.create_task(time_runner()))) + else: + branches["time"] = self._normalize_branch_payload( + "time", + {"success": False, "error": "time runner unavailable", "results": []}, + ) + else: + branches["time"] = self._normalize_branch_payload( + "time", + { + "success": False, + "skipped": True, + "skip_reason": "missing_time_window", + "results": [], + "count": 0, + }, + ) + + if episode_runner is not None: + scheduled.append(("episode", asyncio.create_task(episode_runner()))) + else: + branches["episode"] = self._normalize_branch_payload( + "episode", + {"success": False, "error": "episode runner unavailable", "results": []}, + ) + + if scheduled: + done = await asyncio.gather( + *[task for _, task in scheduled], + return_exceptions=True, + ) + for (branch_name, _), payload in zip(scheduled, done): + if isinstance(payload, Exception): + logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload) + normalized = self._normalize_branch_payload( + branch_name, + { + "success": False, + "error": str(payload), + "results": [], + }, + ) + else: + normalized = self._normalize_branch_payload(branch_name, payload) + branches[branch_name] = normalized + + for name in ("search", "time", "episode"): + branch = branches.get(name) + if not branch: + continue + if branch.get("skipped", False): + continue + if not branch.get("success", False): + errors.append( + { + "branch": name, + "error": str(branch.get("error", "") or "unknown error"), + } + ) + + success = any( + bool(branches.get(name, {}).get("success", False)) + for name in ("search", "time", "episode") + ) + mixed_results: Optional[List[Dict[str, Any]]] = None + if mix: + mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k) + + payload: Dict[str, Any] = { + "success": success, + "query_type": "aggregate", + "query": clean_query, + "top_k": safe_top_k, + "mix": bool(mix), + "mix_top_k": safe_mix_top_k, + "branches": branches, + "errors": errors, + "summary": self._build_summary(branches), + } + if mixed_results is not None: + payload["mixed_results"] = mixed_results + + payload["content"] = self._build_content( + query=clean_query, + branches=branches, + errors=errors, + mixed_results=mixed_results, + ) + return payload diff --git a/plugins/A_memorix/core/utils/episode_retrieval_service.py b/plugins/A_memorix/core/utils/episode_retrieval_service.py new file mode 100644 index 00000000..5a4cd24d --- /dev/null +++ b/plugins/A_memorix/core/utils/episode_retrieval_service.py @@ -0,0 +1,182 @@ +"""Episode hybrid retrieval service.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger + +from ..retrieval import DualPathRetriever, TemporalQueryOptions + +logger = get_logger("A_Memorix.EpisodeRetrievalService") + + +class EpisodeRetrievalService: + """Hybrid episode retrieval backed by lexical rows and evidence projection.""" + + _RRF_K = 60.0 + _BRANCH_WEIGHTS = { + "lexical": 1.0, + "paragraph_evidence": 1.0, + "relation_evidence": 0.85, + } + + def __init__( + self, + *, + metadata_store: Any, + retriever: Optional[DualPathRetriever] = None, + ) -> None: + self.metadata_store = metadata_store + self.retriever = retriever + + async def query( + self, + *, + query: str = "", + top_k: int = 5, + time_from: Optional[float] = None, + time_to: Optional[float] = None, + person: Optional[str] = None, + source: Optional[str] = None, + include_paragraphs: bool = False, + ) -> List[Dict[str, Any]]: + clean_query = str(query or "").strip() + safe_top_k = max(1, int(top_k)) + candidate_k = max(30, safe_top_k * 6) + + branches: Dict[str, List[Dict[str, Any]]] = { + "lexical": self.metadata_store.query_episodes( + query=clean_query, + time_from=time_from, + time_to=time_to, + person=person, + source=source, + limit=(candidate_k if clean_query else safe_top_k), + ) + } + + if clean_query and self.retriever is not None: + try: + temporal = TemporalQueryOptions( + time_from=time_from, + time_to=time_to, + person=person, + source=source, + ) + results = await self.retriever.retrieve( + query=clean_query, + top_k=candidate_k, + temporal=temporal, + ) + except Exception as exc: + logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc) + else: + paragraph_rank_map: Dict[str, int] = {} + relation_rank_map: Dict[str, int] = {} + for rank, item in enumerate(results, start=1): + hash_value = str(getattr(item, "hash_value", "") or "").strip() + result_type = str(getattr(item, "result_type", "") or "").strip().lower() + if not hash_value: + continue + if result_type == "paragraph" and hash_value not in paragraph_rank_map: + paragraph_rank_map[hash_value] = rank + elif result_type == "relation" and hash_value not in relation_rank_map: + relation_rank_map[hash_value] = rank + + if paragraph_rank_map: + paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes( + list(paragraph_rank_map.keys()), + source=source, + ) + if paragraph_rows: + branches["paragraph_evidence"] = self._rank_projected_rows( + paragraph_rows, + rank_map=paragraph_rank_map, + support_key="matched_paragraph_hashes", + ) + + if relation_rank_map: + relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes( + list(relation_rank_map.keys()), + source=source, + ) + if relation_rows: + branches["relation_evidence"] = self._rank_projected_rows( + relation_rows, + rank_map=relation_rank_map, + support_key="matched_relation_hashes", + ) + + fused = self._fuse_branches(branches, top_k=safe_top_k) + if include_paragraphs: + for item in fused: + item["paragraphs"] = self.metadata_store.get_episode_paragraphs( + episode_id=str(item.get("episode_id") or ""), + limit=50, + ) + return fused + + @staticmethod + def _rank_projected_rows( + rows: List[Dict[str, Any]], + *, + rank_map: Dict[str, int], + support_key: str, + ) -> List[Dict[str, Any]]: + sentinel = 10**9 + ranked = [dict(item) for item in rows] + + def _first_support_rank(item: Dict[str, Any]) -> int: + support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])] + ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map] + return min(ranks) if ranks else sentinel + + ranked.sort( + key=lambda item: ( + _first_support_rank(item), + -int(item.get("matched_paragraph_count") or 0), + -float(item.get("updated_at") or 0.0), + str(item.get("episode_id") or ""), + ) + ) + return ranked + + def _fuse_branches( + self, + branches: Dict[str, List[Dict[str, Any]]], + *, + top_k: int, + ) -> List[Dict[str, Any]]: + bucket: Dict[str, Dict[str, Any]] = {} + for branch_name, rows in branches.items(): + weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0) + if weight <= 0.0: + continue + for rank, row in enumerate(rows, start=1): + episode_id = str(row.get("episode_id", "") or "").strip() + if not episode_id: + continue + if episode_id not in bucket: + payload = dict(row) + payload.pop("matched_paragraph_hashes", None) + payload.pop("matched_relation_hashes", None) + payload.pop("matched_paragraph_count", None) + payload.pop("matched_relation_count", None) + payload["_fusion_score"] = 0.0 + bucket[episode_id] = payload + bucket[episode_id]["_fusion_score"] = float( + bucket[episode_id].get("_fusion_score", 0.0) + ) + weight / (self._RRF_K + float(rank)) + + out = list(bucket.values()) + out.sort( + key=lambda item: ( + -float(item.get("_fusion_score", 0.0)), + -float(item.get("updated_at") or 0.0), + str(item.get("episode_id") or ""), + ) + ) + for item in out: + item.pop("_fusion_score", None) + return out[: max(1, int(top_k))] diff --git a/plugins/A_memorix/core/utils/hash.py b/plugins/A_memorix/core/utils/hash.py new file mode 100644 index 00000000..b6363257 --- /dev/null +++ b/plugins/A_memorix/core/utils/hash.py @@ -0,0 +1,129 @@ +""" +哈希工具模块 + +提供文本哈希计算功能,用于唯一标识和去重。 +""" + +import hashlib +import re +from typing import Union + + +def compute_hash(text: str, hash_type: str = "sha256") -> str: + """ + 计算文本的哈希值 + + Args: + text: 输入文本 + hash_type: 哈希算法类型(sha256, md5等) + + Returns: + 哈希值字符串 + """ + if hash_type == "sha256": + return hashlib.sha256(text.encode("utf-8")).hexdigest() + elif hash_type == "md5": + return hashlib.md5(text.encode("utf-8")).hexdigest() + else: + raise ValueError(f"不支持的哈希算法: {hash_type}") + + +def normalize_text(text: str) -> str: + """ + 规范化文本用于哈希计算 + + 执行以下操作: + - 去除首尾空白 + - 统一换行符为\\n + - 压缩多个连续空格 + - 去除不可见字符(保留换行和制表符) + + Args: + text: 输入文本 + + Returns: + 规范化后的文本 + """ + # 去除首尾空白 + text = text.strip() + + # 统一换行符 + text = text.replace("\r\n", "\n").replace("\r", "\n") + + # 压缩多个连续空格为一个(但保留换行和制表符) + text = re.sub(r"[^\S\n]+", " ", text) + + return text + + +def compute_paragraph_hash(paragraph: str) -> str: + """ + 计算段落的哈希值 + + Args: + paragraph: 段落文本 + + Returns: + 段落哈希值(用于paragraph-前缀) + """ + normalized = normalize_text(paragraph) + return compute_hash(normalized) + + +def compute_entity_hash(entity: str) -> str: + """ + 计算实体的哈希值 + + Args: + entity: 实体名称 + + Returns: + 实体哈希值(用于entity-前缀) + """ + normalized = entity.strip().lower() + return compute_hash(normalized) + + +def compute_relation_hash(relation: tuple) -> str: + """ + 计算关系的哈希值 + + Args: + relation: 关系元组 (subject, predicate, object) + + Returns: + 关系哈希值(用于relation-前缀) + """ + # 将关系元组转为字符串 + relation_str = str(tuple(relation)) + return compute_hash(relation_str) + + +def format_hash_key(hash_type: str, hash_value: str) -> str: + """ + 格式化哈希键 + + Args: + hash_type: 类型前缀(paragraph, entity, relation) + hash_value: 哈希值 + + Returns: + 格式化的键(如 paragraph-abc123...) + """ + return f"{hash_type}-{hash_value}" + + +def parse_hash_key(key: str) -> tuple[str, str]: + """ + 解析哈希键 + + Args: + key: 格式化的键(如 paragraph-abc123...) + + Returns: + (类型, 哈希值) 元组 + """ + parts = key.split("-", 1) + if len(parts) != 2: + raise ValueError(f"无效的哈希键格式: {key}") + return parts[0], parts[1] diff --git a/plugins/A_memorix/core/utils/import_payloads.py b/plugins/A_memorix/core/utils/import_payloads.py new file mode 100644 index 00000000..6986a4c1 --- /dev/null +++ b/plugins/A_memorix/core/utils/import_payloads.py @@ -0,0 +1,110 @@ +"""Shared import payload normalization helpers.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from ..storage import KnowledgeType, resolve_stored_knowledge_type +from .time_parser import normalize_time_meta + + +def _normalize_entities(raw_entities: Any) -> List[str]: + if not isinstance(raw_entities, list): + return [] + out: List[str] = [] + seen = set() + for item in raw_entities: + name = str(item or "").strip() + if not name: + continue + key = name.lower() + if key in seen: + continue + seen.add(key) + out.append(name) + return out + + +def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]: + if not isinstance(raw_relations, list): + return [] + out: List[Dict[str, str]] = [] + for item in raw_relations: + if not isinstance(item, dict): + continue + subject = str(item.get("subject", "")).strip() + predicate = str(item.get("predicate", "")).strip() + obj = str(item.get("object", "")).strip() + if not (subject and predicate and obj): + continue + out.append( + { + "subject": subject, + "predicate": predicate, + "object": obj, + } + ) + return out + + +def normalize_paragraph_import_item( + item: Any, + *, + default_source: str, +) -> Dict[str, Any]: + """Normalize one paragraph import item from text/json payloads.""" + + if isinstance(item, str): + content = str(item) + knowledge_type = resolve_stored_knowledge_type(None, content=content) + return { + "content": content, + "knowledge_type": knowledge_type.value, + "source": str(default_source or "").strip(), + "time_meta": None, + "entities": [], + "relations": [], + } + + if not isinstance(item, dict) or "content" not in item: + raise ValueError("段落项必须为字符串或包含 content 的对象") + + content = str(item.get("content", "") or "") + if not content.strip(): + raise ValueError("段落 content 不能为空") + + raw_time_meta = { + "event_time": item.get("event_time"), + "event_time_start": item.get("event_time_start"), + "event_time_end": item.get("event_time_end"), + "time_range": item.get("time_range"), + "time_granularity": item.get("time_granularity"), + "time_confidence": item.get("time_confidence"), + } + time_meta_field = item.get("time_meta") + if isinstance(time_meta_field, dict): + raw_time_meta.update(time_meta_field) + + knowledge_type_raw = item.get("knowledge_type") + if knowledge_type_raw is None: + knowledge_type_raw = item.get("type") + knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content) + source = str(item.get("source") or default_source or "").strip() + if not source: + source = str(default_source or "").strip() + + normalized_time_meta = normalize_time_meta(raw_time_meta) + return { + "content": content, + "knowledge_type": knowledge_type.value, + "source": source, + "time_meta": normalized_time_meta if normalized_time_meta else None, + "entities": _normalize_entities(item.get("entities")), + "relations": _normalize_relations(item.get("relations")), + } + + +def normalize_summary_knowledge_type(value: Any) -> KnowledgeType: + """Normalize config-driven summary knowledge type.""" + + return resolve_stored_knowledge_type(value, content="") diff --git a/plugins/A_memorix/core/utils/io.py b/plugins/A_memorix/core/utils/io.py new file mode 100644 index 00000000..ed14df43 --- /dev/null +++ b/plugins/A_memorix/core/utils/io.py @@ -0,0 +1,84 @@ +""" +IO Utilities + +提供原子文件写入等IO辅助功能。 +""" + +import os +import shutil +import contextlib +from pathlib import Path +from typing import Union + +@contextlib.contextmanager +def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs): + """ + 原子文件写入上下文管理器 + + 原理: + 1. 写入 .tmp 临时文件 + 2. 写入成功后,使用 os.replace 原子替换目标文件 + 3. 如果失败,自动删除临时文件 + + Args: + file_path: 目标文件路径 + mode: 打开模式 ('w', 'wb' 等) + encoding: 编码 + **kwargs: 传给 open() 的其他参数 + """ + path = Path(file_path) + # 确保父目录存在 + path.parent.mkdir(parents=True, exist_ok=True) + + # 临时文件路径 + tmp_path = path.with_suffix(path.suffix + ".tmp") + + try: + with open(tmp_path, mode, encoding=encoding, **kwargs) as f: + yield f + + # 确保写入磁盘 + f.flush() + os.fsync(f.fileno()) + + # 原子替换 (Windows下可能需要先删除目标文件,但 os.replace 在 Py3.3+ 尽可能原子) + # 注意: Windows 上如果有其他进程占用文件,os.replace 可能会失败 + os.replace(tmp_path, path) + + except Exception as e: + # 清理临时文件 + if tmp_path.exists(): + try: + os.remove(tmp_path) + except: + pass + raise e + +@contextlib.contextmanager +def atomic_save_path(file_path: Union[str, Path]): + """ + 提供临时路径用于原子保存 (针对只接受路径的API,如Faiss) + + Args: + file_path: 最终目标路径 + + Yields: + tmp_path: 临时文件路径 (str) + """ + path = Path(file_path) + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + + try: + yield str(tmp_path) + + if Path(tmp_path).exists(): + os.replace(tmp_path, path) + + except Exception as e: + if Path(tmp_path).exists(): + try: + os.remove(tmp_path) + except: + pass + raise e diff --git a/plugins/A_memorix/core/utils/matcher.py b/plugins/A_memorix/core/utils/matcher.py new file mode 100644 index 00000000..bddff5ee --- /dev/null +++ b/plugins/A_memorix/core/utils/matcher.py @@ -0,0 +1,89 @@ +""" +高效文本匹配工具模块 + +实现 Aho-Corasick 算法用于多模式匹配。 +""" + +from typing import List, Dict, Tuple, Set, Any +from collections import deque + + +class AhoCorasick: + """ + Aho-Corasick 自动机实现高效多模式匹配 + """ + + def __init__(self): + # next_states[state][char] = next_state + self.next_states: List[Dict[str, int]] = [{}] + # fail[state] = fail_state + self.fail: List[int] = [0] + # output[state] = set of patterns ending at this state + self.output: List[Set[str]] = [set()] + self.patterns: Set[str] = set() + + def add_pattern(self, pattern: str): + """添加模式""" + if not pattern: + return + self.patterns.add(pattern) + state = 0 + for char in pattern: + if char not in self.next_states[state]: + new_state = len(self.next_states) + self.next_states[state][char] = new_state + self.next_states.append({}) + self.fail.append(0) + self.output.append(set()) + state = self.next_states[state][char] + self.output[state].add(pattern) + + def build(self): + """构建失败指针""" + queue = deque() + # 处理第一层 + for char, state in self.next_states[0].items(): + queue.append(state) + self.fail[state] = 0 + + while queue: + r = queue.popleft() + for char, s in self.next_states[r].items(): + queue.append(s) + # 找到失败路径 + state = self.fail[r] + while char not in self.next_states[state] and state != 0: + state = self.fail[state] + self.fail[s] = self.next_states[state].get(char, 0) + # 合并输出 + self.output[s].update(self.output[self.fail[s]]) + + def search(self, text: str) -> List[Tuple[int, str]]: + """ + 在文本中搜索所有模式 + + Returns: + [(结束索引, 匹配到的模式), ...] + """ + state = 0 + results = [] + for i, char in enumerate(text): + while char not in self.next_states[state] and state != 0: + state = self.fail[state] + state = self.next_states[state].get(char, 0) + for pattern in self.output[state]: + results.append((i, pattern)) + return results + + def find_all(self, text: str) -> Dict[str, int]: + """ + 查找并统计所有模式出现次数 + + Returns: + {模式: 出现次数} + """ + results = self.search(text) + stats = {} + for _, pattern in results: + stats[pattern] = stats.get(pattern, 0) + 1 + return stats diff --git a/plugins/A_memorix/core/utils/monitor.py b/plugins/A_memorix/core/utils/monitor.py new file mode 100644 index 00000000..39c794ab --- /dev/null +++ b/plugins/A_memorix/core/utils/monitor.py @@ -0,0 +1,189 @@ +""" +内存监控模块 + +提供内存使用监控和预警功能。 +""" + +import gc +import threading +import time +from typing import Callable, Optional + +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.MemoryMonitor") + + +class MemoryMonitor: + """ + 内存监控器 + + 功能: + - 实时监控内存使用 + - 超过阈值时触发警告 + - 支持自动垃圾回收 + """ + + def __init__( + self, + max_memory_mb: int, + warning_threshold: float = 0.9, + check_interval: float = 10.0, + enable_auto_gc: bool = True, + ): + """ + 初始化内存监控器 + + Args: + max_memory_mb: 最大内存限制(MB) + warning_threshold: 警告阈值(0-1之间,默认0.9表示90%) + check_interval: 检查间隔(秒) + enable_auto_gc: 是否启用自动垃圾回收 + """ + self.max_memory_mb = max_memory_mb + self.warning_threshold = warning_threshold + self.check_interval = check_interval + self.enable_auto_gc = enable_auto_gc + + self._running = False + self._thread: Optional[threading.Thread] = None + self._callbacks: list[Callable[[float, float], None]] = [] + + def start(self): + """启动监控""" + if self._running: + logger.warning("内存监控已在运行") + return + + if not HAS_PSUTIL: + logger.warning("psutil 未安装,内存监控功能不可用") + return + + self._running = True + self._thread = threading.Thread(target=self._monitor_loop, daemon=True) + self._thread.start() + logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)") + + def stop(self): + """停止监控""" + if not self._running: + return + + self._running = False + if self._thread: + self._thread.join(timeout=5.0) + logger.info("内存监控已停止") + + def register_callback(self, callback: Callable[[float, float], None]): + """ + 注册内存超限回调函数 + + Args: + callback: 回调函数,接收 (当前使用MB, 限制MB) 参数 + """ + self._callbacks.append(callback) + + def get_current_memory_mb(self) -> float: + """ + 获取当前进程内存使用量(MB) + + Returns: + 内存使用量(MB) + """ + if not HAS_PSUTIL: + # 降级方案:使用内置函数 + import sys + return sys.getsizeof(gc.get_objects()) / 1024 / 1024 + + process = psutil.Process() + return process.memory_info().rss / 1024 / 1024 + + def get_memory_usage_ratio(self) -> float: + """ + 获取内存使用率 + + Returns: + 使用率(0-1之间) + """ + current = self.get_current_memory_mb() + return current / self.max_memory_mb if self.max_memory_mb > 0 else 0 + + def _monitor_loop(self): + """监控循环""" + while self._running: + try: + current_mb = self.get_current_memory_mb() + ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0 + + # 检查是否超过阈值 + if ratio >= self.warning_threshold: + logger.warning( + f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB " + f"({ratio*100:.1f}%)" + ) + + # 触发回调 + for callback in self._callbacks: + try: + callback(current_mb, self.max_memory_mb) + except Exception as e: + logger.error(f"内存回调执行失败: {e}") + + # 自动垃圾回收 + if self.enable_auto_gc: + before = self.get_current_memory_mb() + gc.collect() + after = self.get_current_memory_mb() + freed = before - after + if freed > 1: # 释放超过1MB才记录 + logger.info(f"垃圾回收释放: {freed:.1f}MB") + + # 定期垃圾回收(即使未超限) + elif self.enable_auto_gc and int(time.time()) % 60 == 0: + gc.collect() + + except Exception as e: + logger.error(f"内存监控出错: {e}") + + # 等待下次检查 + time.sleep(self.check_interval) + + def __enter__(self): + """上下文管理器入口""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + self.stop() + + +def get_memory_info() -> dict: + """ + 获取系统内存信息 + + Returns: + 内存信息字典 + """ + if not HAS_PSUTIL: + return {"error": "psutil 未安装"} + + try: + mem = psutil.virtual_memory() + process = psutil.Process() + + return { + "system_total_gb": mem.total / 1024 / 1024 / 1024, + "system_available_gb": mem.available / 1024 / 1024 / 1024, + "system_usage_percent": mem.percent, + "process_mb": process.memory_info().rss / 1024 / 1024, + "process_percent": (process.memory_info().rss / mem.total) * 100, + } + except Exception as e: + return {"error": str(e)} diff --git a/plugins/A_memorix/core/utils/path_fallback_service.py b/plugins/A_memorix/core/utils/path_fallback_service.py new file mode 100644 index 00000000..7a802743 --- /dev/null +++ b/plugins/A_memorix/core/utils/path_fallback_service.py @@ -0,0 +1,165 @@ +"""Shared path-fallback helpers for search post-processing.""" + +from __future__ import annotations + +import hashlib +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from ..retrieval.dual_path import RetrievalResult + + +def extract_entities(query: str, graph_store: Any) -> List[str]: + """Extract up to two graph nodes from a query using n-gram matching.""" + if not graph_store: + return [] + + text = str(query or "").strip() + if not text: + return [] + + # Keep the heuristic aligned with previous legacy behavior. + tokens = ( + text.replace("?", " ") + .replace("!", " ") + .replace(".", " ") + .split() + ) + if not tokens: + return [] + + found_entities = set() + skip_indices = set() + max_n = min(4, len(tokens)) + + for size in range(max_n, 0, -1): + for i in range(len(tokens) - size + 1): + if any(idx in skip_indices for idx in range(i, i + size)): + continue + span = " ".join(tokens[i : i + size]) + matched_node = graph_store.find_node(span, ignore_case=True) + if not matched_node: + continue + found_entities.add(matched_node) + for idx in range(i, i + size): + skip_indices.add(idx) + + return list(found_entities) + + +def find_paths_between_entities( + start_node: str, + end_node: str, + graph_store: Any, + metadata_store: Any, + *, + max_depth: int = 3, + max_paths: int = 5, +) -> List[Dict[str, Any]]: + """Find and enrich indirect paths between two nodes.""" + if not graph_store or not metadata_store: + return [] + + try: + paths = graph_store.find_paths( + start_node, + end_node, + max_depth=max_depth, + max_paths=max_paths, + ) + except Exception: + return [] + + if not paths: + return [] + + edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {} + formatted_paths: List[Dict[str, Any]] = [] + + for path_nodes in paths: + if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2: + continue + + path_desc: List[str] = [] + for i in range(len(path_nodes) - 1): + u = str(path_nodes[i]) + v = str(path_nodes[i + 1]) + + cache_key = tuple(sorted((u, v))) + if cache_key in edge_cache: + pred, direction = edge_cache[cache_key] + else: + pred = "related" + direction = "->" + rels = metadata_store.get_relations(subject=u, object=v) + if not rels: + rels = metadata_store.get_relations(subject=v, object=u) + direction = "<-" + if rels: + best_rel = max(rels, key=lambda x: x.get("confidence", 1.0)) + pred = str(best_rel.get("predicate", "related") or "related") + edge_cache[cache_key] = (pred, direction) + + step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-" + path_desc.append(step_str) + + full_path_str = str(path_nodes[0]) + for i, step in enumerate(path_desc): + full_path_str += f" {step} {path_nodes[i + 1]}" + + formatted_paths.append( + { + "nodes": list(path_nodes), + "description": full_path_str, + } + ) + + return formatted_paths + + +def find_paths_from_query( + query: str, + graph_store: Any, + metadata_store: Any, + *, + max_depth: int = 3, + max_paths: int = 5, +) -> List[Dict[str, Any]]: + """Extract entities from query and resolve indirect paths.""" + entities = extract_entities(query, graph_store) + if len(entities) != 2: + return [] + return find_paths_between_entities( + entities[0], + entities[1], + graph_store, + metadata_store, + max_depth=max_depth, + max_paths=max_paths, + ) + + +def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]: + """Convert path results into retrieval results for the unified pipeline.""" + converted: List[RetrievalResult] = [] + for item in paths: + description = str(item.get("description", "")).strip() + if not description: + continue + hash_seed = description.encode("utf-8") + path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}" + converted.append( + RetrievalResult( + hash_value=path_hash, + content=f"[Indirect Relation] {description}", + score=0.95, + result_type="relation", + source="graph_path", + metadata={ + "source": "graph_path", + "is_indirect": True, + "nodes": list(item.get("nodes", [])), + }, + ) + ) + return converted + diff --git a/plugins/A_memorix/core/utils/person_profile_service.py b/plugins/A_memorix/core/utils/person_profile_service.py new file mode 100644 index 00000000..ccbbaf90 --- /dev/null +++ b/plugins/A_memorix/core/utils/person_profile_service.py @@ -0,0 +1,495 @@ +""" +人物画像服务 + +主链路: +person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储 +""" + +import json +import time +from typing import Any, Dict, List, Optional, Tuple + +from src.common.logger import get_logger +from src.common.database.database_model import PersonInfo + +from ..embedding import EmbeddingAPIAdapter +from ..retrieval import ( + DualPathRetriever, + RetrievalStrategy, + DualPathRetrieverConfig, + SparseBM25Config, + FusionConfig, + GraphRelationRecallConfig, +) +from ..storage import MetadataStore, GraphStore, VectorStore + +logger = get_logger("A_Memorix.PersonProfileService") + + +class PersonProfileService: + """人物画像聚合/刷新服务。""" + + def __init__( + self, + metadata_store: MetadataStore, + graph_store: Optional[GraphStore] = None, + vector_store: Optional[VectorStore] = None, + embedding_manager: Optional[EmbeddingAPIAdapter] = None, + sparse_index: Any = None, + plugin_config: Optional[dict] = None, + retriever: Optional[DualPathRetriever] = None, + ): + self.metadata_store = metadata_store + self.graph_store = graph_store + self.vector_store = vector_store + self.embedding_manager = embedding_manager + self.sparse_index = sparse_index + self.plugin_config = plugin_config or {} + self.retriever = retriever or self._build_retriever() + + def _cfg(self, key: str, default: Any = None) -> Any: + """读取嵌套配置。""" + if not isinstance(self.plugin_config, dict): + return default + current: Any = self.plugin_config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + def _build_retriever(self) -> Optional[DualPathRetriever]: + """按需构建检索器(无依赖时返回 None)。""" + if not all( + [ + self.vector_store is not None, + self.graph_store is not None, + self.metadata_store is not None, + self.embedding_manager is not None, + ] + ): + return None + try: + sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {} + fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {} + graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {} + if not isinstance(sparse_cfg_raw, dict): + sparse_cfg_raw = {} + if not isinstance(fusion_cfg_raw, dict): + fusion_cfg_raw = {} + if not isinstance(graph_recall_cfg_raw, dict): + graph_recall_cfg_raw = {} + + sparse_cfg = SparseBM25Config(**sparse_cfg_raw) + fusion_cfg = FusionConfig(**fusion_cfg_raw) + graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw) + config = DualPathRetrieverConfig( + top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)), + top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)), + top_k_final=int(self._cfg("retrieval.top_k_final", 10)), + alpha=float(self._cfg("retrieval.alpha", 0.5)), + enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)), + ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)), + ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)), + enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)), + retrieval_strategy=RetrievalStrategy.DUAL_PATH, + debug=bool(self._cfg("advanced.debug", False)), + sparse=sparse_cfg, + fusion=fusion_cfg, + graph_recall=graph_recall_cfg, + ) + return DualPathRetriever( + vector_store=self.vector_store, + graph_store=self.graph_store, + metadata_store=self.metadata_store, + embedding_manager=self.embedding_manager, + sparse_index=self.sparse_index, + config=config, + ) + except Exception as e: + logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}") + return None + + @staticmethod + def resolve_person_id(identifier: str) -> str: + """按 person_id 或姓名/别名解析 person_id。""" + if not identifier: + return "" + key = str(identifier).strip() + if not key: + return "" + + if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key): + return key.lower() + + try: + record = ( + PersonInfo.select(PersonInfo.person_id) + .where((PersonInfo.person_name == key) | (PersonInfo.nickname == key)) + .first() + ) + if record and record.person_id: + return str(record.person_id) + except Exception: + pass + + try: + record = ( + PersonInfo.select(PersonInfo.person_id) + .where(PersonInfo.group_nick_name.contains(key)) + .first() + ) + if record and record.person_id: + return str(record.person_id) + except Exception: + pass + + return "" + + def _parse_group_nicks(self, raw_value: Any) -> List[str]: + if not raw_value: + return [] + if isinstance(raw_value, list): + items = raw_value + else: + try: + items = json.loads(raw_value) + except Exception: + return [] + names: List[str] = [] + for item in items: + if isinstance(item, dict): + value = str(item.get("group_nick_name", "")).strip() + if value: + names.append(value) + elif isinstance(item, str): + value = item.strip() + if value: + names.append(value) + return names + + def _parse_memory_traits(self, raw_value: Any) -> List[str]: + if not raw_value: + return [] + try: + values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value + except Exception: + return [] + if not isinstance(values, list): + return [] + traits: List[str] = [] + for item in values: + text = str(item).strip() + if not text: + continue + if ":" in text: + parts = text.split(":") + if len(parts) >= 3: + content = ":".join(parts[1:-1]).strip() + if content: + traits.append(content) + continue + traits.append(text) + return traits[:10] + + def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]: + """获取人物别名集合、主展示名、记忆特征。""" + aliases: List[str] = [] + primary_name = "" + memory_traits: List[str] = [] + if not person_id: + return aliases, primary_name, memory_traits + try: + record = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + if not record: + return aliases, primary_name, memory_traits + person_name = str(getattr(record, "person_name", "") or "").strip() + nickname = str(getattr(record, "nickname", "") or "").strip() + group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None)) + memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None)) + + primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id + + candidates = [person_name, nickname] + group_nicks + seen = set() + for item in candidates: + norm = str(item or "").strip() + if not norm or norm in seen: + continue + seen.add(norm) + aliases.append(norm) + except Exception as e: + logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}") + return aliases, primary_name, memory_traits + + def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]: + relation_by_hash: Dict[str, Dict[str, Any]] = {} + for alias in aliases: + for rel in self.metadata_store.get_relations(subject=alias): + h = str(rel.get("hash", "")) + if h: + relation_by_hash[h] = rel + for rel in self.metadata_store.get_relations(object=alias): + h = str(rel.get("hash", "")) + if h: + relation_by_hash[h] = rel + + relations = list(relation_by_hash.values()) + relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True) + relations = relations[: max(1, int(limit))] + + edges: List[Dict[str, Any]] = [] + for rel in relations: + edges.append( + { + "hash": str(rel.get("hash", "")), + "subject": str(rel.get("subject", "")), + "predicate": str(rel.get("predicate", "")), + "object": str(rel.get("object", "")), + "confidence": float(rel.get("confidence", 1.0) or 1.0), + } + ) + return edges + + async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]: + alias_queries = [a for a in aliases if a] + if not alias_queries: + return [] + + if self.retriever is None: + # 回退:无检索器时只做简单内容匹配 + fallback: List[Dict[str, Any]] = [] + seen_hash = set() + for alias in alias_queries: + for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]: + h = str(para.get("hash", "")) + if not h or h in seen_hash: + continue + seen_hash.add(h) + fallback.append( + { + "hash": h, + "type": "paragraph", + "score": 0.0, + "content": str(para.get("content", ""))[:180], + "metadata": {}, + } + ) + return fallback[:top_k] + + per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries)))) + seen_hash = set() + evidence: List[Dict[str, Any]] = [] + for alias in alias_queries: + try: + results = await self.retriever.retrieve(alias, top_k=per_alias_top_k) + except Exception as e: + logger.warning(f"向量证据召回失败: alias={alias}, err={e}") + continue + for item in results: + h = str(getattr(item, "hash_value", "") or "") + if not h or h in seen_hash: + continue + seen_hash.add(h) + evidence.append( + { + "hash": h, + "type": str(getattr(item, "result_type", "")), + "score": float(getattr(item, "score", 0.0) or 0.0), + "content": str(getattr(item, "content", "") or "")[:220], + "metadata": dict(getattr(item, "metadata", {}) or {}), + } + ) + evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True) + return evidence[:top_k] + + def _build_profile_text( + self, + person_id: str, + primary_name: str, + aliases: List[str], + relation_edges: List[Dict[str, Any]], + vector_evidence: List[Dict[str, Any]], + memory_traits: List[str], + ) -> str: + """基于证据构建画像文本(供 LLM 上下文注入)。""" + lines: List[str] = [] + lines.append(f"人物ID: {person_id}") + if primary_name: + lines.append(f"主称呼: {primary_name}") + if aliases: + lines.append(f"别名: {', '.join(aliases[:8])}") + if memory_traits: + lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}") + + if relation_edges: + lines.append("关系证据:") + for rel in relation_edges[:6]: + s = rel.get("subject", "") + p = rel.get("predicate", "") + o = rel.get("object", "") + conf = float(rel.get("confidence", 0.0)) + lines.append(f"- {s} {p} {o} (conf={conf:.2f})") + + if vector_evidence: + lines.append("向量证据摘要:") + for item in vector_evidence[:4]: + content = str(item.get("content", "")).strip() + if content: + lines.append(f"- {content}") + + if len(lines) <= 2: + lines.append("暂无足够证据形成稳定画像。") + + return "\n".join(lines) + + @staticmethod + def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool: + if not snapshot: + return True + now = time.time() + expires_at = snapshot.get("expires_at") + if expires_at is not None: + try: + return now >= float(expires_at) + except Exception: + return True + updated_at = float(snapshot.get("updated_at") or 0.0) + return (now - updated_at) >= ttl_seconds + + def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]: + """将手工覆盖并入画像结果(覆盖 profile_text,同时保留 auto_profile_text)。""" + payload = dict(profile_payload or {}) + auto_text = str(payload.get("profile_text", "") or "") + payload["auto_profile_text"] = auto_text + payload["has_manual_override"] = False + payload["manual_override_text"] = "" + payload["override_updated_at"] = None + payload["override_updated_by"] = "" + payload["profile_source"] = "auto_snapshot" + + if not person_id or self.metadata_store is None: + return payload + + try: + override = self.metadata_store.get_person_profile_override(person_id) + except Exception as e: + logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}") + return payload + + if not override: + return payload + + manual_text = str(override.get("override_text", "") or "").strip() + if not manual_text: + return payload + + payload["has_manual_override"] = True + payload["manual_override_text"] = manual_text + payload["override_updated_at"] = override.get("updated_at") + payload["override_updated_by"] = str(override.get("updated_by", "") or "") + payload["profile_text"] = manual_text + payload["profile_source"] = "manual_override" + return payload + + async def query_person_profile( + self, + person_id: str = "", + person_keyword: str = "", + top_k: int = 12, + ttl_seconds: float = 6 * 3600, + force_refresh: bool = False, + source_note: str = "", + ) -> Dict[str, Any]: + """查询或刷新人物画像。""" + pid = str(person_id or "").strip() + if not pid and person_keyword: + pid = self.resolve_person_id(person_keyword) + + if not pid: + return { + "success": False, + "error": "person_id 无效,且未能通过别名解析", + } + + latest = self.metadata_store.get_latest_person_profile_snapshot(pid) + if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds): + aliases, primary_name, _ = self.get_person_aliases(pid) + payload = { + "success": True, + "person_id": pid, + "person_name": primary_name, + "from_cache": True, + **(latest or {}), + } + if aliases and not payload.get("aliases"): + payload["aliases"] = aliases + return { + **self._apply_manual_override(pid, payload), + } + + aliases, primary_name, memory_traits = self.get_person_aliases(pid) + if not aliases and person_keyword: + aliases = [person_keyword.strip()] + primary_name = person_keyword.strip() + relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2)) + vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k)) + + evidence_ids = [ + str(item.get("hash", "")) + for item in (relation_edges + vector_evidence) + if str(item.get("hash", "")).strip() + ] + dedup_ids: List[str] = [] + seen = set() + for item in evidence_ids: + if item in seen: + continue + seen.add(item) + dedup_ids.append(item) + + profile_text = self._build_profile_text( + person_id=pid, + primary_name=primary_name, + aliases=aliases, + relation_edges=relation_edges, + vector_evidence=vector_evidence, + memory_traits=memory_traits, + ) + + expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None + snapshot = self.metadata_store.upsert_person_profile_snapshot( + person_id=pid, + profile_text=profile_text, + aliases=aliases, + relation_edges=relation_edges, + vector_evidence=vector_evidence, + evidence_ids=dedup_ids, + expires_at=expires_at, + source_note=source_note, + ) + payload = { + "success": True, + "person_id": pid, + "person_name": primary_name, + "from_cache": False, + **snapshot, + } + return { + **self._apply_manual_override(pid, payload), + } + + @staticmethod + def format_persona_profile_block(profile: Dict[str, Any]) -> str: + """格式化给 replyer 的注入块。""" + if not profile or not profile.get("success"): + return "" + text = str(profile.get("profile_text", "") or "").strip() + if not text: + return "" + return ( + "【人物画像-内部参考】\n" + f"{text}\n" + "仅供内部推理,不要向用户逐字复述。" + ) diff --git a/plugins/A_memorix/core/utils/plugin_id_policy.py b/plugins/A_memorix/core/utils/plugin_id_policy.py new file mode 100644 index 00000000..8e730e12 --- /dev/null +++ b/plugins/A_memorix/core/utils/plugin_id_policy.py @@ -0,0 +1,27 @@ +"""Plugin ID matching policy for A_Memorix.""" + +from __future__ import annotations + +from typing import Any + + +class PluginIdPolicy: + """Centralized plugin id normalization/matching policy.""" + + CANONICAL_ID = "a_memorix" + + @classmethod + def normalize(cls, plugin_id: Any) -> str: + if not isinstance(plugin_id, str): + return "" + return plugin_id.strip().lower() + + @classmethod + def is_target_plugin_id(cls, plugin_id: Any) -> bool: + normalized = cls.normalize(plugin_id) + if not normalized: + return False + if normalized == cls.CANONICAL_ID: + return True + return normalized.split(".")[-1] == cls.CANONICAL_ID + diff --git a/plugins/A_memorix/core/utils/quantization.py b/plugins/A_memorix/core/utils/quantization.py new file mode 100644 index 00000000..4e84f977 --- /dev/null +++ b/plugins/A_memorix/core/utils/quantization.py @@ -0,0 +1,344 @@ +""" +向量量化工具模块 + +提供向量量化与反量化功能,用于压缩存储空间。 +""" + +import numpy as np +from enum import Enum +from typing import Tuple, Union + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.Quantization") + + +class QuantizationType(Enum): + """量化类型枚举""" + FLOAT32 = "float32" # 无量化 + INT8 = "int8" # 标量量化(8位整数) + PQ = "pq" # 乘积量化(Product Quantization) + + +def quantize_vector( + vector: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + 量化向量 + + Args: + vector: 输入向量(float32) + quant_type: 量化类型 + + Returns: + 量化后的向量: + - INT8: int8向量 + - PQ: (编码向量, 聚类中心) 元组 + """ + if quant_type == QuantizationType.FLOAT32: + return vector.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + return _scalar_quantize_int8(vector) + + elif quant_type == QuantizationType.PQ: + return _product_quantize(vector) + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def dequantize_vector( + quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], + quant_type: QuantizationType = QuantizationType.INT8, + original_shape: Tuple[int, ...] = None, +) -> np.ndarray: + """ + 反量化向量 + + Args: + quantized_vector: 量化后的向量 + quant_type: 量化类型 + original_shape: 原始向量形状(用于PQ) + + Returns: + 反量化后的向量(float32) + """ + if quant_type == QuantizationType.FLOAT32: + return quantized_vector.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + return _scalar_dequantize_int8(quantized_vector) + + elif quant_type == QuantizationType.PQ: + if not isinstance(quantized_vector, tuple): + raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)") + return _product_dequantize(quantized_vector[0], quantized_vector[1]) + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray: + """ + 标量量化:float32 -> int8 + + 将向量归一化到 [0, 255] 范围,然后映射到 int8 + + Args: + vector: 输入向量 + + Returns: + 量化后的 int8 向量 + """ + # 计算最小最大值 + min_val = np.min(vector) + max_val = np.max(vector) + + # 避免除零 + if max_val == min_val: + return np.zeros_like(vector, dtype=np.int8) + + # 归一化到 [0, 255] + normalized = (vector - min_val) / (max_val - min_val) * 255 + + # 映射到 [-128, 127] 并转换为 int8 + # np.round might return float, minus 128 then cast + quantized = np.round(normalized - 128.0).astype(np.int8) + + # 存储归一化参数(用于反量化) + # 在实际存储中,这些参数需要单独保存 + # 这里为了简单,我们使用一个全局字典来模拟 + if not hasattr(_scalar_quantize_int8, "_params"): + _scalar_quantize_int8._params = {} + + vector_id = id(vector) + _scalar_quantize_int8._params[vector_id] = (min_val, max_val) + + return quantized + + +def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray: + """ + 标量反量化:int8 -> float32 + + Args: + quantized: 量化后的 int8 向量 + + Returns: + 反量化后的 float32 向量 + """ + # 计算归一化参数(如果提供了) + # 在实际应用中,min_val 和 max_val 应该被保存 + if not hasattr(_scalar_dequantize_int8, "_params"): + # 默认假设范围是 [-1, 1] + return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0 + + # 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取) + # return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min + return (quantized.astype(np.float32) + 128.0) / 255.0 + + +def quantize_matrix( + matrix: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """ + 量化矩阵(批量量化向量) + + Args: + matrix: 输入矩阵(N x D,每行是一个向量) + quant_type: 量化类型 + + Returns: + 量化后的矩阵 + """ + if quant_type == QuantizationType.FLOAT32: + return matrix.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + # 对整个矩阵进行全局归一化 + min_val = np.min(matrix) + max_val = np.max(matrix) + + if max_val == min_val: + return np.zeros_like(matrix, dtype=np.int8) + + # 归一化到 [0, 255] + normalized = (matrix - min_val) / (max_val - min_val) * 255 + quantized = np.round(normalized).astype(np.int8) + + return quantized + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def dequantize_matrix( + quantized_matrix: np.ndarray, + quant_type: QuantizationType = QuantizationType.INT8, + min_val: float = None, + max_val: float = None, +) -> np.ndarray: + """ + 反量化矩阵 + + Args: + quantized_matrix: 量化后的矩阵 + quant_type: 量化类型 + min_val: 归一化最小值(int8反量化需要) + max_val: 归一化最大值(int8反量化需要) + + Returns: + 反量化后的矩阵 + """ + if quant_type == QuantizationType.FLOAT32: + return quantized_matrix.astype(np.float32) + + elif quant_type == QuantizationType.INT8: + # 使用提供的归一化参数反量化 + if min_val is None or max_val is None: + # 默认假设范围是 [0, 255] -> [-1, 1] + return quantized_matrix.astype(np.float32) / 127.0 + else: + # 恢复到原始范围 + normalized = quantized_matrix.astype(np.float32) / 255.0 + return normalized * (max_val - min_val) + min_val + + else: + raise ValueError(f"不支持的量化类型: {quant_type}") + + +def estimate_memory_reduction( + num_vectors: int, + dimension: int, + from_type: QuantizationType, + to_type: QuantizationType, +) -> Tuple[float, float]: + """ + 估算内存节省量 + + Args: + num_vectors: 向量数量 + dimension: 向量维度 + from_type: 原始量化类型 + to_type: 目标量化类型 + + Returns: + (原始大小MB, 量化后大小MB, 节省比例) + """ + # 计算每个向量占用的字节数 + bytes_per_element = { + QuantizationType.FLOAT32: 4, + QuantizationType.INT8: 1, + QuantizationType.PQ: 0.25, # 假设压缩到1/4 + } + + original_bytes = num_vectors * dimension * bytes_per_element[from_type] + quantized_bytes = num_vectors * dimension * bytes_per_element[to_type] + + original_mb = original_bytes / 1024 / 1024 + quantized_mb = quantized_bytes / 1024 / 1024 + reduction_ratio = (original_bytes - quantized_bytes) / original_bytes + + return original_mb, quantized_mb, reduction_ratio + + +def estimate_compression_stats( + num_vectors: int, + dimension: int, + quant_type: QuantizationType, +) -> dict: + """ + 估算压缩统计信息 + + Args: + num_vectors: 向量数量 + dimension: 向量维度 + quant_type: 量化类型 + + Returns: + 统计信息字典 + """ + original_mb, quantized_mb, ratio = estimate_memory_reduction( + num_vectors, dimension, QuantizationType.FLOAT32, quant_type + ) + + return { + "num_vectors": num_vectors, + "dimension": dimension, + "quantization_type": quant_type.value, + "original_size_mb": round(original_mb, 2), + "quantized_size_mb": round(quantized_mb, 2), + "saved_mb": round(original_mb - quantized_mb, 2), + "compression_ratio": round(ratio * 100, 2), + } + + +def _product_quantize( + vector: np.ndarray, m: int = 8, k: int = 256 +) -> Tuple[np.ndarray, np.ndarray]: + """ + 乘积量化 (PQ) 简化实现 + + Args: + vector: 输入向量 (D,) + m: 子空间数量 + k: 每个子空间的聚类中心数 + + Returns: + (编码后的向量, 聚类中心) + """ + d = vector.shape[0] + if d % m != 0: + raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除") + + ds = d // m # 子空间维度 + codes = np.zeros(m, dtype=np.uint8) + centroids = np.zeros((m, k, ds), dtype=np.float32) + + # 这里采用一种简化的 PQ:不进行 K-means 训练, + # 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练) + # 为了演示,我们直接将子空间切分为 k 份进行量化 + for i in range(m): + sub_vec = vector[i * ds : (i + 1) * ds] + # 简化:假定每个子空间的取值范围并划分 + # 实际 PQ 应使用 k-means 产生的 centroids + # 这里为演示创建一个随机 codebook 并找到最接近的核心 + sub_min, sub_max = np.min(sub_vec), np.max(sub_vec) + if sub_max == sub_min: + linspace = np.zeros(k) + else: + linspace = np.linspace(sub_min, sub_max, k) + + for j in range(k): + centroids[i, j, :] = linspace[j] + + # 编码:这里简化为取子空间均值找最接近的 centroid + sub_mean = np.mean(sub_vec) + code = np.argmin(np.abs(linspace - sub_mean)) + codes[i] = code + + return codes, centroids + + +def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray: + """ + PQ 反量化 + + Args: + codes: 编码向量 (M,) + centroids: 聚类中心 (M, K, DS) + + Returns: + 恢复后的向量 (D,) + """ + m, k, ds = centroids.shape + vector = np.zeros(m * ds, dtype=np.float32) + + for i in range(m): + code = codes[i] + vector[i * ds : (i + 1) * ds] = centroids[i, code, :] + + return vector diff --git a/plugins/A_memorix/core/utils/relation_query.py b/plugins/A_memorix/core/utils/relation_query.py new file mode 100644 index 00000000..ffde9cac --- /dev/null +++ b/plugins/A_memorix/core/utils/relation_query.py @@ -0,0 +1,121 @@ +"""关系查询规格解析工具。""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class RelationQuerySpec: + raw: str + is_structured: bool + subject: Optional[str] + predicate: Optional[str] + object: Optional[str] + error: Optional[str] = None + + +_NATURAL_LANGUAGE_PATTERN = re.compile( + r"(^\s*(what|who|which|how|why|when|where)\b|" + r"\?|?|" + r"\b(relation|related|between)\b|" + r"(什么关系|有哪些关系|之间|关联))", + re.IGNORECASE, +) + + +def _looks_like_natural_language(raw: str) -> bool: + text = str(raw or "").strip() + if not text: + return False + return _NATURAL_LANGUAGE_PATTERN.search(text) is not None + + +def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec: + raw = str(relation_spec or "").strip() + if not raw: + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + error="empty", + ) + + if "|" in raw: + parts = [p.strip() for p in raw.split("|")] + if len(parts) < 2: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=None, + predicate=None, + object=None, + error="invalid_pipe_format", + ) + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0] or None, + predicate=parts[1] or None, + object=parts[2] if len(parts) > 2 and parts[2] else None, + ) + + if "->" in raw: + parts = [p.strip() for p in raw.split("->") if p.strip()] + if len(parts) >= 3: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=parts[1], + object=parts[2], + ) + if len(parts) == 2: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=None, + object=parts[1], + ) + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=None, + predicate=None, + object=None, + error="invalid_arrow_format", + ) + + if _looks_like_natural_language(raw): + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + ) + + # 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。 + # 两词形式过于模糊,不再视为结构化关系查询。 + parts = raw.split() + if len(parts) == 3: + return RelationQuerySpec( + raw=raw, + is_structured=True, + subject=parts[0], + predicate=parts[1], + object=parts[2], + ) + + return RelationQuerySpec( + raw=raw, + is_structured=False, + subject=None, + predicate=None, + object=None, + ) diff --git a/plugins/A_memorix/core/utils/relation_write_service.py b/plugins/A_memorix/core/utils/relation_write_service.py new file mode 100644 index 00000000..b73e1260 --- /dev/null +++ b/plugins/A_memorix/core/utils/relation_write_service.py @@ -0,0 +1,164 @@ +""" +统一关系写入与关系向量化服务。 + +规则: +1. 元数据是主数据源,向量是从索引。 +2. 关系先写 metadata,再写向量。 +3. 向量失败不回滚 metadata,依赖状态机与回填任务修复。 +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from src.common.logger import get_logger + + +logger = get_logger("A_Memorix.RelationWriteService") + + +@dataclass +class RelationWriteResult: + hash_value: str + vector_written: bool + vector_already_exists: bool + vector_state: str + + +class RelationWriteService: + """关系写入收口服务。""" + + ERROR_MAX_LEN = 500 + + def __init__( + self, + metadata_store: Any, + graph_store: Any, + vector_store: Any, + embedding_manager: Any, + ): + self.metadata_store = metadata_store + self.graph_store = graph_store + self.vector_store = vector_store + self.embedding_manager = embedding_manager + + @staticmethod + def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str: + s = str(subject or "").strip() + p = str(predicate or "").strip() + o = str(obj or "").strip() + # 双表达:兼容关键词检索与自然语言问句 + return f"{s} {p} {o}\n{s}和{o}的关系是{p}" + + async def ensure_relation_vector( + self, + hash_value: str, + subject: str, + predicate: str, + obj: str, + *, + max_error_len: int = ERROR_MAX_LEN, + ) -> RelationWriteResult: + """ + 为已有关系确保向量存在并更新状态。 + """ + if hash_value in self.vector_store: + self.metadata_store.set_relation_vector_state(hash_value, "ready") + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=True, + vector_state="ready", + ) + + self.metadata_store.set_relation_vector_state(hash_value, "pending") + try: + vector_text = self.build_relation_vector_text(subject, predicate, obj) + embedding = await self.embedding_manager.encode(vector_text) + self.vector_store.add( + vectors=embedding.reshape(1, -1), + ids=[hash_value], + ) + self.metadata_store.set_relation_vector_state(hash_value, "ready") + logger.info( + "metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s", + hash_value[:16], + ) + return RelationWriteResult( + hash_value=hash_value, + vector_written=True, + vector_already_exists=False, + vector_state="ready", + ) + except ValueError: + # 向量已存在冲突,按成功处理 + self.metadata_store.set_relation_vector_state(hash_value, "ready") + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=True, + vector_state="ready", + ) + except Exception as e: + err = str(e)[:max_error_len] + self.metadata_store.set_relation_vector_state( + hash_value, + "failed", + error=err, + bump_retry=True, + ) + logger.warning( + "metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s", + hash_value[:16], + err, + ) + return RelationWriteResult( + hash_value=hash_value, + vector_written=False, + vector_already_exists=False, + vector_state="failed", + ) + + async def upsert_relation_with_vector( + self, + subject: str, + predicate: str, + obj: str, + confidence: float = 1.0, + source_paragraph: str = "", + metadata: Optional[Dict[str, Any]] = None, + *, + write_vector: bool = True, + ) -> RelationWriteResult: + """ + 统一关系写入: + 1) 写 metadata relation + 2) 写 graph edge relation_hash + 3) 按需写 relation vector + """ + rel_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + confidence=confidence, + source_paragraph=source_paragraph, + metadata=metadata or {}, + ) + self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + + if not write_vector: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + return RelationWriteResult( + hash_value=rel_hash, + vector_written=False, + vector_already_exists=False, + vector_state="none", + ) + + return await self.ensure_relation_vector( + hash_value=rel_hash, + subject=subject, + predicate=predicate, + obj=obj, + ) diff --git a/plugins/A_memorix/core/utils/runtime_self_check.py b/plugins/A_memorix/core/utils/runtime_self_check.py new file mode 100644 index 00000000..36a2cf7e --- /dev/null +++ b/plugins/A_memorix/core/utils/runtime_self_check.py @@ -0,0 +1,197 @@ +"""Runtime self-check helpers for A_Memorix.""" + +from __future__ import annotations + +import time +from typing import Any, Dict, Optional + +import numpy as np + +from src.common.logger import get_logger + +logger = get_logger("A_Memorix.RuntimeSelfCheck") + +_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check" + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + return int(value) + except Exception: + return int(default) + + +def _get_config_value(config: Any, key: str, default: Any = None) -> Any: + getter = getattr(config, "get_config", None) + if callable(getter): + return getter(key, default) + + current: Any = config + for part in key.split("."): + if isinstance(current, dict) and part in current: + current = current[part] + else: + return default + return current + + +def _build_report( + *, + ok: bool, + code: str, + message: str, + configured_dimension: int, + vector_store_dimension: int, + detected_dimension: int, + encoded_dimension: int, + elapsed_ms: float, + sample_text: str, +) -> Dict[str, Any]: + return { + "ok": bool(ok), + "code": str(code or "").strip(), + "message": str(message or "").strip(), + "configured_dimension": int(configured_dimension), + "vector_store_dimension": int(vector_store_dimension), + "detected_dimension": int(detected_dimension), + "encoded_dimension": int(encoded_dimension), + "elapsed_ms": float(elapsed_ms), + "sample_text": str(sample_text or ""), + "checked_at": time.time(), + } + + +async def run_embedding_runtime_self_check( + *, + config: Any, + vector_store: Optional[Any], + embedding_manager: Optional[Any], + sample_text: str = _DEFAULT_SAMPLE_TEXT, +) -> Dict[str, Any]: + """Probe the real embedding path and compare dimensions with runtime storage.""" + configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0) + vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0) + + if vector_store is None or embedding_manager is None: + return _build_report( + ok=False, + code="runtime_components_missing", + message="vector_store 或 embedding_manager 未初始化", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=0, + encoded_dimension=0, + elapsed_ms=0.0, + sample_text=sample_text, + ) + + start = time.perf_counter() + detected_dimension = 0 + encoded_dimension = 0 + try: + detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0) + encoded = await embedding_manager.encode(sample_text) + if isinstance(encoded, np.ndarray): + encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1]) + else: + encoded_dimension = len(encoded) if encoded is not None else 0 + except Exception as exc: + elapsed_ms = (time.perf_counter() - start) * 1000.0 + logger.warning("embedding runtime self-check failed: %s", exc) + return _build_report( + ok=False, + code="embedding_probe_failed", + message=f"embedding probe failed: {exc}", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + elapsed_ms = (time.perf_counter() - start) * 1000.0 + expected_dimension = vector_store_dimension or configured_dimension or detected_dimension + if expected_dimension <= 0: + return _build_report( + ok=False, + code="invalid_expected_dimension", + message="无法确定期望 embedding 维度", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + if encoded_dimension != expected_dimension: + msg = ( + "embedding 真实输出维度与当前向量存储不一致: " + f"expected={expected_dimension}, encoded={encoded_dimension}" + ) + logger.error(msg) + return _build_report( + ok=False, + code="embedding_dimension_mismatch", + message=msg, + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + return _build_report( + ok=True, + code="ok", + message="embedding runtime self-check passed", + configured_dimension=configured_dimension, + vector_store_dimension=vector_store_dimension, + detected_dimension=detected_dimension, + encoded_dimension=encoded_dimension, + elapsed_ms=elapsed_ms, + sample_text=sample_text, + ) + + +async def ensure_runtime_self_check( + plugin_or_config: Any, + *, + force: bool = False, + sample_text: str = _DEFAULT_SAMPLE_TEXT, +) -> Dict[str, Any]: + """Run or reuse cached runtime self-check report.""" + if plugin_or_config is None: + return _build_report( + ok=False, + code="missing_plugin_or_config", + message="plugin/config unavailable", + configured_dimension=0, + vector_store_dimension=0, + detected_dimension=0, + encoded_dimension=0, + elapsed_ms=0.0, + sample_text=sample_text, + ) + + cache = getattr(plugin_or_config, "_runtime_self_check_report", None) + if isinstance(cache, dict) and cache and not force: + return cache + + report = await run_embedding_runtime_self_check( + config=getattr(plugin_or_config, "config", plugin_or_config), + vector_store=getattr(plugin_or_config, "vector_store", None) + if not isinstance(plugin_or_config, dict) + else plugin_or_config.get("vector_store"), + embedding_manager=getattr(plugin_or_config, "embedding_manager", None) + if not isinstance(plugin_or_config, dict) + else plugin_or_config.get("embedding_manager"), + sample_text=sample_text, + ) + try: + setattr(plugin_or_config, "_runtime_self_check_report", report) + except Exception: + pass + return report diff --git a/plugins/A_memorix/core/utils/search_postprocess.py b/plugins/A_memorix/core/utils/search_postprocess.py new file mode 100644 index 00000000..52688e08 --- /dev/null +++ b/plugins/A_memorix/core/utils/search_postprocess.py @@ -0,0 +1,90 @@ +"""Post-processing helpers for unified search execution.""" + +from __future__ import annotations + +from typing import Any, List, Tuple + +from .path_fallback_service import find_paths_from_query, to_retrieval_results + + +def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]: + """Deduplicate results by hash/content while preserving at least one entry.""" + if not results: + return [], 0 + + unique_results: List[Any] = [] + seen_hashes = set() + seen_contents = set() + + for item in results: + content = str(getattr(item, "content", "") or "").strip() + if not content: + continue + + hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content)) + if hash_value in seen_hashes: + continue + + is_dup = False + for seen in seen_contents: + if content in seen or seen in content: + is_dup = True + break + if is_dup: + continue + + seen_hashes.add(hash_value) + seen_contents.add(content) + unique_results.append(item) + + if not unique_results: + unique_results.append(results[0]) + + removed_count = max(0, len(results) - len(unique_results)) + return unique_results, removed_count + + +def maybe_apply_smart_path_fallback( + *, + query: str, + results: List[Any], + graph_store: Any, + metadata_store: Any, + enabled: bool, + threshold: float, + max_depth: int = 3, + max_paths: int = 5, +) -> Tuple[List[Any], bool, int]: + """Append indirect relation paths when semantic results are weak.""" + if not enabled or not str(query or "").strip(): + return results, False, 0 + if graph_store is None or metadata_store is None: + return results, False, 0 + + max_score = 0.0 + if results: + try: + max_score = float(getattr(results[0], "score", 0.0) or 0.0) + except Exception: + max_score = 0.0 + + if max_score >= float(threshold): + return results, False, 0 + + paths = find_paths_from_query( + query=query, + graph_store=graph_store, + metadata_store=metadata_store, + max_depth=max_depth, + max_paths=max_paths, + ) + if not paths: + return results, False, 0 + + path_results = to_retrieval_results(paths) + if not path_results: + return results, False, 0 + + merged = list(path_results) + list(results) + return merged, True, len(path_results) + diff --git a/plugins/A_memorix/core/utils/time_parser.py b/plugins/A_memorix/core/utils/time_parser.py new file mode 100644 index 00000000..8e577974 --- /dev/null +++ b/plugins/A_memorix/core/utils/time_parser.py @@ -0,0 +1,170 @@ +""" +时间解析工具。 + +约束: +1. 查询参数(Action/Command/Tool)仅接受结构化绝对时间: + - YYYY/MM/DD + - YYYY/MM/DD HH:mm +2. 入库时允许更宽松格式(含时间戳、YYYY-MM-DD 等)。 +""" + +from __future__ import annotations + +import re +from datetime import datetime +from typing import Any, Dict, Optional, Tuple + + +_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$") +_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$") +_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$") + +_INGEST_FORMATS = [ + "%Y/%m/%d %H:%M:%S", + "%Y/%m/%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%d %H:%M", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M", + "%Y/%m/%d", + "%Y-%m-%d", +] + +_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"} + + +def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float: + """解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。""" + text = str(value).strip() + if not text: + raise ValueError("时间不能为空") + + if _QUERY_DATE_RE.fullmatch(text): + dt = datetime.strptime(text, "%Y/%m/%d") + if is_end: + dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) + return dt.timestamp() + + if _QUERY_MINUTE_RE.fullmatch(text): + dt = datetime.strptime(text, "%Y/%m/%d %H:%M") + return dt.timestamp() + + raise ValueError( + f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm" + ) + + +def parse_query_time_range( + time_from: Optional[str], + time_to: Optional[str], +) -> Tuple[Optional[float], Optional[float]]: + """解析查询窗口并验证区间。""" + ts_from = ( + parse_query_datetime_to_timestamp(time_from, is_end=False) + if time_from + else None + ) + ts_to = ( + parse_query_datetime_to_timestamp(time_to, is_end=True) + if time_to + else None + ) + + if ts_from is not None and ts_to is not None and ts_from > ts_to: + raise ValueError("time_from 不能晚于 time_to") + + return ts_from, ts_to + + +def parse_ingest_datetime_to_timestamp( + value: Any, + is_end: bool = False, +) -> Optional[float]: + """解析入库时间,允许 timestamp/常见字符串格式。""" + if value is None: + return None + + if isinstance(value, (int, float)): + return float(value) + + text = str(value).strip() + if not text: + return None + + if _NUMERIC_RE.fullmatch(text): + return float(text) + + for fmt in _INGEST_FORMATS: + try: + dt = datetime.strptime(text, fmt) + if fmt in _INGEST_DATE_FORMATS and is_end: + dt = dt.replace(hour=23, minute=59, second=0, microsecond=0) + return dt.timestamp() + except ValueError: + continue + + raise ValueError(f"无法解析时间: {text}") + + +def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """归一化 time_meta 到存储层字段。""" + if not time_meta: + return {} + + normalized: Dict[str, Any] = {} + + event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time")) + event_start = parse_ingest_datetime_to_timestamp( + time_meta.get("event_time_start"), + is_end=False, + ) + event_end = parse_ingest_datetime_to_timestamp( + time_meta.get("event_time_end"), + is_end=True, + ) + + time_range = time_meta.get("time_range") + if ( + isinstance(time_range, (list, tuple)) + and len(time_range) == 2 + ): + if event_start is None: + event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False) + if event_end is None: + event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True) + + if event_start is not None and event_end is not None and event_start > event_end: + raise ValueError("event_time_start 不能晚于 event_time_end") + + if event_time is not None: + normalized["event_time"] = event_time + if event_start is not None: + normalized["event_time_start"] = event_start + if event_end is not None: + normalized["event_time_end"] = event_end + + granularity = time_meta.get("time_granularity") + if granularity: + normalized["time_granularity"] = str(granularity) + else: + raw_time_values = [ + time_meta.get("event_time"), + time_meta.get("event_time_start"), + time_meta.get("event_time_end"), + ] + has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None) + normalized["time_granularity"] = "minute" if has_minute else "day" + + confidence = time_meta.get("time_confidence") + if confidence is not None: + normalized["time_confidence"] = float(confidence) + + return normalized + + +def format_timestamp(ts: Optional[float]) -> Optional[str]: + """将 timestamp 格式化为 YYYY/MM/DD HH:mm。""" + if ts is None: + return None + return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M") + diff --git a/plugins/A_memorix/plugin.py b/plugins/A_memorix/plugin.py new file mode 100644 index 00000000..56df45b9 --- /dev/null +++ b/plugins/A_memorix/plugin.py @@ -0,0 +1,207 @@ +"""A_Memorix SDK plugin entry.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, List, Optional + +from maibot_sdk import MaiBotPlugin, Tool +from maibot_sdk.types import ToolParameterInfo, ToolParamType + +from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel + + +def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo: + return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required) + + +class AMemorixPlugin(MaiBotPlugin): + def __init__(self) -> None: + super().__init__() + self._plugin_root = Path(__file__).resolve().parent + self._plugin_config: Dict[str, Any] = {} + self._kernel: Optional[SDKMemoryKernel] = None + + def set_plugin_config(self, config: Dict[str, Any]) -> None: + self._plugin_config = config or {} + if self._kernel is not None: + self._kernel.close() + self._kernel = None + + async def on_load(self): + await self._get_kernel() + + async def on_unload(self): + if self._kernel is not None: + self._kernel.close() + self._kernel = None + + async def _get_kernel(self) -> SDKMemoryKernel: + if self._kernel is None: + self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config) + await self._kernel.initialize() + return self._kernel + + @Tool( + "search_memory", + description="搜索长期记忆", + parameters=[ + _tool_param("query", ToolParamType.STRING, "查询文本", False), + _tool_param("limit", ToolParamType.INTEGER, "返回条数", False), + _tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("person_id", ToolParamType.STRING, "人物 ID", False), + _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), + _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + ], + ) + async def handle_search_memory( + self, + query: str = "", + limit: int = 5, + mode: str = "hybrid", + chat_id: str = "", + person_id: str = "", + time_start: float | None = None, + time_end: float | None = None, + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.search_memory( + KernelSearchRequest( + query=query, + limit=limit, + mode=mode, + chat_id=chat_id, + person_id=person_id, + time_start=time_start, + time_end=time_end, + ) + ) + + @Tool( + "ingest_summary", + description="写入聊天摘要到长期记忆", + parameters=[ + _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True), + _tool_param("text", ToolParamType.STRING, "摘要文本", True), + _tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False), + _tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False), + ], + ) + async def handle_ingest_summary( + self, + external_id: str, + chat_id: str, + text: str, + participants: Optional[List[str]] = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.ingest_summary( + external_id=external_id, + chat_id=chat_id, + text=text, + participants=participants, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=metadata, + ) + + @Tool( + "ingest_text", + description="写入普通长期记忆文本", + parameters=[ + _tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True), + _tool_param("source_type", ToolParamType.STRING, "来源类型", True), + _tool_param("text", ToolParamType.STRING, "原始文本", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False), + ], + ) + async def handle_ingest_text( + self, + external_id: str, + source_type: str, + text: str, + chat_id: str = "", + person_ids: Optional[List[str]] = None, + participants: Optional[List[str]] = None, + timestamp: float | None = None, + time_start: float | None = None, + time_end: float | None = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs, + ): + relations = kwargs.get("relations") + entities = kwargs.get("entities") + kernel = await self._get_kernel() + return await kernel.ingest_text( + external_id=external_id, + source_type=source_type, + text=text, + chat_id=chat_id, + person_ids=person_ids, + participants=participants, + timestamp=timestamp, + time_start=time_start, + time_end=time_end, + tags=tags, + metadata=metadata, + entities=entities, + relations=relations, + ) + + @Tool( + "get_person_profile", + description="获取人物画像", + parameters=[ + _tool_param("person_id", ToolParamType.STRING, "人物 ID", True), + _tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False), + _tool_param("limit", ToolParamType.INTEGER, "证据条数", False), + ], + ) + async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit) + + @Tool( + "maintain_memory", + description="维护长期记忆关系状态", + parameters=[ + _tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True), + _tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True), + _tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False), + ], + ) + async def handle_maintain_memory( + self, + action: str, + target: str, + hours: float | None = None, + reason: str = "", + **kwargs, + ): + _ = kwargs + kernel = await self._get_kernel() + return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason) + + @Tool("memory_stats", description="获取长期记忆统计", parameters=[]) + async def handle_memory_stats(self, **kwargs): + _ = kwargs + kernel = await self._get_kernel() + return kernel.memory_stats() + + +def create_plugin(): + return AMemorixPlugin() diff --git a/plugins/A_memorix/scripts/convert_lpmm.py b/plugins/A_memorix/scripts/convert_lpmm.py new file mode 100644 index 00000000..5ff284fb --- /dev/null +++ b/plugins/A_memorix/scripts/convert_lpmm.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python3 +""" +LPMM 到 A_memorix 存储转换器 + +功能: +1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet) +2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl) +3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore +4. 绕过 Embedding 生成以节省 Token +""" + +import sys +import os +import json +import argparse +import asyncio +import pickle +import logging +from pathlib import Path +from typing import Dict, Any, List, Tuple +import numpy as np +import tomlkit + +# 设置路径 +current_dir = Path(__file__).resolve().parent +plugin_root = current_dir.parent +project_root = plugin_root.parent.parent +sys.path.insert(0, str(project_root)) + +def _build_arg_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式") + parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)") + parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录") + parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)") + parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)") + parser.add_argument( + "--skip-relation-vector-rebuild", + action="store_true", + help="跳过按关系元数据重建关系向量(默认开启)", + ) + return parser + + +# --help/-h fast path: avoid heavy host/plugin bootstrap +if any(arg in {"-h", "--help"} for arg in sys.argv[1:]): + _build_arg_parser().print_help() + sys.exit(0) + +# 设置日志 +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logger = logging.getLogger("LPMM_Converter") + +try: + import networkx as nx + from scipy import sparse + import pyarrow.parquet as pq +except ImportError as e: + logger.error(f"缺少依赖: {e}") + logger.error("请安装: pip install pandas pyarrow networkx scipy") + sys.exit(1) + +try: + # 优先采取相对导入 (将插件根目录加入路径) + # 这样可以避免硬编码插件名称 (plugins.A_memorix) + if str(plugin_root) not in sys.path: + sys.path.insert(0, str(plugin_root)) + + from core.storage.vector_store import VectorStore + from core.storage.graph_store import GraphStore + from core.storage.metadata_store import MetadataStore + from core.storage import QuantizationType, SparseMatrixFormat + from core.embedding import create_embedding_api_adapter + from core.utils.relation_write_service import RelationWriteService + +except ImportError as e: + logger.error(f"无法导入 A_memorix 核心模块: {e}") + logger.error("请确保在正确的环境中运行,且已安装所有依赖。") + sys.exit(1) + + +class LPMMConverter: + def __init__( + self, + lpmm_data_dir: Path, + output_dir: Path, + dimension: int = 384, + batch_size: int = 1024, + rebuild_relation_vectors: bool = True, + ): + self.lpmm_dir = lpmm_data_dir + self.output_dir = output_dir + self.dimension = dimension + self.batch_size = max(1, int(batch_size)) + self.rebuild_relation_vectors = bool(rebuild_relation_vectors) + + self.vector_dir = output_dir / "vectors" + self.graph_dir = output_dir / "graph" + self.metadata_dir = output_dir / "metadata" + + self.vector_store = None + self.graph_store = None + self.metadata_store = None + self.embedding_manager = None + self.relation_write_service = None + # LPMM 原 ID -> A_memorix ID 映射(用于图重写) + self.id_mapping: Dict[str, str] = {} + + def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None: + """记录 ID 映射,兼容带/不带类型前缀两种格式。""" + if raw_id is None: + return + + raw = str(raw_id).strip() + if not raw: + return + + self.id_mapping[raw] = mapped_id + + prefix = f"{p_type}-" + if raw.startswith(prefix): + self.id_mapping[raw[len(prefix):]] = mapped_id + else: + self.id_mapping[prefix + raw] = mapped_id + + def _map_node_id(self, node: Any) -> str: + """将图节点 ID 映射到转换后的 A_memorix ID。""" + node_key = str(node) + return self.id_mapping.get(node_key, node_key) + + def initialize_stores(self): + """初始化空的 A_memorix 存储""" + logger.info(f"正在初始化存储于 {self.output_dir}...") + + # 初始化 VectorStore (A_memorix 默认使用 INT8 量化) + self.vector_store = VectorStore( + dimension=self.dimension, + quantization_type=QuantizationType.INT8, + data_dir=self.vector_dir + ) + self.vector_store.clear() # 清空旧数据 + + # 初始化 GraphStore (使用 CSR 格式) + self.graph_store = GraphStore( + matrix_format=SparseMatrixFormat.CSR, + data_dir=self.graph_dir + ) + self.graph_store.clear() + + # 初始化 MetadataStore + self.metadata_store = MetadataStore(data_dir=self.metadata_dir) + self.metadata_store.connect() + # 清空元数据表?理想情况下是的,但要小心。 + # 对于转换,我们假设是全新的开始或覆盖。 + # A_memorix 中的 MetadataStore 通常使用 SQLite。 + # 如果目录是新的,我们会依赖它创建新文件。 + if self.rebuild_relation_vectors: + self._init_relation_vector_service() + + def _load_plugin_config(self) -> Dict[str, Any]: + config_path = plugin_root / "config.toml" + if not config_path.exists(): + return {} + try: + with open(config_path, "r", encoding="utf-8") as f: + parsed = tomlkit.load(f) + return dict(parsed) if isinstance(parsed, dict) else {} + except Exception as e: + logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}") + return {} + + def _init_relation_vector_service(self) -> None: + if not self.rebuild_relation_vectors: + return + cfg = self._load_plugin_config() + emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {} + if not isinstance(emb_cfg, dict): + emb_cfg = {} + try: + self.embedding_manager = create_embedding_api_adapter( + batch_size=int(emb_cfg.get("batch_size", 32)), + max_concurrent=int(emb_cfg.get("max_concurrent", 5)), + default_dimension=int(emb_cfg.get("dimension", self.dimension)), + model_name=str(emb_cfg.get("model_name", "auto")), + retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {}, + ) + self.relation_write_service = RelationWriteService( + metadata_store=self.metadata_store, + graph_store=self.graph_store, + vector_store=self.vector_store, + embedding_manager=self.embedding_manager, + ) + except Exception as e: + self.embedding_manager = None + self.relation_write_service = None + logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}") + + async def _rebuild_relation_vectors(self) -> None: + if not self.rebuild_relation_vectors: + return + if self.relation_write_service is None: + logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。") + return + + rows = self.metadata_store.get_relations() + if not rows: + logger.info("未发现关系元数据,无需重建关系向量。") + return + + success = 0 + failed = 0 + skipped = 0 + for row in rows: + result = await self.relation_write_service.ensure_relation_vector( + hash_value=str(row["hash"]), + subject=str(row.get("subject", "")), + predicate=str(row.get("predicate", "")), + obj=str(row.get("object", "")), + ) + if result.vector_state == "ready": + if result.vector_written: + success += 1 + else: + skipped += 1 + else: + failed += 1 + + logger.info( + "关系向量重建完成: total=%s success=%s skipped=%s failed=%s", + len(rows), + success, + skipped, + failed, + ) + + @staticmethod + def _parse_relation_text(text: str) -> Tuple[str, str, str]: + raw = str(text or "").strip() + if not raw: + return "", "", "" + if "|" in raw: + parts = [p.strip() for p in raw.split("|") if p.strip()] + if len(parts) >= 3: + return parts[0], parts[1], parts[2] + if "->" in raw: + parts = [p.strip() for p in raw.split("->") if p.strip()] + if len(parts) >= 3: + return parts[0], parts[1], parts[2] + pieces = raw.split() + if len(pieces) >= 3: + return pieces[0], pieces[1], " ".join(pieces[2:]) + return "", "", "" + + def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int: + if not relation_path.exists(): + return 0 + + try: + parquet_file = pq.ParquetFile(relation_path) + except Exception as e: + logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}") + return 0 + + cols = set(parquet_file.schema_arrow.names) + has_triple_cols = {"subject", "predicate", "object"}.issubset(cols) + content_col = "str" if "str" in cols else ("content" if "content" in cols else "") + + imported_hashes = set() + imported = 0 + for record_batch in parquet_file.iter_batches(batch_size=self.batch_size): + df_batch = record_batch.to_pandas() + for _, row in df_batch.iterrows(): + subject = "" + predicate = "" + obj = "" + if has_triple_cols: + subject = str(row.get("subject", "") or "").strip() + predicate = str(row.get("predicate", "") or "").strip() + obj = str(row.get("object", "") or "").strip() + elif content_col: + subject, predicate, obj = self._parse_relation_text(row.get(content_col, "")) + + if not (subject and predicate and obj): + continue + + rel_hash = self.metadata_store.add_relation( + subject=subject, + predicate=predicate, + obj=obj, + source_paragraph=None, + ) + if rel_hash in imported_hashes: + continue + imported_hashes.add(rel_hash) + self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash]) + try: + self.metadata_store.set_relation_vector_state(rel_hash, "none") + except Exception: + pass + imported += 1 + + return imported + + def convert_vectors(self): + """将 Parquet 向量转换为 VectorStore""" + # LPMM 默认文件名 + parquet_files = { + "paragraph": self.lpmm_dir / "paragraph.parquet", + "entity": self.lpmm_dir / "entity.parquet", + "relation": self.lpmm_dir / "relation.parquet" + } + + total_vectors = 0 + + for p_type, p_path in parquet_files.items(): + # 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应, + # 直接导入会污染召回结果(命中后无法反查 relation 元数据)。 + if p_type == "relation": + relation_count = self._import_relation_metadata_from_parquet(p_path) + logger.warning( + "跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s", + relation_count, + ) + continue + + if not p_path.exists(): + logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。") + continue + + logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...") + try: + parquet_file = pq.ParquetFile(p_path) + total_rows = parquet_file.metadata.num_rows + if total_rows == 0: + logger.info(f"{p_path} 为空,跳过。") + continue + + # LPMM Schema: 'hash', 'embedding', 'str' + cols = parquet_file.schema_arrow.names + # 兼容性检查 + content_col = 'str' if 'str' in cols else 'content' + emb_col = 'embedding' + hash_col = 'hash' + + if content_col not in cols or emb_col not in cols: + logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}") + continue + + batch_columns = [content_col, emb_col] + if hash_col in cols: + batch_columns.append(hash_col) + + processed_rows = 0 + added_for_type = 0 + batch_idx = 0 + + for record_batch in parquet_file.iter_batches( + batch_size=self.batch_size, + columns=batch_columns, + ): + batch_idx += 1 + df_batch = record_batch.to_pandas() + + embeddings_list = [] + ids_list = [] + + # 同时处理元数据映射 + for _, row in df_batch.iterrows(): + processed_rows += 1 + content = row[content_col] + emb = row[emb_col] + + if content is None or (isinstance(content, float) and np.isnan(content)): + continue + content = str(content).strip() + if not content: + continue + + if emb is None or len(emb) == 0: + continue + + # 先写 MetadataStore,并使用其返回的真实 hash 作为向量 ID + # 保证检索返回 ID 可以直接反查元数据。 + store_id = None + if p_type == "paragraph": + store_id = self.metadata_store.add_paragraph( + content=content, + source="lpmm_import", + knowledge_type="factual", + ) + elif p_type == "entity": + store_id = self.metadata_store.add_entity(name=content) + else: + continue + + raw_hash = row[hash_col] if hash_col in df_batch.columns else None + if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)): + self._register_id_mapping(raw_hash, store_id, p_type) + + # 确保 embedding 是 numpy 数组 + emb_np = np.array(emb, dtype=np.float32) + if emb_np.shape[0] != self.dimension: + logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}") + continue + + embeddings_list.append(emb_np) + ids_list.append(store_id) + + if embeddings_list: + # 分批添加到向量存储 + vectors_np = np.stack(embeddings_list) + count = self.vector_store.add(vectors_np, ids_list) + added_for_type += count + total_vectors += count + + if batch_idx == 1 or batch_idx % 10 == 0: + logger.info( + f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}" + ) + + logger.info( + f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}" + ) + + except Exception as e: + logger.error(f"处理 {p_path} 时出错: {e}") + + # 提交向量存储 + self.vector_store.save() + logger.info(f"向量转换完成。总向量数: {total_vectors}") + + def convert_graph(self): + """将 LPMM 图转换为 GraphStore""" + # LPMM 默认文件名是 rag-graph.graphml + graph_files = [ + self.lpmm_dir / "rag-graph.graphml", + self.lpmm_dir / "graph.graphml", + self.lpmm_dir / "graph_structure.pkl" + ] + + nx_graph = None + + for g_path in graph_files: + if g_path.exists(): + logger.info(f"发现图文件: {g_path}") + try: + if g_path.suffix == ".graphml": + nx_graph = nx.read_graphml(g_path) + elif g_path.suffix == ".pkl": + with open(g_path, "rb") as f: + data = pickle.load(f) + # LPMM 可能会将图存储在包装类中 + if hasattr(data, "graph") and isinstance(data.graph, nx.Graph): + nx_graph = data.graph + elif isinstance(data, nx.Graph): + nx_graph = data + break + except Exception as e: + logger.error(f"加载 {g_path} 失败: {e}") + + if nx_graph is None: + logger.warning("未找到有效的图文件。跳过图转换。") + return + + logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。") + + # 1. 添加节点 + # LPMM 节点通常是哈希或带前缀的字符串。 + # 我们需要将它们映射到 A_memorix 格式。 + # 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。 + + nodes_to_add = [] + node_attrs = {} + + for node, attrs in nx_graph.nodes(data=True): + # 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..." + mapped_node = self._map_node_id(node) + nodes_to_add.append(mapped_node) + if attrs: + node_attrs[mapped_node] = attrs + + self.graph_store.add_nodes(nodes_to_add, node_attrs) + + # 2. 添加边 + edges_to_add = [] + weights = [] + + for u, v, data in nx_graph.edges(data=True): + weight = data.get("weight", 1.0) + edges_to_add.append((self._map_node_id(u), self._map_node_id(v))) + weights.append(float(weight)) + + # 如果可能,将关系同步到 MetadataStore + # 但图的边并不总是包含关系谓词 + # 如果 LPMM 边数据有 'predicate',我们可以添加到元数据 + # 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失 + + if edges_to_add: + self.graph_store.add_edges(edges_to_add, weights) + + self.graph_store.save() + logger.info("图转换完成。") + + def run(self): + self.initialize_stores() + self.convert_vectors() + self.convert_graph() + asyncio.run(self._rebuild_relation_vectors()) + self.vector_store.save() + self.graph_store.save() + self.metadata_store.close() + logger.info("所有转换成功完成。") + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + input_path = Path(args.input) + output_path = Path(args.output) + + if not input_path.exists(): + logger.error(f"输入目录不存在: {input_path}") + sys.exit(1) + + converter = LPMMConverter( + input_path, + output_path, + dimension=args.dim, + batch_size=args.batch_size, + rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild), + ) + converter.run() + +if __name__ == "__main__": + main() diff --git a/plugins/A_memorix/scripts/migrate_chat_history.py b/plugins/A_memorix/scripts/migrate_chat_history.py new file mode 100644 index 00000000..0fb0bfe1 --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_chat_history.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import sqlite3 +import sys +from datetime import datetime +from pathlib import Path +from typing import Any, Dict + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" + +if str(WORKSPACE_ROOT) not in sys.path: + sys.path.insert(0, str(WORKSPACE_ROOT)) +if str(MAIBOT_ROOT) not in sys.path: + sys.path.insert(0, str(MAIBOT_ROOT)) + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix") + parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") + parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") + parser.add_argument("--limit", type=int, default=0, help="限制迁移条数,0 表示全部") + parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") + return parser.parse_args() + + +def _to_timestamp(value: Any) -> float | None: + if value is None: + return None + if isinstance(value, (int, float)): + return float(value) + text = str(value).strip() + if not text: + return None + try: + return datetime.fromisoformat(text).timestamp() + except ValueError: + return None + + +async def _main() -> int: + args = _parse_args() + db_path = Path(args.db_path).resolve() + if not db_path.exists(): + print(f"数据库不存在: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + sql = """ + SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary + FROM chat_history + ORDER BY id ASC + """ + if int(args.limit or 0) > 0: + sql += " LIMIT ?" + rows = conn.execute(sql, (int(args.limit),)).fetchall() + else: + rows = conn.execute(sql).fetchall() + conn.close() + + print(f"chat_history 待处理: {len(rows)}") + if args.dry_run: + for row in rows[:5]: + print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}") + return 0 + + kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) + await kernel.initialize() + migrated = 0 + skipped = 0 + for row in rows: + participants = json.loads(row["participants"]) if row["participants"] else [] + keywords = json.loads(row["keywords"]) if row["keywords"] else [] + theme = str(row["theme"] or "").strip() + summary = str(row["summary"] or "").strip() + text = f"主题:{theme}\n概括:{summary}".strip() + result: Dict[str, Any] = await kernel.ingest_summary( + external_id=f"chat_history:{row['id']}", + chat_id=str(row["session_id"] or ""), + text=text, + participants=participants, + time_start=_to_timestamp(row["start_timestamp"]), + time_end=_to_timestamp(row["end_timestamp"]), + tags=keywords, + metadata={"theme": theme, "source_row_id": int(row["id"])}, + ) + if result.get("stored_ids"): + migrated += 1 + else: + skipped += 1 + + print(f"迁移完成: migrated={migrated} skipped={skipped}") + print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) + kernel.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(_main())) diff --git a/plugins/A_memorix/scripts/migrate_person_memory_points.py b/plugins/A_memorix/scripts/migrate_person_memory_points.py new file mode 100644 index 00000000..a03a8914 --- /dev/null +++ b/plugins/A_memorix/scripts/migrate_person_memory_points.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import asyncio +import json +import sqlite3 +import sys +from pathlib import Path +from typing import Any, Dict, List + + +CURRENT_DIR = Path(__file__).resolve().parent +PLUGIN_ROOT = CURRENT_DIR.parent +WORKSPACE_ROOT = PLUGIN_ROOT.parent +MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot" +DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db" + +if str(WORKSPACE_ROOT) not in sys.path: + sys.path.insert(0, str(WORKSPACE_ROOT)) +if str(MAIBOT_ROOT) not in sys.path: + sys.path.insert(0, str(MAIBOT_ROOT)) + +from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402 + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix") + parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径") + parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录") + parser.add_argument("--limit", type=int, default=0, help="限制迁移人数,0 表示全部") + parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入") + return parser.parse_args() + + +def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]: + try: + values = json.loads(raw_value) if raw_value else [] + except Exception: + values = [] + items: List[Dict[str, Any]] = [] + for index, item in enumerate(values): + text = str(item or "").strip() + if not text: + continue + parts = text.split(":") + if len(parts) >= 3: + category = parts[0].strip() + content = ":".join(parts[1:-1]).strip() + weight = parts[-1].strip() + else: + category = "其他" + content = text + weight = "1.0" + if content: + items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"}) + return items + + +async def _main() -> int: + args = _parse_args() + db_path = Path(args.db_path).resolve() + if not db_path.exists(): + print(f"数据库不存在: {db_path}") + return 1 + + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + sql = """ + SELECT person_id, person_name, user_nickname, memory_points + FROM person_info + WHERE memory_points IS NOT NULL AND memory_points != '' + ORDER BY id ASC + """ + if int(args.limit or 0) > 0: + sql += " LIMIT ?" + rows = conn.execute(sql, (int(args.limit),)).fetchall() + else: + rows = conn.execute(sql).fetchall() + conn.close() + + preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows) + print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}") + if args.dry_run: + for row in rows[:5]: + print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}") + return 0 + + kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}}) + await kernel.initialize() + migrated = 0 + skipped = 0 + for row in rows: + person_id = str(row["person_id"] or "").strip() + if not person_id: + continue + display_name = str(row["person_name"] or row["user_nickname"] or "").strip() + for item in _parse_memory_points(row["memory_points"]): + result: Dict[str, Any] = await kernel.ingest_text( + external_id=f"person_memory:{person_id}:{item['index']}", + source_type="person_fact", + text=f"[{item['category']}] {item['content']}", + person_ids=[person_id], + tags=[item["category"]], + entities=[person_id, display_name] if display_name else [person_id], + metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name}, + ) + if result.get("stored_ids"): + migrated += 1 + else: + skipped += 1 + + print(f"迁移完成: migrated={migrated} skipped={skipped}") + print(json.dumps(kernel.memory_stats(), ensure_ascii=False)) + kernel.close() + return 0 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(_main()))