feat:新增 A_Memorix 记忆插件

引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
DawnARC
2026-03-18 21:33:15 +08:00
parent a5a6d2cb26
commit 999e7246e2
48 changed files with 17070 additions and 0 deletions

View File

@@ -0,0 +1,84 @@
"""核心模块 - 存储、嵌入、检索引擎"""
# 存储模块(已实现)
from .storage import (
VectorStore,
GraphStore,
MetadataStore,
ImportStrategy,
KnowledgeType,
parse_import_strategy,
resolve_stored_knowledge_type,
detect_knowledge_type,
select_import_strategy,
should_extract_relations,
get_type_display_name,
)
# 嵌入模块(使用主程序 API
from .embedding import (
EmbeddingAPIAdapter,
create_embedding_api_adapter,
)
# 检索模块(已实现)
from .retrieval import (
DualPathRetriever,
RetrievalStrategy,
RetrievalResult,
DualPathRetrieverConfig,
TemporalQueryOptions,
FusionConfig,
GraphRelationRecallConfig,
RelationIntentConfig,
PersonalizedPageRank,
PageRankConfig,
create_ppr_from_graph,
DynamicThresholdFilter,
ThresholdMethod,
ThresholdConfig,
SparseBM25Index,
SparseBM25Config,
)
from .utils import (
RelationWriteService,
RelationWriteResult,
)
__all__ = [
# Storage
"VectorStore",
"GraphStore",
"MetadataStore",
"ImportStrategy",
"KnowledgeType",
"parse_import_strategy",
"resolve_stored_knowledge_type",
"detect_knowledge_type",
"select_import_strategy",
"should_extract_relations",
"get_type_display_name",
# Embedding
"EmbeddingAPIAdapter",
"create_embedding_api_adapter",
# Retrieval
"DualPathRetriever",
"RetrievalStrategy",
"RetrievalResult",
"DualPathRetrieverConfig",
"TemporalQueryOptions",
"FusionConfig",
"GraphRelationRecallConfig",
"RelationIntentConfig",
"PersonalizedPageRank",
"PageRankConfig",
"create_ppr_from_graph",
"DynamicThresholdFilter",
"ThresholdMethod",
"ThresholdConfig",
"SparseBM25Index",
"SparseBM25Config",
"RelationWriteService",
"RelationWriteResult",
]

View File

@@ -0,0 +1,18 @@
"""嵌入模块 - 向量生成与量化"""
# 新的 API 适配器(主程序嵌入 API
from .api_adapter import (
EmbeddingAPIAdapter,
create_embedding_api_adapter,
)
from ..utils.quantization import QuantizationType
__all__ = [
# 新的 API 适配器(推荐使用)
"EmbeddingAPIAdapter",
"create_embedding_api_adapter",
# 量化
"QuantizationType",
]

View File

@@ -0,0 +1,174 @@
"""
Hash-based embedding adapter used by the SDK runtime.
The plugin runtime cannot import MaiBot host embedding internals from ``src.chat``
or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in
Runner by generating deterministic dense vectors locally.
"""
from __future__ import annotations
import hashlib
import re
import time
from typing import List, Optional, Union
import numpy as np
from src.common.logger import get_logger
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}")
class EmbeddingAPIAdapter:
"""Deterministic local embedding adapter."""
def __init__(
self,
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
enable_cache: bool = False,
model_name: str = "hash-v1",
retry_config: Optional[dict] = None,
) -> None:
self.batch_size = max(1, int(batch_size))
self.max_concurrent = max(1, int(max_concurrent))
self.default_dimension = max(32, int(default_dimension))
self.enable_cache = bool(enable_cache)
self.model_name = str(model_name or "hash-v1")
self.retry_config = retry_config or {}
self._dimension: Optional[int] = None
self._dimension_detected = False
self._total_encoded = 0
self._total_errors = 0
self._total_time = 0.0
logger.info(
"EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s",
self.model_name,
self.batch_size,
self.default_dimension,
)
async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None:
return self._dimension
self._dimension = self.default_dimension
self._dimension_detected = True
return self._dimension
@staticmethod
def _tokenize(text: str) -> List[str]:
clean = str(text or "").strip().lower()
if not clean:
return []
return _TOKEN_PATTERN.findall(clean)
@staticmethod
def _feature_weight(token: str) -> float:
digest = hashlib.sha256(token.encode("utf-8")).digest()
return 1.0 + (digest[10] / 255.0) * 0.5
def _encode_single(self, text: str, dimension: int) -> np.ndarray:
vector = np.zeros(dimension, dtype=np.float32)
content = str(text or "").strip()
tokens = self._tokenize(content)
if not tokens and content:
tokens = [content.lower()]
if not tokens:
vector[0] = 1.0
return vector
for token in tokens:
digest = hashlib.sha256(token.encode("utf-8")).digest()
bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
sign = 1.0 if digest[8] % 2 == 0 else -1.0
vector[bucket] += sign * self._feature_weight(token)
second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension
if second_bucket != bucket:
vector[second_bucket] += (sign * 0.35)
norm = float(np.linalg.norm(vector))
if norm > 1e-8:
vector /= norm
else:
vector[0] = 1.0
return vector
async def encode(
self,
texts: Union[str, List[str]],
batch_size: Optional[int] = None,
show_progress: bool = False,
normalize: bool = True,
dimensions: Optional[int] = None,
) -> np.ndarray:
_ = batch_size
_ = show_progress
_ = normalize
started_at = time.time()
target_dimension = max(32, int(dimensions or await self._detect_dimension()))
if isinstance(texts, str):
single_input = True
normalized_texts = [texts]
else:
single_input = False
normalized_texts = list(texts or [])
if not normalized_texts:
empty = np.zeros((0, target_dimension), dtype=np.float32)
return empty[0] if single_input else empty
try:
matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts])
self._total_encoded += len(normalized_texts)
self._total_time += time.time() - started_at
except Exception:
self._total_errors += 1
raise
return matrix[0] if single_input else matrix
def get_statistics(self) -> dict:
avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0
return {
"model_name": self.model_name,
"dimension": self._dimension or self.default_dimension,
"total_encoded": self._total_encoded,
"total_errors": self._total_errors,
"total_time": self._total_time,
"avg_time_per_text": avg_time,
}
def __repr__(self) -> str:
return (
f"EmbeddingAPIAdapter(model_name={self.model_name}, "
f"dimension={self._dimension or self.default_dimension}, "
f"total_encoded={self._total_encoded})"
)
def create_embedding_api_adapter(
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
enable_cache: bool = False,
model_name: str = "hash-v1",
retry_config: Optional[dict] = None,
) -> EmbeddingAPIAdapter:
return EmbeddingAPIAdapter(
batch_size=batch_size,
max_concurrent=max_concurrent,
default_dimension=default_dimension,
enable_cache=enable_cache,
model_name=model_name,
retry_config=retry_config,
)

View File

@@ -0,0 +1,510 @@
"""
嵌入管理器
负责嵌入模型的加载、缓存和批量生成。
"""
import hashlib
import pickle
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional, Union, List, Dict, Any, Tuple
import numpy as np
try:
from sentence_transformers import SentenceTransformer
HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
HAS_SENTENCE_TRANSFORMERS = False
from src.common.logger import get_logger
from .presets import (
EmbeddingModelConfig,
get_custom_config,
validate_config_compatibility,
are_models_compatible,
)
from ..utils.quantization import QuantizationType
logger = get_logger("A_Memorix.EmbeddingManager")
class EmbeddingManager:
"""
嵌入管理器
功能:
- 模型加载与缓存
- 批量生成嵌入
- 多线程/多进程支持
- 模型一致性检查
- 智能分批
参数:
config: 模型配置
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
"""
def __init__(
self,
config: EmbeddingModelConfig,
cache_dir: Optional[Union[str, Path]] = None,
enable_cache: bool = True,
num_workers: int = 1,
):
"""
初始化嵌入管理器
Args:
config: 模型配置
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
"""
if not HAS_SENTENCE_TRANSFORMERS:
raise ImportError(
"sentence-transformers 未安装,请安装: "
"pip install sentence-transformers"
)
self.config = config
self.cache_dir = Path(cache_dir) if cache_dir else None
self.enable_cache = enable_cache
self.num_workers = max(1, num_workers)
# 模型实例
self._model: Optional[SentenceTransformer] = None
self._model_lock = threading.Lock()
# 缓存
self._embedding_cache: Dict[str, np.ndarray] = {}
self._cache_lock = threading.Lock()
# 统计
self._total_encoded = 0
self._cache_hits = 0
self._cache_misses = 0
logger.info(
f"EmbeddingManager 初始化: model={config.model_name}, "
f"dim={config.dimension}, workers={num_workers}"
)
def load_model(self) -> None:
"""加载模型(懒加载)"""
if self._model is not None:
return
with self._model_lock:
# 双重检查
if self._model is not None:
return
logger.info(f"正在加载模型: {self.config.model_name}")
try:
# 构建模型参数
model_kwargs = {}
if self.config.cache_dir:
model_kwargs["cache_folder"] = self.config.cache_dir
# 加载模型
self._model = SentenceTransformer(
self.config.model_path,
**model_kwargs,
)
logger.info(f"模型加载成功: {self.config.model_name}")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def encode(
self,
texts: Union[str, List[str]],
batch_size: Optional[int] = None,
show_progress: bool = False,
normalize: bool = True,
) -> np.ndarray:
"""
生成文本嵌入
Args:
texts: 文本或文本列表
batch_size: 批次大小(默认使用配置值)
show_progress: 是否显示进度条
normalize: 是否归一化
Returns:
嵌入向量 (N x D)
"""
# 确保模型已加载
self.load_model()
# 标准化输入
if isinstance(texts, str):
texts = [texts]
single_input = True
else:
single_input = False
if not texts:
return np.zeros((0, self.config.dimension), dtype=np.float32)
# 使用配置的批次大小
if batch_size is None:
batch_size = self.config.batch_size
# 生成嵌入
try:
embeddings = self._model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
normalize_embeddings=normalize and self.config.normalization,
convert_to_numpy=True,
)
# 确保是2D数组
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
self._total_encoded += len(texts)
# 如果是单个输入返回1D数组
if single_input:
return embeddings[0]
return embeddings
except Exception as e:
logger.error(f"生成嵌入失败: {e}")
raise
def encode_batch(
self,
texts: List[str],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
show_progress: bool = False,
) -> np.ndarray:
"""
批量生成嵌入(多线程优化)
Args:
texts: 文本列表
batch_size: 批次大小
num_workers: 工作线程数(默认使用初始化时的值)
show_progress: 是否显示进度条
Returns:
嵌入向量 (N x D)
"""
if not texts:
return np.zeros((0, self.config.dimension), dtype=np.float32)
# 单线程模式
num_workers = num_workers if num_workers is not None else self.num_workers
if num_workers == 1:
return self.encode(texts, batch_size=batch_size, show_progress=show_progress)
# 多线程模式
logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入")
# 分批
batch_size = batch_size or self.config.batch_size
batches = [
texts[i:i + batch_size]
for i in range(0, len(texts), batch_size)
]
# 多线程生成
all_embeddings = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# 提交任务
future_to_batch = {
executor.submit(
self.encode,
batch,
batch_size,
False, # 不显示进度条(多线程时会混乱)
): i
for i, batch in enumerate(batches)
}
# 收集结果
for future in as_completed(future_to_batch):
batch_idx = future_to_batch[future]
try:
embeddings = future.result()
all_embeddings.append((batch_idx, embeddings))
except Exception as e:
logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}")
raise
# 按顺序合并
all_embeddings.sort(key=lambda x: x[0])
final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0)
return final_embeddings
def encode_with_cache(
self,
texts: List[str],
batch_size: Optional[int] = None,
show_progress: bool = False,
) -> np.ndarray:
"""
生成嵌入(带缓存)
Args:
texts: 文本列表
batch_size: 批次大小
show_progress: 是否显示进度条
Returns:
嵌入向量 (N x D)
"""
if not self.enable_cache:
return self.encode(texts, batch_size, show_progress)
# 分离缓存命中和未命中的文本
cached_embeddings = []
uncached_texts = []
uncached_indices = []
for i, text in enumerate(texts):
cache_key = self._get_cache_key(text)
with self._cache_lock:
if cache_key in self._embedding_cache:
cached_embeddings.append((i, self._embedding_cache[cache_key]))
self._cache_hits += 1
else:
uncached_texts.append(text)
uncached_indices.append(i)
self._cache_misses += 1
# 生成未缓存的嵌入
if uncached_texts:
new_embeddings = self.encode(
uncached_texts,
batch_size,
show_progress,
)
# 更新缓存
with self._cache_lock:
for text, embedding in zip(uncached_texts, new_embeddings):
cache_key = self._get_cache_key(text)
self._embedding_cache[cache_key] = embedding.copy()
# 合并结果
for idx, embedding in zip(uncached_indices, new_embeddings):
cached_embeddings.append((idx, embedding))
# 按原始顺序排序
cached_embeddings.sort(key=lambda x: x[0])
final_embeddings = np.array([emb for _, emb in cached_embeddings])
return final_embeddings
def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
"""
保存缓存到磁盘
Args:
cache_path: 缓存文件路径默认使用cache_dir/embeddings_cache.pkl
"""
if cache_path is None:
if self.cache_dir is None:
raise ValueError("未指定缓存目录")
cache_path = self.cache_dir / "embeddings_cache.pkl"
cache_path = Path(cache_path)
cache_path.parent.mkdir(parents=True, exist_ok=True)
with self._cache_lock:
with open(cache_path, "wb") as f:
pickle.dump(self._embedding_cache, f)
logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)")
def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
"""
从磁盘加载缓存
Args:
cache_path: 缓存文件路径默认使用cache_dir/embeddings_cache.pkl
"""
if cache_path is None:
if self.cache_dir is None:
raise ValueError("未指定缓存目录")
cache_path = self.cache_dir / "embeddings_cache.pkl"
cache_path = Path(cache_path)
if not cache_path.exists():
logger.warning(f"缓存文件不存在: {cache_path}")
return
with self._cache_lock:
with open(cache_path, "rb") as f:
self._embedding_cache = pickle.load(f)
logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)")
def clear_cache(self) -> None:
"""清空缓存"""
with self._cache_lock:
count = len(self._embedding_cache)
self._embedding_cache.clear()
logger.info(f"已清空缓存: {count}")
def check_model_consistency(
self,
stored_embeddings: np.ndarray,
sample_texts: List[str] = None,
) -> Tuple[bool, str]:
"""
检查模型一致性
Args:
stored_embeddings: 存储的嵌入向量
sample_texts: 样本文本(用于重新生成对比)
Returns:
(是否一致, 详细信息)
"""
# 检查维度
if stored_embeddings.shape[1] != self.config.dimension:
return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}"
# 如果提供了样本文本,重新生成并比较
if sample_texts:
try:
new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个
# 计算相似度
similarities = np.dot(
stored_embeddings[:5],
new_embeddings.T,
).diagonal()
# 检查相似度
if np.mean(similarities) < 0.95:
return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}"
return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}"
except Exception as e:
return False, f"一致性检查失败: {e}"
return True, "维度匹配"
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
Returns:
模型信息字典
"""
return {
"model_name": self.config.model_name,
"dimension": self.config.dimension,
"max_seq_length": self.config.max_seq_length,
"batch_size": self.config.batch_size,
"normalization": self.config.normalization,
"pooling": self.config.pooling,
"model_loaded": self._model is not None,
"cache_enabled": self.enable_cache,
"cache_size": len(self._embedding_cache),
"total_encoded": self._total_encoded,
"cache_hits": self._cache_hits,
"cache_misses": self._cache_misses,
}
def get_embedding_dimension(self) -> int:
"""获取嵌入维度"""
return self.config.dimension
def _get_cache_key(self, text: str) -> str:
"""
生成缓存键
Args:
text: 文本内容
Returns:
缓存键SHA256哈希
"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()
@property
def is_model_loaded(self) -> bool:
"""模型是否已加载"""
return self._model is not None
@property
def cache_hit_rate(self) -> float:
"""缓存命中率"""
total = self._cache_hits + self._cache_misses
if total == 0:
return 0.0
return self._cache_hits / total
def __repr__(self) -> str:
return (
f"EmbeddingManager(model={self.config.model_name}, "
f"dim={self.config.dimension}, "
f"loaded={self.is_model_loaded}, "
f"cache={len(self._embedding_cache)})"
)
def create_embedding_manager_from_config(
model_name: str,
model_path: str,
dimension: int,
cache_dir: Optional[Union[str, Path]] = None,
enable_cache: bool = True,
num_workers: int = 1,
**config_kwargs,
) -> EmbeddingManager:
"""
从自定义配置创建嵌入管理器
Args:
model_name: 模型名称
model_path: HuggingFace模型路径
dimension: 输出维度
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
**config_kwargs: 其他配置参数
Returns:
嵌入管理器实例
"""
# 创建自定义配置
config = get_custom_config(
model_name=model_name,
model_path=model_path,
dimension=dimension,
cache_dir=cache_dir,
**config_kwargs,
)
# 创建管理器
return EmbeddingManager(
config=config,
cache_dir=cache_dir,
enable_cache=enable_cache,
num_workers=num_workers,
)

View File

@@ -0,0 +1,72 @@
"""
嵌入模型配置模块
"""
from dataclasses import dataclass
from typing import Optional, Dict, Any, Union
from pathlib import Path
@dataclass
class EmbeddingModelConfig:
"""
嵌入模型配置
属性:
model_name: 模型描述名称
model_path: 实际加载路径Local or HF
dimension: 嵌入向量维度
max_seq_length: 最大序列长度
batch_size: 编码批次大小
model_size_mb: 估计显存占用
description: 模型说明
normalization: 是否自动归一化
pooling: 池化策略 (mean, cls, max)
cache_dir: 模型缓存目录
"""
model_name: str
model_path: str
dimension: int
max_seq_length: int = 512
batch_size: int = 32
model_size_mb: int = 100
description: str = ""
normalization: bool = True
pooling: str = "mean"
cache_dir: Optional[Union[str, Path]] = None
def validate_config_compatibility(
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
) -> bool:
"""检查两个配置是否兼容(主要看维度)"""
return config1.dimension == config2.dimension
def are_models_compatible(
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
) -> bool:
"""检查模型是否完全相同(用于热切换判断)"""
return (
config1.model_path == config2.model_path
and config1.dimension == config2.dimension
and config1.pooling == config2.pooling
)
def get_custom_config(
model_name: str,
model_path: str,
dimension: int,
cache_dir: Optional[Union[str, Path]] = None,
**kwargs,
) -> EmbeddingModelConfig:
"""创建自定义模型配置"""
return EmbeddingModelConfig(
model_name=model_name,
model_path=model_path,
dimension=dimension,
cache_dir=cache_dir,
**kwargs,
)

View File

@@ -0,0 +1,54 @@
"""检索模块 - 双路检索与排序"""
from .dual_path import (
DualPathRetriever,
RetrievalStrategy,
RetrievalResult,
DualPathRetrieverConfig,
TemporalQueryOptions,
FusionConfig,
RelationIntentConfig,
)
from .pagerank import (
PersonalizedPageRank,
PageRankConfig,
create_ppr_from_graph,
)
from .threshold import (
DynamicThresholdFilter,
ThresholdMethod,
ThresholdConfig,
)
from .sparse_bm25 import (
SparseBM25Index,
SparseBM25Config,
)
from .graph_relation_recall import (
GraphRelationRecallConfig,
GraphRelationRecallService,
)
__all__ = [
# DualPathRetriever
"DualPathRetriever",
"RetrievalStrategy",
"RetrievalResult",
"DualPathRetrieverConfig",
"TemporalQueryOptions",
"FusionConfig",
"RelationIntentConfig",
# PersonalizedPageRank
"PersonalizedPageRank",
"PageRankConfig",
"create_ppr_from_graph",
# DynamicThresholdFilter
"DynamicThresholdFilter",
"ThresholdMethod",
"ThresholdConfig",
# Sparse BM25
"SparseBM25Index",
"SparseBM25Config",
# Graph relation recall
"GraphRelationRecallConfig",
"GraphRelationRecallService",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
"""Graph-assisted relation candidate recall for relation-oriented queries."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set
from src.common.logger import get_logger
logger = get_logger("A_Memorix.GraphRelationRecall")
@dataclass
class GraphRelationRecallConfig:
"""Configuration for controlled graph relation recall."""
enabled: bool = True
candidate_k: int = 24
max_hop: int = 1
allow_two_hop_pair: bool = True
max_paths: int = 4
def __post_init__(self) -> None:
self.enabled = bool(self.enabled)
self.candidate_k = max(1, int(self.candidate_k))
self.max_hop = max(1, int(self.max_hop))
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
self.max_paths = max(1, int(self.max_paths))
@dataclass
class GraphRelationCandidate:
"""A graph-derived relation candidate before retriever-side fusion."""
hash_value: str
subject: str
predicate: str
object: str
confidence: float
graph_seed_entities: List[str]
graph_hops: int
graph_candidate_type: str
supporting_paragraph_count: int
def to_payload(self) -> Dict[str, Any]:
content = f"{self.subject} {self.predicate} {self.object}"
return {
"hash": self.hash_value,
"content": content,
"subject": self.subject,
"predicate": self.predicate,
"object": self.object,
"confidence": self.confidence,
"graph_seed_entities": list(self.graph_seed_entities),
"graph_hops": int(self.graph_hops),
"graph_candidate_type": self.graph_candidate_type,
"supporting_paragraph_count": int(self.supporting_paragraph_count),
}
class GraphRelationRecallService:
"""Collect relation candidates from the entity graph in a controlled way."""
def __init__(
self,
*,
graph_store: Any,
metadata_store: Any,
config: Optional[GraphRelationRecallConfig] = None,
) -> None:
self.graph_store = graph_store
self.metadata_store = metadata_store
self.config = config or GraphRelationRecallConfig()
def recall(
self,
*,
seed_entities: Sequence[str],
) -> List[GraphRelationCandidate]:
if not self.config.enabled:
return []
if self.graph_store is None or self.metadata_store is None:
return []
seeds = self._normalize_seed_entities(seed_entities)
if not seeds:
return []
seen_hashes: Set[str] = set()
candidates: List[GraphRelationCandidate] = []
if len(seeds) >= 2:
self._collect_direct_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
if (
len(candidates) < 3
and self.config.allow_two_hop_pair
and len(candidates) < self.config.candidate_k
):
self._collect_two_hop_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
else:
self._collect_one_hop_seed_candidates(
seed=seeds[0],
seen_hashes=seen_hashes,
out=candidates,
)
return candidates[: self.config.candidate_k]
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
out: List[str] = []
seen = set()
for raw in list(seed_entities)[:2]:
resolved = None
try:
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
except Exception:
resolved = None
if not resolved:
continue
canon = str(resolved).strip().lower()
if not canon or canon in seen:
continue
seen.add(canon)
out.append(str(resolved))
return out
def _collect_direct_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="direct_pair",
graph_hops=1,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_two_hop_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
paths = self.graph_store.find_paths(
seed_a,
seed_b,
max_depth=2,
max_paths=self.config.max_paths,
)
except Exception as e:
logger.debug("graph two-hop recall skipped: %s", e)
return
for path_nodes in paths:
if len(out) >= self.config.candidate_k:
break
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
continue
if len(path_nodes) != 3:
continue
for idx in range(len(path_nodes) - 1):
if len(out) >= self.config.candidate_k:
break
u = str(path_nodes[idx])
v = str(path_nodes[idx + 1])
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="two_hop_pair",
graph_hops=2,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_one_hop_seed_candidates(
self,
*,
seed: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
relation_hashes = self.graph_store.get_incident_relation_hashes(
seed,
limit=self.config.candidate_k,
)
except Exception as e:
logger.debug("graph one-hop recall skipped: %s", e)
return
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="one_hop_seed",
graph_hops=min(1, self.config.max_hop),
graph_seed_entities=[seed],
)
def _append_relation_hashes(
self,
*,
relation_hashes: Sequence[str],
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> None:
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
if len(out) >= self.config.candidate_k:
break
if relation_hash in seen_hashes:
continue
candidate = self._build_candidate(
relation_hash=relation_hash,
candidate_type=candidate_type,
graph_hops=graph_hops,
graph_seed_entities=graph_seed_entities,
)
if candidate is None:
continue
seen_hashes.add(relation_hash)
out.append(candidate)
def _build_candidate(
self,
*,
relation_hash: str,
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> Optional[GraphRelationCandidate]:
relation = self.metadata_store.get_relation(relation_hash)
if relation is None:
return None
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
return GraphRelationCandidate(
hash_value=relation_hash,
subject=str(relation.get("subject", "")),
predicate=str(relation.get("predicate", "")),
object=str(relation.get("object", "")),
confidence=float(relation.get("confidence", 1.0) or 1.0),
graph_seed_entities=[str(x) for x in graph_seed_entities],
graph_hops=int(graph_hops),
graph_candidate_type=str(candidate_type),
supporting_paragraph_count=len(supporting_paragraphs),
)

View File

@@ -0,0 +1,482 @@
"""
Personalized PageRank实现
提供个性化的图节点排序功能。
"""
from typing import Dict, List, Optional, Tuple, Union, Any
from dataclasses import dataclass
import numpy as np
from src.common.logger import get_logger
from ..storage import GraphStore
from ..utils.matcher import AhoCorasick
logger = get_logger("A_Memorix.PersonalizedPageRank")
@dataclass
class PageRankConfig:
"""
PageRank配置
属性:
alpha: 阻尼系数0-1之间
max_iter: 最大迭代次数
tol: 收敛阈值
normalize: 是否归一化结果
min_iterations: 最小迭代次数
"""
alpha: float = 0.85
max_iter: int = 100
tol: float = 1e-6
normalize: bool = True
min_iterations: int = 20
def __post_init__(self):
"""验证配置"""
if not 0 <= self.alpha < 1:
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
if self.max_iter <= 0:
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
if self.tol <= 0:
raise ValueError(f"tol必须大于0: {self.tol}")
if self.min_iterations < 0:
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
if self.min_iterations >= self.max_iter:
raise ValueError(f"min_iterations必须小于max_iter")
class PersonalizedPageRank:
"""
Personalized PageRank计算器
功能:
- 个性化向量支持
- 快速收敛检测
- 结果归一化
- 批量计算
- 统计信息
参数:
graph_store: 图存储
config: PageRank配置
"""
def __init__(
self,
graph_store: GraphStore,
config: Optional[PageRankConfig] = None,
):
"""
初始化PPR计算器
Args:
graph_store: 图存储
config: PageRank配置
"""
self.graph_store = graph_store
self.config = config or PageRankConfig()
# 统计信息
self._total_computations = 0
self._total_iterations = 0
self._convergence_history: List[int] = []
logger.info(
f"PersonalizedPageRank 初始化: "
f"alpha={self.config.alpha}, "
f"max_iter={self.config.max_iter}"
)
# 缓存 Aho-Corasick 匹配器
self._ac_matcher: Optional[AhoCorasick] = None
self._ac_nodes_count = 0
def compute(
self,
personalization: Optional[Dict[str, float]] = None,
alpha: Optional[float] = None,
max_iter: Optional[int] = None,
normalize: Optional[bool] = None,
) -> Dict[str, float]:
"""
计算Personalized PageRank
Args:
personalization: 个性化向量 {节点名: 权重}
alpha: 阻尼系数(覆盖配置值)
max_iter: 最大迭代次数(覆盖配置值)
normalize: 是否归一化(覆盖配置值)
Returns:
节点PageRank值字典 {节点名: 分数}
"""
# 使用覆盖值或配置值
alpha = alpha if alpha is not None else self.config.alpha
max_iter = max_iter if max_iter is not None else self.config.max_iter
normalize = normalize if normalize is not None else self.config.normalize
# 调用GraphStore的compute_pagerank
scores = self.graph_store.compute_pagerank(
personalization=personalization,
alpha=alpha,
max_iter=max_iter,
tol=self.config.tol,
)
# 归一化(如果需要)
if normalize and scores:
total = sum(scores.values())
if total > 0:
scores = {node: score / total for node, score in scores.items()}
# 更新统计
self._total_computations += 1
logger.debug(
f"PPR计算完成: {len(scores)} 个节点, "
f"personalization_nodes={len(personalization) if personalization else 0}"
)
return scores
def compute_batch(
self,
personalization_list: List[Dict[str, float]],
normalize: bool = True,
) -> List[Dict[str, float]]:
"""
批量计算PPR
Args:
personalization_list: 个性化向量列表
normalize: 是否归一化
Returns:
PageRank值字典列表
"""
results = []
for i, personalization in enumerate(personalization_list):
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
scores = self.compute(personalization=personalization, normalize=normalize)
results.append(scores)
return results
def compute_for_entities(
self,
entities: List[str],
weights: Optional[List[float]] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为实体列表计算PPR
Args:
entities: 实体列表
weights: 权重列表(默认均匀权重)
normalize: 是否归一化
Returns:
PageRank值字典
"""
if not entities:
logger.warning("实体列表为空返回均匀PPR")
return self.compute(personalization=None, normalize=normalize)
# 构建个性化向量
if weights is None:
weights = [1.0] * len(entities)
if len(weights) != len(entities):
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
personalization = {entity: weight for entity, weight in zip(entities, weights)}
return self.compute(personalization=personalization, normalize=normalize)
def compute_for_query(
self,
query: str,
entity_extractor: Optional[callable] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为查询计算PPR
Args:
query: 查询文本
entity_extractor: 实体提取函数(可选)
normalize: 是否归一化
Returns:
PageRank值字典
"""
# 提取实体
if entity_extractor is not None:
entities = entity_extractor(query)
else:
# 简单实现:基于图中的节点匹配
entities = self._extract_entities_from_query(query)
if not entities:
logger.debug(f"未从查询中提取到实体: '{query}'")
return self.compute(personalization=None, normalize=normalize)
# 计算PPR
return self.compute_for_entities(entities, normalize=normalize)
def rank_nodes(
self,
scores: Dict[str, float],
top_k: Optional[int] = None,
min_score: float = 0.0,
) -> List[Tuple[str, float]]:
"""
对节点排序
Args:
scores: PageRank分数字典
top_k: 返回前k个节点None表示全部
min_score: 最小分数阈值
Returns:
排序后的节点列表 [(节点名, 分数), ...]
"""
# 过滤低分节点
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
# 按分数降序排序
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
# 返回top_k
if top_k is not None:
sorted_nodes = sorted_nodes[:top_k]
return sorted_nodes
def get_personalization_vector(
self,
nodes: List[str],
method: str = "uniform",
) -> Dict[str, float]:
"""
生成个性化向量
Args:
nodes: 节点列表
method: 生成方法
- "uniform": 均匀权重
- "degree": 按度数加权
- "inverse_degree": 按度数反比加权
Returns:
个性化向量 {节点名: 权重}
"""
if not nodes:
return {}
if method == "uniform":
# 均匀权重
weight = 1.0 / len(nodes)
return {node: weight for node in nodes}
elif method == "degree":
# 按度数加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
total_degree = sum(node_degrees.values())
if total_degree > 0:
return {node: degree / total_degree for node, degree in node_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
elif method == "inverse_degree":
# 按度数反比加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
# 反度数
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
total_inv = sum(inv_degrees.values())
if total_inv > 0:
return {node: inv / total_inv for node, inv in inv_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
else:
raise ValueError(f"不支持的个性化向量生成方法: {method}")
def compare_scores(
self,
scores1: Dict[str, float],
scores2: Dict[str, float],
) -> Dict[str, Dict[str, float]]:
"""
比较两组PPR分数
Args:
scores1: 第一组分数
scores2: 第二组分数
Returns:
比较结果 {
"common_nodes": {节点: (score1, score2)},
"only_in_1": {节点: score1},
"only_in_2": {节点: score2},
}
"""
common_nodes = {}
only_in_1 = {}
only_in_2 = {}
all_nodes = set(scores1.keys()) | set(scores2.keys())
for node in all_nodes:
if node in scores1 and node in scores2:
common_nodes[node] = (scores1[node], scores2[node])
elif node in scores1:
only_in_1[node] = scores1[node]
else:
only_in_2[node] = scores2[node]
return {
"common_nodes": common_nodes,
"only_in_1": only_in_1,
"only_in_2": only_in_2,
}
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
avg_iterations = (
self._total_iterations / self._total_computations
if self._total_computations > 0
else 0
)
return {
"config": {
"alpha": self.config.alpha,
"max_iter": self.config.max_iter,
"tol": self.config.tol,
"normalize": self.config.normalize,
"min_iterations": self.config.min_iterations,
},
"statistics": {
"total_computations": self._total_computations,
"total_iterations": self._total_iterations,
"avg_iterations": avg_iterations,
"convergence_history": self._convergence_history.copy(),
},
"graph": {
"num_nodes": self.graph_store.num_nodes,
"num_edges": self.graph_store.num_edges,
},
}
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_computations = 0
self._total_iterations = 0
self._convergence_history.clear()
logger.info("统计信息已重置")
def _extract_entities_from_query(self, query: str) -> List[str]:
"""
从查询中提取实体(简化实现)
Args:
query: 查询文本
Returns:
实体列表
"""
# 获取所有节点
all_nodes = self.graph_store.get_nodes()
if not all_nodes:
return []
# 检查是否需要更新 Aho-Corasick 匹配器
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
self._ac_matcher = AhoCorasick()
for node in all_nodes:
# 统一转为小写进行不区分大小写匹配
self._ac_matcher.add_pattern(node.lower())
self._ac_matcher.build()
self._ac_nodes_count = len(all_nodes)
# 执行匹配
query_lower = query.lower()
stats = self._ac_matcher.find_all(query_lower)
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
# 为了简单AC 中 add_pattern 存的是小写
# 我们需要一个映射:小写 -> 原始
node_map = {node.lower(): node for node in all_nodes}
entities = [node_map[low_name] for low_name in stats.keys()]
return entities
@property
def num_computations(self) -> int:
"""计算次数"""
return self._total_computations
@property
def avg_iterations(self) -> float:
"""平均迭代次数"""
if self._total_computations == 0:
return 0.0
return self._total_iterations / self._total_computations
def __repr__(self) -> str:
return (
f"PersonalizedPageRank("
f"alpha={self.config.alpha}, "
f"computations={self._total_computations})"
)
def create_ppr_from_graph(
graph_store: GraphStore,
alpha: float = 0.85,
max_iter: int = 100,
) -> PersonalizedPageRank:
"""
从图存储创建PPR计算器
Args:
graph_store: 图存储
alpha: 阻尼系数
max_iter: 最大迭代次数
Returns:
PPR计算器实例
"""
config = PageRankConfig(
alpha=alpha,
max_iter=max_iter,
)
return PersonalizedPageRank(
graph_store=graph_store,
config=config,
)

View File

@@ -0,0 +1,402 @@
"""
稀疏检索组件FTS5 + BM25
支持:
- 懒加载索引连接
- jieba / char n-gram 分词
- 可卸载并收缩 SQLite 内存缓存
"""
from __future__ import annotations
import re
import sqlite3
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..storage import MetadataStore
logger = get_logger("A_Memorix.SparseBM25")
try:
import jieba # type: ignore
HAS_JIEBA = True
except Exception:
HAS_JIEBA = False
jieba = None
@dataclass
class SparseBM25Config:
"""BM25 稀疏检索配置。"""
enabled: bool = True
backend: str = "fts5"
lazy_load: bool = True
mode: str = "auto" # auto | fallback_only | hybrid
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
jieba_user_dict: str = ""
char_ngram_n: int = 2
candidate_k: int = 80
max_doc_len: int = 2000
enable_ngram_fallback_index: bool = True
enable_like_fallback: bool = False
enable_relation_sparse_fallback: bool = True
relation_candidate_k: int = 60
relation_max_doc_len: int = 512
unload_on_disable: bool = True
shrink_memory_on_unload: bool = True
def __post_init__(self) -> None:
self.backend = str(self.backend or "fts5").strip().lower()
self.mode = str(self.mode or "auto").strip().lower()
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
self.char_ngram_n = max(1, int(self.char_ngram_n))
self.candidate_k = max(1, int(self.candidate_k))
self.max_doc_len = max(0, int(self.max_doc_len))
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
if self.backend != "fts5":
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
if self.mode not in {"auto", "fallback_only", "hybrid"}:
raise ValueError(f"sparse.mode 非法: {self.mode}")
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
class SparseBM25Index:
"""
基于 SQLite FTS5 的 BM25 检索适配层。
"""
def __init__(
self,
metadata_store: MetadataStore,
config: Optional[SparseBM25Config] = None,
):
self.metadata_store = metadata_store
self.config = config or SparseBM25Config()
self._conn: Optional[sqlite3.Connection] = None
self._loaded: bool = False
self._jieba_dict_loaded: bool = False
@property
def loaded(self) -> bool:
return self._loaded and self._conn is not None
def ensure_loaded(self) -> bool:
"""按需加载 FTS 连接与索引。"""
if not self.config.enabled:
return False
if self.loaded:
return True
db_path = self.metadata_store.get_db_path()
conn = sqlite3.connect(
str(db_path),
check_same_thread=False,
timeout=30.0,
)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA temp_store=MEMORY")
if not self.metadata_store.ensure_fts_schema(conn=conn):
conn.close()
return False
self.metadata_store.ensure_fts_backfilled(conn=conn)
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
if self.config.enable_relation_sparse_fallback:
self.metadata_store.ensure_relations_fts_schema(conn=conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
if self.config.enable_ngram_fallback_index:
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=conn,
)
self._conn = conn
self._loaded = True
self._prepare_tokenizer()
logger.info(
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
self.config.tokenizer_mode,
self.config.mode,
)
return True
def _prepare_tokenizer(self) -> None:
if self._jieba_dict_loaded:
return
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
return
if not HAS_JIEBA:
logger.warning("jieba 不可用tokenizer 将退化为 char n-gram")
return
user_dict = str(self.config.jieba_user_dict or "").strip()
if user_dict:
try:
jieba.load_userdict(user_dict) # type: ignore[union-attr]
logger.info("已加载 jieba 用户词典: %s", user_dict)
except Exception as e:
logger.warning("加载 jieba 用户词典失败: %s", e)
self._jieba_dict_loaded = True
def _tokenize_jieba(self, text: str) -> List[str]:
if not HAS_JIEBA:
return []
try:
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
return [t.strip().lower() for t in tokens if t and t.strip()]
except Exception:
return []
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
compact = re.sub(r"\s+", "", text.lower())
if not compact:
return []
if len(compact) < n:
return [compact]
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
def _tokenize(self, text: str) -> List[str]:
text = str(text or "").strip()
if not text:
return []
mode = self.config.tokenizer_mode
if mode == "jieba":
tokens = self._tokenize_jieba(text)
if tokens:
return list(dict.fromkeys(tokens))
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
if mode == "mixed":
toks = self._tokenize_jieba(text)
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
return list(dict.fromkeys([t for t in toks if t]))
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
def _build_match_query(self, tokens: List[str]) -> str:
safe_tokens: List[str] = []
for token in tokens:
t = token.replace('"', '""').strip()
if not t:
continue
safe_tokens.append(f'"{t}"')
if not safe_tokens:
return ""
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
return " OR ".join(safe_tokens[:64])
def _fallback_substring_search(
self,
tokens: List[str],
limit: int,
) -> List[Dict[str, Any]]:
"""
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
说明:
- FTS 索引当前采用 unicode61 tokenizer。
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
"""
if not tokens:
return []
# 去重并裁剪 token 数量,避免生成超长 SQL。
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
uniq_tokens = uniq_tokens[:32]
if not uniq_tokens:
return []
if self.config.enable_ngram_fallback_index:
try:
# 允许运行时切换开关后按需补齐 schema/回填。
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=self._conn,
)
rows = self.metadata_store.ngram_search_paragraphs(
tokens=uniq_tokens,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if rows:
return rows
except Exception as e:
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
if not self.config.enable_like_fallback:
return []
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
scan_limit = max(int(limit) * 8, 200)
params.append(scan_limit)
sql = f"""
SELECT p.hash, p.content
FROM paragraphs p
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
AND ({conditions})
LIMIT ?
"""
rows = self.metadata_store.query(sql, tuple(params))
if not rows:
return []
scored: List[Dict[str, Any]] = []
token_count = max(1, len(uniq_tokens))
for row in rows:
content = str(row.get("content") or "")
content_low = content.lower()
matched = [tok for tok in uniq_tokens if tok in content_low]
if not matched:
continue
coverage = len(matched) / token_count
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
# 兜底路径使用相对分,保持与上层接口兼容。
fallback_score = coverage * 0.8 + length_bonus * 0.2
scored.append(
{
"hash": row["hash"],
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
"bm25_score": -float(fallback_score),
"fallback_score": float(fallback_score),
}
)
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
return scored[:limit]
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行 BM25 检索。"""
if not self.config.enabled:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
limit = max(1, int(k))
rows = self.metadata_store.fts_search_bm25(
match_query=match_query,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if not rows:
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
results: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
results.append(
{
"hash": row["hash"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
}
)
return results
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行关系稀疏检索FTS5 + BM25"""
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
rows = self.metadata_store.fts_search_relations_bm25(
match_query=match_query,
limit=max(1, int(k)),
max_doc_len=self.config.relation_max_doc_len,
conn=self._conn,
)
out: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
out.append(
{
"hash": row["hash"],
"subject": row["subject"],
"predicate": row["predicate"],
"object": row["object"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score,
}
)
return out
def upsert_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
def delete_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
def unload(self) -> None:
"""卸载 BM25 连接并尽量释放内存。"""
if self._conn is not None:
try:
if self.config.shrink_memory_on_unload:
self.metadata_store.shrink_memory(conn=self._conn)
except Exception:
pass
try:
self._conn.close()
except Exception:
pass
self._conn = None
self._loaded = False
logger.info("SparseBM25Index unloaded")
def stats(self) -> Dict[str, Any]:
doc_count = 0
if self.loaded:
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
return {
"enabled": self.config.enabled,
"backend": self.config.backend,
"mode": self.config.mode,
"tokenizer_mode": self.config.tokenizer_mode,
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
"enable_like_fallback": self.config.enable_like_fallback,
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
"loaded": self.loaded,
"has_jieba": HAS_JIEBA,
"doc_count": doc_count,
}

View File

@@ -0,0 +1,450 @@
"""
动态阈值过滤器
根据检索结果的分布特征自适应调整过滤阈值。
"""
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum
from src.common.logger import get_logger
from .dual_path import RetrievalResult
logger = get_logger("A_Memorix.DynamicThresholdFilter")
class ThresholdMethod(Enum):
"""阈值计算方法"""
PERCENTILE = "percentile" # 百分位数
STD_DEV = "std_dev" # 标准差
GAP_DETECTION = "gap_detection" # 跳变检测
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
@dataclass
class ThresholdConfig:
"""
阈值配置
属性:
method: 阈值计算方法
min_threshold: 最小阈值(绝对值)
max_threshold: 最大阈值(绝对值)
percentile: 百分位数用于percentile方法
std_multiplier: 标准差倍数用于std_dev方法
min_results: 最少保留结果数
enable_auto_adjust: 是否自动调整参数
"""
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
min_threshold: float = 0.3
max_threshold: float = 0.95
percentile: float = 75.0 # 百分位数
std_multiplier: float = 1.5 # 标准差倍数
min_results: int = 3 # 最少保留结果数
enable_auto_adjust: bool = True
def __post_init__(self):
"""验证配置"""
if not 0 <= self.min_threshold <= 1:
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
if not 0 <= self.max_threshold <= 1:
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
if self.min_threshold >= self.max_threshold:
raise ValueError(f"min_threshold必须小于max_threshold")
if not 0 <= self.percentile <= 100:
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
if self.std_multiplier <= 0:
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
if self.min_results < 0:
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
class DynamicThresholdFilter:
"""
动态阈值过滤器
功能:
- 基于结果分布自适应计算阈值
- 多种阈值计算方法
- 自动参数调整
- 统计信息收集
参数:
config: 阈值配置
"""
def __init__(
self,
config: Optional[ThresholdConfig] = None,
):
"""
初始化动态阈值过滤器
Args:
config: 阈值配置
"""
self.config = config or ThresholdConfig()
# 统计信息
self._total_filtered = 0
self._total_processed = 0
self._threshold_history: List[float] = []
logger.info(
f"DynamicThresholdFilter 初始化: "
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}"
)
def filter(
self,
results: List[RetrievalResult],
return_threshold: bool = False,
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
"""
过滤检索结果
Args:
results: 检索结果列表
return_threshold: 是否返回使用的阈值
Returns:
过滤后的结果列表,或 (结果列表, 阈值) 元组
"""
if not results:
logger.debug("结果列表为空,无需过滤")
return ([], 0.0) if return_threshold else []
self._total_processed += len(results)
# 提取分数
scores = np.array([r.score for r in results])
# 计算阈值
threshold = self._compute_threshold(scores, results)
# 记录阈值
self._threshold_history.append(threshold)
# 应用阈值过滤
filtered_results = [
r for r in results
if r.score >= threshold
]
# 确保至少保留min_results个结果
if len(filtered_results) < self.config.min_results:
# 按分数排序取前min_results个
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
filtered_results = sorted_results[:self.config.min_results]
threshold = filtered_results[-1].score if filtered_results else 0.0
self._total_filtered += len(results) - len(filtered_results)
logger.info(
f"过滤完成: {len(results)} -> {len(filtered_results)} "
f"(threshold={threshold:.3f})"
)
if return_threshold:
return filtered_results, threshold
return filtered_results
def _compute_threshold(
self,
scores: np.ndarray,
results: List[RetrievalResult],
) -> float:
"""
计算阈值
Args:
scores: 分数数组
results: 检索结果列表
Returns:
阈值
"""
if self.config.method == ThresholdMethod.PERCENTILE:
threshold = self._percentile_threshold(scores)
elif self.config.method == ThresholdMethod.STD_DEV:
threshold = self._std_dev_threshold(scores)
elif self.config.method == ThresholdMethod.GAP_DETECTION:
threshold = self._gap_detection_threshold(scores)
else: # ADAPTIVE
# 自适应方法:综合多种方法
thresholds = [
self._percentile_threshold(scores),
self._std_dev_threshold(scores),
self._gap_detection_threshold(scores),
]
# 使用中位数作为最终阈值
threshold = float(np.median(thresholds))
# 限制在[min_threshold, max_threshold]范围内
threshold = np.clip(
threshold,
self.config.min_threshold,
self.config.max_threshold,
)
# 自动调整
if self.config.enable_auto_adjust:
threshold = self._auto_adjust_threshold(threshold, scores)
return float(threshold)
def _percentile_threshold(self, scores: np.ndarray) -> float:
"""
基于百分位数计算阈值
Args:
scores: 分数数组
Returns:
阈值
"""
percentile = self.config.percentile
threshold = float(np.percentile(scores, percentile))
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
return threshold
def _std_dev_threshold(self, scores: np.ndarray) -> float:
"""
基于标准差计算阈值
threshold = mean - std_multiplier * std
Args:
scores: 分数数组
Returns:
阈值
"""
mean = float(np.mean(scores))
std = float(np.std(scores))
multiplier = self.config.std_multiplier
threshold = mean - multiplier * std
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
return threshold
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
"""
基于跳变检测计算阈值
找到分数分布中最大的"跳变"位置,以此为阈值
Args:
scores: 分数数组(降序排列)
Returns:
阈值
"""
# 降序排列
sorted_scores = np.sort(scores)[::-1]
if len(sorted_scores) < 2:
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
# 计算相邻分数的差值
gaps = np.diff(sorted_scores)
# 找到最大的跳变位置
max_gap_idx = int(np.argmax(gaps))
# 阈值为跳变后的分数
threshold = float(sorted_scores[max_gap_idx + 1])
logger.debug(
f"跳变检测阈值: {threshold:.3f} "
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
)
return threshold
def _auto_adjust_threshold(
self,
threshold: float,
scores: np.ndarray,
) -> float:
"""
自动调整阈值
基于历史阈值和当前分数分布调整
Args:
threshold: 当前阈值
scores: 分数数组
Returns:
调整后的阈值
"""
if not self._threshold_history:
return threshold
# 计算历史阈值的移动平均
recent_thresholds = self._threshold_history[-10:] # 最近10次
avg_threshold = float(np.mean(recent_thresholds))
# 当前阈值与历史平均的差异
diff = threshold - avg_threshold
# 如果差异过大(>0.2),向历史平均靠拢
if abs(diff) > 0.2:
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
logger.debug(
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
f"(历史平均={avg_threshold:.3f})"
)
return adjusted_threshold
return threshold
def filter_by_confidence(
self,
results: List[RetrievalResult],
min_confidence: float = 0.5,
) -> List[RetrievalResult]:
"""
基于置信度过滤结果
Args:
results: 检索结果列表
min_confidence: 最小置信度
Returns:
过滤后的结果列表
"""
filtered = []
for result in results:
# 对于关系结果使用confidence字段
if result.result_type == "relation":
confidence = result.metadata.get("confidence", 1.0)
if confidence >= min_confidence:
filtered.append(result)
else:
# 对于段落结果,直接使用分数
if result.score >= min_confidence:
filtered.append(result)
logger.info(
f"置信度过滤: {len(results)} -> {len(filtered)} "
f"(min_confidence={min_confidence})"
)
return filtered
def filter_by_diversity(
self,
results: List[RetrievalResult],
similarity_threshold: float = 0.9,
top_k: int = 10,
) -> List[RetrievalResult]:
"""
基于多样性过滤结果(去除重复)
Args:
results: 检索结果列表
similarity_threshold: 相似度阈值(高于此值视为重复)
top_k: 最多保留结果数
Returns:
过滤后的结果列表
"""
if not results:
return []
# 按分数排序
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
# 贪心选择:选择与已选结果相似度低的结果
selected = []
selected_hashes = []
for result in sorted_results:
if len(selected) >= top_k:
break
# 检查与已选结果的相似度
is_duplicate = False
for selected_hash in selected_hashes:
# 简单判断基于hash的前缀
if result.hash_value[:8] == selected_hash[:8]:
is_duplicate = True
break
if not is_duplicate:
selected.append(result)
selected_hashes.append(result.hash_value)
logger.info(
f"多样性过滤: {len(results)} -> {len(selected)} "
f"(similarity_threshold={similarity_threshold})"
)
return selected
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
filter_rate = (
self._total_filtered / self._total_processed
if self._total_processed > 0
else 0.0
)
stats = {
"config": {
"method": self.config.method.value,
"min_threshold": self.config.min_threshold,
"max_threshold": self.config.max_threshold,
"percentile": self.config.percentile,
"std_multiplier": self.config.std_multiplier,
"min_results": self.config.min_results,
"enable_auto_adjust": self.config.enable_auto_adjust,
},
"statistics": {
"total_processed": self._total_processed,
"total_filtered": self._total_filtered,
"filter_rate": filter_rate,
"avg_threshold": float(np.mean(self._threshold_history))
if self._threshold_history else 0.0,
"threshold_count": len(self._threshold_history),
},
}
if self._threshold_history:
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
return stats
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_filtered = 0
self._total_processed = 0
self._threshold_history.clear()
logger.info("统计信息已重置")
def __repr__(self) -> str:
return (
f"DynamicThresholdFilter("
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}, "
f"filtered={self._total_filtered}/{self._total_processed})"
)

View File

@@ -0,0 +1,8 @@
"""SDK runtime exports for A_Memorix."""
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
__all__ = [
"KernelSearchRequest",
"SDKMemoryKernel",
]

View File

@@ -0,0 +1,579 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
from src.common.logger import get_logger
from ..embedding import create_embedding_api_adapter
from ..retrieval import (
DualPathRetriever,
DualPathRetrieverConfig,
RetrievalResult,
SparseBM25Config,
SparseBM25Index,
TemporalQueryOptions,
)
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
from ..utils.aggregate_query_service import AggregateQueryService
from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.hash import normalize_text
from ..utils.relation_write_service import RelationWriteService
logger = get_logger("A_Memorix.SDKMemoryKernel")
@dataclass
class KernelSearchRequest:
query: str = ""
limit: int = 5
mode: str = "hybrid"
chat_id: str = ""
person_id: str = ""
time_start: Optional[float] = None
time_end: Optional[float] = None
class SDKMemoryKernel:
def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None:
self.plugin_root = Path(plugin_root).resolve()
self.config = config or {}
storage_cfg = self._cfg("storage", {}) or {}
data_dir = str(storage_cfg.get("data_dir", "./data") or "./data")
self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir)
self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256)))
self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False))
self.embedding_manager = None
self.vector_store: Optional[VectorStore] = None
self.graph_store: Optional[GraphStore] = None
self.metadata_store: Optional[MetadataStore] = None
self.relation_write_service: Optional[RelationWriteService] = None
self.sparse_index = None
self.retriever: Optional[DualPathRetriever] = None
self.episode_retriever: Optional[EpisodeRetrievalService] = None
self.aggregate_query_service: Optional[AggregateQueryService] = None
self._initialized = False
self._last_maintenance_at: Optional[float] = None
def _cfg(self, key: str, default: Any = None) -> Any:
current: Any = self.config
if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict):
return current.get(key, default)
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
async def initialize(self) -> None:
if self._initialized:
return
self.data_dir.mkdir(parents=True, exist_ok=True)
self.embedding_manager = create_embedding_api_adapter(
batch_size=int(self._cfg("embedding.batch_size", 32)),
max_concurrent=int(self._cfg("embedding.max_concurrent", 5)),
default_dimension=self.embedding_dimension,
model_name=str(self._cfg("embedding.model_name", "hash-v1")),
retry_config=self._cfg("embedding.retry", {}) or {},
)
self.embedding_dimension = int(await self.embedding_manager._detect_dimension())
self.vector_store = VectorStore(
dimension=self.embedding_dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors",
)
self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph")
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
self.metadata_store.connect()
if self.vector_store.has_data():
self.vector_store.load()
self.vector_store.warmup_index(force_train=True)
if self.graph_store.has_data():
self.graph_store.load()
sparse_cfg = self._cfg("retrieval.sparse", {}) or {}
self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg))
if getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded()
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
self.retriever = DualPathRetriever(
vector_store=self.vector_store,
graph_store=self.graph_store,
metadata_store=self.metadata_store,
embedding_manager=self.embedding_manager,
sparse_index=self.sparse_index,
config=DualPathRetrieverConfig(
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)),
top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)),
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
alpha=float(self._cfg("retrieval.alpha", 0.5)),
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
sparse=sparse_cfg,
fusion=self._cfg("retrieval.fusion", {}) or {},
graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {},
relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {},
),
)
self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever)
self.aggregate_query_service = AggregateQueryService(plugin_config=self.config)
self._initialized = True
def close(self) -> None:
if self.vector_store is not None:
self.vector_store.save()
if self.graph_store is not None:
self.graph_store.save()
if self.metadata_store is not None:
self.metadata_store.close()
self._initialized = False
async def ingest_summary(
self,
*,
external_id: str,
chat_id: str,
text: str,
participants: Optional[Sequence[str]] = None,
time_start: Optional[float] = None,
time_end: Optional[float] = None,
tags: Optional[Sequence[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
summary_meta = dict(metadata or {})
summary_meta.setdefault("kind", "chat_summary")
return await self.ingest_text(
external_id=external_id,
source_type="chat_summary",
text=text,
chat_id=chat_id,
participants=participants,
time_start=time_start,
time_end=time_end,
tags=tags,
metadata=summary_meta,
)
async def ingest_text(
self,
*,
external_id: str,
source_type: str,
text: str,
chat_id: str = "",
person_ids: Optional[Sequence[str]] = None,
participants: Optional[Sequence[str]] = None,
timestamp: Optional[float] = None,
time_start: Optional[float] = None,
time_end: Optional[float] = None,
tags: Optional[Sequence[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
entities: Optional[Sequence[str]] = None,
relations: Optional[Sequence[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
await self.initialize()
assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager
assert self.relation_write_service
content = normalize_text(text)
if not content:
return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"}
if ref := self.metadata_store.get_external_memory_ref(external_id):
return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"}
person_tokens = self._tokens(person_ids)
participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {})
paragraph_meta.update(
{
"external_id": external_id,
"source_type": str(source_type or "").strip(),
"chat_id": str(chat_id or "").strip(),
"person_ids": person_tokens,
"participants": participant_tokens,
"tags": self._tokens(tags),
}
)
paragraph_hash = self.metadata_store.add_paragraph(
content=content,
source=source,
metadata=paragraph_meta,
knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed",
time_meta=self._time_meta(timestamp, time_start, time_end),
)
embedding = await self.embedding_manager.encode(content)
self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash])
for name in entity_tokens:
self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash)
stored_relations: List[str] = []
for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]:
s = str(row.get("subject", "") or "").strip()
p = str(row.get("predicate", "") or "").strip()
o = str(row.get("object", "") or "").strip()
if not (s and p and o):
continue
result = await self.relation_write_service.upsert_relation_with_vector(
subject=s,
predicate=p,
obj=o,
confidence=float(row.get("confidence", 1.0) or 1.0),
source_paragraph=paragraph_hash,
metadata={"external_id": external_id, "source_type": source_type},
write_vector=self.relation_vectors_enabled,
)
self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value)
stored_relations.append(result.hash_value)
self.metadata_store.upsert_external_memory_ref(
external_id=external_id,
paragraph_hash=paragraph_hash,
source_type=source_type,
metadata={"chat_id": chat_id, "person_ids": person_tokens},
)
self._persist()
self.rebuild_episodes_for_sources([source])
for person_id in person_tokens:
await self.refresh_person_profile(person_id)
return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []}
async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]:
await self.initialize()
assert self.retriever and self.episode_retriever and self.aggregate_query_service
mode = str(request.mode or "hybrid").strip().lower() or "hybrid"
clean_query = str(request.query or "").strip()
limit = max(1, int(request.limit or 5))
temporal = self._temporal(request)
if mode == "episode":
rows = await self.episode_retriever.query(
query=clean_query,
top_k=limit,
time_from=request.time_start,
time_to=request.time_end,
source=self._chat_source(request.chat_id),
)
hits = [self._episode_hit(row) for row in rows]
return {"summary": self._summary(hits), "hits": hits}
if mode == "aggregate":
payload = await self.aggregate_query_service.execute(
query=clean_query,
top_k=limit,
mix=True,
mix_top_k=limit,
time_from=str(request.time_start) if request.time_start is not None else None,
time_to=str(request.time_end) if request.time_end is not None else None,
search_runner=lambda: self._aggregate_search(clean_query, limit, temporal),
time_runner=lambda: self._aggregate_time(clean_query, limit, temporal),
episode_runner=lambda: self._aggregate_episode(clean_query, limit, request),
)
hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)]
for item in hits:
item.setdefault("metadata", {})
return {"summary": self._summary(hits), "hits": hits}
results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal)
hits = [self._retrieval_hit(item) for item in results]
return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)}
async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]:
_ = chat_id
await self.initialize()
assert self.metadata_store
snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit)
evidence = []
for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]:
paragraph = self.metadata_store.get_paragraph(hash_value)
if paragraph is not None:
evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}})
text = str(snapshot.get("profile_text", "") or "").strip()
traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8]
return {"summary": text, "traits": traits, "evidence": evidence}
async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]:
await self.initialize()
assert self.metadata_store
rows = self.metadata_store.query(
"""
SELECT DISTINCT p.*
FROM paragraphs p
JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash
JOIN entities e ON e.hash = pe.entity_hash
WHERE e.name = ?
AND (p.is_deleted IS NULL OR p.is_deleted = 0)
ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC
LIMIT ?
""",
(person_id, max(1, int(limit)) * 3),
)
evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]]
relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]]
if relation_edges:
profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6])
elif vector_evidence:
profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6])
else:
profile_text = "暂无稳定画像证据。"
return self.metadata_store.upsert_person_profile_snapshot(
person_id=person_id,
profile_text=profile_text,
aliases=[person_id],
relation_edges=relation_edges,
vector_evidence=vector_evidence,
evidence_ids=evidence_ids[: max(1, int(limit))],
expires_at=time.time() + 6 * 3600,
source_note="sdk_memory_kernel",
)
async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]:
_ = reason
await self.initialize()
assert self.metadata_store
hashes = self._resolve_relation_hashes(target)
if not hashes:
return {"success": False, "detail": "未命中可维护关系"}
act = str(action or "").strip().lower()
if act == "reinforce":
self.metadata_store.reinforce_relations(hashes)
elif act == "protect":
ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0
self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0)
elif act == "restore":
restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value))
if restored <= 0:
return {"success": False, "detail": "未恢复任何关系"}
else:
return {"success": False, "detail": f"不支持的维护动作: {act}"}
self._last_maintenance_at = time.time()
self._persist()
return {"success": True, "detail": f"{act} {len(hashes)} 条关系"}
def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int:
assert self.metadata_store
rebuilt = 0
for source in self._tokens(sources):
rows = self.metadata_store.query(
"""
SELECT * FROM paragraphs
WHERE source = ?
AND (is_deleted IS NULL OR is_deleted = 0)
ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC
""",
(source,),
)
if not rows:
continue
paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
payload = self.metadata_store.upsert_episode(
{
"source": source,
"title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80],
"summary": "".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。",
"participants": self._episode_participants(rows),
"keywords": self._episode_keywords(rows),
"evidence_ids": paragraph_hashes,
"paragraph_count": len(paragraph_hashes),
"event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False),
"event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True),
"time_granularity": "day",
"time_confidence": 0.7,
"llm_confidence": 0.0,
"segmentation_model": "rule_based_sdk",
"segmentation_version": "1",
}
)
self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes)
rebuilt += 1
return rebuilt
def memory_stats(self) -> Dict[str, Any]:
assert self.metadata_store
stats = self.metadata_store.get_statistics()
episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"]
profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"]
return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at}
async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
assert self.retriever
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
return {"success": True, "results": hits, "count": len(hits), "query_type": "search"}
async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
if temporal is None:
return {"success": False, "error": "missing temporal window", "results": []}
assert self.retriever
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
return {"success": True, "results": hits, "count": len(hits), "query_type": "time"}
async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]:
assert self.episode_retriever
rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
hits = [self._episode_hit(row) for row in rows]
return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"}
def _persist(self) -> None:
if self.vector_store is not None:
self.vector_store.save()
if self.graph_store is not None:
self.graph_store.save()
if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded()
@staticmethod
def _tokens(values: Optional[Iterable[Any]]) -> List[str]:
result: List[str] = []
seen = set()
for item in values or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result
@classmethod
def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]:
merged: List[str] = []
seen = set()
for group in groups:
for item in cls._tokens(group):
if item in seen:
continue
seen.add(item)
merged.append(item)
return merged
@staticmethod
def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str:
clean_type = str(source_type or "").strip() or "memory"
if clean_type == "chat_summary" and chat_id:
return f"chat_summary:{chat_id}"
if clean_type == "person_fact" and person_ids:
return f"person_fact:{person_ids[0]}"
return f"{clean_type}:{chat_id}" if chat_id else clean_type
@staticmethod
def _chat_source(chat_id: str) -> Optional[str]:
clean = str(chat_id or "").strip()
return f"chat_summary:{clean}" if clean else None
@staticmethod
def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]:
payload: Dict[str, Any] = {}
if timestamp is not None:
payload["event_time"] = float(timestamp)
if time_start is not None:
payload["event_time_start"] = float(time_start)
if time_end is not None:
payload["event_time_end"] = float(time_end)
if payload:
payload["time_granularity"] = "minute"
payload["time_confidence"] = 0.95
return payload
def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]:
if request.time_start is None and request.time_end is None and not request.chat_id:
return None
return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
@staticmethod
def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]:
payload = item.to_dict()
return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}}
@staticmethod
def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]:
return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}}
@staticmethod
def _summary(hits: Sequence[Dict[str, Any]]) -> str:
if not hits:
return ""
lines = []
for index, item in enumerate(hits[:5], start=1):
content = str(item.get("content", "") or "").strip().replace("\n", " ")
lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}")
return "\n".join(lines)
@staticmethod
def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]:
if not person_id:
return hits
filtered = []
for item in hits:
metadata = item.get("metadata", {}) or {}
if person_id in (metadata.get("person_ids", []) or []):
filtered.append(item)
continue
if person_id and person_id in str(item.get("content", "") or ""):
filtered.append(item)
return filtered or hits
@staticmethod
def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]:
seen = set()
result: List[str] = []
for row in rows:
meta = row.get("metadata", {}) or {}
for key in ("participants", "person_ids"):
for item in meta.get(key, []) or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result[:16]
@staticmethod
def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]:
seen = set()
result: List[str] = []
for row in rows:
meta = row.get("metadata", {}) or {}
for item in meta.get("tags", []) or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result[:12]
@staticmethod
def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]:
values: List[float] = []
for row in rows:
for key in (primary, fallback):
value = row.get(key)
try:
if value is not None:
values.append(float(value))
break
except Exception:
continue
if not values:
return None
return max(values) if reverse else min(values)
def _resolve_relation_hashes(self, target: str) -> List[str]:
assert self.metadata_store
token = str(target or "").strip()
if not token:
return []
if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()):
return [token]
hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10)
if hashes:
return hashes
return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()]

View File

@@ -0,0 +1,53 @@
"""存储层"""
from .vector_store import VectorStore, QuantizationType
from .graph_store import GraphStore, SparseMatrixFormat
from .metadata_store import MetadataStore
from .knowledge_types import (
ImportStrategy,
KnowledgeType,
allowed_import_strategy_values,
allowed_knowledge_type_values,
get_knowledge_type_from_string,
get_import_strategy_from_string,
parse_import_strategy,
resolve_stored_knowledge_type,
should_extract_relations,
get_default_chunk_size,
get_type_display_name,
validate_stored_knowledge_type,
)
from .type_detection import (
detect_knowledge_type,
get_type_from_user_input,
looks_like_factual_text,
looks_like_quote_text,
looks_like_structured_text,
select_import_strategy,
)
__all__ = [
"VectorStore",
"GraphStore",
"MetadataStore",
"QuantizationType",
"SparseMatrixFormat",
"ImportStrategy",
"KnowledgeType",
"allowed_import_strategy_values",
"allowed_knowledge_type_values",
"get_knowledge_type_from_string",
"get_import_strategy_from_string",
"parse_import_strategy",
"resolve_stored_knowledge_type",
"should_extract_relations",
"get_default_chunk_size",
"get_type_display_name",
"validate_stored_knowledge_type",
"detect_knowledge_type",
"get_type_from_user_input",
"looks_like_factual_text",
"looks_like_quote_text",
"looks_like_structured_text",
"select_import_strategy",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,183 @@
"""Knowledge type and import strategy helpers."""
from __future__ import annotations
from enum import Enum
from typing import Any, Optional
class KnowledgeType(str, Enum):
"""持久化到 paragraphs.knowledge_type 的合法类型。"""
STRUCTURED = "structured"
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
MIXED = "mixed"
class ImportStrategy(str, Enum):
"""文本导入阶段的策略选择。"""
AUTO = "auto"
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
def allowed_knowledge_type_values() -> tuple[str, ...]:
return tuple(item.value for item in KnowledgeType)
def allowed_import_strategy_values() -> tuple[str, ...]:
return tuple(item.value for item in ImportStrategy)
def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]:
"""从字符串解析合法的落库知识类型。"""
if not isinstance(type_str, str):
return None
normalized = type_str.lower().strip()
try:
return KnowledgeType(normalized)
except ValueError:
return None
def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]:
"""从字符串解析文本导入策略。"""
if not isinstance(value, str):
return None
normalized = value.lower().strip()
try:
return ImportStrategy(normalized)
except ValueError:
return None
def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy:
"""解析 import strategy非法值直接报错。"""
if value is None:
return default
if isinstance(value, ImportStrategy):
return value
normalized = str(value or "").strip().lower()
if not normalized:
return default
strategy = get_import_strategy_from_string(normalized)
if strategy is None:
allowed = "/".join(allowed_import_strategy_values())
raise ValueError(f"strategy_override 必须为 {allowed}")
return strategy
def validate_stored_knowledge_type(value: Any) -> KnowledgeType:
"""校验写库 knowledge_type仅允许合法落库类型。"""
if isinstance(value, KnowledgeType):
return value
resolved = get_knowledge_type_from_string(value)
if resolved is None:
allowed = "/".join(allowed_knowledge_type_values())
raise ValueError(f"knowledge_type 必须为 {allowed}")
return resolved
def resolve_stored_knowledge_type(
value: Any,
*,
content: str = "",
allow_legacy: bool = False,
unknown_fallback: Optional[KnowledgeType] = None,
) -> KnowledgeType:
"""
将策略/字符串/旧值解析为合法落库类型。
`allow_legacy=True` 仅供迁移使用。
"""
if isinstance(value, KnowledgeType):
return value
if isinstance(value, ImportStrategy):
if value == ImportStrategy.AUTO:
if not str(content or "").strip():
raise ValueError("knowledge_type=auto 需要 content 才能推断")
from .type_detection import detect_knowledge_type
return detect_knowledge_type(content)
return KnowledgeType(value.value)
raw = str(value or "").strip()
if not raw:
if str(content or "").strip():
from .type_detection import detect_knowledge_type
return detect_knowledge_type(content)
raise ValueError("knowledge_type 不能为空")
direct = get_knowledge_type_from_string(raw)
if direct is not None:
return direct
strategy = get_import_strategy_from_string(raw)
if strategy is not None:
return resolve_stored_knowledge_type(strategy, content=content)
if allow_legacy:
normalized = raw.lower()
if normalized == "imported":
return KnowledgeType.FACTUAL
if str(content or "").strip():
from .type_detection import detect_knowledge_type
detected = detect_knowledge_type(content)
if detected is not None:
return detected
if unknown_fallback is not None:
return unknown_fallback
allowed = "/".join(allowed_knowledge_type_values())
raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed}")
def should_extract_relations(knowledge_type: KnowledgeType) -> bool:
"""判断是否应该做关系抽取。"""
return knowledge_type in [
KnowledgeType.STRUCTURED,
KnowledgeType.FACTUAL,
KnowledgeType.MIXED,
]
def get_default_chunk_size(knowledge_type: KnowledgeType) -> int:
"""获取默认分块大小。"""
chunk_sizes = {
KnowledgeType.STRUCTURED: 300,
KnowledgeType.NARRATIVE: 800,
KnowledgeType.FACTUAL: 500,
KnowledgeType.QUOTE: 400,
KnowledgeType.MIXED: 500,
}
return chunk_sizes.get(knowledge_type, 500)
def get_type_display_name(knowledge_type: KnowledgeType) -> str:
"""获取知识类型中文名称。"""
display_names = {
KnowledgeType.STRUCTURED: "结构化知识",
KnowledgeType.NARRATIVE: "叙事性文本",
KnowledgeType.FACTUAL: "事实陈述",
KnowledgeType.QUOTE: "引用文本",
KnowledgeType.MIXED: "混合类型",
}
return display_names.get(knowledge_type, "未知类型")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
"""Heuristic detection for import strategies and stored knowledge types."""
from __future__ import annotations
import re
from typing import Optional
from .knowledge_types import (
ImportStrategy,
KnowledgeType,
parse_import_strategy,
resolve_stored_knowledge_type,
)
_NARRATIVE_MARKERS = [
r"然后",
r"接着",
r"于是",
r"后来",
r"最后",
r"突然",
r"一天",
r"曾经",
r"有一次",
r"从前",
r"说道",
r"问道",
r"想着",
r"觉得",
]
_FACTUAL_MARKERS = [
r"",
r"",
r"",
r"",
r"属于",
r"位于",
r"包含",
r"拥有",
r"成立于",
r"出生于",
]
def _non_empty_lines(content: str) -> list[str]:
return [line for line in str(content or "").splitlines() if line.strip()]
def looks_like_structured_text(content: str) -> bool:
text = str(content or "").strip()
if "|" not in text or text.count("|") < 2:
return False
parts = text.split("|")
return len(parts) == 3 and all(part.strip() for part in parts)
def looks_like_quote_text(content: str) -> bool:
lines = _non_empty_lines(content)
if len(lines) < 5:
return False
avg_len = sum(len(line) for line in lines) / len(lines)
return avg_len < 20
def looks_like_narrative_text(content: str) -> bool:
text = str(content or "").strip()
if not text:
return False
narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text))
has_dialogue = bool(re.search(r'["「『].*?["」』]', text))
has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###"))
return has_chapter or has_dialogue or narrative_score >= 2
def looks_like_factual_text(content: str) -> bool:
text = str(content or "").strip()
if not text:
return False
if looks_like_structured_text(text) or looks_like_quote_text(text):
return False
factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text))
if factual_score <= 0:
return False
if len(text) <= 240:
return True
return factual_score >= 2 and not looks_like_narrative_text(text)
def select_import_strategy(
content: str,
*,
override: Optional[str | ImportStrategy] = None,
chat_log: bool = False,
) -> ImportStrategy:
"""文本导入策略选择override > quote > factual > narrative。"""
if chat_log:
return ImportStrategy.NARRATIVE
strategy = parse_import_strategy(override, default=ImportStrategy.AUTO)
if strategy != ImportStrategy.AUTO:
return strategy
if looks_like_quote_text(content):
return ImportStrategy.QUOTE
if looks_like_factual_text(content):
return ImportStrategy.FACTUAL
return ImportStrategy.NARRATIVE
def detect_knowledge_type(content: str) -> KnowledgeType:
"""自动检测落库 knowledge_type无法可靠判断时回退 mixed。"""
text = str(content or "").strip()
if not text:
return KnowledgeType.MIXED
if looks_like_structured_text(text):
return KnowledgeType.STRUCTURED
if looks_like_quote_text(text):
return KnowledgeType.QUOTE
if looks_like_factual_text(text):
return KnowledgeType.FACTUAL
if looks_like_narrative_text(text):
return KnowledgeType.NARRATIVE
return KnowledgeType.MIXED
def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType:
"""优先使用显式 type_hint否则自动检测。"""
if type_hint:
return resolve_stored_knowledge_type(type_hint, content=content)
return detect_knowledge_type(content)

View File

@@ -0,0 +1,776 @@
"""
向量存储模块
基于Faiss的高效向量存储与检索支持SQ8量化、Append-Only磁盘存储和内存映射。
"""
import os
import pickle
import hashlib
import shutil
import time
from pathlib import Path
from typing import Optional, Union, Tuple, List, Dict, Set, Any
import random
import threading # Added threading import
import numpy as np
try:
import faiss
HAS_FAISS = True
except ImportError:
HAS_FAISS = False
from src.common.logger import get_logger
from ..utils.quantization import QuantizationType
from ..utils.io import atomic_write, atomic_save_path
logger = get_logger("A_Memorix.VectorStore")
class VectorStore:
"""
向量存储类 (SQ8 + Append-Only Disk)
特性:
- 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit))
- 存储: float16 on-disk binary (vectors.bin)
- 内存: 仅索引常驻 RAM (<512MB for 100k vectors)
- ID: SHA1-based stable int64 IDs
- 一致性: 强制 L2 Normalization (IP == Cosine)
"""
# 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化)
DEFAULT_MIN_TRAIN = 40
# 强制训练样本量
TRAIN_SIZE = 10000
# 储水池采样上限 (流式处理前 50k 数据)
RESERVOIR_CAPACITY = 10000
RESERVOIR_SAMPLE_SCOPE = 50000
def __init__(
self,
dimension: int,
quantization_type: QuantizationType = QuantizationType.INT8,
index_type: str = "sq8",
data_dir: Optional[Union[str, Path]] = None,
use_mmap: bool = True,
buffer_size: int = 1024,
):
if not HAS_FAISS:
raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu")
self.dimension = dimension
self.data_dir = Path(data_dir) if data_dir else None
if self.data_dir:
self.data_dir.mkdir(parents=True, exist_ok=True)
if quantization_type != QuantizationType.INT8:
raise ValueError(
"vNext 仅支持 quantization_type=int8(SQ8)。"
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
)
normalized_index_type = str(index_type or "sq8").strip().lower()
if normalized_index_type not in {"sq8", "int8"}:
raise ValueError(
"vNext 仅支持 index_type=sq8。"
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
)
self.quantization_type = QuantizationType.INT8
self.index_type = "sq8"
self.buffer_size = buffer_size
self._index: Optional[faiss.IndexIDMap2] = None
self._init_index()
self._is_trained = False
self._vector_norm = "l2"
# Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力
# 必须使用 IndexIDMap2 以保证 ID 与主索引一致
self._fallback_index: Optional[faiss.IndexIDMap2] = None
self._init_fallback_index()
self._known_hashes: Set[str] = set()
self._deleted_ids: Set[int] = set()
self._reservoir_buffer: List[np.ndarray] = []
self._seen_count_for_reservoir = 0
self._write_buffer_vecs: List[np.ndarray] = []
self._write_buffer_ids: List[int] = []
self._total_added = 0
self._total_deleted = 0
self._bin_count = 0
# Thread safety lock
self._lock = threading.RLock()
logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage")
def _init_index(self):
"""初始化空的 Faiss 索引"""
quantizer = faiss.IndexScalarQuantizer(
self.dimension,
faiss.ScalarQuantizer.QT_8bit,
faiss.METRIC_INNER_PRODUCT
)
self._index = faiss.IndexIDMap2(quantizer)
self._is_trained = False
def _init_fallback_index(self):
"""初始化 Flat 回退索引"""
flat_index = faiss.IndexFlatIP(self.dimension)
self._fallback_index = faiss.IndexIDMap2(flat_index)
logger.debug("Fallback index (Flat) initialized.")
@staticmethod
def _generate_id(key: str) -> int:
"""生成稳定的 int64 ID (SHA1 截断)"""
h = hashlib.sha1(key.encode("utf-8")).digest()
val = int.from_bytes(h[:8], byteorder="big", signed=False)
return val & 0x7FFFFFFFFFFFFFFF
@property
def _bin_path(self) -> Path:
return self.data_dir / "vectors.bin"
@property
def _ids_bin_path(self) -> Path:
return self.data_dir / "vectors_ids.bin"
@property
def _int_to_str_map(self) -> Dict[int, str]:
"""Lazy build volatile map from known hashes"""
# Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently
# But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache
if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes):
with self._lock: # Protect cache rebuild
self._cached_map = {self._generate_id(k): k for k in self._known_hashes}
return self._cached_map
def add(self, vectors: np.ndarray, ids: List[str]) -> int:
with self._lock:
if vectors.shape[1] != self.dimension:
raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}")
vectors = np.ascontiguousarray(vectors, dtype=np.float32)
faiss.normalize_L2(vectors)
processed_vecs = []
processed_int_ids = []
for i, str_id in enumerate(ids):
if str_id in self._known_hashes:
continue
int_id = self._generate_id(str_id)
self._known_hashes.add(str_id)
processed_vecs.append(vectors[i])
processed_int_ids.append(int_id)
if not processed_vecs:
return 0
batch_vecs = np.array(processed_vecs, dtype=np.float32)
batch_ids = np.array(processed_int_ids, dtype=np.int64)
self._write_buffer_vecs.append(batch_vecs)
self._write_buffer_ids.extend(processed_int_ids)
if len(self._write_buffer_ids) >= self.buffer_size:
self._flush_write_buffer_unlocked()
if not self._is_trained:
# 双写到回退索引
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
self._update_reservoir(batch_vecs)
# 这里的 TRAIN_SIZE 取默认 10k或者根据当前数据量动态判断
if len(self._reservoir_buffer) >= 10000:
logger.info(f"训练样本达到上限,开始训练...")
self._train_and_replay_unlocked()
self._total_added += len(batch_ids)
return len(batch_ids)
def _flush_write_buffer(self):
with self._lock:
self._flush_write_buffer_unlocked()
def _flush_write_buffer_unlocked(self):
if not self._write_buffer_vecs:
return
batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0)
batch_ids = np.array(self._write_buffer_ids, dtype=np.int64)
vecs_fp16 = batch_vecs.astype(np.float16)
with open(self._bin_path, "ab") as f:
f.write(vecs_fp16.tobytes())
ids_bytes = batch_ids.astype('>i8').tobytes()
with open(self._ids_bin_path, "ab") as f:
f.write(ids_bytes)
self._bin_count += len(batch_ids)
if self._is_trained and self._index.is_trained:
self._index.add_with_ids(batch_vecs, batch_ids)
else:
# 即使在 flush 时,如果未训练,也要同步到 fallback
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
self._write_buffer_vecs.clear()
self._write_buffer_ids.clear()
def _update_reservoir(self, vectors: np.ndarray):
for vec in vectors:
self._seen_count_for_reservoir += 1
if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY:
self._reservoir_buffer.append(vec)
else:
if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE:
r = random.randint(0, self._seen_count_for_reservoir - 1)
if r < self.RESERVOIR_CAPACITY:
self._reservoir_buffer[r] = vec
def _train_and_replay(self):
with self._lock:
self._train_and_replay_unlocked()
def _train_and_replay_unlocked(self):
if not self._reservoir_buffer:
logger.warning("No training data available.")
return
train_data = np.array(self._reservoir_buffer, dtype=np.float32)
logger.info(f"Training Index with {len(train_data)} samples...")
try:
self._index.train(train_data)
except Exception as e:
logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.")
return
self._is_trained = True
self._reservoir_buffer = []
logger.info("Replaying data from disk to populate index...")
try:
replay_count = self._replay_vectors_to_index()
# 只有当 replay 成功且数据量一致时,才释放回退索引
if self._index.ntotal >= self._bin_count:
logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.")
self._fallback_index.reset()
else:
logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.")
except Exception as e:
logger.error(f"Replay failed: {e}. Keeping fallback index as backup.")
def _replay_vectors_to_index(self) -> int:
"""从 vectors.bin 读取并添加到 index"""
if not self._bin_path.exists() or not self._ids_bin_path.exists():
return 0
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data:
break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_fp32 = batch_fp16.astype(np.float32)
faiss.normalize_L2(batch_fp32)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if not all(valid_mask):
batch_fp32 = batch_fp32[valid_mask]
batch_ids = batch_ids[valid_mask]
if len(batch_ids) > 0:
self._index.add_with_ids(batch_fp32, batch_ids)
def search(
self,
query: np.ndarray,
k: int = 10,
filter_deleted: bool = True,
) -> Tuple[List[str], List[float]]:
query_local = np.array(query, dtype=np.float32, order="C", copy=True)
if query_local.ndim == 1:
got_dim = int(query_local.shape[0])
query_local = query_local.reshape(1, -1)
elif query_local.ndim == 2:
if query_local.shape[0] != 1:
raise ValueError(
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
)
got_dim = int(query_local.shape[1])
else:
raise ValueError(
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
)
if got_dim != self.dimension:
raise ValueError(
f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}"
)
if not np.all(np.isfinite(query_local)):
raise ValueError("query embedding contains non-finite values")
faiss.normalize_L2(query_local)
# 查询路径仅负责检索,不在此触发训练/回放。
# 训练/回放前置到 warmup_index(),并由插件启动阶段触发。
# Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。
with self._lock:
self._flush_write_buffer_unlocked()
search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index
if search_index.ntotal == 0:
logger.warning("Indices are empty. No data to search.")
return [], []
# 执行检索
dists, ids = search_index.search(query_local, k * 2)
# Faiss search 返回的是 (1, K) 的数组,取第一行
dists = dists[0]
ids = ids[0]
results = []
for id_val, score in zip(ids, dists):
if id_val == -1: continue
if filter_deleted and id_val in self._deleted_ids:
continue
str_id = self._int_to_str_map.get(id_val)
if str_id:
results.append((str_id, float(score)))
# Sort and trim just in case filtering reduced count
results.sort(key=lambda x: x[1], reverse=True)
results = results[:k]
if not results:
return [], []
return [r[0] for r in results], [r[1] for r in results]
def warmup_index(self, force_train: bool = True) -> Dict[str, Any]:
"""
预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。
Args:
force_train: 是否在满足阈值时强制训练 SQ8 索引
Returns:
预热状态摘要
"""
started = time.perf_counter()
logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}")
try:
with self._lock:
self._flush_write_buffer()
if self._bin_path.exists():
self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2)
else:
self._bin_count = 0
needs_fallback_bootstrap = (
self._bin_count > 0
and self._fallback_index.ntotal == 0
and (not self._is_trained or self._index.ntotal == 0)
)
if needs_fallback_bootstrap:
self._bootstrap_fallback_from_disk()
min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN)))
needs_train = (
bool(force_train)
and self._bin_count >= min_train
and not self._is_trained
)
if needs_train:
self._force_train_small_data()
duration_ms = (time.perf_counter() - started) * 1000.0
summary = {
"ok": True,
"trained": bool(self._is_trained),
"index_ntotal": int(self._index.ntotal),
"fallback_ntotal": int(self._fallback_index.ntotal),
"bin_count": int(self._bin_count),
"duration_ms": duration_ms,
"error": None,
}
except Exception as e:
duration_ms = (time.perf_counter() - started) * 1000.0
summary = {
"ok": False,
"trained": bool(self._is_trained),
"index_ntotal": int(self._index.ntotal) if self._index is not None else 0,
"fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0,
"bin_count": int(getattr(self, "_bin_count", 0)),
"duration_ms": duration_ms,
"error": str(e),
}
logger.error(
"metric.vector_index_prewarm_fail=1 "
f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} "
f"error={e}"
)
return summary
logger.info(
"metric.vector_index_prewarm_success=1 "
f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} "
f"trained={summary['trained']} "
f"index_ntotal={summary['index_ntotal']} "
f"fallback_ntotal={summary['fallback_ntotal']} "
f"bin_count={summary['bin_count']}"
)
return summary
def _bootstrap_fallback_from_disk(self):
with self._lock:
self._bootstrap_fallback_from_disk_unlocked()
def _bootstrap_fallback_from_disk_unlocked(self):
"""重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引"""
if not self._bin_path.exists() or not self._ids_bin_path.exists():
return
logger.info("Replaying all disk vectors to fallback index...")
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data: break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_fp32 = batch_fp16.astype(np.float32)
faiss.normalize_L2(batch_fp32)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if any(valid_mask):
self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask])
logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.")
def _force_train_small_data(self):
with self._lock:
self._force_train_small_data_unlocked()
def _force_train_small_data_unlocked(self):
logger.info("Forcing training on small dataset...")
self._reservoir_buffer = []
chunk_size = 10000
vec_item_size = self.dimension * 2
with open(self._bin_path, "rb") as f:
while len(self._reservoir_buffer) < self.TRAIN_SIZE:
data = f.read(chunk_size * vec_item_size)
if not data: break
fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension)
fp32 = fp16.astype(np.float32)
faiss.normalize_L2(fp32)
for vec in fp32:
self._reservoir_buffer.append(vec)
if len(self._reservoir_buffer) >= self.TRAIN_SIZE:
break
self._train_and_replay_unlocked()
def delete(self, ids: List[str]) -> int:
with self._lock:
count = 0
for str_id in ids:
if str_id not in self._known_hashes:
continue
int_id = self._generate_id(str_id)
if int_id not in self._deleted_ids:
self._deleted_ids.add(int_id)
if self._index.is_trained:
self._index.remove_ids(np.array([int_id], dtype=np.int64))
# 同步从 fallback 移除
if self._fallback_index.ntotal > 0:
self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64))
count += 1
self._total_deleted += count
# Check GC
self._check_rebuild_needed()
return count
def _check_rebuild_needed(self):
"""GC Excution Check"""
if self._bin_count == 0: return
ratio = len(self._deleted_ids) / self._bin_count
if ratio > 0.3 and len(self._deleted_ids) > 1000:
logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})")
self.rebuild_index()
def rebuild_index(self):
"""GC: 重建索引,压缩 bin 文件"""
with self._lock:
self._rebuild_index_locked()
def _rebuild_index_locked(self):
"""实际 GC 重建逻辑。"""
logger.info("Starting Compaction (GC)...")
tmp_bin = self.data_dir / "vectors.bin.tmp"
tmp_ids = self.data_dir / "vectors_ids.bin.tmp"
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
new_count = 0
# 1. Compact Files
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \
open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data: break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if any(keep_mask):
keep_vecs = batch_fp16[keep_mask]
keep_ids = batch_ids[keep_mask]
w_vec.write(keep_vecs.tobytes())
w_id.write(keep_ids.astype('>i8').tobytes())
new_count += len(keep_ids)
# 2. Reset State & Atomic Swap
self._bin_count = new_count
# Close current index
self._index.reset()
if self._fallback_index: self._fallback_index.reset() # Also clear fallback
self._is_trained = False
# Swap files
shutil.move(str(tmp_bin), str(self._bin_path))
shutil.move(str(tmp_ids), str(self._ids_bin_path))
# Reset Tombstones (Critical)
self._deleted_ids.clear()
# 3. Reload/Rebuild Index (Fresh Train)
# We need to re-train because data distribution might have changed significantly after deletion
self._init_index()
self._init_fallback_index() # Re-init fallback too
self._force_train_small_data() # This will train and replay from the NEW compact file
logger.info("Compaction Complete.")
def save(self, data_dir: Optional[Union[str, Path]] = None) -> None:
with self._lock:
if not data_dir:
data_dir = self.data_dir
if not data_dir:
raise ValueError("No data_dir")
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
self._flush_write_buffer_unlocked()
if self._is_trained:
index_path = data_dir / "vectors.index"
with atomic_save_path(index_path) as tmp:
faiss.write_index(self._index, tmp)
meta = {
"dimension": self.dimension,
"quantization_type": self.quantization_type.value,
"is_trained": self._is_trained,
"vector_norm": self._vector_norm,
"deleted_ids": list(self._deleted_ids),
"known_hashes": list(self._known_hashes),
}
with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f:
pickle.dump(meta, f)
logger.info("VectorStore saved.")
def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
"""
离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。
"""
with self._lock:
target_dir = Path(data_dir) if data_dir else self.data_dir
if target_dir is None:
raise ValueError("No data_dir")
target_dir = Path(target_dir)
npy_path = target_dir / "vectors.npy"
idx_path = target_dir / "vectors.index"
bin_path = target_dir / "vectors.bin"
ids_bin_path = target_dir / "vectors_ids.bin"
meta_path = target_dir / "vectors_metadata.pkl"
if not npy_path.exists():
return {"migrated": False, "reason": "npy_missing"}
if not meta_path.exists():
raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl")
if bin_path.exists() and ids_bin_path.exists():
return {"migrated": False, "reason": "bin_exists"}
# Reset in-memory state to avoid appending to stale runtime buffers.
self._known_hashes.clear()
self._deleted_ids.clear()
self._write_buffer_vecs.clear()
self._write_buffer_ids.clear()
self._init_index()
self._init_fallback_index()
self._is_trained = False
self._bin_count = 0
self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir)
self.save(target_dir)
return {"migrated": True, "reason": "ok"}
def load(self, data_dir: Optional[Union[str, Path]] = None) -> None:
with self._lock:
if not data_dir: data_dir = self.data_dir
data_dir = Path(data_dir)
npy_path = data_dir / "vectors.npy"
idx_path = data_dir / "vectors.index"
bin_path = data_dir / "vectors.bin"
if npy_path.exists() and not bin_path.exists():
raise RuntimeError(
"检测到 legacy vectors.npyvNext 不再支持运行时自动迁移。"
" 请先执行 scripts/release_vnext_migrate.py migrate。"
)
meta_path = data_dir / "vectors_metadata.pkl"
if not meta_path.exists():
logger.warning("No metadata found, initialized empty.")
return
with open(meta_path, "rb") as f:
meta = pickle.load(f)
if meta.get("vector_norm") != "l2":
logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...")
self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", []))
self._deleted_ids = set(meta.get("deleted_ids", []))
self._init_index()
self._force_train_small_data()
return
self._is_trained = meta.get("is_trained", False)
self._vector_norm = meta.get("vector_norm", "l2")
self._deleted_ids = set(meta.get("deleted_ids", []))
self._known_hashes = set(meta.get("known_hashes", []))
if self._is_trained:
if idx_path.exists():
try:
self._index = faiss.read_index(str(idx_path))
if not isinstance(self._index, faiss.IndexIDMap2):
logger.warning("Loaded index type mismatch. Rebuilding...")
self._init_index()
self._force_train_small_data()
except Exception as e:
logger.error(f"Failed to load index: {e}. Rebuilding...")
self._init_index()
self._force_train_small_data()
else:
logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...")
self._init_index()
self._force_train_small_data()
if bin_path.exists():
self._bin_count = bin_path.stat().st_size // (self.dimension * 2)
def _migrate_from_npy(self, npy_path, idx_path, data_dir):
with self._lock:
self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir)
def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir):
try:
arr = np.load(npy_path, mmap_mode="r")
except Exception:
arr = np.load(npy_path)
meta_path = data_dir / "vectors_metadata.pkl"
old_ids = []
if meta_path.exists():
with open(meta_path, "rb") as f:
m = pickle.load(f)
old_ids = m.get("ids", [])
if len(arr) != len(old_ids):
logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}")
return
logger.info(f"Migrating {len(arr)} vectors...")
chunk = 1000
for i in range(0, len(arr), chunk):
sub_arr = arr[i : i+chunk]
sub_ids = old_ids[i : i+chunk]
self.add(sub_arr, sub_ids)
if not self._is_trained:
self._force_train_small_data()
shutil.move(str(npy_path), str(npy_path) + ".bak")
if idx_path.exists():
shutil.move(str(idx_path), str(idx_path) + ".bak")
logger.info("Migration complete.")
def clear(self) -> None:
with self._lock:
self._ids_bin_path.unlink(missing_ok=True)
self._bin_path.unlink(missing_ok=True)
self._init_index()
self._known_hashes.clear()
self._deleted_ids.clear()
self._bin_count = 0
logger.info("VectorStore cleared.")
def has_data(self) -> bool:
return (self.data_dir / "vectors_metadata.pkl").exists()
@property
def num_vectors(self) -> int:
return len(self._known_hashes) - len(self._deleted_ids)
def __contains__(self, hash_value: str) -> bool:
"""Check if a hash exists in the store"""
return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
import hashlib
class KnowledgeType(str, Enum):
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
MIXED = "mixed"
@dataclass
class SourceInfo:
file: str
offset_start: int
offset_end: int
checksum: str = ""
@dataclass
class ChunkContext:
chunk_id: str
index: int
context: Dict[str, Any] = field(default_factory=dict)
text: str = ""
@dataclass
class ChunkFlags:
verbatim: bool = False
requires_llm: bool = True
@dataclass
class ProcessedChunk:
type: KnowledgeType
source: SourceInfo
chunk: ChunkContext
data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities
flags: ChunkFlags = field(default_factory=ChunkFlags)
def to_dict(self) -> Dict:
return {
"type": self.type.value,
"source": {
"file": self.source.file,
"offset_start": self.source.offset_start,
"offset_end": self.source.offset_end,
"checksum": self.source.checksum
},
"chunk": {
"text": self.chunk.text,
"chunk_id": self.chunk.chunk_id,
"index": self.chunk.index,
"context": self.chunk.context
},
"data": self.data,
"flags": {
"verbatim": self.flags.verbatim,
"requires_llm": self.flags.requires_llm
}
}
class BaseStrategy(ABC):
def __init__(self, filename: str):
self.filename = filename
@abstractmethod
def split(self, text: str) -> List[ProcessedChunk]:
"""按策略将文本切分为块。"""
pass
@abstractmethod
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
"""从文本块中抽取结构化信息。"""
pass
def calculate_checksum(self, text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def build_language_guard(self, text: str) -> str:
"""
构建统一的输出语言约束。
不区分语言类型,仅要求抽取值保持原文语言,不做翻译。
"""
_ = text # 预留参数,便于后续按需扩展
return (
"Focus on the original source language. Keep extracted events, entities, predicates "
"and objects in the same language as the source text, preserve names/terms as-is, "
"and do not translate."
)

View File

@@ -0,0 +1,98 @@
import re
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
class FactualStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
# 结构感知切分
lines = text.split('\n')
chunks = []
current_chunk_lines = []
current_len = 0
target_size = 600
for i, line in enumerate(lines):
# 判断是否应当切分
# 若当前行为列表项/定义/表格行,则尽量不切分
is_structure = self._is_structural_line(line)
current_len += len(line) + 1
current_chunk_lines.append(line)
# 达到目标长度且不在紧凑结构块内时切分(过长时强制切分)
if current_len >= target_size and not is_structure:
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
current_chunk_lines = []
current_len = 0
elif current_len >= target_size * 2: # 超长时强制切分
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
current_chunk_lines = []
current_len = 0
if current_chunk_lines:
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
return chunks
def _is_structural_line(self, line: str) -> bool:
line = line.strip()
if not line: return False
# 列表项
if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line):
return True
# 定义项(术语: 定义)
if re.match(r'^[^:]+[:].+', line):
return True
# 表格行(按 markdown 语法假设)
if line.startswith('|') and line.endswith('|'):
return True
return False
def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk:
text = "\n".join(lines)
return ProcessedChunk(
type=KnowledgeType.FACTUAL,
source=SourceInfo(
file=self.filename,
offset_start=0, # 简化处理:真实偏移跟踪需要额外状态
offset_end=0,
checksum=self.calculate_checksum(text)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{index}",
index=index,
text=text
)
)
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
if not llm_func:
raise ValueError("LLM function required for Factual extraction")
language_guard = self.build_language_guard(chunk.chunk.text)
prompt = f"""You are a factual knowledge extraction engine.
Extract factual triples and entities from the text.
Preserve lists and definitions accurately.
Language constraints:
- {language_guard}
- Preserve original names and domain terms exactly when possible.
- JSON keys must stay exactly as: triples, entities, subject, predicate, object.
Text:
{chunk.chunk.text}
Return ONLY valid JSON:
{{
"triples": [
{{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}}
],
"entities": ["Entity1", "Entity2"]
}}
"""
result = await llm_func(prompt)
# 结果保持原样存入 data后续统一归一化流程会处理
# vector_store 侧期望关系字段为 subject/predicate/object 映射形式
chunk.data = result
return chunk

View File

@@ -0,0 +1,126 @@
import re
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
class NarrativeStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
scenes = self._split_into_scenes(text)
chunks = []
for scene_idx, (scene_text, scene_title) in enumerate(scenes):
scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx)
chunks.extend(scene_chunks)
return chunks
def _split_into_scenes(self, text: str) -> List[tuple[str, str]]:
"""按标题或分隔符把文本切分为场景。"""
# 简单启发式:按 markdown 标题或特定分隔符切分
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)'
# 保留分隔符,以便识别场景起点
parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE)
scenes = []
current_scene_title = "Start"
current_scene_content = []
if parts and parts[0].strip() == "":
parts = parts[1:]
for part in parts:
if re.match(scene_pattern_str, part, re.MULTILINE):
# 先保存上一段场景
if current_scene_content:
scenes.append(("".join(current_scene_content), current_scene_title))
current_scene_content = []
current_scene_title = part.strip()
else:
current_scene_content.append(part)
if current_scene_content:
scenes.append(("".join(current_scene_content), current_scene_title))
# 若未识别到场景,则把全文视作单一场景
if not scenes:
scenes = [(text, "Whole Text")]
return scenes
def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]:
chunks = []
if len(text) <= window_size:
chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0))
return chunks
stride = window_size - overlap
start = 0
local_idx = 0
while start < len(text):
end = min(start + window_size, len(text))
chunk_text = text[start:end]
# 尽量对齐到最近换行,避免生硬截断句子
# 仅在未到文本尾部时进行回退
if end < len(text):
last_newline = chunk_text.rfind('\n')
if last_newline > window_size // 2: # 仅在回退距离可接受时启用
end = start + last_newline + 1
chunk_text = text[start:end]
chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start))
start += len(chunk_text) - overlap if end < len(text) else len(chunk_text)
local_idx += 1
return chunks
def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk:
return ProcessedChunk(
type=KnowledgeType.NARRATIVE,
source=SourceInfo(
file=self.filename,
offset_start=offset,
offset_end=offset + len(text),
checksum=self.calculate_checksum(text)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{scene_idx}_{local_idx}",
index=local_idx,
text=text,
context={"scene_id": scene_id}
)
)
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
if not llm_func:
raise ValueError("LLM function required for Narrative extraction")
language_guard = self.build_language_guard(chunk.chunk.text)
prompt = f"""You are a narrative knowledge extraction engine.
Extract key events and character relations from the scene text.
Language constraints:
- {language_guard}
- Preserve original names and terms exactly when possible.
- JSON keys must stay exactly as: events, relations, subject, predicate, object.
Scene:
{chunk.chunk.context.get('scene_id')}
Text:
{chunk.chunk.text}
Return ONLY valid JSON:
{{
"events": ["event description 1", "event description 2"],
"relations": [
{{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}}
]
}}
"""
result = await llm_func(prompt)
chunk.data = result
return chunk

View File

@@ -0,0 +1,52 @@
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
class QuoteStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
# Split by double newlines (stanzas)
stanzas = text.split("\n\n")
chunks = []
offset = 0
for idx, stanza in enumerate(stanzas):
if not stanza.strip():
offset += len(stanza) + 2
continue
chunk = ProcessedChunk(
type=KnowledgeType.QUOTE,
source=SourceInfo(
file=self.filename,
offset_start=offset,
offset_end=offset + len(stanza),
checksum=self.calculate_checksum(stanza)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{idx}",
index=idx,
text=stanza
),
flags=ChunkFlags(
verbatim=True,
requires_llm=False # Default to no LLM, but can be overridden
)
)
chunks.append(chunk)
offset += len(stanza) + 2 # +2 for \n\n
return chunks
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
# For quotes, the text itself is the entity/knowledge
# We might use LLM to extract headers/metadata if requested, but core logic is pass-through
# Treat the whole chunk text as a verbatim entity
chunk.data = {
"verbatim_entities": [chunk.chunk.text]
}
if llm_func and chunk.flags.requires_llm:
# Optional: Extract metadata
pass
return chunk

View File

@@ -0,0 +1,33 @@
"""工具模块 - 哈希、监控等辅助功能"""
from .hash import compute_hash, normalize_text
from .monitor import MemoryMonitor
from .quantization import quantize_vector, dequantize_vector
from .time_parser import (
parse_query_datetime_to_timestamp,
parse_query_time_range,
parse_ingest_datetime_to_timestamp,
normalize_time_meta,
format_timestamp,
)
from .relation_write_service import RelationWriteService, RelationWriteResult
from .relation_query import RelationQuerySpec, parse_relation_query_spec
from .plugin_id_policy import PluginIdPolicy
__all__ = [
"compute_hash",
"normalize_text",
"MemoryMonitor",
"quantize_vector",
"dequantize_vector",
"parse_query_datetime_to_timestamp",
"parse_query_time_range",
"parse_ingest_datetime_to_timestamp",
"normalize_time_meta",
"format_timestamp",
"RelationWriteService",
"RelationWriteResult",
"RelationQuerySpec",
"parse_relation_query_spec",
"PluginIdPolicy",
]

View File

@@ -0,0 +1,360 @@
"""
聚合查询服务:
- 并发执行 search/time/episode 分支
- 统一分支结果结构
- 可选混合排序Weighted RRF
"""
from __future__ import annotations
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
logger = get_logger("A_Memorix.AggregateQueryService")
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
class AggregateQueryService:
"""聚合查询执行服务search/time/episode"""
def __init__(self, plugin_config: Optional[Any] = None):
self.plugin_config = plugin_config or {}
def _cfg(self, key: str, default: Any = None) -> Any:
getter = getattr(self.plugin_config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _as_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return float(default)
@staticmethod
def _as_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _rrf_k(self) -> float:
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
value = self._as_float(raw, 60.0)
return max(1.0, value)
def _weights(self) -> Dict[str, float]:
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
raw = self._cfg("retrieval.aggregate.weights", {})
if not isinstance(raw, dict):
return defaults
out = dict(defaults)
for key in ("search", "time", "episode"):
if key in raw:
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
return out
@staticmethod
def _normalize_branch_payload(
name: str,
payload: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
data = payload if isinstance(payload, dict) else {}
results_raw = data.get("results", [])
results = results_raw if isinstance(results_raw, list) else []
count = data.get("count")
if count is None:
count = len(results)
return {
"name": name,
"success": bool(data.get("success", False)),
"skipped": bool(data.get("skipped", False)),
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
"error": str(data.get("error", "") or "").strip(),
"results": results,
"count": max(0, int(count)),
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
"content": str(data.get("content", "") or ""),
"query_type": str(data.get("query_type", "") or name),
}
@staticmethod
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
item_type = str(item.get("type", "") or "").strip().lower()
if item_type == "episode":
episode_id = str(item.get("episode_id", "") or "").strip()
if episode_id:
return f"episode:{episode_id}"
item_hash = str(item.get("hash", "") or "").strip()
if item_hash:
return f"{item_type}:{item_hash}"
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
def _build_mixed_results(
self,
*,
branches: Dict[str, Dict[str, Any]],
top_k: int,
) -> List[Dict[str, Any]]:
rrf_k = self._rrf_k()
weights = self._weights()
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, branch in branches.items():
if not branch.get("success", False):
continue
results = branch.get("results", [])
if not isinstance(results, list):
continue
weight = max(0.0, float(weights.get(branch_name, 1.0)))
for idx, item in enumerate(results, start=1):
if not isinstance(item, dict):
continue
key = self._mix_key(item, branch_name, idx)
score = weight / (rrf_k + float(idx))
if key not in bucket:
merged = dict(item)
merged["fusion_score"] = 0.0
merged["_source_branches"] = set()
bucket[key] = merged
target = bucket[key]
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
target["_source_branches"].add(branch_name)
mixed = list(bucket.values())
mixed.sort(
key=lambda x: (
-float(x.get("fusion_score", 0.0)),
str(x.get("type", "") or ""),
str(x.get("hash", "") or x.get("episode_id", "") or ""),
)
)
out: List[Dict[str, Any]] = []
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
merged = dict(item)
branches_set = merged.pop("_source_branches", set())
merged["source_branches"] = sorted(list(branches_set))
merged["rank"] = rank
out.append(merged)
return out
@staticmethod
def _status(branch: Dict[str, Any]) -> str:
if branch.get("skipped", False):
return "skipped"
if branch.get("success", False):
return "success"
return "failed"
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
summary: Dict[str, Dict[str, Any]] = {}
for name, branch in branches.items():
status = self._status(branch)
summary[name] = {
"status": status,
"count": int(branch.get("count", 0) or 0),
}
if status == "skipped":
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
if status == "failed":
summary[name]["error"] = str(branch.get("error", "") or "")
return summary
def _build_content(
self,
*,
query: str,
branches: Dict[str, Dict[str, Any]],
errors: List[Dict[str, str]],
mixed_results: Optional[List[Dict[str, Any]]],
) -> str:
lines: List[str] = [
f"🔀 聚合查询结果query='{query or 'N/A'}'",
"",
"分支状态:",
]
for name in ("search", "time", "episode"):
branch = branches.get(name, {})
status = self._status(branch)
count = int(branch.get("count", 0) or 0)
line = f"- {name}: {status}, count={count}"
reason = str(branch.get("skip_reason", "") or "").strip()
err = str(branch.get("error", "") or "").strip()
if status == "skipped" and reason:
line += f" ({reason})"
if status == "failed" and err:
line += f" ({err})"
lines.append(line)
if errors:
lines.append("")
lines.append("错误:")
for item in errors[:6]:
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
if mixed_results is not None:
lines.append("")
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
for idx, item in enumerate(mixed_results[:5], start=1):
src = ",".join(item.get("source_branches", []) or [])
if str(item.get("type", "") or "") == "episode":
title = str(item.get("title", "") or "Untitled")
lines.append(f"{idx}. 🧠 {title} [{src}]")
else:
text = str(item.get("content", "") or "")
if len(text) > 80:
text = text[:80] + "..."
lines.append(f"{idx}. {text} [{src}]")
return "\n".join(lines)
async def execute(
self,
*,
query: str,
top_k: int,
mix: bool,
mix_top_k: Optional[int],
time_from: Optional[str],
time_to: Optional[str],
search_runner: Optional[BranchRunner],
time_runner: Optional[BranchRunner],
episode_runner: Optional[BranchRunner],
) -> Dict[str, Any]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
branches: Dict[str, Dict[str, Any]] = {}
errors: List[Dict[str, str]] = []
scheduled: List[Tuple[str, asyncio.Task]] = []
if clean_query:
if search_runner is not None:
scheduled.append(("search", asyncio.create_task(search_runner())))
else:
branches["search"] = self._normalize_branch_payload(
"search",
{"success": False, "error": "search runner unavailable", "results": []},
)
else:
branches["search"] = self._normalize_branch_payload(
"search",
{
"success": False,
"skipped": True,
"skip_reason": "missing_query",
"results": [],
"count": 0,
},
)
if time_from or time_to:
if time_runner is not None:
scheduled.append(("time", asyncio.create_task(time_runner())))
else:
branches["time"] = self._normalize_branch_payload(
"time",
{"success": False, "error": "time runner unavailable", "results": []},
)
else:
branches["time"] = self._normalize_branch_payload(
"time",
{
"success": False,
"skipped": True,
"skip_reason": "missing_time_window",
"results": [],
"count": 0,
},
)
if episode_runner is not None:
scheduled.append(("episode", asyncio.create_task(episode_runner())))
else:
branches["episode"] = self._normalize_branch_payload(
"episode",
{"success": False, "error": "episode runner unavailable", "results": []},
)
if scheduled:
done = await asyncio.gather(
*[task for _, task in scheduled],
return_exceptions=True,
)
for (branch_name, _), payload in zip(scheduled, done):
if isinstance(payload, Exception):
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
normalized = self._normalize_branch_payload(
branch_name,
{
"success": False,
"error": str(payload),
"results": [],
},
)
else:
normalized = self._normalize_branch_payload(branch_name, payload)
branches[branch_name] = normalized
for name in ("search", "time", "episode"):
branch = branches.get(name)
if not branch:
continue
if branch.get("skipped", False):
continue
if not branch.get("success", False):
errors.append(
{
"branch": name,
"error": str(branch.get("error", "") or "unknown error"),
}
)
success = any(
bool(branches.get(name, {}).get("success", False))
for name in ("search", "time", "episode")
)
mixed_results: Optional[List[Dict[str, Any]]] = None
if mix:
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
payload: Dict[str, Any] = {
"success": success,
"query_type": "aggregate",
"query": clean_query,
"top_k": safe_top_k,
"mix": bool(mix),
"mix_top_k": safe_mix_top_k,
"branches": branches,
"errors": errors,
"summary": self._build_summary(branches),
}
if mixed_results is not None:
payload["mixed_results"] = mixed_results
payload["content"] = self._build_content(
query=clean_query,
branches=branches,
errors=errors,
mixed_results=mixed_results,
)
return payload

View File

@@ -0,0 +1,182 @@
"""Episode hybrid retrieval service."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..retrieval import DualPathRetriever, TemporalQueryOptions
logger = get_logger("A_Memorix.EpisodeRetrievalService")
class EpisodeRetrievalService:
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
_RRF_K = 60.0
_BRANCH_WEIGHTS = {
"lexical": 1.0,
"paragraph_evidence": 1.0,
"relation_evidence": 0.85,
}
def __init__(
self,
*,
metadata_store: Any,
retriever: Optional[DualPathRetriever] = None,
) -> None:
self.metadata_store = metadata_store
self.retriever = retriever
async def query(
self,
*,
query: str = "",
top_k: int = 5,
time_from: Optional[float] = None,
time_to: Optional[float] = None,
person: Optional[str] = None,
source: Optional[str] = None,
include_paragraphs: bool = False,
) -> List[Dict[str, Any]]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
candidate_k = max(30, safe_top_k * 6)
branches: Dict[str, List[Dict[str, Any]]] = {
"lexical": self.metadata_store.query_episodes(
query=clean_query,
time_from=time_from,
time_to=time_to,
person=person,
source=source,
limit=(candidate_k if clean_query else safe_top_k),
)
}
if clean_query and self.retriever is not None:
try:
temporal = TemporalQueryOptions(
time_from=time_from,
time_to=time_to,
person=person,
source=source,
)
results = await self.retriever.retrieve(
query=clean_query,
top_k=candidate_k,
temporal=temporal,
)
except Exception as exc:
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
else:
paragraph_rank_map: Dict[str, int] = {}
relation_rank_map: Dict[str, int] = {}
for rank, item in enumerate(results, start=1):
hash_value = str(getattr(item, "hash_value", "") or "").strip()
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
if not hash_value:
continue
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
paragraph_rank_map[hash_value] = rank
elif result_type == "relation" and hash_value not in relation_rank_map:
relation_rank_map[hash_value] = rank
if paragraph_rank_map:
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
list(paragraph_rank_map.keys()),
source=source,
)
if paragraph_rows:
branches["paragraph_evidence"] = self._rank_projected_rows(
paragraph_rows,
rank_map=paragraph_rank_map,
support_key="matched_paragraph_hashes",
)
if relation_rank_map:
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
list(relation_rank_map.keys()),
source=source,
)
if relation_rows:
branches["relation_evidence"] = self._rank_projected_rows(
relation_rows,
rank_map=relation_rank_map,
support_key="matched_relation_hashes",
)
fused = self._fuse_branches(branches, top_k=safe_top_k)
if include_paragraphs:
for item in fused:
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
episode_id=str(item.get("episode_id") or ""),
limit=50,
)
return fused
@staticmethod
def _rank_projected_rows(
rows: List[Dict[str, Any]],
*,
rank_map: Dict[str, int],
support_key: str,
) -> List[Dict[str, Any]]:
sentinel = 10**9
ranked = [dict(item) for item in rows]
def _first_support_rank(item: Dict[str, Any]) -> int:
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
return min(ranks) if ranks else sentinel
ranked.sort(
key=lambda item: (
_first_support_rank(item),
-int(item.get("matched_paragraph_count") or 0),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
return ranked
def _fuse_branches(
self,
branches: Dict[str, List[Dict[str, Any]]],
*,
top_k: int,
) -> List[Dict[str, Any]]:
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, rows in branches.items():
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
if weight <= 0.0:
continue
for rank, row in enumerate(rows, start=1):
episode_id = str(row.get("episode_id", "") or "").strip()
if not episode_id:
continue
if episode_id not in bucket:
payload = dict(row)
payload.pop("matched_paragraph_hashes", None)
payload.pop("matched_relation_hashes", None)
payload.pop("matched_paragraph_count", None)
payload.pop("matched_relation_count", None)
payload["_fusion_score"] = 0.0
bucket[episode_id] = payload
bucket[episode_id]["_fusion_score"] = float(
bucket[episode_id].get("_fusion_score", 0.0)
) + weight / (self._RRF_K + float(rank))
out = list(bucket.values())
out.sort(
key=lambda item: (
-float(item.get("_fusion_score", 0.0)),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
for item in out:
item.pop("_fusion_score", None)
return out[: max(1, int(top_k))]

View File

@@ -0,0 +1,129 @@
"""
哈希工具模块
提供文本哈希计算功能,用于唯一标识和去重。
"""
import hashlib
import re
from typing import Union
def compute_hash(text: str, hash_type: str = "sha256") -> str:
"""
计算文本的哈希值
Args:
text: 输入文本
hash_type: 哈希算法类型sha256, md5等
Returns:
哈希值字符串
"""
if hash_type == "sha256":
return hashlib.sha256(text.encode("utf-8")).hexdigest()
elif hash_type == "md5":
return hashlib.md5(text.encode("utf-8")).hexdigest()
else:
raise ValueError(f"不支持的哈希算法: {hash_type}")
def normalize_text(text: str) -> str:
"""
规范化文本用于哈希计算
执行以下操作:
- 去除首尾空白
- 统一换行符为\\n
- 压缩多个连续空格
- 去除不可见字符(保留换行和制表符)
Args:
text: 输入文本
Returns:
规范化后的文本
"""
# 去除首尾空白
text = text.strip()
# 统一换行符
text = text.replace("\r\n", "\n").replace("\r", "\n")
# 压缩多个连续空格为一个(但保留换行和制表符)
text = re.sub(r"[^\S\n]+", " ", text)
return text
def compute_paragraph_hash(paragraph: str) -> str:
"""
计算段落的哈希值
Args:
paragraph: 段落文本
Returns:
段落哈希值用于paragraph-前缀)
"""
normalized = normalize_text(paragraph)
return compute_hash(normalized)
def compute_entity_hash(entity: str) -> str:
"""
计算实体的哈希值
Args:
entity: 实体名称
Returns:
实体哈希值用于entity-前缀)
"""
normalized = entity.strip().lower()
return compute_hash(normalized)
def compute_relation_hash(relation: tuple) -> str:
"""
计算关系的哈希值
Args:
relation: 关系元组 (subject, predicate, object)
Returns:
关系哈希值用于relation-前缀)
"""
# 将关系元组转为字符串
relation_str = str(tuple(relation))
return compute_hash(relation_str)
def format_hash_key(hash_type: str, hash_value: str) -> str:
"""
格式化哈希键
Args:
hash_type: 类型前缀paragraph, entity, relation
hash_value: 哈希值
Returns:
格式化的键(如 paragraph-abc123...
"""
return f"{hash_type}-{hash_value}"
def parse_hash_key(key: str) -> tuple[str, str]:
"""
解析哈希键
Args:
key: 格式化的键(如 paragraph-abc123...
Returns:
(类型, 哈希值) 元组
"""
parts = key.split("-", 1)
if len(parts) != 2:
raise ValueError(f"无效的哈希键格式: {key}")
return parts[0], parts[1]

View File

@@ -0,0 +1,110 @@
"""Shared import payload normalization helpers."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from ..storage import KnowledgeType, resolve_stored_knowledge_type
from .time_parser import normalize_time_meta
def _normalize_entities(raw_entities: Any) -> List[str]:
if not isinstance(raw_entities, list):
return []
out: List[str] = []
seen = set()
for item in raw_entities:
name = str(item or "").strip()
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
out.append(name)
return out
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
if not isinstance(raw_relations, list):
return []
out: List[Dict[str, str]] = []
for item in raw_relations:
if not isinstance(item, dict):
continue
subject = str(item.get("subject", "")).strip()
predicate = str(item.get("predicate", "")).strip()
obj = str(item.get("object", "")).strip()
if not (subject and predicate and obj):
continue
out.append(
{
"subject": subject,
"predicate": predicate,
"object": obj,
}
)
return out
def normalize_paragraph_import_item(
item: Any,
*,
default_source: str,
) -> Dict[str, Any]:
"""Normalize one paragraph import item from text/json payloads."""
if isinstance(item, str):
content = str(item)
knowledge_type = resolve_stored_knowledge_type(None, content=content)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": str(default_source or "").strip(),
"time_meta": None,
"entities": [],
"relations": [],
}
if not isinstance(item, dict) or "content" not in item:
raise ValueError("段落项必须为字符串或包含 content 的对象")
content = str(item.get("content", "") or "")
if not content.strip():
raise ValueError("段落 content 不能为空")
raw_time_meta = {
"event_time": item.get("event_time"),
"event_time_start": item.get("event_time_start"),
"event_time_end": item.get("event_time_end"),
"time_range": item.get("time_range"),
"time_granularity": item.get("time_granularity"),
"time_confidence": item.get("time_confidence"),
}
time_meta_field = item.get("time_meta")
if isinstance(time_meta_field, dict):
raw_time_meta.update(time_meta_field)
knowledge_type_raw = item.get("knowledge_type")
if knowledge_type_raw is None:
knowledge_type_raw = item.get("type")
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
source = str(item.get("source") or default_source or "").strip()
if not source:
source = str(default_source or "").strip()
normalized_time_meta = normalize_time_meta(raw_time_meta)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": source,
"time_meta": normalized_time_meta if normalized_time_meta else None,
"entities": _normalize_entities(item.get("entities")),
"relations": _normalize_relations(item.get("relations")),
}
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
"""Normalize config-driven summary knowledge type."""
return resolve_stored_knowledge_type(value, content="")

View File

@@ -0,0 +1,84 @@
"""
IO Utilities
提供原子文件写入等IO辅助功能。
"""
import os
import shutil
import contextlib
from pathlib import Path
from typing import Union
@contextlib.contextmanager
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
"""
原子文件写入上下文管理器
原理:
1. 写入 .tmp 临时文件
2. 写入成功后,使用 os.replace 原子替换目标文件
3. 如果失败,自动删除临时文件
Args:
file_path: 目标文件路径
mode: 打开模式 ('w', 'wb' 等)
encoding: 编码
**kwargs: 传给 open() 的其他参数
"""
path = Path(file_path)
# 确保父目录存在
path.parent.mkdir(parents=True, exist_ok=True)
# 临时文件路径
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
yield f
# 确保写入磁盘
f.flush()
os.fsync(f.fileno())
# 原子替换 (Windows下可能需要先删除目标文件但 os.replace 在 Py3.3+ 尽可能原子)
# 注意: Windows 上如果有其他进程占用文件os.replace 可能会失败
os.replace(tmp_path, path)
except Exception as e:
# 清理临时文件
if tmp_path.exists():
try:
os.remove(tmp_path)
except:
pass
raise e
@contextlib.contextmanager
def atomic_save_path(file_path: Union[str, Path]):
"""
提供临时路径用于原子保存 (针对只接受路径的API如Faiss)
Args:
file_path: 最终目标路径
Yields:
tmp_path: 临时文件路径 (str)
"""
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
yield str(tmp_path)
if Path(tmp_path).exists():
os.replace(tmp_path, path)
except Exception as e:
if Path(tmp_path).exists():
try:
os.remove(tmp_path)
except:
pass
raise e

View File

@@ -0,0 +1,89 @@
"""
高效文本匹配工具模块
实现 Aho-Corasick 算法用于多模式匹配。
"""
from typing import List, Dict, Tuple, Set, Any
from collections import deque
class AhoCorasick:
"""
Aho-Corasick 自动机实现高效多模式匹配
"""
def __init__(self):
# next_states[state][char] = next_state
self.next_states: List[Dict[str, int]] = [{}]
# fail[state] = fail_state
self.fail: List[int] = [0]
# output[state] = set of patterns ending at this state
self.output: List[Set[str]] = [set()]
self.patterns: Set[str] = set()
def add_pattern(self, pattern: str):
"""添加模式"""
if not pattern:
return
self.patterns.add(pattern)
state = 0
for char in pattern:
if char not in self.next_states[state]:
new_state = len(self.next_states)
self.next_states[state][char] = new_state
self.next_states.append({})
self.fail.append(0)
self.output.append(set())
state = self.next_states[state][char]
self.output[state].add(pattern)
def build(self):
"""构建失败指针"""
queue = deque()
# 处理第一层
for char, state in self.next_states[0].items():
queue.append(state)
self.fail[state] = 0
while queue:
r = queue.popleft()
for char, s in self.next_states[r].items():
queue.append(s)
# 找到失败路径
state = self.fail[r]
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
self.fail[s] = self.next_states[state].get(char, 0)
# 合并输出
self.output[s].update(self.output[self.fail[s]])
def search(self, text: str) -> List[Tuple[int, str]]:
"""
在文本中搜索所有模式
Returns:
[(结束索引, 匹配到的模式), ...]
"""
state = 0
results = []
for i, char in enumerate(text):
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
state = self.next_states[state].get(char, 0)
for pattern in self.output[state]:
results.append((i, pattern))
return results
def find_all(self, text: str) -> Dict[str, int]:
"""
查找并统计所有模式出现次数
Returns:
{模式: 出现次数}
"""
results = self.search(text)
stats = {}
for _, pattern in results:
stats[pattern] = stats.get(pattern, 0) + 1
return stats

View File

@@ -0,0 +1,189 @@
"""
内存监控模块
提供内存使用监控和预警功能。
"""
import gc
import threading
import time
from typing import Callable, Optional
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
from src.common.logger import get_logger
logger = get_logger("A_Memorix.MemoryMonitor")
class MemoryMonitor:
"""
内存监控器
功能:
- 实时监控内存使用
- 超过阈值时触发警告
- 支持自动垃圾回收
"""
def __init__(
self,
max_memory_mb: int,
warning_threshold: float = 0.9,
check_interval: float = 10.0,
enable_auto_gc: bool = True,
):
"""
初始化内存监控器
Args:
max_memory_mb: 最大内存限制MB
warning_threshold: 警告阈值0-1之间默认0.9表示90%
check_interval: 检查间隔(秒)
enable_auto_gc: 是否启用自动垃圾回收
"""
self.max_memory_mb = max_memory_mb
self.warning_threshold = warning_threshold
self.check_interval = check_interval
self.enable_auto_gc = enable_auto_gc
self._running = False
self._thread: Optional[threading.Thread] = None
self._callbacks: list[Callable[[float, float], None]] = []
def start(self):
"""启动监控"""
if self._running:
logger.warning("内存监控已在运行")
return
if not HAS_PSUTIL:
logger.warning("psutil 未安装,内存监控功能不可用")
return
self._running = True
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
def stop(self):
"""停止监控"""
if not self._running:
return
self._running = False
if self._thread:
self._thread.join(timeout=5.0)
logger.info("内存监控已停止")
def register_callback(self, callback: Callable[[float, float], None]):
"""
注册内存超限回调函数
Args:
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
"""
self._callbacks.append(callback)
def get_current_memory_mb(self) -> float:
"""
获取当前进程内存使用量MB
Returns:
内存使用量MB
"""
if not HAS_PSUTIL:
# 降级方案:使用内置函数
import sys
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def get_memory_usage_ratio(self) -> float:
"""
获取内存使用率
Returns:
使用率0-1之间
"""
current = self.get_current_memory_mb()
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
def _monitor_loop(self):
"""监控循环"""
while self._running:
try:
current_mb = self.get_current_memory_mb()
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
# 检查是否超过阈值
if ratio >= self.warning_threshold:
logger.warning(
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
f"({ratio*100:.1f}%)"
)
# 触发回调
for callback in self._callbacks:
try:
callback(current_mb, self.max_memory_mb)
except Exception as e:
logger.error(f"内存回调执行失败: {e}")
# 自动垃圾回收
if self.enable_auto_gc:
before = self.get_current_memory_mb()
gc.collect()
after = self.get_current_memory_mb()
freed = before - after
if freed > 1: # 释放超过1MB才记录
logger.info(f"垃圾回收释放: {freed:.1f}MB")
# 定期垃圾回收(即使未超限)
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
gc.collect()
except Exception as e:
logger.error(f"内存监控出错: {e}")
# 等待下次检查
time.sleep(self.check_interval)
def __enter__(self):
"""上下文管理器入口"""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.stop()
def get_memory_info() -> dict:
"""
获取系统内存信息
Returns:
内存信息字典
"""
if not HAS_PSUTIL:
return {"error": "psutil 未安装"}
try:
mem = psutil.virtual_memory()
process = psutil.Process()
return {
"system_total_gb": mem.total / 1024 / 1024 / 1024,
"system_available_gb": mem.available / 1024 / 1024 / 1024,
"system_usage_percent": mem.percent,
"process_mb": process.memory_info().rss / 1024 / 1024,
"process_percent": (process.memory_info().rss / mem.total) * 100,
}
except Exception as e:
return {"error": str(e)}

View File

@@ -0,0 +1,165 @@
"""Shared path-fallback helpers for search post-processing."""
from __future__ import annotations
import hashlib
from typing import Any, Dict, List, Optional, Sequence, Tuple
from ..retrieval.dual_path import RetrievalResult
def extract_entities(query: str, graph_store: Any) -> List[str]:
"""Extract up to two graph nodes from a query using n-gram matching."""
if not graph_store:
return []
text = str(query or "").strip()
if not text:
return []
# Keep the heuristic aligned with previous legacy behavior.
tokens = (
text.replace("?", " ")
.replace("!", " ")
.replace(".", " ")
.split()
)
if not tokens:
return []
found_entities = set()
skip_indices = set()
max_n = min(4, len(tokens))
for size in range(max_n, 0, -1):
for i in range(len(tokens) - size + 1):
if any(idx in skip_indices for idx in range(i, i + size)):
continue
span = " ".join(tokens[i : i + size])
matched_node = graph_store.find_node(span, ignore_case=True)
if not matched_node:
continue
found_entities.add(matched_node)
for idx in range(i, i + size):
skip_indices.add(idx)
return list(found_entities)
def find_paths_between_entities(
start_node: str,
end_node: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Find and enrich indirect paths between two nodes."""
if not graph_store or not metadata_store:
return []
try:
paths = graph_store.find_paths(
start_node,
end_node,
max_depth=max_depth,
max_paths=max_paths,
)
except Exception:
return []
if not paths:
return []
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
formatted_paths: List[Dict[str, Any]] = []
for path_nodes in paths:
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
continue
path_desc: List[str] = []
for i in range(len(path_nodes) - 1):
u = str(path_nodes[i])
v = str(path_nodes[i + 1])
cache_key = tuple(sorted((u, v)))
if cache_key in edge_cache:
pred, direction = edge_cache[cache_key]
else:
pred = "related"
direction = "->"
rels = metadata_store.get_relations(subject=u, object=v)
if not rels:
rels = metadata_store.get_relations(subject=v, object=u)
direction = "<-"
if rels:
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
pred = str(best_rel.get("predicate", "related") or "related")
edge_cache[cache_key] = (pred, direction)
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
path_desc.append(step_str)
full_path_str = str(path_nodes[0])
for i, step in enumerate(path_desc):
full_path_str += f" {step} {path_nodes[i + 1]}"
formatted_paths.append(
{
"nodes": list(path_nodes),
"description": full_path_str,
}
)
return formatted_paths
def find_paths_from_query(
query: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Extract entities from query and resolve indirect paths."""
entities = extract_entities(query, graph_store)
if len(entities) != 2:
return []
return find_paths_between_entities(
entities[0],
entities[1],
graph_store,
metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
"""Convert path results into retrieval results for the unified pipeline."""
converted: List[RetrievalResult] = []
for item in paths:
description = str(item.get("description", "")).strip()
if not description:
continue
hash_seed = description.encode("utf-8")
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
converted.append(
RetrievalResult(
hash_value=path_hash,
content=f"[Indirect Relation] {description}",
score=0.95,
result_type="relation",
source="graph_path",
metadata={
"source": "graph_path",
"is_indirect": True,
"nodes": list(item.get("nodes", [])),
},
)
)
return converted

View File

@@ -0,0 +1,495 @@
"""
人物画像服务
主链路:
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
"""
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from ..embedding import EmbeddingAPIAdapter
from ..retrieval import (
DualPathRetriever,
RetrievalStrategy,
DualPathRetrieverConfig,
SparseBM25Config,
FusionConfig,
GraphRelationRecallConfig,
)
from ..storage import MetadataStore, GraphStore, VectorStore
logger = get_logger("A_Memorix.PersonProfileService")
class PersonProfileService:
"""人物画像聚合/刷新服务。"""
def __init__(
self,
metadata_store: MetadataStore,
graph_store: Optional[GraphStore] = None,
vector_store: Optional[VectorStore] = None,
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
sparse_index: Any = None,
plugin_config: Optional[dict] = None,
retriever: Optional[DualPathRetriever] = None,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
self.sparse_index = sparse_index
self.plugin_config = plugin_config or {}
self.retriever = retriever or self._build_retriever()
def _cfg(self, key: str, default: Any = None) -> Any:
"""读取嵌套配置。"""
if not isinstance(self.plugin_config, dict):
return default
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_retriever(self) -> Optional[DualPathRetriever]:
"""按需构建检索器(无依赖时返回 None"""
if not all(
[
self.vector_store is not None,
self.graph_store is not None,
self.metadata_store is not None,
self.embedding_manager is not None,
]
):
return None
try:
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
if not isinstance(sparse_cfg_raw, dict):
sparse_cfg_raw = {}
if not isinstance(fusion_cfg_raw, dict):
fusion_cfg_raw = {}
if not isinstance(graph_recall_cfg_raw, dict):
graph_recall_cfg_raw = {}
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
fusion_cfg = FusionConfig(**fusion_cfg_raw)
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
config = DualPathRetrieverConfig(
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
alpha=float(self._cfg("retrieval.alpha", 0.5)),
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
debug=bool(self._cfg("advanced.debug", False)),
sparse=sparse_cfg,
fusion=fusion_cfg,
graph_recall=graph_recall_cfg,
)
return DualPathRetriever(
vector_store=self.vector_store,
graph_store=self.graph_store,
metadata_store=self.metadata_store,
embedding_manager=self.embedding_manager,
sparse_index=self.sparse_index,
config=config,
)
except Exception as e:
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
return None
@staticmethod
def resolve_person_id(identifier: str) -> str:
"""按 person_id 或姓名/别名解析 person_id。"""
if not identifier:
return ""
key = str(identifier).strip()
if not key:
return ""
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
return key.lower()
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where(PersonInfo.group_nick_name.contains(key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
return ""
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
if isinstance(raw_value, list):
items = raw_value
else:
try:
items = json.loads(raw_value)
except Exception:
return []
names: List[str] = []
for item in items:
if isinstance(item, dict):
value = str(item.get("group_nick_name", "")).strip()
if value:
names.append(value)
elif isinstance(item, str):
value = item.strip()
if value:
names.append(value)
return names
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
try:
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
except Exception:
return []
if not isinstance(values, list):
return []
traits: List[str] = []
for item in values:
text = str(item).strip()
if not text:
continue
if ":" in text:
parts = text.split(":")
if len(parts) >= 3:
content = ":".join(parts[1:-1]).strip()
if content:
traits.append(content)
continue
traits.append(text)
return traits[:10]
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
"""获取人物别名集合、主展示名、记忆特征。"""
aliases: List[str] = []
primary_name = ""
memory_traits: List[str] = []
if not person_id:
return aliases, primary_name, memory_traits
try:
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not record:
return aliases, primary_name, memory_traits
person_name = str(getattr(record, "person_name", "") or "").strip()
nickname = str(getattr(record, "nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
candidates = [person_name, nickname] + group_nicks
seen = set()
for item in candidates:
norm = str(item or "").strip()
if not norm or norm in seen:
continue
seen.add(norm)
aliases.append(norm)
except Exception as e:
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
return aliases, primary_name, memory_traits
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
relation_by_hash: Dict[str, Dict[str, Any]] = {}
for alias in aliases:
for rel in self.metadata_store.get_relations(subject=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
for rel in self.metadata_store.get_relations(object=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
relations = list(relation_by_hash.values())
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
relations = relations[: max(1, int(limit))]
edges: List[Dict[str, Any]] = []
for rel in relations:
edges.append(
{
"hash": str(rel.get("hash", "")),
"subject": str(rel.get("subject", "")),
"predicate": str(rel.get("predicate", "")),
"object": str(rel.get("object", "")),
"confidence": float(rel.get("confidence", 1.0) or 1.0),
}
)
return edges
async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]:
alias_queries = [a for a in aliases if a]
if not alias_queries:
return []
if self.retriever is None:
# 回退:无检索器时只做简单内容匹配
fallback: List[Dict[str, Any]] = []
seen_hash = set()
for alias in alias_queries:
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
h = str(para.get("hash", ""))
if not h or h in seen_hash:
continue
seen_hash.add(h)
fallback.append(
{
"hash": h,
"type": "paragraph",
"score": 0.0,
"content": str(para.get("content", ""))[:180],
"metadata": {},
}
)
return fallback[:top_k]
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
seen_hash = set()
evidence: List[Dict[str, Any]] = []
for alias in alias_queries:
try:
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
except Exception as e:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue
for item in results:
h = str(getattr(item, "hash_value", "") or "")
if not h or h in seen_hash:
continue
seen_hash.add(h)
evidence.append(
{
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"metadata": dict(getattr(item, "metadata", {}) or {}),
}
)
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return evidence[:top_k]
def _build_profile_text(
self,
person_id: str,
primary_name: str,
aliases: List[str],
relation_edges: List[Dict[str, Any]],
vector_evidence: List[Dict[str, Any]],
memory_traits: List[str],
) -> str:
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
lines: List[str] = []
lines.append(f"人物ID: {person_id}")
if primary_name:
lines.append(f"主称呼: {primary_name}")
if aliases:
lines.append(f"别名: {', '.join(aliases[:8])}")
if memory_traits:
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
if relation_edges:
lines.append("关系证据:")
for rel in relation_edges[:6]:
s = rel.get("subject", "")
p = rel.get("predicate", "")
o = rel.get("object", "")
conf = float(rel.get("confidence", 0.0))
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
if vector_evidence:
lines.append("向量证据摘要:")
for item in vector_evidence[:4]:
content = str(item.get("content", "")).strip()
if content:
lines.append(f"- {content}")
if len(lines) <= 2:
lines.append("暂无足够证据形成稳定画像。")
return "\n".join(lines)
@staticmethod
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
if not snapshot:
return True
now = time.time()
expires_at = snapshot.get("expires_at")
if expires_at is not None:
try:
return now >= float(expires_at)
except Exception:
return True
updated_at = float(snapshot.get("updated_at") or 0.0)
return (now - updated_at) >= ttl_seconds
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
"""将手工覆盖并入画像结果(覆盖 profile_text同时保留 auto_profile_text"""
payload = dict(profile_payload or {})
auto_text = str(payload.get("profile_text", "") or "")
payload["auto_profile_text"] = auto_text
payload["has_manual_override"] = False
payload["manual_override_text"] = ""
payload["override_updated_at"] = None
payload["override_updated_by"] = ""
payload["profile_source"] = "auto_snapshot"
if not person_id or self.metadata_store is None:
return payload
try:
override = self.metadata_store.get_person_profile_override(person_id)
except Exception as e:
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
return payload
if not override:
return payload
manual_text = str(override.get("override_text", "") or "").strip()
if not manual_text:
return payload
payload["has_manual_override"] = True
payload["manual_override_text"] = manual_text
payload["override_updated_at"] = override.get("updated_at")
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
payload["profile_text"] = manual_text
payload["profile_source"] = "manual_override"
return payload
async def query_person_profile(
self,
person_id: str = "",
person_keyword: str = "",
top_k: int = 12,
ttl_seconds: float = 6 * 3600,
force_refresh: bool = False,
source_note: str = "",
) -> Dict[str, Any]:
"""查询或刷新人物画像。"""
pid = str(person_id or "").strip()
if not pid and person_keyword:
pid = self.resolve_person_id(person_keyword)
if not pid:
return {
"success": False,
"error": "person_id 无效,且未能通过别名解析",
}
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
aliases, primary_name, _ = self.get_person_aliases(pid)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": True,
**(latest or {}),
}
if aliases and not payload.get("aliases"):
payload["aliases"] = aliases
return {
**self._apply_manual_override(pid, payload),
}
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
if not aliases and person_keyword:
aliases = [person_keyword.strip()]
primary_name = person_keyword.strip()
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k))
evidence_ids = [
str(item.get("hash", ""))
for item in (relation_edges + vector_evidence)
if str(item.get("hash", "")).strip()
]
dedup_ids: List[str] = []
seen = set()
for item in evidence_ids:
if item in seen:
continue
seen.add(item)
dedup_ids.append(item)
profile_text = self._build_profile_text(
person_id=pid,
primary_name=primary_name,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
memory_traits=memory_traits,
)
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
snapshot = self.metadata_store.upsert_person_profile_snapshot(
person_id=pid,
profile_text=profile_text,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
evidence_ids=dedup_ids,
expires_at=expires_at,
source_note=source_note,
)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": False,
**snapshot,
}
return {
**self._apply_manual_override(pid, payload),
}
@staticmethod
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
"""格式化给 replyer 的注入块。"""
if not profile or not profile.get("success"):
return ""
text = str(profile.get("profile_text", "") or "").strip()
if not text:
return ""
return (
"【人物画像-内部参考】\n"
f"{text}\n"
"仅供内部推理,不要向用户逐字复述。"
)

View File

@@ -0,0 +1,27 @@
"""Plugin ID matching policy for A_Memorix."""
from __future__ import annotations
from typing import Any
class PluginIdPolicy:
"""Centralized plugin id normalization/matching policy."""
CANONICAL_ID = "a_memorix"
@classmethod
def normalize(cls, plugin_id: Any) -> str:
if not isinstance(plugin_id, str):
return ""
return plugin_id.strip().lower()
@classmethod
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
normalized = cls.normalize(plugin_id)
if not normalized:
return False
if normalized == cls.CANONICAL_ID:
return True
return normalized.split(".")[-1] == cls.CANONICAL_ID

View File

@@ -0,0 +1,344 @@
"""
向量量化工具模块
提供向量量化与反量化功能,用于压缩存储空间。
"""
import numpy as np
from enum import Enum
from typing import Tuple, Union
from src.common.logger import get_logger
logger = get_logger("A_Memorix.Quantization")
class QuantizationType(Enum):
"""量化类型枚举"""
FLOAT32 = "float32" # 无量化
INT8 = "int8" # 标量量化8位整数
PQ = "pq" # 乘积量化Product Quantization
def quantize_vector(
vector: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化向量
Args:
vector: 输入向量float32
quant_type: 量化类型
Returns:
量化后的向量:
- INT8: int8向量
- PQ: (编码向量, 聚类中心) 元组
"""
if quant_type == QuantizationType.FLOAT32:
return vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_quantize_int8(vector)
elif quant_type == QuantizationType.PQ:
return _product_quantize(vector)
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_vector(
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
quant_type: QuantizationType = QuantizationType.INT8,
original_shape: Tuple[int, ...] = None,
) -> np.ndarray:
"""
反量化向量
Args:
quantized_vector: 量化后的向量
quant_type: 量化类型
original_shape: 原始向量形状用于PQ
Returns:
反量化后的向量float32
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_dequantize_int8(quantized_vector)
elif quant_type == QuantizationType.PQ:
if not isinstance(quantized_vector, tuple):
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
return _product_dequantize(quantized_vector[0], quantized_vector[1])
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
"""
标量量化float32 -> int8
将向量归一化到 [0, 255] 范围,然后映射到 int8
Args:
vector: 输入向量
Returns:
量化后的 int8 向量
"""
# 计算最小最大值
min_val = np.min(vector)
max_val = np.max(vector)
# 避免除零
if max_val == min_val:
return np.zeros_like(vector, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (vector - min_val) / (max_val - min_val) * 255
# 映射到 [-128, 127] 并转换为 int8
# np.round might return float, minus 128 then cast
quantized = np.round(normalized - 128.0).astype(np.int8)
# 存储归一化参数(用于反量化)
# 在实际存储中,这些参数需要单独保存
# 这里为了简单,我们使用一个全局字典来模拟
if not hasattr(_scalar_quantize_int8, "_params"):
_scalar_quantize_int8._params = {}
vector_id = id(vector)
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
return quantized
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
"""
标量反量化int8 -> float32
Args:
quantized: 量化后的 int8 向量
Returns:
反量化后的 float32 向量
"""
# 计算归一化参数(如果提供了)
# 在实际应用中min_val 和 max_val 应该被保存
if not hasattr(_scalar_dequantize_int8, "_params"):
# 默认假设范围是 [-1, 1]
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
return (quantized.astype(np.float32) + 128.0) / 255.0
def quantize_matrix(
matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化矩阵(批量量化向量)
Args:
matrix: 输入矩阵N x D每行是一个向量
quant_type: 量化类型
Returns:
量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 对整个矩阵进行全局归一化
min_val = np.min(matrix)
max_val = np.max(matrix)
if max_val == min_val:
return np.zeros_like(matrix, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (matrix - min_val) / (max_val - min_val) * 255
quantized = np.round(normalized).astype(np.int8)
return quantized
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_matrix(
quantized_matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
min_val: float = None,
max_val: float = None,
) -> np.ndarray:
"""
反量化矩阵
Args:
quantized_matrix: 量化后的矩阵
quant_type: 量化类型
min_val: 归一化最小值int8反量化需要
max_val: 归一化最大值int8反量化需要
Returns:
反量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 使用提供的归一化参数反量化
if min_val is None or max_val is None:
# 默认假设范围是 [0, 255] -> [-1, 1]
return quantized_matrix.astype(np.float32) / 127.0
else:
# 恢复到原始范围
normalized = quantized_matrix.astype(np.float32) / 255.0
return normalized * (max_val - min_val) + min_val
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def estimate_memory_reduction(
num_vectors: int,
dimension: int,
from_type: QuantizationType,
to_type: QuantizationType,
) -> Tuple[float, float]:
"""
估算内存节省量
Args:
num_vectors: 向量数量
dimension: 向量维度
from_type: 原始量化类型
to_type: 目标量化类型
Returns:
(原始大小MB, 量化后大小MB, 节省比例)
"""
# 计算每个向量占用的字节数
bytes_per_element = {
QuantizationType.FLOAT32: 4,
QuantizationType.INT8: 1,
QuantizationType.PQ: 0.25, # 假设压缩到1/4
}
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
original_mb = original_bytes / 1024 / 1024
quantized_mb = quantized_bytes / 1024 / 1024
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
return original_mb, quantized_mb, reduction_ratio
def estimate_compression_stats(
num_vectors: int,
dimension: int,
quant_type: QuantizationType,
) -> dict:
"""
估算压缩统计信息
Args:
num_vectors: 向量数量
dimension: 向量维度
quant_type: 量化类型
Returns:
统计信息字典
"""
original_mb, quantized_mb, ratio = estimate_memory_reduction(
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
)
return {
"num_vectors": num_vectors,
"dimension": dimension,
"quantization_type": quant_type.value,
"original_size_mb": round(original_mb, 2),
"quantized_size_mb": round(quantized_mb, 2),
"saved_mb": round(original_mb - quantized_mb, 2),
"compression_ratio": round(ratio * 100, 2),
}
def _product_quantize(
vector: np.ndarray, m: int = 8, k: int = 256
) -> Tuple[np.ndarray, np.ndarray]:
"""
乘积量化 (PQ) 简化实现
Args:
vector: 输入向量 (D,)
m: 子空间数量
k: 每个子空间的聚类中心数
Returns:
(编码后的向量, 聚类中心)
"""
d = vector.shape[0]
if d % m != 0:
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
ds = d // m # 子空间维度
codes = np.zeros(m, dtype=np.uint8)
centroids = np.zeros((m, k, ds), dtype=np.float32)
# 这里采用一种简化的 PQ不进行 K-means 训练,
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
# 为了演示,我们直接将子空间切分为 k 份进行量化
for i in range(m):
sub_vec = vector[i * ds : (i + 1) * ds]
# 简化:假定每个子空间的取值范围并划分
# 实际 PQ 应使用 k-means 产生的 centroids
# 这里为演示创建一个随机 codebook 并找到最接近的核心
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
if sub_max == sub_min:
linspace = np.zeros(k)
else:
linspace = np.linspace(sub_min, sub_max, k)
for j in range(k):
centroids[i, j, :] = linspace[j]
# 编码:这里简化为取子空间均值找最接近的 centroid
sub_mean = np.mean(sub_vec)
code = np.argmin(np.abs(linspace - sub_mean))
codes[i] = code
return codes, centroids
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""
PQ 反量化
Args:
codes: 编码向量 (M,)
centroids: 聚类中心 (M, K, DS)
Returns:
恢复后的向量 (D,)
"""
m, k, ds = centroids.shape
vector = np.zeros(m * ds, dtype=np.float32)
for i in range(m):
code = codes[i]
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
return vector

View File

@@ -0,0 +1,121 @@
"""关系查询规格解析工具。"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class RelationQuerySpec:
raw: str
is_structured: bool
subject: Optional[str]
predicate: Optional[str]
object: Optional[str]
error: Optional[str] = None
_NATURAL_LANGUAGE_PATTERN = re.compile(
r"(^\s*(what|who|which|how|why|when|where)\b|"
r"\?||"
r"\b(relation|related|between)\b|"
r"(什么关系|有哪些关系|之间|关联))",
re.IGNORECASE,
)
def _looks_like_natural_language(raw: str) -> bool:
text = str(raw or "").strip()
if not text:
return False
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
raw = str(relation_spec or "").strip()
if not raw:
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
error="empty",
)
if "|" in raw:
parts = [p.strip() for p in raw.split("|")]
if len(parts) < 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_pipe_format",
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0] or None,
predicate=parts[1] or None,
object=parts[2] if len(parts) > 2 and parts[2] else None,
)
if "->" in raw:
parts = [p.strip() for p in raw.split("->") if p.strip()]
if len(parts) >= 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
if len(parts) == 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=None,
object=parts[1],
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_arrow_format",
)
if _looks_like_natural_language(raw):
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
# 两词形式过于模糊,不再视为结构化关系查询。
parts = raw.split()
if len(parts) == 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)

View File

@@ -0,0 +1,164 @@
"""
统一关系写入与关系向量化服务。
规则:
1. 元数据是主数据源,向量是从索引。
2. 关系先写 metadata再写向量。
3. 向量失败不回滚 metadata依赖状态机与回填任务修复。
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RelationWriteService")
@dataclass
class RelationWriteResult:
hash_value: str
vector_written: bool
vector_already_exists: bool
vector_state: str
class RelationWriteService:
"""关系写入收口服务。"""
ERROR_MAX_LEN = 500
def __init__(
self,
metadata_store: Any,
graph_store: Any,
vector_store: Any,
embedding_manager: Any,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
@staticmethod
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
s = str(subject or "").strip()
p = str(predicate or "").strip()
o = str(obj or "").strip()
# 双表达:兼容关键词检索与自然语言问句
return f"{s} {p} {o}\n{s}{o}的关系是{p}"
async def ensure_relation_vector(
self,
hash_value: str,
subject: str,
predicate: str,
obj: str,
*,
max_error_len: int = ERROR_MAX_LEN,
) -> RelationWriteResult:
"""
为已有关系确保向量存在并更新状态。
"""
if hash_value in self.vector_store:
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
self.metadata_store.set_relation_vector_state(hash_value, "pending")
try:
vector_text = self.build_relation_vector_text(subject, predicate, obj)
embedding = await self.embedding_manager.encode(vector_text)
self.vector_store.add(
vectors=embedding.reshape(1, -1),
ids=[hash_value],
)
self.metadata_store.set_relation_vector_state(hash_value, "ready")
logger.info(
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
hash_value[:16],
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=True,
vector_already_exists=False,
vector_state="ready",
)
except ValueError:
# 向量已存在冲突,按成功处理
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
except Exception as e:
err = str(e)[:max_error_len]
self.metadata_store.set_relation_vector_state(
hash_value,
"failed",
error=err,
bump_retry=True,
)
logger.warning(
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
hash_value[:16],
err,
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=False,
vector_state="failed",
)
async def upsert_relation_with_vector(
self,
subject: str,
predicate: str,
obj: str,
confidence: float = 1.0,
source_paragraph: str = "",
metadata: Optional[Dict[str, Any]] = None,
*,
write_vector: bool = True,
) -> RelationWriteResult:
"""
统一关系写入:
1) 写 metadata relation
2) 写 graph edge relation_hash
3) 按需写 relation vector
"""
rel_hash = self.metadata_store.add_relation(
subject=subject,
predicate=predicate,
obj=obj,
confidence=confidence,
source_paragraph=source_paragraph,
metadata=metadata or {},
)
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
if not write_vector:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
return RelationWriteResult(
hash_value=rel_hash,
vector_written=False,
vector_already_exists=False,
vector_state="none",
)
return await self.ensure_relation_vector(
hash_value=rel_hash,
subject=subject,
predicate=predicate,
obj=obj,
)

View File

@@ -0,0 +1,197 @@
"""Runtime self-check helpers for A_Memorix."""
from __future__ import annotations
import time
from typing import Any, Dict, Optional
import numpy as np
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RuntimeSelfCheck")
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
getter = getattr(config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_report(
*,
ok: bool,
code: str,
message: str,
configured_dimension: int,
vector_store_dimension: int,
detected_dimension: int,
encoded_dimension: int,
elapsed_ms: float,
sample_text: str,
) -> Dict[str, Any]:
return {
"ok": bool(ok),
"code": str(code or "").strip(),
"message": str(message or "").strip(),
"configured_dimension": int(configured_dimension),
"vector_store_dimension": int(vector_store_dimension),
"detected_dimension": int(detected_dimension),
"encoded_dimension": int(encoded_dimension),
"elapsed_ms": float(elapsed_ms),
"sample_text": str(sample_text or ""),
"checked_at": time.time(),
}
async def run_embedding_runtime_self_check(
*,
config: Any,
vector_store: Optional[Any],
embedding_manager: Optional[Any],
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Probe the real embedding path and compare dimensions with runtime storage."""
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
if vector_store is None or embedding_manager is None:
return _build_report(
ok=False,
code="runtime_components_missing",
message="vector_store 或 embedding_manager 未初始化",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
start = time.perf_counter()
detected_dimension = 0
encoded_dimension = 0
try:
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
encoded = await embedding_manager.encode(sample_text)
if isinstance(encoded, np.ndarray):
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
else:
encoded_dimension = len(encoded) if encoded is not None else 0
except Exception as exc:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.warning("embedding runtime self-check failed: %s", exc)
return _build_report(
ok=False,
code="embedding_probe_failed",
message=f"embedding probe failed: {exc}",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
elapsed_ms = (time.perf_counter() - start) * 1000.0
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
if expected_dimension <= 0:
return _build_report(
ok=False,
code="invalid_expected_dimension",
message="无法确定期望 embedding 维度",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
if encoded_dimension != expected_dimension:
msg = (
"embedding 真实输出维度与当前向量存储不一致: "
f"expected={expected_dimension}, encoded={encoded_dimension}"
)
logger.error(msg)
return _build_report(
ok=False,
code="embedding_dimension_mismatch",
message=msg,
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
return _build_report(
ok=True,
code="ok",
message="embedding runtime self-check passed",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
async def ensure_runtime_self_check(
plugin_or_config: Any,
*,
force: bool = False,
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Run or reuse cached runtime self-check report."""
if plugin_or_config is None:
return _build_report(
ok=False,
code="missing_plugin_or_config",
message="plugin/config unavailable",
configured_dimension=0,
vector_store_dimension=0,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
if isinstance(cache, dict) and cache and not force:
return cache
report = await run_embedding_runtime_self_check(
config=getattr(plugin_or_config, "config", plugin_or_config),
vector_store=getattr(plugin_or_config, "vector_store", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("vector_store"),
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("embedding_manager"),
sample_text=sample_text,
)
try:
setattr(plugin_or_config, "_runtime_self_check_report", report)
except Exception:
pass
return report

View File

@@ -0,0 +1,90 @@
"""Post-processing helpers for unified search execution."""
from __future__ import annotations
from typing import Any, List, Tuple
from .path_fallback_service import find_paths_from_query, to_retrieval_results
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
"""Deduplicate results by hash/content while preserving at least one entry."""
if not results:
return [], 0
unique_results: List[Any] = []
seen_hashes = set()
seen_contents = set()
for item in results:
content = str(getattr(item, "content", "") or "").strip()
if not content:
continue
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
if hash_value in seen_hashes:
continue
is_dup = False
for seen in seen_contents:
if content in seen or seen in content:
is_dup = True
break
if is_dup:
continue
seen_hashes.add(hash_value)
seen_contents.add(content)
unique_results.append(item)
if not unique_results:
unique_results.append(results[0])
removed_count = max(0, len(results) - len(unique_results))
return unique_results, removed_count
def maybe_apply_smart_path_fallback(
*,
query: str,
results: List[Any],
graph_store: Any,
metadata_store: Any,
enabled: bool,
threshold: float,
max_depth: int = 3,
max_paths: int = 5,
) -> Tuple[List[Any], bool, int]:
"""Append indirect relation paths when semantic results are weak."""
if not enabled or not str(query or "").strip():
return results, False, 0
if graph_store is None or metadata_store is None:
return results, False, 0
max_score = 0.0
if results:
try:
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
except Exception:
max_score = 0.0
if max_score >= float(threshold):
return results, False, 0
paths = find_paths_from_query(
query=query,
graph_store=graph_store,
metadata_store=metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
if not paths:
return results, False, 0
path_results = to_retrieval_results(paths)
if not path_results:
return results, False, 0
merged = list(path_results) + list(results)
return merged, True, len(path_results)

View File

@@ -0,0 +1,170 @@
"""
时间解析工具。
约束:
1. 查询参数Action/Command/Tool仅接受结构化绝对时间
- YYYY/MM/DD
- YYYY/MM/DD HH:mm
2. 入库时允许更宽松格式含时间戳、YYYY-MM-DD 等)。
"""
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
_INGEST_FORMATS = [
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
"%Y/%m/%d",
"%Y-%m-%d",
]
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
text = str(value).strip()
if not text:
raise ValueError("时间不能为空")
if _QUERY_DATE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d")
if is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
if _QUERY_MINUTE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
return dt.timestamp()
raise ValueError(
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
)
def parse_query_time_range(
time_from: Optional[str],
time_to: Optional[str],
) -> Tuple[Optional[float], Optional[float]]:
"""解析查询窗口并验证区间。"""
ts_from = (
parse_query_datetime_to_timestamp(time_from, is_end=False)
if time_from
else None
)
ts_to = (
parse_query_datetime_to_timestamp(time_to, is_end=True)
if time_to
else None
)
if ts_from is not None and ts_to is not None and ts_from > ts_to:
raise ValueError("time_from 不能晚于 time_to")
return ts_from, ts_to
def parse_ingest_datetime_to_timestamp(
value: Any,
is_end: bool = False,
) -> Optional[float]:
"""解析入库时间,允许 timestamp/常见字符串格式。"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return None
if _NUMERIC_RE.fullmatch(text):
return float(text)
for fmt in _INGEST_FORMATS:
try:
dt = datetime.strptime(text, fmt)
if fmt in _INGEST_DATE_FORMATS and is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
except ValueError:
continue
raise ValueError(f"无法解析时间: {text}")
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""归一化 time_meta 到存储层字段。"""
if not time_meta:
return {}
normalized: Dict[str, Any] = {}
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
event_start = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_start"),
is_end=False,
)
event_end = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_end"),
is_end=True,
)
time_range = time_meta.get("time_range")
if (
isinstance(time_range, (list, tuple))
and len(time_range) == 2
):
if event_start is None:
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
if event_end is None:
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
if event_start is not None and event_end is not None and event_start > event_end:
raise ValueError("event_time_start 不能晚于 event_time_end")
if event_time is not None:
normalized["event_time"] = event_time
if event_start is not None:
normalized["event_time_start"] = event_start
if event_end is not None:
normalized["event_time_end"] = event_end
granularity = time_meta.get("time_granularity")
if granularity:
normalized["time_granularity"] = str(granularity)
else:
raw_time_values = [
time_meta.get("event_time"),
time_meta.get("event_time_start"),
time_meta.get("event_time_end"),
]
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
normalized["time_granularity"] = "minute" if has_minute else "day"
confidence = time_meta.get("time_confidence")
if confidence is not None:
normalized["time_confidence"] = float(confidence)
return normalized
def format_timestamp(ts: Optional[float]) -> Optional[str]:
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
if ts is None:
return None
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")