feat:新增 A_Memorix 记忆插件

引入 A_Memorix 插件(v2.0.0)——一个轻量级的长期记忆提供器。新增插件清单(manifest)和入口(AMemorixPlugin),并提供完整的核心能力:嵌入(基于哈希的 EmbeddingAPIAdapter、EmbeddingManager、预设)、检索(双路径检索器、PageRank、图关系召回、BM25 稀疏索引、阈值与融合配置)、存储与元数据层,以及大量实用工具和迁移/转换脚本。同时更新 .gitignore 以允许 /plugins/A_memorix。该变更为在宿主应用中实现统一的记忆摄取、检索、分析与维护奠定了基础。
This commit is contained in:
DawnARC
2026-03-18 21:33:15 +08:00
parent a5a6d2cb26
commit 999e7246e2
48 changed files with 17070 additions and 0 deletions

1
.gitignore vendored
View File

@@ -339,6 +339,7 @@ run_pet.bat
/plugins/*
!/plugins
!/plugins/A_memorix
!/plugins/hello_world_plugin
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin

View File

@@ -0,0 +1,12 @@
"""
A_Memorix - 轻量级知识库插件
完全独立的记忆增强系统,优化低资源环境下的知识存储与检索。
"""
__version__ = "2.0.0"
__author__ = "A_Dawn"
from .plugin import AMemorixPlugin
__all__ = ["AMemorixPlugin"]

View File

@@ -0,0 +1,62 @@
{
"manifest_version": 1,
"name": "A_Memorix",
"version": "2.0.0",
"description": "MaiBot SDK 长期记忆插件,负责统一检索、写入、画像与记忆维护。",
"author": {
"name": "A_Dawn"
},
"license": "AGPL-3.0",
"repository_url": "https://github.com/A-Dawn/A_memorix/",
"host_application": {
"min_version": "1.0.0"
},
"keywords": [
"memory",
"knowledge",
"retrieval",
"profile",
"episode"
],
"categories": [
"Memory",
"Data"
],
"plugin_info": {
"is_built_in": false,
"plugin_type": "memory_provider",
"components": [
{
"type": "tool",
"name": "search_memory",
"description": "搜索长期记忆"
},
{
"type": "tool",
"name": "ingest_summary",
"description": "写入聊天摘要"
},
{
"type": "tool",
"name": "ingest_text",
"description": "写入普通长期记忆文本"
},
{
"type": "tool",
"name": "get_person_profile",
"description": "查询人物画像"
},
{
"type": "tool",
"name": "maintain_memory",
"description": "维护记忆关系"
},
{
"type": "tool",
"name": "memory_stats",
"description": "查询记忆统计"
}
]
},
"capabilities": []
}

View File

@@ -0,0 +1,84 @@
"""核心模块 - 存储、嵌入、检索引擎"""
# 存储模块(已实现)
from .storage import (
VectorStore,
GraphStore,
MetadataStore,
ImportStrategy,
KnowledgeType,
parse_import_strategy,
resolve_stored_knowledge_type,
detect_knowledge_type,
select_import_strategy,
should_extract_relations,
get_type_display_name,
)
# 嵌入模块(使用主程序 API
from .embedding import (
EmbeddingAPIAdapter,
create_embedding_api_adapter,
)
# 检索模块(已实现)
from .retrieval import (
DualPathRetriever,
RetrievalStrategy,
RetrievalResult,
DualPathRetrieverConfig,
TemporalQueryOptions,
FusionConfig,
GraphRelationRecallConfig,
RelationIntentConfig,
PersonalizedPageRank,
PageRankConfig,
create_ppr_from_graph,
DynamicThresholdFilter,
ThresholdMethod,
ThresholdConfig,
SparseBM25Index,
SparseBM25Config,
)
from .utils import (
RelationWriteService,
RelationWriteResult,
)
__all__ = [
# Storage
"VectorStore",
"GraphStore",
"MetadataStore",
"ImportStrategy",
"KnowledgeType",
"parse_import_strategy",
"resolve_stored_knowledge_type",
"detect_knowledge_type",
"select_import_strategy",
"should_extract_relations",
"get_type_display_name",
# Embedding
"EmbeddingAPIAdapter",
"create_embedding_api_adapter",
# Retrieval
"DualPathRetriever",
"RetrievalStrategy",
"RetrievalResult",
"DualPathRetrieverConfig",
"TemporalQueryOptions",
"FusionConfig",
"GraphRelationRecallConfig",
"RelationIntentConfig",
"PersonalizedPageRank",
"PageRankConfig",
"create_ppr_from_graph",
"DynamicThresholdFilter",
"ThresholdMethod",
"ThresholdConfig",
"SparseBM25Index",
"SparseBM25Config",
"RelationWriteService",
"RelationWriteResult",
]

View File

@@ -0,0 +1,18 @@
"""嵌入模块 - 向量生成与量化"""
# 新的 API 适配器(主程序嵌入 API
from .api_adapter import (
EmbeddingAPIAdapter,
create_embedding_api_adapter,
)
from ..utils.quantization import QuantizationType
__all__ = [
# 新的 API 适配器(推荐使用)
"EmbeddingAPIAdapter",
"create_embedding_api_adapter",
# 量化
"QuantizationType",
]

View File

@@ -0,0 +1,174 @@
"""
Hash-based embedding adapter used by the SDK runtime.
The plugin runtime cannot import MaiBot host embedding internals from ``src.chat``
or ``src.llm_models``. This adapter keeps A_Memorix self-contained and stable in
Runner by generating deterministic dense vectors locally.
"""
from __future__ import annotations
import hashlib
import re
import time
from typing import List, Optional, Union
import numpy as np
from src.common.logger import get_logger
logger = get_logger("A_Memorix.EmbeddingAPIAdapter")
_TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_\u4e00-\u9fff]{1,}")
class EmbeddingAPIAdapter:
"""Deterministic local embedding adapter."""
def __init__(
self,
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
enable_cache: bool = False,
model_name: str = "hash-v1",
retry_config: Optional[dict] = None,
) -> None:
self.batch_size = max(1, int(batch_size))
self.max_concurrent = max(1, int(max_concurrent))
self.default_dimension = max(32, int(default_dimension))
self.enable_cache = bool(enable_cache)
self.model_name = str(model_name or "hash-v1")
self.retry_config = retry_config or {}
self._dimension: Optional[int] = None
self._dimension_detected = False
self._total_encoded = 0
self._total_errors = 0
self._total_time = 0.0
logger.info(
"EmbeddingAPIAdapter 初始化: model=%s, batch_size=%s, dimension=%s",
self.model_name,
self.batch_size,
self.default_dimension,
)
async def _detect_dimension(self) -> int:
if self._dimension_detected and self._dimension is not None:
return self._dimension
self._dimension = self.default_dimension
self._dimension_detected = True
return self._dimension
@staticmethod
def _tokenize(text: str) -> List[str]:
clean = str(text or "").strip().lower()
if not clean:
return []
return _TOKEN_PATTERN.findall(clean)
@staticmethod
def _feature_weight(token: str) -> float:
digest = hashlib.sha256(token.encode("utf-8")).digest()
return 1.0 + (digest[10] / 255.0) * 0.5
def _encode_single(self, text: str, dimension: int) -> np.ndarray:
vector = np.zeros(dimension, dtype=np.float32)
content = str(text or "").strip()
tokens = self._tokenize(content)
if not tokens and content:
tokens = [content.lower()]
if not tokens:
vector[0] = 1.0
return vector
for token in tokens:
digest = hashlib.sha256(token.encode("utf-8")).digest()
bucket = int.from_bytes(digest[:8], byteorder="big", signed=False) % dimension
sign = 1.0 if digest[8] % 2 == 0 else -1.0
vector[bucket] += sign * self._feature_weight(token)
second_bucket = int.from_bytes(digest[12:20], byteorder="big", signed=False) % dimension
if second_bucket != bucket:
vector[second_bucket] += (sign * 0.35)
norm = float(np.linalg.norm(vector))
if norm > 1e-8:
vector /= norm
else:
vector[0] = 1.0
return vector
async def encode(
self,
texts: Union[str, List[str]],
batch_size: Optional[int] = None,
show_progress: bool = False,
normalize: bool = True,
dimensions: Optional[int] = None,
) -> np.ndarray:
_ = batch_size
_ = show_progress
_ = normalize
started_at = time.time()
target_dimension = max(32, int(dimensions or await self._detect_dimension()))
if isinstance(texts, str):
single_input = True
normalized_texts = [texts]
else:
single_input = False
normalized_texts = list(texts or [])
if not normalized_texts:
empty = np.zeros((0, target_dimension), dtype=np.float32)
return empty[0] if single_input else empty
try:
matrix = np.vstack([self._encode_single(item, target_dimension) for item in normalized_texts])
self._total_encoded += len(normalized_texts)
self._total_time += time.time() - started_at
except Exception:
self._total_errors += 1
raise
return matrix[0] if single_input else matrix
def get_statistics(self) -> dict:
avg_time = self._total_time / self._total_encoded if self._total_encoded else 0.0
return {
"model_name": self.model_name,
"dimension": self._dimension or self.default_dimension,
"total_encoded": self._total_encoded,
"total_errors": self._total_errors,
"total_time": self._total_time,
"avg_time_per_text": avg_time,
}
def __repr__(self) -> str:
return (
f"EmbeddingAPIAdapter(model_name={self.model_name}, "
f"dimension={self._dimension or self.default_dimension}, "
f"total_encoded={self._total_encoded})"
)
def create_embedding_api_adapter(
batch_size: int = 32,
max_concurrent: int = 5,
default_dimension: int = 256,
enable_cache: bool = False,
model_name: str = "hash-v1",
retry_config: Optional[dict] = None,
) -> EmbeddingAPIAdapter:
return EmbeddingAPIAdapter(
batch_size=batch_size,
max_concurrent=max_concurrent,
default_dimension=default_dimension,
enable_cache=enable_cache,
model_name=model_name,
retry_config=retry_config,
)

View File

@@ -0,0 +1,510 @@
"""
嵌入管理器
负责嵌入模型的加载、缓存和批量生成。
"""
import hashlib
import pickle
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Optional, Union, List, Dict, Any, Tuple
import numpy as np
try:
from sentence_transformers import SentenceTransformer
HAS_SENTENCE_TRANSFORMERS = True
except ImportError:
HAS_SENTENCE_TRANSFORMERS = False
from src.common.logger import get_logger
from .presets import (
EmbeddingModelConfig,
get_custom_config,
validate_config_compatibility,
are_models_compatible,
)
from ..utils.quantization import QuantizationType
logger = get_logger("A_Memorix.EmbeddingManager")
class EmbeddingManager:
"""
嵌入管理器
功能:
- 模型加载与缓存
- 批量生成嵌入
- 多线程/多进程支持
- 模型一致性检查
- 智能分批
参数:
config: 模型配置
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
"""
def __init__(
self,
config: EmbeddingModelConfig,
cache_dir: Optional[Union[str, Path]] = None,
enable_cache: bool = True,
num_workers: int = 1,
):
"""
初始化嵌入管理器
Args:
config: 模型配置
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
"""
if not HAS_SENTENCE_TRANSFORMERS:
raise ImportError(
"sentence-transformers 未安装,请安装: "
"pip install sentence-transformers"
)
self.config = config
self.cache_dir = Path(cache_dir) if cache_dir else None
self.enable_cache = enable_cache
self.num_workers = max(1, num_workers)
# 模型实例
self._model: Optional[SentenceTransformer] = None
self._model_lock = threading.Lock()
# 缓存
self._embedding_cache: Dict[str, np.ndarray] = {}
self._cache_lock = threading.Lock()
# 统计
self._total_encoded = 0
self._cache_hits = 0
self._cache_misses = 0
logger.info(
f"EmbeddingManager 初始化: model={config.model_name}, "
f"dim={config.dimension}, workers={num_workers}"
)
def load_model(self) -> None:
"""加载模型(懒加载)"""
if self._model is not None:
return
with self._model_lock:
# 双重检查
if self._model is not None:
return
logger.info(f"正在加载模型: {self.config.model_name}")
try:
# 构建模型参数
model_kwargs = {}
if self.config.cache_dir:
model_kwargs["cache_folder"] = self.config.cache_dir
# 加载模型
self._model = SentenceTransformer(
self.config.model_path,
**model_kwargs,
)
logger.info(f"模型加载成功: {self.config.model_name}")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
def encode(
self,
texts: Union[str, List[str]],
batch_size: Optional[int] = None,
show_progress: bool = False,
normalize: bool = True,
) -> np.ndarray:
"""
生成文本嵌入
Args:
texts: 文本或文本列表
batch_size: 批次大小(默认使用配置值)
show_progress: 是否显示进度条
normalize: 是否归一化
Returns:
嵌入向量 (N x D)
"""
# 确保模型已加载
self.load_model()
# 标准化输入
if isinstance(texts, str):
texts = [texts]
single_input = True
else:
single_input = False
if not texts:
return np.zeros((0, self.config.dimension), dtype=np.float32)
# 使用配置的批次大小
if batch_size is None:
batch_size = self.config.batch_size
# 生成嵌入
try:
embeddings = self._model.encode(
texts,
batch_size=batch_size,
show_progress_bar=show_progress,
normalize_embeddings=normalize and self.config.normalization,
convert_to_numpy=True,
)
# 确保是2D数组
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
self._total_encoded += len(texts)
# 如果是单个输入返回1D数组
if single_input:
return embeddings[0]
return embeddings
except Exception as e:
logger.error(f"生成嵌入失败: {e}")
raise
def encode_batch(
self,
texts: List[str],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
show_progress: bool = False,
) -> np.ndarray:
"""
批量生成嵌入(多线程优化)
Args:
texts: 文本列表
batch_size: 批次大小
num_workers: 工作线程数(默认使用初始化时的值)
show_progress: 是否显示进度条
Returns:
嵌入向量 (N x D)
"""
if not texts:
return np.zeros((0, self.config.dimension), dtype=np.float32)
# 单线程模式
num_workers = num_workers if num_workers is not None else self.num_workers
if num_workers == 1:
return self.encode(texts, batch_size=batch_size, show_progress=show_progress)
# 多线程模式
logger.info(f"使用 {num_workers} 个线程生成 {len(texts)} 个嵌入")
# 分批
batch_size = batch_size or self.config.batch_size
batches = [
texts[i:i + batch_size]
for i in range(0, len(texts), batch_size)
]
# 多线程生成
all_embeddings = []
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# 提交任务
future_to_batch = {
executor.submit(
self.encode,
batch,
batch_size,
False, # 不显示进度条(多线程时会混乱)
): i
for i, batch in enumerate(batches)
}
# 收集结果
for future in as_completed(future_to_batch):
batch_idx = future_to_batch[future]
try:
embeddings = future.result()
all_embeddings.append((batch_idx, embeddings))
except Exception as e:
logger.error(f"批次 {batch_idx} 生成嵌入失败: {e}")
raise
# 按顺序合并
all_embeddings.sort(key=lambda x: x[0])
final_embeddings = np.concatenate([emb for _, emb in all_embeddings], axis=0)
return final_embeddings
def encode_with_cache(
self,
texts: List[str],
batch_size: Optional[int] = None,
show_progress: bool = False,
) -> np.ndarray:
"""
生成嵌入(带缓存)
Args:
texts: 文本列表
batch_size: 批次大小
show_progress: 是否显示进度条
Returns:
嵌入向量 (N x D)
"""
if not self.enable_cache:
return self.encode(texts, batch_size, show_progress)
# 分离缓存命中和未命中的文本
cached_embeddings = []
uncached_texts = []
uncached_indices = []
for i, text in enumerate(texts):
cache_key = self._get_cache_key(text)
with self._cache_lock:
if cache_key in self._embedding_cache:
cached_embeddings.append((i, self._embedding_cache[cache_key]))
self._cache_hits += 1
else:
uncached_texts.append(text)
uncached_indices.append(i)
self._cache_misses += 1
# 生成未缓存的嵌入
if uncached_texts:
new_embeddings = self.encode(
uncached_texts,
batch_size,
show_progress,
)
# 更新缓存
with self._cache_lock:
for text, embedding in zip(uncached_texts, new_embeddings):
cache_key = self._get_cache_key(text)
self._embedding_cache[cache_key] = embedding.copy()
# 合并结果
for idx, embedding in zip(uncached_indices, new_embeddings):
cached_embeddings.append((idx, embedding))
# 按原始顺序排序
cached_embeddings.sort(key=lambda x: x[0])
final_embeddings = np.array([emb for _, emb in cached_embeddings])
return final_embeddings
def save_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
"""
保存缓存到磁盘
Args:
cache_path: 缓存文件路径默认使用cache_dir/embeddings_cache.pkl
"""
if cache_path is None:
if self.cache_dir is None:
raise ValueError("未指定缓存目录")
cache_path = self.cache_dir / "embeddings_cache.pkl"
cache_path = Path(cache_path)
cache_path.parent.mkdir(parents=True, exist_ok=True)
with self._cache_lock:
with open(cache_path, "wb") as f:
pickle.dump(self._embedding_cache, f)
logger.info(f"缓存已保存: {cache_path} ({len(self._embedding_cache)} 条)")
def load_cache(self, cache_path: Optional[Union[str, Path]] = None) -> None:
"""
从磁盘加载缓存
Args:
cache_path: 缓存文件路径默认使用cache_dir/embeddings_cache.pkl
"""
if cache_path is None:
if self.cache_dir is None:
raise ValueError("未指定缓存目录")
cache_path = self.cache_dir / "embeddings_cache.pkl"
cache_path = Path(cache_path)
if not cache_path.exists():
logger.warning(f"缓存文件不存在: {cache_path}")
return
with self._cache_lock:
with open(cache_path, "rb") as f:
self._embedding_cache = pickle.load(f)
logger.info(f"缓存已加载: {cache_path} ({len(self._embedding_cache)} 条)")
def clear_cache(self) -> None:
"""清空缓存"""
with self._cache_lock:
count = len(self._embedding_cache)
self._embedding_cache.clear()
logger.info(f"已清空缓存: {count}")
def check_model_consistency(
self,
stored_embeddings: np.ndarray,
sample_texts: List[str] = None,
) -> Tuple[bool, str]:
"""
检查模型一致性
Args:
stored_embeddings: 存储的嵌入向量
sample_texts: 样本文本(用于重新生成对比)
Returns:
(是否一致, 详细信息)
"""
# 检查维度
if stored_embeddings.shape[1] != self.config.dimension:
return False, f"维度不匹配: 期望 {self.config.dimension}, 实际 {stored_embeddings.shape[1]}"
# 如果提供了样本文本,重新生成并比较
if sample_texts:
try:
new_embeddings = self.encode(sample_texts[:5]) # 只比较前5个
# 计算相似度
similarities = np.dot(
stored_embeddings[:5],
new_embeddings.T,
).diagonal()
# 检查相似度
if np.mean(similarities) < 0.95:
return False, f"模型可能已更改,平均相似度: {np.mean(similarities):.3f}"
return True, f"模型一致,平均相似度: {np.mean(similarities):.3f}"
except Exception as e:
return False, f"一致性检查失败: {e}"
return True, "维度匹配"
def get_model_info(self) -> Dict[str, Any]:
"""
获取模型信息
Returns:
模型信息字典
"""
return {
"model_name": self.config.model_name,
"dimension": self.config.dimension,
"max_seq_length": self.config.max_seq_length,
"batch_size": self.config.batch_size,
"normalization": self.config.normalization,
"pooling": self.config.pooling,
"model_loaded": self._model is not None,
"cache_enabled": self.enable_cache,
"cache_size": len(self._embedding_cache),
"total_encoded": self._total_encoded,
"cache_hits": self._cache_hits,
"cache_misses": self._cache_misses,
}
def get_embedding_dimension(self) -> int:
"""获取嵌入维度"""
return self.config.dimension
def _get_cache_key(self, text: str) -> str:
"""
生成缓存键
Args:
text: 文本内容
Returns:
缓存键SHA256哈希
"""
return hashlib.sha256(text.encode("utf-8")).hexdigest()
@property
def is_model_loaded(self) -> bool:
"""模型是否已加载"""
return self._model is not None
@property
def cache_hit_rate(self) -> float:
"""缓存命中率"""
total = self._cache_hits + self._cache_misses
if total == 0:
return 0.0
return self._cache_hits / total
def __repr__(self) -> str:
return (
f"EmbeddingManager(model={self.config.model_name}, "
f"dim={self.config.dimension}, "
f"loaded={self.is_model_loaded}, "
f"cache={len(self._embedding_cache)})"
)
def create_embedding_manager_from_config(
model_name: str,
model_path: str,
dimension: int,
cache_dir: Optional[Union[str, Path]] = None,
enable_cache: bool = True,
num_workers: int = 1,
**config_kwargs,
) -> EmbeddingManager:
"""
从自定义配置创建嵌入管理器
Args:
model_name: 模型名称
model_path: HuggingFace模型路径
dimension: 输出维度
cache_dir: 缓存目录
enable_cache: 是否启用缓存
num_workers: 工作线程数
**config_kwargs: 其他配置参数
Returns:
嵌入管理器实例
"""
# 创建自定义配置
config = get_custom_config(
model_name=model_name,
model_path=model_path,
dimension=dimension,
cache_dir=cache_dir,
**config_kwargs,
)
# 创建管理器
return EmbeddingManager(
config=config,
cache_dir=cache_dir,
enable_cache=enable_cache,
num_workers=num_workers,
)

View File

@@ -0,0 +1,72 @@
"""
嵌入模型配置模块
"""
from dataclasses import dataclass
from typing import Optional, Dict, Any, Union
from pathlib import Path
@dataclass
class EmbeddingModelConfig:
"""
嵌入模型配置
属性:
model_name: 模型描述名称
model_path: 实际加载路径Local or HF
dimension: 嵌入向量维度
max_seq_length: 最大序列长度
batch_size: 编码批次大小
model_size_mb: 估计显存占用
description: 模型说明
normalization: 是否自动归一化
pooling: 池化策略 (mean, cls, max)
cache_dir: 模型缓存目录
"""
model_name: str
model_path: str
dimension: int
max_seq_length: int = 512
batch_size: int = 32
model_size_mb: int = 100
description: str = ""
normalization: bool = True
pooling: str = "mean"
cache_dir: Optional[Union[str, Path]] = None
def validate_config_compatibility(
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
) -> bool:
"""检查两个配置是否兼容(主要看维度)"""
return config1.dimension == config2.dimension
def are_models_compatible(
config1: EmbeddingModelConfig, config2: EmbeddingModelConfig
) -> bool:
"""检查模型是否完全相同(用于热切换判断)"""
return (
config1.model_path == config2.model_path
and config1.dimension == config2.dimension
and config1.pooling == config2.pooling
)
def get_custom_config(
model_name: str,
model_path: str,
dimension: int,
cache_dir: Optional[Union[str, Path]] = None,
**kwargs,
) -> EmbeddingModelConfig:
"""创建自定义模型配置"""
return EmbeddingModelConfig(
model_name=model_name,
model_path=model_path,
dimension=dimension,
cache_dir=cache_dir,
**kwargs,
)

View File

@@ -0,0 +1,54 @@
"""检索模块 - 双路检索与排序"""
from .dual_path import (
DualPathRetriever,
RetrievalStrategy,
RetrievalResult,
DualPathRetrieverConfig,
TemporalQueryOptions,
FusionConfig,
RelationIntentConfig,
)
from .pagerank import (
PersonalizedPageRank,
PageRankConfig,
create_ppr_from_graph,
)
from .threshold import (
DynamicThresholdFilter,
ThresholdMethod,
ThresholdConfig,
)
from .sparse_bm25 import (
SparseBM25Index,
SparseBM25Config,
)
from .graph_relation_recall import (
GraphRelationRecallConfig,
GraphRelationRecallService,
)
__all__ = [
# DualPathRetriever
"DualPathRetriever",
"RetrievalStrategy",
"RetrievalResult",
"DualPathRetrieverConfig",
"TemporalQueryOptions",
"FusionConfig",
"RelationIntentConfig",
# PersonalizedPageRank
"PersonalizedPageRank",
"PageRankConfig",
"create_ppr_from_graph",
# DynamicThresholdFilter
"DynamicThresholdFilter",
"ThresholdMethod",
"ThresholdConfig",
# Sparse BM25
"SparseBM25Index",
"SparseBM25Config",
# Graph relation recall
"GraphRelationRecallConfig",
"GraphRelationRecallService",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,272 @@
"""Graph-assisted relation candidate recall for relation-oriented queries."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set
from src.common.logger import get_logger
logger = get_logger("A_Memorix.GraphRelationRecall")
@dataclass
class GraphRelationRecallConfig:
"""Configuration for controlled graph relation recall."""
enabled: bool = True
candidate_k: int = 24
max_hop: int = 1
allow_two_hop_pair: bool = True
max_paths: int = 4
def __post_init__(self) -> None:
self.enabled = bool(self.enabled)
self.candidate_k = max(1, int(self.candidate_k))
self.max_hop = max(1, int(self.max_hop))
self.allow_two_hop_pair = bool(self.allow_two_hop_pair)
self.max_paths = max(1, int(self.max_paths))
@dataclass
class GraphRelationCandidate:
"""A graph-derived relation candidate before retriever-side fusion."""
hash_value: str
subject: str
predicate: str
object: str
confidence: float
graph_seed_entities: List[str]
graph_hops: int
graph_candidate_type: str
supporting_paragraph_count: int
def to_payload(self) -> Dict[str, Any]:
content = f"{self.subject} {self.predicate} {self.object}"
return {
"hash": self.hash_value,
"content": content,
"subject": self.subject,
"predicate": self.predicate,
"object": self.object,
"confidence": self.confidence,
"graph_seed_entities": list(self.graph_seed_entities),
"graph_hops": int(self.graph_hops),
"graph_candidate_type": self.graph_candidate_type,
"supporting_paragraph_count": int(self.supporting_paragraph_count),
}
class GraphRelationRecallService:
"""Collect relation candidates from the entity graph in a controlled way."""
def __init__(
self,
*,
graph_store: Any,
metadata_store: Any,
config: Optional[GraphRelationRecallConfig] = None,
) -> None:
self.graph_store = graph_store
self.metadata_store = metadata_store
self.config = config or GraphRelationRecallConfig()
def recall(
self,
*,
seed_entities: Sequence[str],
) -> List[GraphRelationCandidate]:
if not self.config.enabled:
return []
if self.graph_store is None or self.metadata_store is None:
return []
seeds = self._normalize_seed_entities(seed_entities)
if not seeds:
return []
seen_hashes: Set[str] = set()
candidates: List[GraphRelationCandidate] = []
if len(seeds) >= 2:
self._collect_direct_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
if (
len(candidates) < 3
and self.config.allow_two_hop_pair
and len(candidates) < self.config.candidate_k
):
self._collect_two_hop_pair_candidates(
seed_a=seeds[0],
seed_b=seeds[1],
seen_hashes=seen_hashes,
out=candidates,
)
else:
self._collect_one_hop_seed_candidates(
seed=seeds[0],
seen_hashes=seen_hashes,
out=candidates,
)
return candidates[: self.config.candidate_k]
def _normalize_seed_entities(self, seed_entities: Sequence[str]) -> List[str]:
out: List[str] = []
seen = set()
for raw in list(seed_entities)[:2]:
resolved = None
try:
resolved = self.graph_store.find_node(str(raw), ignore_case=True)
except Exception:
resolved = None
if not resolved:
continue
canon = str(resolved).strip().lower()
if not canon or canon in seen:
continue
seen.add(canon)
out.append(str(resolved))
return out
def _collect_direct_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_a, seed_b))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(seed_b, seed_a))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="direct_pair",
graph_hops=1,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_two_hop_pair_candidates(
self,
*,
seed_a: str,
seed_b: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
paths = self.graph_store.find_paths(
seed_a,
seed_b,
max_depth=2,
max_paths=self.config.max_paths,
)
except Exception as e:
logger.debug("graph two-hop recall skipped: %s", e)
return
for path_nodes in paths:
if len(out) >= self.config.candidate_k:
break
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 3:
continue
if len(path_nodes) != 3:
continue
for idx in range(len(path_nodes) - 1):
if len(out) >= self.config.candidate_k:
break
u = str(path_nodes[idx])
v = str(path_nodes[idx + 1])
relation_hashes = []
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(u, v))
relation_hashes.extend(self.graph_store.get_relation_hashes_for_edge(v, u))
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="two_hop_pair",
graph_hops=2,
graph_seed_entities=[seed_a, seed_b],
)
def _collect_one_hop_seed_candidates(
self,
*,
seed: str,
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
) -> None:
try:
relation_hashes = self.graph_store.get_incident_relation_hashes(
seed,
limit=self.config.candidate_k,
)
except Exception as e:
logger.debug("graph one-hop recall skipped: %s", e)
return
self._append_relation_hashes(
relation_hashes=relation_hashes,
seen_hashes=seen_hashes,
out=out,
candidate_type="one_hop_seed",
graph_hops=min(1, self.config.max_hop),
graph_seed_entities=[seed],
)
def _append_relation_hashes(
self,
*,
relation_hashes: Sequence[str],
seen_hashes: Set[str],
out: List[GraphRelationCandidate],
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> None:
for relation_hash in sorted({str(h) for h in relation_hashes if str(h).strip()}):
if len(out) >= self.config.candidate_k:
break
if relation_hash in seen_hashes:
continue
candidate = self._build_candidate(
relation_hash=relation_hash,
candidate_type=candidate_type,
graph_hops=graph_hops,
graph_seed_entities=graph_seed_entities,
)
if candidate is None:
continue
seen_hashes.add(relation_hash)
out.append(candidate)
def _build_candidate(
self,
*,
relation_hash: str,
candidate_type: str,
graph_hops: int,
graph_seed_entities: Sequence[str],
) -> Optional[GraphRelationCandidate]:
relation = self.metadata_store.get_relation(relation_hash)
if relation is None:
return None
supporting_paragraphs = self.metadata_store.get_paragraphs_by_relation(relation_hash)
return GraphRelationCandidate(
hash_value=relation_hash,
subject=str(relation.get("subject", "")),
predicate=str(relation.get("predicate", "")),
object=str(relation.get("object", "")),
confidence=float(relation.get("confidence", 1.0) or 1.0),
graph_seed_entities=[str(x) for x in graph_seed_entities],
graph_hops=int(graph_hops),
graph_candidate_type=str(candidate_type),
supporting_paragraph_count=len(supporting_paragraphs),
)

View File

@@ -0,0 +1,482 @@
"""
Personalized PageRank实现
提供个性化的图节点排序功能。
"""
from typing import Dict, List, Optional, Tuple, Union, Any
from dataclasses import dataclass
import numpy as np
from src.common.logger import get_logger
from ..storage import GraphStore
from ..utils.matcher import AhoCorasick
logger = get_logger("A_Memorix.PersonalizedPageRank")
@dataclass
class PageRankConfig:
"""
PageRank配置
属性:
alpha: 阻尼系数0-1之间
max_iter: 最大迭代次数
tol: 收敛阈值
normalize: 是否归一化结果
min_iterations: 最小迭代次数
"""
alpha: float = 0.85
max_iter: int = 100
tol: float = 1e-6
normalize: bool = True
min_iterations: int = 20
def __post_init__(self):
"""验证配置"""
if not 0 <= self.alpha < 1:
raise ValueError(f"alpha必须在[0, 1)之间: {self.alpha}")
if self.max_iter <= 0:
raise ValueError(f"max_iter必须大于0: {self.max_iter}")
if self.tol <= 0:
raise ValueError(f"tol必须大于0: {self.tol}")
if self.min_iterations < 0:
raise ValueError(f"min_iterations必须大于等于0: {self.min_iterations}")
if self.min_iterations >= self.max_iter:
raise ValueError(f"min_iterations必须小于max_iter")
class PersonalizedPageRank:
"""
Personalized PageRank计算器
功能:
- 个性化向量支持
- 快速收敛检测
- 结果归一化
- 批量计算
- 统计信息
参数:
graph_store: 图存储
config: PageRank配置
"""
def __init__(
self,
graph_store: GraphStore,
config: Optional[PageRankConfig] = None,
):
"""
初始化PPR计算器
Args:
graph_store: 图存储
config: PageRank配置
"""
self.graph_store = graph_store
self.config = config or PageRankConfig()
# 统计信息
self._total_computations = 0
self._total_iterations = 0
self._convergence_history: List[int] = []
logger.info(
f"PersonalizedPageRank 初始化: "
f"alpha={self.config.alpha}, "
f"max_iter={self.config.max_iter}"
)
# 缓存 Aho-Corasick 匹配器
self._ac_matcher: Optional[AhoCorasick] = None
self._ac_nodes_count = 0
def compute(
self,
personalization: Optional[Dict[str, float]] = None,
alpha: Optional[float] = None,
max_iter: Optional[int] = None,
normalize: Optional[bool] = None,
) -> Dict[str, float]:
"""
计算Personalized PageRank
Args:
personalization: 个性化向量 {节点名: 权重}
alpha: 阻尼系数(覆盖配置值)
max_iter: 最大迭代次数(覆盖配置值)
normalize: 是否归一化(覆盖配置值)
Returns:
节点PageRank值字典 {节点名: 分数}
"""
# 使用覆盖值或配置值
alpha = alpha if alpha is not None else self.config.alpha
max_iter = max_iter if max_iter is not None else self.config.max_iter
normalize = normalize if normalize is not None else self.config.normalize
# 调用GraphStore的compute_pagerank
scores = self.graph_store.compute_pagerank(
personalization=personalization,
alpha=alpha,
max_iter=max_iter,
tol=self.config.tol,
)
# 归一化(如果需要)
if normalize and scores:
total = sum(scores.values())
if total > 0:
scores = {node: score / total for node, score in scores.items()}
# 更新统计
self._total_computations += 1
logger.debug(
f"PPR计算完成: {len(scores)} 个节点, "
f"personalization_nodes={len(personalization) if personalization else 0}"
)
return scores
def compute_batch(
self,
personalization_list: List[Dict[str, float]],
normalize: bool = True,
) -> List[Dict[str, float]]:
"""
批量计算PPR
Args:
personalization_list: 个性化向量列表
normalize: 是否归一化
Returns:
PageRank值字典列表
"""
results = []
for i, personalization in enumerate(personalization_list):
logger.debug(f"计算第 {i+1}/{len(personalization_list)} 个PPR")
scores = self.compute(personalization=personalization, normalize=normalize)
results.append(scores)
return results
def compute_for_entities(
self,
entities: List[str],
weights: Optional[List[float]] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为实体列表计算PPR
Args:
entities: 实体列表
weights: 权重列表(默认均匀权重)
normalize: 是否归一化
Returns:
PageRank值字典
"""
if not entities:
logger.warning("实体列表为空返回均匀PPR")
return self.compute(personalization=None, normalize=normalize)
# 构建个性化向量
if weights is None:
weights = [1.0] * len(entities)
if len(weights) != len(entities):
raise ValueError(f"权重数量与实体数量不匹配: {len(weights)} vs {len(entities)}")
personalization = {entity: weight for entity, weight in zip(entities, weights)}
return self.compute(personalization=personalization, normalize=normalize)
def compute_for_query(
self,
query: str,
entity_extractor: Optional[callable] = None,
normalize: bool = True,
) -> Dict[str, float]:
"""
为查询计算PPR
Args:
query: 查询文本
entity_extractor: 实体提取函数(可选)
normalize: 是否归一化
Returns:
PageRank值字典
"""
# 提取实体
if entity_extractor is not None:
entities = entity_extractor(query)
else:
# 简单实现:基于图中的节点匹配
entities = self._extract_entities_from_query(query)
if not entities:
logger.debug(f"未从查询中提取到实体: '{query}'")
return self.compute(personalization=None, normalize=normalize)
# 计算PPR
return self.compute_for_entities(entities, normalize=normalize)
def rank_nodes(
self,
scores: Dict[str, float],
top_k: Optional[int] = None,
min_score: float = 0.0,
) -> List[Tuple[str, float]]:
"""
对节点排序
Args:
scores: PageRank分数字典
top_k: 返回前k个节点None表示全部
min_score: 最小分数阈值
Returns:
排序后的节点列表 [(节点名, 分数), ...]
"""
# 过滤低分节点
filtered = [(node, score) for node, score in scores.items() if score >= min_score]
# 按分数降序排序
sorted_nodes = sorted(filtered, key=lambda x: x[1], reverse=True)
# 返回top_k
if top_k is not None:
sorted_nodes = sorted_nodes[:top_k]
return sorted_nodes
def get_personalization_vector(
self,
nodes: List[str],
method: str = "uniform",
) -> Dict[str, float]:
"""
生成个性化向量
Args:
nodes: 节点列表
method: 生成方法
- "uniform": 均匀权重
- "degree": 按度数加权
- "inverse_degree": 按度数反比加权
Returns:
个性化向量 {节点名: 权重}
"""
if not nodes:
return {}
if method == "uniform":
# 均匀权重
weight = 1.0 / len(nodes)
return {node: weight for node in nodes}
elif method == "degree":
# 按度数加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
total_degree = sum(node_degrees.values())
if total_degree > 0:
return {node: degree / total_degree for node, degree in node_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
elif method == "inverse_degree":
# 按度数反比加权
node_degrees = {}
for node in nodes:
neighbors = self.graph_store.get_neighbors(node)
node_degrees[node] = len(neighbors)
# 反度数
inv_degrees = {node: 1.0 / (degree + 1) for node, degree in node_degrees.items()}
total_inv = sum(inv_degrees.values())
if total_inv > 0:
return {node: inv / total_inv for node, inv in inv_degrees.items()}
else:
return {node: 1.0 / len(nodes) for node in nodes}
else:
raise ValueError(f"不支持的个性化向量生成方法: {method}")
def compare_scores(
self,
scores1: Dict[str, float],
scores2: Dict[str, float],
) -> Dict[str, Dict[str, float]]:
"""
比较两组PPR分数
Args:
scores1: 第一组分数
scores2: 第二组分数
Returns:
比较结果 {
"common_nodes": {节点: (score1, score2)},
"only_in_1": {节点: score1},
"only_in_2": {节点: score2},
}
"""
common_nodes = {}
only_in_1 = {}
only_in_2 = {}
all_nodes = set(scores1.keys()) | set(scores2.keys())
for node in all_nodes:
if node in scores1 and node in scores2:
common_nodes[node] = (scores1[node], scores2[node])
elif node in scores1:
only_in_1[node] = scores1[node]
else:
only_in_2[node] = scores2[node]
return {
"common_nodes": common_nodes,
"only_in_1": only_in_1,
"only_in_2": only_in_2,
}
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
avg_iterations = (
self._total_iterations / self._total_computations
if self._total_computations > 0
else 0
)
return {
"config": {
"alpha": self.config.alpha,
"max_iter": self.config.max_iter,
"tol": self.config.tol,
"normalize": self.config.normalize,
"min_iterations": self.config.min_iterations,
},
"statistics": {
"total_computations": self._total_computations,
"total_iterations": self._total_iterations,
"avg_iterations": avg_iterations,
"convergence_history": self._convergence_history.copy(),
},
"graph": {
"num_nodes": self.graph_store.num_nodes,
"num_edges": self.graph_store.num_edges,
},
}
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_computations = 0
self._total_iterations = 0
self._convergence_history.clear()
logger.info("统计信息已重置")
def _extract_entities_from_query(self, query: str) -> List[str]:
"""
从查询中提取实体(简化实现)
Args:
query: 查询文本
Returns:
实体列表
"""
# 获取所有节点
all_nodes = self.graph_store.get_nodes()
if not all_nodes:
return []
# 检查是否需要更新 Aho-Corasick 匹配器
if self._ac_matcher is None or self._ac_nodes_count != len(all_nodes):
self._ac_matcher = AhoCorasick()
for node in all_nodes:
# 统一转为小写进行不区分大小写匹配
self._ac_matcher.add_pattern(node.lower())
self._ac_matcher.build()
self._ac_nodes_count = len(all_nodes)
# 执行匹配
query_lower = query.lower()
stats = self._ac_matcher.find_all(query_lower)
# 转换回原始的大小写(这里简化为从 all_nodes 中找,或者 AC 存原始值)
# 为了简单AC 中 add_pattern 存的是小写
# 我们需要一个映射:小写 -> 原始
node_map = {node.lower(): node for node in all_nodes}
entities = [node_map[low_name] for low_name in stats.keys()]
return entities
@property
def num_computations(self) -> int:
"""计算次数"""
return self._total_computations
@property
def avg_iterations(self) -> float:
"""平均迭代次数"""
if self._total_computations == 0:
return 0.0
return self._total_iterations / self._total_computations
def __repr__(self) -> str:
return (
f"PersonalizedPageRank("
f"alpha={self.config.alpha}, "
f"computations={self._total_computations})"
)
def create_ppr_from_graph(
graph_store: GraphStore,
alpha: float = 0.85,
max_iter: int = 100,
) -> PersonalizedPageRank:
"""
从图存储创建PPR计算器
Args:
graph_store: 图存储
alpha: 阻尼系数
max_iter: 最大迭代次数
Returns:
PPR计算器实例
"""
config = PageRankConfig(
alpha=alpha,
max_iter=max_iter,
)
return PersonalizedPageRank(
graph_store=graph_store,
config=config,
)

View File

@@ -0,0 +1,402 @@
"""
稀疏检索组件FTS5 + BM25
支持:
- 懒加载索引连接
- jieba / char n-gram 分词
- 可卸载并收缩 SQLite 内存缓存
"""
from __future__ import annotations
import re
import sqlite3
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..storage import MetadataStore
logger = get_logger("A_Memorix.SparseBM25")
try:
import jieba # type: ignore
HAS_JIEBA = True
except Exception:
HAS_JIEBA = False
jieba = None
@dataclass
class SparseBM25Config:
"""BM25 稀疏检索配置。"""
enabled: bool = True
backend: str = "fts5"
lazy_load: bool = True
mode: str = "auto" # auto | fallback_only | hybrid
tokenizer_mode: str = "jieba" # jieba | mixed | char_2gram
jieba_user_dict: str = ""
char_ngram_n: int = 2
candidate_k: int = 80
max_doc_len: int = 2000
enable_ngram_fallback_index: bool = True
enable_like_fallback: bool = False
enable_relation_sparse_fallback: bool = True
relation_candidate_k: int = 60
relation_max_doc_len: int = 512
unload_on_disable: bool = True
shrink_memory_on_unload: bool = True
def __post_init__(self) -> None:
self.backend = str(self.backend or "fts5").strip().lower()
self.mode = str(self.mode or "auto").strip().lower()
self.tokenizer_mode = str(self.tokenizer_mode or "jieba").strip().lower()
self.char_ngram_n = max(1, int(self.char_ngram_n))
self.candidate_k = max(1, int(self.candidate_k))
self.max_doc_len = max(0, int(self.max_doc_len))
self.relation_candidate_k = max(1, int(self.relation_candidate_k))
self.relation_max_doc_len = max(0, int(self.relation_max_doc_len))
if self.backend != "fts5":
raise ValueError(f"sparse.backend 暂仅支持 fts5: {self.backend}")
if self.mode not in {"auto", "fallback_only", "hybrid"}:
raise ValueError(f"sparse.mode 非法: {self.mode}")
if self.tokenizer_mode not in {"jieba", "mixed", "char_2gram"}:
raise ValueError(f"sparse.tokenizer_mode 非法: {self.tokenizer_mode}")
class SparseBM25Index:
"""
基于 SQLite FTS5 的 BM25 检索适配层。
"""
def __init__(
self,
metadata_store: MetadataStore,
config: Optional[SparseBM25Config] = None,
):
self.metadata_store = metadata_store
self.config = config or SparseBM25Config()
self._conn: Optional[sqlite3.Connection] = None
self._loaded: bool = False
self._jieba_dict_loaded: bool = False
@property
def loaded(self) -> bool:
return self._loaded and self._conn is not None
def ensure_loaded(self) -> bool:
"""按需加载 FTS 连接与索引。"""
if not self.config.enabled:
return False
if self.loaded:
return True
db_path = self.metadata_store.get_db_path()
conn = sqlite3.connect(
str(db_path),
check_same_thread=False,
timeout=30.0,
)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
conn.execute("PRAGMA temp_store=MEMORY")
if not self.metadata_store.ensure_fts_schema(conn=conn):
conn.close()
return False
self.metadata_store.ensure_fts_backfilled(conn=conn)
# 关系稀疏检索按独立开关加载,避免不必要的初始化开销。
if self.config.enable_relation_sparse_fallback:
self.metadata_store.ensure_relations_fts_schema(conn=conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=conn)
if self.config.enable_ngram_fallback_index:
self.metadata_store.ensure_paragraph_ngram_schema(conn=conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=conn,
)
self._conn = conn
self._loaded = True
self._prepare_tokenizer()
logger.info(
"SparseBM25Index loaded: backend=fts5, tokenizer=%s, mode=%s",
self.config.tokenizer_mode,
self.config.mode,
)
return True
def _prepare_tokenizer(self) -> None:
if self._jieba_dict_loaded:
return
if self.config.tokenizer_mode not in {"jieba", "mixed"}:
return
if not HAS_JIEBA:
logger.warning("jieba 不可用tokenizer 将退化为 char n-gram")
return
user_dict = str(self.config.jieba_user_dict or "").strip()
if user_dict:
try:
jieba.load_userdict(user_dict) # type: ignore[union-attr]
logger.info("已加载 jieba 用户词典: %s", user_dict)
except Exception as e:
logger.warning("加载 jieba 用户词典失败: %s", e)
self._jieba_dict_loaded = True
def _tokenize_jieba(self, text: str) -> List[str]:
if not HAS_JIEBA:
return []
try:
tokens = list(jieba.cut_for_search(text)) # type: ignore[union-attr]
return [t.strip().lower() for t in tokens if t and t.strip()]
except Exception:
return []
def _tokenize_char_ngram(self, text: str, n: int) -> List[str]:
compact = re.sub(r"\s+", "", text.lower())
if not compact:
return []
if len(compact) < n:
return [compact]
return [compact[i : i + n] for i in range(0, len(compact) - n + 1)]
def _tokenize(self, text: str) -> List[str]:
text = str(text or "").strip()
if not text:
return []
mode = self.config.tokenizer_mode
if mode == "jieba":
tokens = self._tokenize_jieba(text)
if tokens:
return list(dict.fromkeys(tokens))
return self._tokenize_char_ngram(text, self.config.char_ngram_n)
if mode == "mixed":
toks = self._tokenize_jieba(text)
toks.extend(self._tokenize_char_ngram(text, self.config.char_ngram_n))
return list(dict.fromkeys([t for t in toks if t]))
return list(dict.fromkeys(self._tokenize_char_ngram(text, self.config.char_ngram_n)))
def _build_match_query(self, tokens: List[str]) -> str:
safe_tokens: List[str] = []
for token in tokens:
t = token.replace('"', '""').strip()
if not t:
continue
safe_tokens.append(f'"{t}"')
if not safe_tokens:
return ""
# 采用 OR 提升召回,再交由 RRF 和阈值做稳健排序。
return " OR ".join(safe_tokens[:64])
def _fallback_substring_search(
self,
tokens: List[str],
limit: int,
) -> List[Dict[str, Any]]:
"""
当 FTS5 因分词不一致召回为空时,退化为子串匹配召回。
说明:
- FTS 索引当前采用 unicode61 tokenizer。
- 若查询 token 来源为 char n-gram 或中文词元,可能与索引 token 不一致。
- 这里使用 SQL LIKE 做兜底,按命中 token 覆盖度打分。
"""
if not tokens:
return []
# 去重并裁剪 token 数量,避免生成超长 SQL。
uniq_tokens = [t for t in dict.fromkeys(tokens) if t]
uniq_tokens = uniq_tokens[:32]
if not uniq_tokens:
return []
if self.config.enable_ngram_fallback_index:
try:
# 允许运行时切换开关后按需补齐 schema/回填。
self.metadata_store.ensure_paragraph_ngram_schema(conn=self._conn)
self.metadata_store.ensure_paragraph_ngram_backfilled(
n=self.config.char_ngram_n,
conn=self._conn,
)
rows = self.metadata_store.ngram_search_paragraphs(
tokens=uniq_tokens,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if rows:
return rows
except Exception as e:
logger.warning(f"ngram 倒排回退失败,将按配置决定是否使用 LIKE 回退: {e}")
if not self.config.enable_like_fallback:
return []
conditions = " OR ".join(["p.content LIKE ?"] * len(uniq_tokens))
params: List[Any] = [f"%{tok}%" for tok in uniq_tokens]
scan_limit = max(int(limit) * 8, 200)
params.append(scan_limit)
sql = f"""
SELECT p.hash, p.content
FROM paragraphs p
WHERE (p.is_deleted IS NULL OR p.is_deleted = 0)
AND ({conditions})
LIMIT ?
"""
rows = self.metadata_store.query(sql, tuple(params))
if not rows:
return []
scored: List[Dict[str, Any]] = []
token_count = max(1, len(uniq_tokens))
for row in rows:
content = str(row.get("content") or "")
content_low = content.lower()
matched = [tok for tok in uniq_tokens if tok in content_low]
if not matched:
continue
coverage = len(matched) / token_count
length_bonus = sum(len(tok) for tok in matched) / max(1, len(content_low))
# 兜底路径使用相对分,保持与上层接口兼容。
fallback_score = coverage * 0.8 + length_bonus * 0.2
scored.append(
{
"hash": row["hash"],
"content": content[: self.config.max_doc_len] if self.config.max_doc_len > 0 else content,
"bm25_score": -float(fallback_score),
"fallback_score": float(fallback_score),
}
)
scored.sort(key=lambda x: x["fallback_score"], reverse=True)
return scored[:limit]
def search(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行 BM25 检索。"""
if not self.config.enabled:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
# 关系稀疏检索可独立开关,运行时开启后也能按需补齐 schema/回填。
self.metadata_store.ensure_relations_fts_schema(conn=self._conn)
self.metadata_store.ensure_relations_fts_backfilled(conn=self._conn)
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
limit = max(1, int(k))
rows = self.metadata_store.fts_search_bm25(
match_query=match_query,
limit=limit,
max_doc_len=self.config.max_doc_len,
conn=self._conn,
)
if not rows:
rows = self._fallback_substring_search(tokens=tokens, limit=limit)
results: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
results.append(
{
"hash": row["hash"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score, # bm25 越小越相关,这里取反作为相对分数
}
)
return results
def search_relations(self, query: str, k: int = 20) -> List[Dict[str, Any]]:
"""执行关系稀疏检索FTS5 + BM25"""
if not self.config.enabled or not self.config.enable_relation_sparse_fallback:
return []
if self.config.lazy_load and not self.loaded:
if not self.ensure_loaded():
return []
if not self.loaded:
return []
tokens = self._tokenize(query)
match_query = self._build_match_query(tokens)
if not match_query:
return []
rows = self.metadata_store.fts_search_relations_bm25(
match_query=match_query,
limit=max(1, int(k)),
max_doc_len=self.config.relation_max_doc_len,
conn=self._conn,
)
out: List[Dict[str, Any]] = []
for rank, row in enumerate(rows, start=1):
bm25_score = float(row.get("bm25_score", 0.0))
out.append(
{
"hash": row["hash"],
"subject": row["subject"],
"predicate": row["predicate"],
"object": row["object"],
"content": row["content"],
"rank": rank,
"bm25_score": bm25_score,
"score": -bm25_score,
}
)
return out
def upsert_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_upsert_paragraph(paragraph_hash, conn=self._conn)
def delete_paragraph(self, paragraph_hash: str) -> bool:
if not self.loaded:
return False
return self.metadata_store.fts_delete_paragraph(paragraph_hash, conn=self._conn)
def unload(self) -> None:
"""卸载 BM25 连接并尽量释放内存。"""
if self._conn is not None:
try:
if self.config.shrink_memory_on_unload:
self.metadata_store.shrink_memory(conn=self._conn)
except Exception:
pass
try:
self._conn.close()
except Exception:
pass
self._conn = None
self._loaded = False
logger.info("SparseBM25Index unloaded")
def stats(self) -> Dict[str, Any]:
doc_count = 0
if self.loaded:
doc_count = self.metadata_store.fts_doc_count(conn=self._conn)
return {
"enabled": self.config.enabled,
"backend": self.config.backend,
"mode": self.config.mode,
"tokenizer_mode": self.config.tokenizer_mode,
"enable_ngram_fallback_index": self.config.enable_ngram_fallback_index,
"enable_like_fallback": self.config.enable_like_fallback,
"enable_relation_sparse_fallback": self.config.enable_relation_sparse_fallback,
"loaded": self.loaded,
"has_jieba": HAS_JIEBA,
"doc_count": doc_count,
}

View File

@@ -0,0 +1,450 @@
"""
动态阈值过滤器
根据检索结果的分布特征自适应调整过滤阈值。
"""
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from dataclasses import dataclass
from enum import Enum
from src.common.logger import get_logger
from .dual_path import RetrievalResult
logger = get_logger("A_Memorix.DynamicThresholdFilter")
class ThresholdMethod(Enum):
"""阈值计算方法"""
PERCENTILE = "percentile" # 百分位数
STD_DEV = "std_dev" # 标准差
GAP_DETECTION = "gap_detection" # 跳变检测
ADAPTIVE = "adaptive" # 自适应(综合多种方法)
@dataclass
class ThresholdConfig:
"""
阈值配置
属性:
method: 阈值计算方法
min_threshold: 最小阈值(绝对值)
max_threshold: 最大阈值(绝对值)
percentile: 百分位数用于percentile方法
std_multiplier: 标准差倍数用于std_dev方法
min_results: 最少保留结果数
enable_auto_adjust: 是否自动调整参数
"""
method: ThresholdMethod = ThresholdMethod.ADAPTIVE
min_threshold: float = 0.3
max_threshold: float = 0.95
percentile: float = 75.0 # 百分位数
std_multiplier: float = 1.5 # 标准差倍数
min_results: int = 3 # 最少保留结果数
enable_auto_adjust: bool = True
def __post_init__(self):
"""验证配置"""
if not 0 <= self.min_threshold <= 1:
raise ValueError(f"min_threshold必须在[0, 1]之间: {self.min_threshold}")
if not 0 <= self.max_threshold <= 1:
raise ValueError(f"max_threshold必须在[0, 1]之间: {self.max_threshold}")
if self.min_threshold >= self.max_threshold:
raise ValueError(f"min_threshold必须小于max_threshold")
if not 0 <= self.percentile <= 100:
raise ValueError(f"percentile必须在[0, 100]之间: {self.percentile}")
if self.std_multiplier <= 0:
raise ValueError(f"std_multiplier必须大于0: {self.std_multiplier}")
if self.min_results < 0:
raise ValueError(f"min_results必须大于等于0: {self.min_results}")
class DynamicThresholdFilter:
"""
动态阈值过滤器
功能:
- 基于结果分布自适应计算阈值
- 多种阈值计算方法
- 自动参数调整
- 统计信息收集
参数:
config: 阈值配置
"""
def __init__(
self,
config: Optional[ThresholdConfig] = None,
):
"""
初始化动态阈值过滤器
Args:
config: 阈值配置
"""
self.config = config or ThresholdConfig()
# 统计信息
self._total_filtered = 0
self._total_processed = 0
self._threshold_history: List[float] = []
logger.info(
f"DynamicThresholdFilter 初始化: "
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}"
)
def filter(
self,
results: List[RetrievalResult],
return_threshold: bool = False,
) -> Union[List[RetrievalResult], Tuple[List[RetrievalResult], float]]:
"""
过滤检索结果
Args:
results: 检索结果列表
return_threshold: 是否返回使用的阈值
Returns:
过滤后的结果列表,或 (结果列表, 阈值) 元组
"""
if not results:
logger.debug("结果列表为空,无需过滤")
return ([], 0.0) if return_threshold else []
self._total_processed += len(results)
# 提取分数
scores = np.array([r.score for r in results])
# 计算阈值
threshold = self._compute_threshold(scores, results)
# 记录阈值
self._threshold_history.append(threshold)
# 应用阈值过滤
filtered_results = [
r for r in results
if r.score >= threshold
]
# 确保至少保留min_results个结果
if len(filtered_results) < self.config.min_results:
# 按分数排序取前min_results个
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
filtered_results = sorted_results[:self.config.min_results]
threshold = filtered_results[-1].score if filtered_results else 0.0
self._total_filtered += len(results) - len(filtered_results)
logger.info(
f"过滤完成: {len(results)} -> {len(filtered_results)} "
f"(threshold={threshold:.3f})"
)
if return_threshold:
return filtered_results, threshold
return filtered_results
def _compute_threshold(
self,
scores: np.ndarray,
results: List[RetrievalResult],
) -> float:
"""
计算阈值
Args:
scores: 分数数组
results: 检索结果列表
Returns:
阈值
"""
if self.config.method == ThresholdMethod.PERCENTILE:
threshold = self._percentile_threshold(scores)
elif self.config.method == ThresholdMethod.STD_DEV:
threshold = self._std_dev_threshold(scores)
elif self.config.method == ThresholdMethod.GAP_DETECTION:
threshold = self._gap_detection_threshold(scores)
else: # ADAPTIVE
# 自适应方法:综合多种方法
thresholds = [
self._percentile_threshold(scores),
self._std_dev_threshold(scores),
self._gap_detection_threshold(scores),
]
# 使用中位数作为最终阈值
threshold = float(np.median(thresholds))
# 限制在[min_threshold, max_threshold]范围内
threshold = np.clip(
threshold,
self.config.min_threshold,
self.config.max_threshold,
)
# 自动调整
if self.config.enable_auto_adjust:
threshold = self._auto_adjust_threshold(threshold, scores)
return float(threshold)
def _percentile_threshold(self, scores: np.ndarray) -> float:
"""
基于百分位数计算阈值
Args:
scores: 分数数组
Returns:
阈值
"""
percentile = self.config.percentile
threshold = float(np.percentile(scores, percentile))
logger.debug(f"百分位数阈值: {threshold:.3f} (percentile={percentile})")
return threshold
def _std_dev_threshold(self, scores: np.ndarray) -> float:
"""
基于标准差计算阈值
threshold = mean - std_multiplier * std
Args:
scores: 分数数组
Returns:
阈值
"""
mean = float(np.mean(scores))
std = float(np.std(scores))
multiplier = self.config.std_multiplier
threshold = mean - multiplier * std
logger.debug(f"标准差阈值: {threshold:.3f} (mean={mean:.3f}, std={std:.3f})")
return threshold
def _gap_detection_threshold(self, scores: np.ndarray) -> float:
"""
基于跳变检测计算阈值
找到分数分布中最大的"跳变"位置,以此为阈值
Args:
scores: 分数数组(降序排列)
Returns:
阈值
"""
# 降序排列
sorted_scores = np.sort(scores)[::-1]
if len(sorted_scores) < 2:
return float(sorted_scores[0]) if len(sorted_scores) > 0 else 0.0
# 计算相邻分数的差值
gaps = np.diff(sorted_scores)
# 找到最大的跳变位置
max_gap_idx = int(np.argmax(gaps))
# 阈值为跳变后的分数
threshold = float(sorted_scores[max_gap_idx + 1])
logger.debug(
f"跳变检测阈值: {threshold:.3f} "
f"(gap={gaps[max_gap_idx]:.3f}, idx={max_gap_idx})"
)
return threshold
def _auto_adjust_threshold(
self,
threshold: float,
scores: np.ndarray,
) -> float:
"""
自动调整阈值
基于历史阈值和当前分数分布调整
Args:
threshold: 当前阈值
scores: 分数数组
Returns:
调整后的阈值
"""
if not self._threshold_history:
return threshold
# 计算历史阈值的移动平均
recent_thresholds = self._threshold_history[-10:] # 最近10次
avg_threshold = float(np.mean(recent_thresholds))
# 当前阈值与历史平均的差异
diff = threshold - avg_threshold
# 如果差异过大(>0.2),向历史平均靠拢
if abs(diff) > 0.2:
adjusted_threshold = avg_threshold + diff * 0.5 # 向中间靠拢50%
logger.debug(
f"阈值调整: {threshold:.3f} -> {adjusted_threshold:.3f} "
f"(历史平均={avg_threshold:.3f})"
)
return adjusted_threshold
return threshold
def filter_by_confidence(
self,
results: List[RetrievalResult],
min_confidence: float = 0.5,
) -> List[RetrievalResult]:
"""
基于置信度过滤结果
Args:
results: 检索结果列表
min_confidence: 最小置信度
Returns:
过滤后的结果列表
"""
filtered = []
for result in results:
# 对于关系结果使用confidence字段
if result.result_type == "relation":
confidence = result.metadata.get("confidence", 1.0)
if confidence >= min_confidence:
filtered.append(result)
else:
# 对于段落结果,直接使用分数
if result.score >= min_confidence:
filtered.append(result)
logger.info(
f"置信度过滤: {len(results)} -> {len(filtered)} "
f"(min_confidence={min_confidence})"
)
return filtered
def filter_by_diversity(
self,
results: List[RetrievalResult],
similarity_threshold: float = 0.9,
top_k: int = 10,
) -> List[RetrievalResult]:
"""
基于多样性过滤结果(去除重复)
Args:
results: 检索结果列表
similarity_threshold: 相似度阈值(高于此值视为重复)
top_k: 最多保留结果数
Returns:
过滤后的结果列表
"""
if not results:
return []
# 按分数排序
sorted_results = sorted(results, key=lambda x: x.score, reverse=True)
# 贪心选择:选择与已选结果相似度低的结果
selected = []
selected_hashes = []
for result in sorted_results:
if len(selected) >= top_k:
break
# 检查与已选结果的相似度
is_duplicate = False
for selected_hash in selected_hashes:
# 简单判断基于hash的前缀
if result.hash_value[:8] == selected_hash[:8]:
is_duplicate = True
break
if not is_duplicate:
selected.append(result)
selected_hashes.append(result.hash_value)
logger.info(
f"多样性过滤: {len(results)} -> {len(selected)} "
f"(similarity_threshold={similarity_threshold})"
)
return selected
def get_statistics(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
filter_rate = (
self._total_filtered / self._total_processed
if self._total_processed > 0
else 0.0
)
stats = {
"config": {
"method": self.config.method.value,
"min_threshold": self.config.min_threshold,
"max_threshold": self.config.max_threshold,
"percentile": self.config.percentile,
"std_multiplier": self.config.std_multiplier,
"min_results": self.config.min_results,
"enable_auto_adjust": self.config.enable_auto_adjust,
},
"statistics": {
"total_processed": self._total_processed,
"total_filtered": self._total_filtered,
"filter_rate": filter_rate,
"avg_threshold": float(np.mean(self._threshold_history))
if self._threshold_history else 0.0,
"threshold_count": len(self._threshold_history),
},
}
if self._threshold_history:
stats["statistics"]["min_threshold_used"] = float(np.min(self._threshold_history))
stats["statistics"]["max_threshold_used"] = float(np.max(self._threshold_history))
return stats
def reset_statistics(self) -> None:
"""重置统计信息"""
self._total_filtered = 0
self._total_processed = 0
self._threshold_history.clear()
logger.info("统计信息已重置")
def __repr__(self) -> str:
return (
f"DynamicThresholdFilter("
f"method={self.config.method.value}, "
f"min_threshold={self.config.min_threshold}, "
f"filtered={self._total_filtered}/{self._total_processed})"
)

View File

@@ -0,0 +1,8 @@
"""SDK runtime exports for A_Memorix."""
from .sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
__all__ = [
"KernelSearchRequest",
"SDKMemoryKernel",
]

View File

@@ -0,0 +1,579 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
from src.common.logger import get_logger
from ..embedding import create_embedding_api_adapter
from ..retrieval import (
DualPathRetriever,
DualPathRetrieverConfig,
RetrievalResult,
SparseBM25Config,
SparseBM25Index,
TemporalQueryOptions,
)
from ..storage import GraphStore, MetadataStore, QuantizationType, SparseMatrixFormat, VectorStore
from ..utils.aggregate_query_service import AggregateQueryService
from ..utils.episode_retrieval_service import EpisodeRetrievalService
from ..utils.hash import normalize_text
from ..utils.relation_write_service import RelationWriteService
logger = get_logger("A_Memorix.SDKMemoryKernel")
@dataclass
class KernelSearchRequest:
query: str = ""
limit: int = 5
mode: str = "hybrid"
chat_id: str = ""
person_id: str = ""
time_start: Optional[float] = None
time_end: Optional[float] = None
class SDKMemoryKernel:
def __init__(self, *, plugin_root: Path, config: Optional[Dict[str, Any]] = None) -> None:
self.plugin_root = Path(plugin_root).resolve()
self.config = config or {}
storage_cfg = self._cfg("storage", {}) or {}
data_dir = str(storage_cfg.get("data_dir", "./data") or "./data")
self.data_dir = (self.plugin_root / data_dir).resolve() if data_dir.startswith(".") else Path(data_dir)
self.embedding_dimension = max(32, int(self._cfg("embedding.dimension", 256)))
self.relation_vectors_enabled = bool(self._cfg("retrieval.relation_vectorization.enabled", False))
self.embedding_manager = None
self.vector_store: Optional[VectorStore] = None
self.graph_store: Optional[GraphStore] = None
self.metadata_store: Optional[MetadataStore] = None
self.relation_write_service: Optional[RelationWriteService] = None
self.sparse_index = None
self.retriever: Optional[DualPathRetriever] = None
self.episode_retriever: Optional[EpisodeRetrievalService] = None
self.aggregate_query_service: Optional[AggregateQueryService] = None
self._initialized = False
self._last_maintenance_at: Optional[float] = None
def _cfg(self, key: str, default: Any = None) -> Any:
current: Any = self.config
if key in {"storage", "embedding", "retrieval"} and isinstance(current, dict):
return current.get(key, default)
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
async def initialize(self) -> None:
if self._initialized:
return
self.data_dir.mkdir(parents=True, exist_ok=True)
self.embedding_manager = create_embedding_api_adapter(
batch_size=int(self._cfg("embedding.batch_size", 32)),
max_concurrent=int(self._cfg("embedding.max_concurrent", 5)),
default_dimension=self.embedding_dimension,
model_name=str(self._cfg("embedding.model_name", "hash-v1")),
retry_config=self._cfg("embedding.retry", {}) or {},
)
self.embedding_dimension = int(await self.embedding_manager._detect_dimension())
self.vector_store = VectorStore(
dimension=self.embedding_dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.data_dir / "vectors",
)
self.graph_store = GraphStore(matrix_format=SparseMatrixFormat.CSR, data_dir=self.data_dir / "graph")
self.metadata_store = MetadataStore(data_dir=self.data_dir / "metadata")
self.metadata_store.connect()
if self.vector_store.has_data():
self.vector_store.load()
self.vector_store.warmup_index(force_train=True)
if self.graph_store.has_data():
self.graph_store.load()
sparse_cfg = self._cfg("retrieval.sparse", {}) or {}
self.sparse_index = SparseBM25Index(metadata_store=self.metadata_store, config=SparseBM25Config(**sparse_cfg))
if getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded()
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
self.retriever = DualPathRetriever(
vector_store=self.vector_store,
graph_store=self.graph_store,
metadata_store=self.metadata_store,
embedding_manager=self.embedding_manager,
sparse_index=self.sparse_index,
config=DualPathRetrieverConfig(
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 24)),
top_k_relations=int(self._cfg("retrieval.top_k_relations", 12)),
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
alpha=float(self._cfg("retrieval.alpha", 0.5)),
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
sparse=sparse_cfg,
fusion=self._cfg("retrieval.fusion", {}) or {},
graph_recall=self._cfg("retrieval.search.graph_recall", {}) or {},
relation_intent=self._cfg("retrieval.search.relation_intent", {}) or {},
),
)
self.episode_retriever = EpisodeRetrievalService(metadata_store=self.metadata_store, retriever=self.retriever)
self.aggregate_query_service = AggregateQueryService(plugin_config=self.config)
self._initialized = True
def close(self) -> None:
if self.vector_store is not None:
self.vector_store.save()
if self.graph_store is not None:
self.graph_store.save()
if self.metadata_store is not None:
self.metadata_store.close()
self._initialized = False
async def ingest_summary(
self,
*,
external_id: str,
chat_id: str,
text: str,
participants: Optional[Sequence[str]] = None,
time_start: Optional[float] = None,
time_end: Optional[float] = None,
tags: Optional[Sequence[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
summary_meta = dict(metadata or {})
summary_meta.setdefault("kind", "chat_summary")
return await self.ingest_text(
external_id=external_id,
source_type="chat_summary",
text=text,
chat_id=chat_id,
participants=participants,
time_start=time_start,
time_end=time_end,
tags=tags,
metadata=summary_meta,
)
async def ingest_text(
self,
*,
external_id: str,
source_type: str,
text: str,
chat_id: str = "",
person_ids: Optional[Sequence[str]] = None,
participants: Optional[Sequence[str]] = None,
timestamp: Optional[float] = None,
time_start: Optional[float] = None,
time_end: Optional[float] = None,
tags: Optional[Sequence[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
entities: Optional[Sequence[str]] = None,
relations: Optional[Sequence[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
await self.initialize()
assert self.metadata_store and self.vector_store and self.graph_store and self.embedding_manager
assert self.relation_write_service
content = normalize_text(text)
if not content:
return {"stored_ids": [], "skipped_ids": [external_id], "reason": "empty_text"}
if ref := self.metadata_store.get_external_memory_ref(external_id):
return {"stored_ids": [], "skipped_ids": [str(ref.get("paragraph_hash", ""))], "reason": "exists"}
person_tokens = self._tokens(person_ids)
participant_tokens = self._tokens(participants)
entity_tokens = self._merge_tokens(entities, person_tokens, participant_tokens)
source = self._build_source(source_type, chat_id, person_tokens)
paragraph_meta = dict(metadata or {})
paragraph_meta.update(
{
"external_id": external_id,
"source_type": str(source_type or "").strip(),
"chat_id": str(chat_id or "").strip(),
"person_ids": person_tokens,
"participants": participant_tokens,
"tags": self._tokens(tags),
}
)
paragraph_hash = self.metadata_store.add_paragraph(
content=content,
source=source,
metadata=paragraph_meta,
knowledge_type="factual" if source_type == "person_fact" else "narrative" if source_type == "chat_summary" else "mixed",
time_meta=self._time_meta(timestamp, time_start, time_end),
)
embedding = await self.embedding_manager.encode(content)
self.vector_store.add(vectors=embedding.reshape(1, -1), ids=[paragraph_hash])
for name in entity_tokens:
self.metadata_store.add_entity(name=name, source_paragraph=paragraph_hash)
stored_relations: List[str] = []
for row in [dict(item) for item in (relations or []) if isinstance(item, dict)]:
s = str(row.get("subject", "") or "").strip()
p = str(row.get("predicate", "") or "").strip()
o = str(row.get("object", "") or "").strip()
if not (s and p and o):
continue
result = await self.relation_write_service.upsert_relation_with_vector(
subject=s,
predicate=p,
obj=o,
confidence=float(row.get("confidence", 1.0) or 1.0),
source_paragraph=paragraph_hash,
metadata={"external_id": external_id, "source_type": source_type},
write_vector=self.relation_vectors_enabled,
)
self.metadata_store.link_paragraph_relation(paragraph_hash, result.hash_value)
stored_relations.append(result.hash_value)
self.metadata_store.upsert_external_memory_ref(
external_id=external_id,
paragraph_hash=paragraph_hash,
source_type=source_type,
metadata={"chat_id": chat_id, "person_ids": person_tokens},
)
self._persist()
self.rebuild_episodes_for_sources([source])
for person_id in person_tokens:
await self.refresh_person_profile(person_id)
return {"stored_ids": [paragraph_hash, *stored_relations], "skipped_ids": []}
async def search_memory(self, request: KernelSearchRequest) -> Dict[str, Any]:
await self.initialize()
assert self.retriever and self.episode_retriever and self.aggregate_query_service
mode = str(request.mode or "hybrid").strip().lower() or "hybrid"
clean_query = str(request.query or "").strip()
limit = max(1, int(request.limit or 5))
temporal = self._temporal(request)
if mode == "episode":
rows = await self.episode_retriever.query(
query=clean_query,
top_k=limit,
time_from=request.time_start,
time_to=request.time_end,
source=self._chat_source(request.chat_id),
)
hits = [self._episode_hit(row) for row in rows]
return {"summary": self._summary(hits), "hits": hits}
if mode == "aggregate":
payload = await self.aggregate_query_service.execute(
query=clean_query,
top_k=limit,
mix=True,
mix_top_k=limit,
time_from=str(request.time_start) if request.time_start is not None else None,
time_to=str(request.time_end) if request.time_end is not None else None,
search_runner=lambda: self._aggregate_search(clean_query, limit, temporal),
time_runner=lambda: self._aggregate_time(clean_query, limit, temporal),
episode_runner=lambda: self._aggregate_episode(clean_query, limit, request),
)
hits = [dict(item) for item in payload.get("mixed_results", []) if isinstance(item, dict)]
for item in hits:
item.setdefault("metadata", {})
return {"summary": self._summary(hits), "hits": hits}
results = await self.retriever.retrieve(query=clean_query, top_k=limit, temporal=temporal)
hits = [self._retrieval_hit(item) for item in results]
return {"summary": self._summary(self._filter_hits(hits, request.person_id)), "hits": self._filter_hits(hits, request.person_id)}
async def get_person_profile(self, *, person_id: str, chat_id: str = "", limit: int = 10) -> Dict[str, Any]:
_ = chat_id
await self.initialize()
assert self.metadata_store
snapshot = self.metadata_store.get_latest_person_profile_snapshot(person_id) or await self.refresh_person_profile(person_id, limit=limit)
evidence = []
for hash_value in snapshot.get("evidence_ids", [])[: max(1, int(limit))]:
paragraph = self.metadata_store.get_paragraph(hash_value)
if paragraph is not None:
evidence.append({"hash": hash_value, "content": str(paragraph.get("content", "") or "")[:220], "metadata": paragraph.get("metadata", {}) or {}})
text = str(snapshot.get("profile_text", "") or "").strip()
traits = [line.strip("- ").strip() for line in text.splitlines() if line.strip()][:8]
return {"summary": text, "traits": traits, "evidence": evidence}
async def refresh_person_profile(self, person_id: str, limit: int = 10) -> Dict[str, Any]:
await self.initialize()
assert self.metadata_store
rows = self.metadata_store.query(
"""
SELECT DISTINCT p.*
FROM paragraphs p
JOIN paragraph_entities pe ON pe.paragraph_hash = p.hash
JOIN entities e ON e.hash = pe.entity_hash
WHERE e.name = ?
AND (p.is_deleted IS NULL OR p.is_deleted = 0)
ORDER BY COALESCE(p.event_time_end, p.event_time_start, p.event_time, p.updated_at, p.created_at) DESC
LIMIT ?
""",
(person_id, max(1, int(limit)) * 3),
)
evidence_ids = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
vector_evidence = [{"hash": str(row.get("hash", "") or ""), "type": "paragraph", "score": 0.0, "content": str(row.get("content", "") or "")[:220], "metadata": row.get("metadata", {}) or {}} for row in rows[: max(1, int(limit))]]
relation_edges = [{"hash": str(row.get("hash", "") or ""), "subject": str(row.get("subject", "") or ""), "predicate": str(row.get("predicate", "") or ""), "object": str(row.get("object", "") or ""), "confidence": float(row.get("confidence", 1.0) or 1.0)} for row in self.metadata_store.get_relations(subject=person_id)[:limit]]
if relation_edges:
profile_text = "\n".join(f"{item['subject']} {item['predicate']} {item['object']}" for item in relation_edges[:6])
elif vector_evidence:
profile_text = "\n".join(f"- {item['content']}" for item in vector_evidence[:6])
else:
profile_text = "暂无稳定画像证据。"
return self.metadata_store.upsert_person_profile_snapshot(
person_id=person_id,
profile_text=profile_text,
aliases=[person_id],
relation_edges=relation_edges,
vector_evidence=vector_evidence,
evidence_ids=evidence_ids[: max(1, int(limit))],
expires_at=time.time() + 6 * 3600,
source_note="sdk_memory_kernel",
)
async def maintain_memory(self, *, action: str, target: str, hours: Optional[float] = None, reason: str = "") -> Dict[str, Any]:
_ = reason
await self.initialize()
assert self.metadata_store
hashes = self._resolve_relation_hashes(target)
if not hashes:
return {"success": False, "detail": "未命中可维护关系"}
act = str(action or "").strip().lower()
if act == "reinforce":
self.metadata_store.reinforce_relations(hashes)
elif act == "protect":
ttl_seconds = max(0.0, float(hours or 0.0)) * 3600.0
self.metadata_store.protect_relations(hashes, ttl_seconds=ttl_seconds, is_pinned=ttl_seconds <= 0)
elif act == "restore":
restored = sum(1 for hash_value in hashes if self.metadata_store.restore_relation(hash_value))
if restored <= 0:
return {"success": False, "detail": "未恢复任何关系"}
else:
return {"success": False, "detail": f"不支持的维护动作: {act}"}
self._last_maintenance_at = time.time()
self._persist()
return {"success": True, "detail": f"{act} {len(hashes)} 条关系"}
def rebuild_episodes_for_sources(self, sources: Iterable[str]) -> int:
assert self.metadata_store
rebuilt = 0
for source in self._tokens(sources):
rows = self.metadata_store.query(
"""
SELECT * FROM paragraphs
WHERE source = ?
AND (is_deleted IS NULL OR is_deleted = 0)
ORDER BY COALESCE(event_time_start, event_time, created_at) ASC, hash ASC
""",
(source,),
)
if not rows:
continue
paragraph_hashes = [str(row.get("hash", "") or "") for row in rows if str(row.get("hash", "")).strip()]
payload = self.metadata_store.upsert_episode(
{
"source": source,
"title": str((rows[0].get("metadata", {}) or {}).get("theme", "") or f"{source} 情景记忆")[:80],
"summary": "".join(str(row.get("content", "") or "").strip().replace("\n", " ")[:120] for row in rows[:3] if str(row.get("content", "") or "").strip())[:500] or "自动构建的情景记忆。",
"participants": self._episode_participants(rows),
"keywords": self._episode_keywords(rows),
"evidence_ids": paragraph_hashes,
"paragraph_count": len(paragraph_hashes),
"event_time_start": self._time_bound(rows, "event_time_start", "event_time", reverse=False),
"event_time_end": self._time_bound(rows, "event_time_end", "event_time", reverse=True),
"time_granularity": "day",
"time_confidence": 0.7,
"llm_confidence": 0.0,
"segmentation_model": "rule_based_sdk",
"segmentation_version": "1",
}
)
self.metadata_store.bind_episode_paragraphs(payload["episode_id"], paragraph_hashes)
rebuilt += 1
return rebuilt
def memory_stats(self) -> Dict[str, Any]:
assert self.metadata_store
stats = self.metadata_store.get_statistics()
episodes = self.metadata_store.query("SELECT COUNT(*) AS c FROM episodes")[0]["c"]
profiles = self.metadata_store.query("SELECT COUNT(*) AS c FROM person_profile_snapshots")[0]["c"]
return {"paragraphs": int(stats.get("paragraph_count", 0) or 0), "relations": int(stats.get("relation_count", 0) or 0), "episodes": int(episodes or 0), "profiles": int(profiles or 0), "last_maintenance_at": self._last_maintenance_at}
async def _aggregate_search(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
assert self.retriever
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
return {"success": True, "results": hits, "count": len(hits), "query_type": "search"}
async def _aggregate_time(self, query: str, limit: int, temporal: Optional[TemporalQueryOptions]) -> Dict[str, Any]:
if temporal is None:
return {"success": False, "error": "missing temporal window", "results": []}
assert self.retriever
hits = [self._retrieval_hit(item) for item in await self.retriever.retrieve(query=query, top_k=limit, temporal=temporal)]
return {"success": True, "results": hits, "count": len(hits), "query_type": "time"}
async def _aggregate_episode(self, query: str, limit: int, request: KernelSearchRequest) -> Dict[str, Any]:
assert self.episode_retriever
rows = await self.episode_retriever.query(query=query, top_k=limit, time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
hits = [self._episode_hit(row) for row in rows]
return {"success": True, "results": hits, "count": len(hits), "query_type": "episode"}
def _persist(self) -> None:
if self.vector_store is not None:
self.vector_store.save()
if self.graph_store is not None:
self.graph_store.save()
if self.sparse_index is not None and getattr(self.sparse_index.config, "enabled", False):
self.sparse_index.ensure_loaded()
@staticmethod
def _tokens(values: Optional[Iterable[Any]]) -> List[str]:
result: List[str] = []
seen = set()
for item in values or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result
@classmethod
def _merge_tokens(cls, *groups: Optional[Iterable[Any]]) -> List[str]:
merged: List[str] = []
seen = set()
for group in groups:
for item in cls._tokens(group):
if item in seen:
continue
seen.add(item)
merged.append(item)
return merged
@staticmethod
def _build_source(source_type: str, chat_id: str, person_ids: Sequence[str]) -> str:
clean_type = str(source_type or "").strip() or "memory"
if clean_type == "chat_summary" and chat_id:
return f"chat_summary:{chat_id}"
if clean_type == "person_fact" and person_ids:
return f"person_fact:{person_ids[0]}"
return f"{clean_type}:{chat_id}" if chat_id else clean_type
@staticmethod
def _chat_source(chat_id: str) -> Optional[str]:
clean = str(chat_id or "").strip()
return f"chat_summary:{clean}" if clean else None
@staticmethod
def _time_meta(timestamp: Optional[float], time_start: Optional[float], time_end: Optional[float]) -> Dict[str, Any]:
payload: Dict[str, Any] = {}
if timestamp is not None:
payload["event_time"] = float(timestamp)
if time_start is not None:
payload["event_time_start"] = float(time_start)
if time_end is not None:
payload["event_time_end"] = float(time_end)
if payload:
payload["time_granularity"] = "minute"
payload["time_confidence"] = 0.95
return payload
def _temporal(self, request: KernelSearchRequest) -> Optional[TemporalQueryOptions]:
if request.time_start is None and request.time_end is None and not request.chat_id:
return None
return TemporalQueryOptions(time_from=request.time_start, time_to=request.time_end, source=self._chat_source(request.chat_id))
@staticmethod
def _retrieval_hit(item: RetrievalResult) -> Dict[str, Any]:
payload = item.to_dict()
return {"hash": payload.get("hash", ""), "content": payload.get("content", ""), "score": payload.get("score", 0.0), "type": payload.get("type", ""), "source": payload.get("source", ""), "metadata": payload.get("metadata", {}) or {}}
@staticmethod
def _episode_hit(row: Dict[str, Any]) -> Dict[str, Any]:
return {"type": "episode", "episode_id": str(row.get("episode_id", "") or ""), "title": str(row.get("title", "") or ""), "content": str(row.get("summary", "") or ""), "score": float(row.get("lexical_score", 0.0) or 0.0), "source": "episode", "metadata": {"participants": row.get("participants", []) or [], "keywords": row.get("keywords", []) or [], "source": row.get("source"), "event_time_start": row.get("event_time_start"), "event_time_end": row.get("event_time_end")}}
@staticmethod
def _summary(hits: Sequence[Dict[str, Any]]) -> str:
if not hits:
return ""
lines = []
for index, item in enumerate(hits[:5], start=1):
content = str(item.get("content", "") or "").strip().replace("\n", " ")
lines.append(f"{index}. {(content[:120] + '...') if len(content) > 120 else content}")
return "\n".join(lines)
@staticmethod
def _filter_hits(hits: List[Dict[str, Any]], person_id: str) -> List[Dict[str, Any]]:
if not person_id:
return hits
filtered = []
for item in hits:
metadata = item.get("metadata", {}) or {}
if person_id in (metadata.get("person_ids", []) or []):
filtered.append(item)
continue
if person_id and person_id in str(item.get("content", "") or ""):
filtered.append(item)
return filtered or hits
@staticmethod
def _episode_participants(rows: Sequence[Dict[str, Any]]) -> List[str]:
seen = set()
result: List[str] = []
for row in rows:
meta = row.get("metadata", {}) or {}
for key in ("participants", "person_ids"):
for item in meta.get(key, []) or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result[:16]
@staticmethod
def _episode_keywords(rows: Sequence[Dict[str, Any]]) -> List[str]:
seen = set()
result: List[str] = []
for row in rows:
meta = row.get("metadata", {}) or {}
for item in meta.get("tags", []) or []:
token = str(item or "").strip()
if not token or token in seen:
continue
seen.add(token)
result.append(token)
return result[:12]
@staticmethod
def _time_bound(rows: Sequence[Dict[str, Any]], primary: str, fallback: str, reverse: bool) -> Optional[float]:
values: List[float] = []
for row in rows:
for key in (primary, fallback):
value = row.get(key)
try:
if value is not None:
values.append(float(value))
break
except Exception:
continue
if not values:
return None
return max(values) if reverse else min(values)
def _resolve_relation_hashes(self, target: str) -> List[str]:
assert self.metadata_store
token = str(target or "").strip()
if not token:
return []
if len(token) == 64 and all(ch in "0123456789abcdef" for ch in token.lower()):
return [token]
hashes = self.metadata_store.search_relation_hashes_by_text(token, limit=10)
if hashes:
return hashes
return [str(row.get("hash", "") or "") for row in self.metadata_store.get_relations(subject=token)[:10] if str(row.get("hash", "")).strip()]

View File

@@ -0,0 +1,53 @@
"""存储层"""
from .vector_store import VectorStore, QuantizationType
from .graph_store import GraphStore, SparseMatrixFormat
from .metadata_store import MetadataStore
from .knowledge_types import (
ImportStrategy,
KnowledgeType,
allowed_import_strategy_values,
allowed_knowledge_type_values,
get_knowledge_type_from_string,
get_import_strategy_from_string,
parse_import_strategy,
resolve_stored_knowledge_type,
should_extract_relations,
get_default_chunk_size,
get_type_display_name,
validate_stored_knowledge_type,
)
from .type_detection import (
detect_knowledge_type,
get_type_from_user_input,
looks_like_factual_text,
looks_like_quote_text,
looks_like_structured_text,
select_import_strategy,
)
__all__ = [
"VectorStore",
"GraphStore",
"MetadataStore",
"QuantizationType",
"SparseMatrixFormat",
"ImportStrategy",
"KnowledgeType",
"allowed_import_strategy_values",
"allowed_knowledge_type_values",
"get_knowledge_type_from_string",
"get_import_strategy_from_string",
"parse_import_strategy",
"resolve_stored_knowledge_type",
"should_extract_relations",
"get_default_chunk_size",
"get_type_display_name",
"validate_stored_knowledge_type",
"detect_knowledge_type",
"get_type_from_user_input",
"looks_like_factual_text",
"looks_like_quote_text",
"looks_like_structured_text",
"select_import_strategy",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,183 @@
"""Knowledge type and import strategy helpers."""
from __future__ import annotations
from enum import Enum
from typing import Any, Optional
class KnowledgeType(str, Enum):
"""持久化到 paragraphs.knowledge_type 的合法类型。"""
STRUCTURED = "structured"
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
MIXED = "mixed"
class ImportStrategy(str, Enum):
"""文本导入阶段的策略选择。"""
AUTO = "auto"
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
def allowed_knowledge_type_values() -> tuple[str, ...]:
return tuple(item.value for item in KnowledgeType)
def allowed_import_strategy_values() -> tuple[str, ...]:
return tuple(item.value for item in ImportStrategy)
def get_knowledge_type_from_string(type_str: Any) -> Optional[KnowledgeType]:
"""从字符串解析合法的落库知识类型。"""
if not isinstance(type_str, str):
return None
normalized = type_str.lower().strip()
try:
return KnowledgeType(normalized)
except ValueError:
return None
def get_import_strategy_from_string(value: Any) -> Optional[ImportStrategy]:
"""从字符串解析文本导入策略。"""
if not isinstance(value, str):
return None
normalized = value.lower().strip()
try:
return ImportStrategy(normalized)
except ValueError:
return None
def parse_import_strategy(value: Any, default: ImportStrategy = ImportStrategy.AUTO) -> ImportStrategy:
"""解析 import strategy非法值直接报错。"""
if value is None:
return default
if isinstance(value, ImportStrategy):
return value
normalized = str(value or "").strip().lower()
if not normalized:
return default
strategy = get_import_strategy_from_string(normalized)
if strategy is None:
allowed = "/".join(allowed_import_strategy_values())
raise ValueError(f"strategy_override 必须为 {allowed}")
return strategy
def validate_stored_knowledge_type(value: Any) -> KnowledgeType:
"""校验写库 knowledge_type仅允许合法落库类型。"""
if isinstance(value, KnowledgeType):
return value
resolved = get_knowledge_type_from_string(value)
if resolved is None:
allowed = "/".join(allowed_knowledge_type_values())
raise ValueError(f"knowledge_type 必须为 {allowed}")
return resolved
def resolve_stored_knowledge_type(
value: Any,
*,
content: str = "",
allow_legacy: bool = False,
unknown_fallback: Optional[KnowledgeType] = None,
) -> KnowledgeType:
"""
将策略/字符串/旧值解析为合法落库类型。
`allow_legacy=True` 仅供迁移使用。
"""
if isinstance(value, KnowledgeType):
return value
if isinstance(value, ImportStrategy):
if value == ImportStrategy.AUTO:
if not str(content or "").strip():
raise ValueError("knowledge_type=auto 需要 content 才能推断")
from .type_detection import detect_knowledge_type
return detect_knowledge_type(content)
return KnowledgeType(value.value)
raw = str(value or "").strip()
if not raw:
if str(content or "").strip():
from .type_detection import detect_knowledge_type
return detect_knowledge_type(content)
raise ValueError("knowledge_type 不能为空")
direct = get_knowledge_type_from_string(raw)
if direct is not None:
return direct
strategy = get_import_strategy_from_string(raw)
if strategy is not None:
return resolve_stored_knowledge_type(strategy, content=content)
if allow_legacy:
normalized = raw.lower()
if normalized == "imported":
return KnowledgeType.FACTUAL
if str(content or "").strip():
from .type_detection import detect_knowledge_type
detected = detect_knowledge_type(content)
if detected is not None:
return detected
if unknown_fallback is not None:
return unknown_fallback
allowed = "/".join(allowed_knowledge_type_values())
raise ValueError(f"非法 knowledge_type: {raw}(仅允许 {allowed}")
def should_extract_relations(knowledge_type: KnowledgeType) -> bool:
"""判断是否应该做关系抽取。"""
return knowledge_type in [
KnowledgeType.STRUCTURED,
KnowledgeType.FACTUAL,
KnowledgeType.MIXED,
]
def get_default_chunk_size(knowledge_type: KnowledgeType) -> int:
"""获取默认分块大小。"""
chunk_sizes = {
KnowledgeType.STRUCTURED: 300,
KnowledgeType.NARRATIVE: 800,
KnowledgeType.FACTUAL: 500,
KnowledgeType.QUOTE: 400,
KnowledgeType.MIXED: 500,
}
return chunk_sizes.get(knowledge_type, 500)
def get_type_display_name(knowledge_type: KnowledgeType) -> str:
"""获取知识类型中文名称。"""
display_names = {
KnowledgeType.STRUCTURED: "结构化知识",
KnowledgeType.NARRATIVE: "叙事性文本",
KnowledgeType.FACTUAL: "事实陈述",
KnowledgeType.QUOTE: "引用文本",
KnowledgeType.MIXED: "混合类型",
}
return display_names.get(knowledge_type, "未知类型")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
"""Heuristic detection for import strategies and stored knowledge types."""
from __future__ import annotations
import re
from typing import Optional
from .knowledge_types import (
ImportStrategy,
KnowledgeType,
parse_import_strategy,
resolve_stored_knowledge_type,
)
_NARRATIVE_MARKERS = [
r"然后",
r"接着",
r"于是",
r"后来",
r"最后",
r"突然",
r"一天",
r"曾经",
r"有一次",
r"从前",
r"说道",
r"问道",
r"想着",
r"觉得",
]
_FACTUAL_MARKERS = [
r"",
r"",
r"",
r"",
r"属于",
r"位于",
r"包含",
r"拥有",
r"成立于",
r"出生于",
]
def _non_empty_lines(content: str) -> list[str]:
return [line for line in str(content or "").splitlines() if line.strip()]
def looks_like_structured_text(content: str) -> bool:
text = str(content or "").strip()
if "|" not in text or text.count("|") < 2:
return False
parts = text.split("|")
return len(parts) == 3 and all(part.strip() for part in parts)
def looks_like_quote_text(content: str) -> bool:
lines = _non_empty_lines(content)
if len(lines) < 5:
return False
avg_len = sum(len(line) for line in lines) / len(lines)
return avg_len < 20
def looks_like_narrative_text(content: str) -> bool:
text = str(content or "").strip()
if not text:
return False
narrative_score = sum(1 for marker in _NARRATIVE_MARKERS if re.search(marker, text))
has_dialogue = bool(re.search(r'["「『].*?["」』]', text))
has_chapter = any(token in text[:500] for token in ("Chapter", "CHAPTER", "###"))
return has_chapter or has_dialogue or narrative_score >= 2
def looks_like_factual_text(content: str) -> bool:
text = str(content or "").strip()
if not text:
return False
if looks_like_structured_text(text) or looks_like_quote_text(text):
return False
factual_score = sum(1 for marker in _FACTUAL_MARKERS if re.search(r"\s*" + marker + r"\s*", text))
if factual_score <= 0:
return False
if len(text) <= 240:
return True
return factual_score >= 2 and not looks_like_narrative_text(text)
def select_import_strategy(
content: str,
*,
override: Optional[str | ImportStrategy] = None,
chat_log: bool = False,
) -> ImportStrategy:
"""文本导入策略选择override > quote > factual > narrative。"""
if chat_log:
return ImportStrategy.NARRATIVE
strategy = parse_import_strategy(override, default=ImportStrategy.AUTO)
if strategy != ImportStrategy.AUTO:
return strategy
if looks_like_quote_text(content):
return ImportStrategy.QUOTE
if looks_like_factual_text(content):
return ImportStrategy.FACTUAL
return ImportStrategy.NARRATIVE
def detect_knowledge_type(content: str) -> KnowledgeType:
"""自动检测落库 knowledge_type无法可靠判断时回退 mixed。"""
text = str(content or "").strip()
if not text:
return KnowledgeType.MIXED
if looks_like_structured_text(text):
return KnowledgeType.STRUCTURED
if looks_like_quote_text(text):
return KnowledgeType.QUOTE
if looks_like_factual_text(text):
return KnowledgeType.FACTUAL
if looks_like_narrative_text(text):
return KnowledgeType.NARRATIVE
return KnowledgeType.MIXED
def get_type_from_user_input(type_hint: Optional[str], content: str) -> KnowledgeType:
"""优先使用显式 type_hint否则自动检测。"""
if type_hint:
return resolve_stored_knowledge_type(type_hint, content=content)
return detect_knowledge_type(content)

View File

@@ -0,0 +1,776 @@
"""
向量存储模块
基于Faiss的高效向量存储与检索支持SQ8量化、Append-Only磁盘存储和内存映射。
"""
import os
import pickle
import hashlib
import shutil
import time
from pathlib import Path
from typing import Optional, Union, Tuple, List, Dict, Set, Any
import random
import threading # Added threading import
import numpy as np
try:
import faiss
HAS_FAISS = True
except ImportError:
HAS_FAISS = False
from src.common.logger import get_logger
from ..utils.quantization import QuantizationType
from ..utils.io import atomic_write, atomic_save_path
logger = get_logger("A_Memorix.VectorStore")
class VectorStore:
"""
向量存储类 (SQ8 + Append-Only Disk)
特性:
- 索引: IndexIDMap2(IndexScalarQuantizer(QT_8bit))
- 存储: float16 on-disk binary (vectors.bin)
- 内存: 仅索引常驻 RAM (<512MB for 100k vectors)
- ID: SHA1-based stable int64 IDs
- 一致性: 强制 L2 Normalization (IP == Cosine)
"""
# 默认训练触发阈值 (40 样本,过大可能导致小数据集不生效,过小可能量化退化)
DEFAULT_MIN_TRAIN = 40
# 强制训练样本量
TRAIN_SIZE = 10000
# 储水池采样上限 (流式处理前 50k 数据)
RESERVOIR_CAPACITY = 10000
RESERVOIR_SAMPLE_SCOPE = 50000
def __init__(
self,
dimension: int,
quantization_type: QuantizationType = QuantizationType.INT8,
index_type: str = "sq8",
data_dir: Optional[Union[str, Path]] = None,
use_mmap: bool = True,
buffer_size: int = 1024,
):
if not HAS_FAISS:
raise ImportError("Faiss 未安装,请安装: pip install faiss-cpu")
self.dimension = dimension
self.data_dir = Path(data_dir) if data_dir else None
if self.data_dir:
self.data_dir.mkdir(parents=True, exist_ok=True)
if quantization_type != QuantizationType.INT8:
raise ValueError(
"vNext 仅支持 quantization_type=int8(SQ8)。"
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
)
normalized_index_type = str(index_type or "sq8").strip().lower()
if normalized_index_type not in {"sq8", "int8"}:
raise ValueError(
"vNext 仅支持 index_type=sq8。"
" 请更新配置并执行 scripts/release_vnext_migrate.py migrate。"
)
self.quantization_type = QuantizationType.INT8
self.index_type = "sq8"
self.buffer_size = buffer_size
self._index: Optional[faiss.IndexIDMap2] = None
self._init_index()
self._is_trained = False
self._vector_norm = "l2"
# Fallback Index (Flat) - 用于在 SQ8 训练完成前提供检索能力
# 必须使用 IndexIDMap2 以保证 ID 与主索引一致
self._fallback_index: Optional[faiss.IndexIDMap2] = None
self._init_fallback_index()
self._known_hashes: Set[str] = set()
self._deleted_ids: Set[int] = set()
self._reservoir_buffer: List[np.ndarray] = []
self._seen_count_for_reservoir = 0
self._write_buffer_vecs: List[np.ndarray] = []
self._write_buffer_ids: List[int] = []
self._total_added = 0
self._total_deleted = 0
self._bin_count = 0
# Thread safety lock
self._lock = threading.RLock()
logger.info(f"VectorStore Init: dim={dimension}, SQ8 Mode, Append-Only Storage")
def _init_index(self):
"""初始化空的 Faiss 索引"""
quantizer = faiss.IndexScalarQuantizer(
self.dimension,
faiss.ScalarQuantizer.QT_8bit,
faiss.METRIC_INNER_PRODUCT
)
self._index = faiss.IndexIDMap2(quantizer)
self._is_trained = False
def _init_fallback_index(self):
"""初始化 Flat 回退索引"""
flat_index = faiss.IndexFlatIP(self.dimension)
self._fallback_index = faiss.IndexIDMap2(flat_index)
logger.debug("Fallback index (Flat) initialized.")
@staticmethod
def _generate_id(key: str) -> int:
"""生成稳定的 int64 ID (SHA1 截断)"""
h = hashlib.sha1(key.encode("utf-8")).digest()
val = int.from_bytes(h[:8], byteorder="big", signed=False)
return val & 0x7FFFFFFFFFFFFFFF
@property
def _bin_path(self) -> Path:
return self.data_dir / "vectors.bin"
@property
def _ids_bin_path(self) -> Path:
return self.data_dir / "vectors_ids.bin"
@property
def _int_to_str_map(self) -> Dict[int, str]:
"""Lazy build volatile map from known hashes"""
# Note: This is read-heavy and cached, might need lock if _known_hashes updates concurrently
# But add/delete are now locked, so checking len mismatch is somewhat safe-ish for quick dirty cache
if not hasattr(self, "_cached_map") or len(self._cached_map) != len(self._known_hashes):
with self._lock: # Protect cache rebuild
self._cached_map = {self._generate_id(k): k for k in self._known_hashes}
return self._cached_map
def add(self, vectors: np.ndarray, ids: List[str]) -> int:
with self._lock:
if vectors.shape[1] != self.dimension:
raise ValueError(f"Dimension mismatch: {vectors.shape[1]} vs {self.dimension}")
vectors = np.ascontiguousarray(vectors, dtype=np.float32)
faiss.normalize_L2(vectors)
processed_vecs = []
processed_int_ids = []
for i, str_id in enumerate(ids):
if str_id in self._known_hashes:
continue
int_id = self._generate_id(str_id)
self._known_hashes.add(str_id)
processed_vecs.append(vectors[i])
processed_int_ids.append(int_id)
if not processed_vecs:
return 0
batch_vecs = np.array(processed_vecs, dtype=np.float32)
batch_ids = np.array(processed_int_ids, dtype=np.int64)
self._write_buffer_vecs.append(batch_vecs)
self._write_buffer_ids.extend(processed_int_ids)
if len(self._write_buffer_ids) >= self.buffer_size:
self._flush_write_buffer_unlocked()
if not self._is_trained:
# 双写到回退索引
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
self._update_reservoir(batch_vecs)
# 这里的 TRAIN_SIZE 取默认 10k或者根据当前数据量动态判断
if len(self._reservoir_buffer) >= 10000:
logger.info(f"训练样本达到上限,开始训练...")
self._train_and_replay_unlocked()
self._total_added += len(batch_ids)
return len(batch_ids)
def _flush_write_buffer(self):
with self._lock:
self._flush_write_buffer_unlocked()
def _flush_write_buffer_unlocked(self):
if not self._write_buffer_vecs:
return
batch_vecs = np.concatenate(self._write_buffer_vecs, axis=0)
batch_ids = np.array(self._write_buffer_ids, dtype=np.int64)
vecs_fp16 = batch_vecs.astype(np.float16)
with open(self._bin_path, "ab") as f:
f.write(vecs_fp16.tobytes())
ids_bytes = batch_ids.astype('>i8').tobytes()
with open(self._ids_bin_path, "ab") as f:
f.write(ids_bytes)
self._bin_count += len(batch_ids)
if self._is_trained and self._index.is_trained:
self._index.add_with_ids(batch_vecs, batch_ids)
else:
# 即使在 flush 时,如果未训练,也要同步到 fallback
self._fallback_index.add_with_ids(batch_vecs, batch_ids)
self._write_buffer_vecs.clear()
self._write_buffer_ids.clear()
def _update_reservoir(self, vectors: np.ndarray):
for vec in vectors:
self._seen_count_for_reservoir += 1
if len(self._reservoir_buffer) < self.RESERVOIR_CAPACITY:
self._reservoir_buffer.append(vec)
else:
if self._seen_count_for_reservoir <= self.RESERVOIR_SAMPLE_SCOPE:
r = random.randint(0, self._seen_count_for_reservoir - 1)
if r < self.RESERVOIR_CAPACITY:
self._reservoir_buffer[r] = vec
def _train_and_replay(self):
with self._lock:
self._train_and_replay_unlocked()
def _train_and_replay_unlocked(self):
if not self._reservoir_buffer:
logger.warning("No training data available.")
return
train_data = np.array(self._reservoir_buffer, dtype=np.float32)
logger.info(f"Training Index with {len(train_data)} samples...")
try:
self._index.train(train_data)
except Exception as e:
logger.error(f"SQ8 Training failed: {e}. Staying in fallback mode.")
return
self._is_trained = True
self._reservoir_buffer = []
logger.info("Replaying data from disk to populate index...")
try:
replay_count = self._replay_vectors_to_index()
# 只有当 replay 成功且数据量一致时,才释放回退索引
if self._index.ntotal >= self._bin_count:
logger.info(f"Replay successful ({self._index.ntotal}/{self._bin_count}). Releasing fallback index.")
self._fallback_index.reset()
else:
logger.warning(f"Replay count mismatch: {self._index.ntotal} vs {self._bin_count}. Keeping fallback index.")
except Exception as e:
logger.error(f"Replay failed: {e}. Keeping fallback index as backup.")
def _replay_vectors_to_index(self) -> int:
"""从 vectors.bin 读取并添加到 index"""
if not self._bin_path.exists() or not self._ids_bin_path.exists():
return 0
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data:
break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_fp32 = batch_fp16.astype(np.float32)
faiss.normalize_L2(batch_fp32)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if not all(valid_mask):
batch_fp32 = batch_fp32[valid_mask]
batch_ids = batch_ids[valid_mask]
if len(batch_ids) > 0:
self._index.add_with_ids(batch_fp32, batch_ids)
def search(
self,
query: np.ndarray,
k: int = 10,
filter_deleted: bool = True,
) -> Tuple[List[str], List[float]]:
query_local = np.array(query, dtype=np.float32, order="C", copy=True)
if query_local.ndim == 1:
got_dim = int(query_local.shape[0])
query_local = query_local.reshape(1, -1)
elif query_local.ndim == 2:
if query_local.shape[0] != 1:
raise ValueError(
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
)
got_dim = int(query_local.shape[1])
else:
raise ValueError(
f"query embedding must have shape (D,) or (1, D), got {tuple(query_local.shape)}"
)
if got_dim != self.dimension:
raise ValueError(
f"query embedding dimension mismatch: expected={self.dimension} got={got_dim}"
)
if not np.all(np.isfinite(query_local)):
raise ValueError("query embedding contains non-finite values")
faiss.normalize_L2(query_local)
# 查询路径仅负责检索,不在此触发训练/回放。
# 训练/回放前置到 warmup_index(),并由插件启动阶段触发。
# Faiss 索引在并发 search 下可能出现阻塞,这里串行化检索调用保证稳定性。
with self._lock:
self._flush_write_buffer_unlocked()
search_index = self._index if (self._is_trained and self._index.ntotal > 0) else self._fallback_index
if search_index.ntotal == 0:
logger.warning("Indices are empty. No data to search.")
return [], []
# 执行检索
dists, ids = search_index.search(query_local, k * 2)
# Faiss search 返回的是 (1, K) 的数组,取第一行
dists = dists[0]
ids = ids[0]
results = []
for id_val, score in zip(ids, dists):
if id_val == -1: continue
if filter_deleted and id_val in self._deleted_ids:
continue
str_id = self._int_to_str_map.get(id_val)
if str_id:
results.append((str_id, float(score)))
# Sort and trim just in case filtering reduced count
results.sort(key=lambda x: x[1], reverse=True)
results = results[:k]
if not results:
return [], []
return [r[0] for r in results], [r[1] for r in results]
def warmup_index(self, force_train: bool = True) -> Dict[str, Any]:
"""
预热向量索引(训练/回放前置),避免首个线上查询触发重初始化。
Args:
force_train: 是否在满足阈值时强制训练 SQ8 索引
Returns:
预热状态摘要
"""
started = time.perf_counter()
logger.info(f"metric.vector_index_prewarm_started=1 force_train={bool(force_train)}")
try:
with self._lock:
self._flush_write_buffer()
if self._bin_path.exists():
self._bin_count = self._bin_path.stat().st_size // (self.dimension * 2)
else:
self._bin_count = 0
needs_fallback_bootstrap = (
self._bin_count > 0
and self._fallback_index.ntotal == 0
and (not self._is_trained or self._index.ntotal == 0)
)
if needs_fallback_bootstrap:
self._bootstrap_fallback_from_disk()
min_train = max(1, int(getattr(self, "min_train_threshold", self.DEFAULT_MIN_TRAIN)))
needs_train = (
bool(force_train)
and self._bin_count >= min_train
and not self._is_trained
)
if needs_train:
self._force_train_small_data()
duration_ms = (time.perf_counter() - started) * 1000.0
summary = {
"ok": True,
"trained": bool(self._is_trained),
"index_ntotal": int(self._index.ntotal),
"fallback_ntotal": int(self._fallback_index.ntotal),
"bin_count": int(self._bin_count),
"duration_ms": duration_ms,
"error": None,
}
except Exception as e:
duration_ms = (time.perf_counter() - started) * 1000.0
summary = {
"ok": False,
"trained": bool(self._is_trained),
"index_ntotal": int(self._index.ntotal) if self._index is not None else 0,
"fallback_ntotal": int(self._fallback_index.ntotal) if self._fallback_index is not None else 0,
"bin_count": int(getattr(self, "_bin_count", 0)),
"duration_ms": duration_ms,
"error": str(e),
}
logger.error(
"metric.vector_index_prewarm_fail=1 "
f"metric.vector_index_prewarm_duration_ms={duration_ms:.2f} "
f"error={e}"
)
return summary
logger.info(
"metric.vector_index_prewarm_success=1 "
f"metric.vector_index_prewarm_duration_ms={summary['duration_ms']:.2f} "
f"trained={summary['trained']} "
f"index_ntotal={summary['index_ntotal']} "
f"fallback_ntotal={summary['fallback_ntotal']} "
f"bin_count={summary['bin_count']}"
)
return summary
def _bootstrap_fallback_from_disk(self):
with self._lock:
self._bootstrap_fallback_from_disk_unlocked()
def _bootstrap_fallback_from_disk_unlocked(self):
"""重启后自举:从磁盘 vectors.bin 加载数据到 fallback 索引"""
if not self._bin_path.exists() or not self._ids_bin_path.exists():
return
logger.info("Replaying all disk vectors to fallback index...")
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data: break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_fp32 = batch_fp16.astype(np.float32)
faiss.normalize_L2(batch_fp32)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
valid_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if any(valid_mask):
self._fallback_index.add_with_ids(batch_fp32[valid_mask], batch_ids[valid_mask])
logger.info(f"Fallback index self-bootstrapped with {self._fallback_index.ntotal} items.")
def _force_train_small_data(self):
with self._lock:
self._force_train_small_data_unlocked()
def _force_train_small_data_unlocked(self):
logger.info("Forcing training on small dataset...")
self._reservoir_buffer = []
chunk_size = 10000
vec_item_size = self.dimension * 2
with open(self._bin_path, "rb") as f:
while len(self._reservoir_buffer) < self.TRAIN_SIZE:
data = f.read(chunk_size * vec_item_size)
if not data: break
fp16 = np.frombuffer(data, dtype=np.float16).reshape(-1, self.dimension)
fp32 = fp16.astype(np.float32)
faiss.normalize_L2(fp32)
for vec in fp32:
self._reservoir_buffer.append(vec)
if len(self._reservoir_buffer) >= self.TRAIN_SIZE:
break
self._train_and_replay_unlocked()
def delete(self, ids: List[str]) -> int:
with self._lock:
count = 0
for str_id in ids:
if str_id not in self._known_hashes:
continue
int_id = self._generate_id(str_id)
if int_id not in self._deleted_ids:
self._deleted_ids.add(int_id)
if self._index.is_trained:
self._index.remove_ids(np.array([int_id], dtype=np.int64))
# 同步从 fallback 移除
if self._fallback_index.ntotal > 0:
self._fallback_index.remove_ids(np.array([int_id], dtype=np.int64))
count += 1
self._total_deleted += count
# Check GC
self._check_rebuild_needed()
return count
def _check_rebuild_needed(self):
"""GC Excution Check"""
if self._bin_count == 0: return
ratio = len(self._deleted_ids) / self._bin_count
if ratio > 0.3 and len(self._deleted_ids) > 1000:
logger.info(f"Triggering GC/Rebuild (deleted ratio: {ratio:.2f})")
self.rebuild_index()
def rebuild_index(self):
"""GC: 重建索引,压缩 bin 文件"""
with self._lock:
self._rebuild_index_locked()
def _rebuild_index_locked(self):
"""实际 GC 重建逻辑。"""
logger.info("Starting Compaction (GC)...")
tmp_bin = self.data_dir / "vectors.bin.tmp"
tmp_ids = self.data_dir / "vectors_ids.bin.tmp"
vec_item_size = self.dimension * 2
id_item_size = 8
chunk_size = 10000
new_count = 0
# 1. Compact Files
with open(self._bin_path, "rb") as f_vec, open(self._ids_bin_path, "rb") as f_id, \
open(tmp_bin, "wb") as w_vec, open(tmp_ids, "wb") as w_id:
while True:
vec_data = f_vec.read(chunk_size * vec_item_size)
id_data = f_id.read(chunk_size * id_item_size)
if not vec_data: break
batch_fp16 = np.frombuffer(vec_data, dtype=np.float16).reshape(-1, self.dimension)
batch_ids = np.frombuffer(id_data, dtype='>i8').astype(np.int64)
keep_mask = [id_ not in self._deleted_ids for id_ in batch_ids]
if any(keep_mask):
keep_vecs = batch_fp16[keep_mask]
keep_ids = batch_ids[keep_mask]
w_vec.write(keep_vecs.tobytes())
w_id.write(keep_ids.astype('>i8').tobytes())
new_count += len(keep_ids)
# 2. Reset State & Atomic Swap
self._bin_count = new_count
# Close current index
self._index.reset()
if self._fallback_index: self._fallback_index.reset() # Also clear fallback
self._is_trained = False
# Swap files
shutil.move(str(tmp_bin), str(self._bin_path))
shutil.move(str(tmp_ids), str(self._ids_bin_path))
# Reset Tombstones (Critical)
self._deleted_ids.clear()
# 3. Reload/Rebuild Index (Fresh Train)
# We need to re-train because data distribution might have changed significantly after deletion
self._init_index()
self._init_fallback_index() # Re-init fallback too
self._force_train_small_data() # This will train and replay from the NEW compact file
logger.info("Compaction Complete.")
def save(self, data_dir: Optional[Union[str, Path]] = None) -> None:
with self._lock:
if not data_dir:
data_dir = self.data_dir
if not data_dir:
raise ValueError("No data_dir")
data_dir = Path(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
self._flush_write_buffer_unlocked()
if self._is_trained:
index_path = data_dir / "vectors.index"
with atomic_save_path(index_path) as tmp:
faiss.write_index(self._index, tmp)
meta = {
"dimension": self.dimension,
"quantization_type": self.quantization_type.value,
"is_trained": self._is_trained,
"vector_norm": self._vector_norm,
"deleted_ids": list(self._deleted_ids),
"known_hashes": list(self._known_hashes),
}
with atomic_write(data_dir / "vectors_metadata.pkl", "wb") as f:
pickle.dump(meta, f)
logger.info("VectorStore saved.")
def migrate_legacy_npy(self, data_dir: Optional[Union[str, Path]] = None) -> Dict[str, Any]:
"""
离线迁移入口:将 legacy vectors.npy 转为 vNext 二进制格式。
"""
with self._lock:
target_dir = Path(data_dir) if data_dir else self.data_dir
if target_dir is None:
raise ValueError("No data_dir")
target_dir = Path(target_dir)
npy_path = target_dir / "vectors.npy"
idx_path = target_dir / "vectors.index"
bin_path = target_dir / "vectors.bin"
ids_bin_path = target_dir / "vectors_ids.bin"
meta_path = target_dir / "vectors_metadata.pkl"
if not npy_path.exists():
return {"migrated": False, "reason": "npy_missing"}
if not meta_path.exists():
raise RuntimeError("legacy vectors.npy migration requires vectors_metadata.pkl")
if bin_path.exists() and ids_bin_path.exists():
return {"migrated": False, "reason": "bin_exists"}
# Reset in-memory state to avoid appending to stale runtime buffers.
self._known_hashes.clear()
self._deleted_ids.clear()
self._write_buffer_vecs.clear()
self._write_buffer_ids.clear()
self._init_index()
self._init_fallback_index()
self._is_trained = False
self._bin_count = 0
self._migrate_from_npy_unlocked(npy_path, idx_path, target_dir)
self.save(target_dir)
return {"migrated": True, "reason": "ok"}
def load(self, data_dir: Optional[Union[str, Path]] = None) -> None:
with self._lock:
if not data_dir: data_dir = self.data_dir
data_dir = Path(data_dir)
npy_path = data_dir / "vectors.npy"
idx_path = data_dir / "vectors.index"
bin_path = data_dir / "vectors.bin"
if npy_path.exists() and not bin_path.exists():
raise RuntimeError(
"检测到 legacy vectors.npyvNext 不再支持运行时自动迁移。"
" 请先执行 scripts/release_vnext_migrate.py migrate。"
)
meta_path = data_dir / "vectors_metadata.pkl"
if not meta_path.exists():
logger.warning("No metadata found, initialized empty.")
return
with open(meta_path, "rb") as f:
meta = pickle.load(f)
if meta.get("vector_norm") != "l2":
logger.warning("Index IDMap2 version mismatch (L2 Norm), forcing rebuild...")
self._known_hashes = set(meta.get("ids", [])) | set(meta.get("known_hashes", []))
self._deleted_ids = set(meta.get("deleted_ids", []))
self._init_index()
self._force_train_small_data()
return
self._is_trained = meta.get("is_trained", False)
self._vector_norm = meta.get("vector_norm", "l2")
self._deleted_ids = set(meta.get("deleted_ids", []))
self._known_hashes = set(meta.get("known_hashes", []))
if self._is_trained:
if idx_path.exists():
try:
self._index = faiss.read_index(str(idx_path))
if not isinstance(self._index, faiss.IndexIDMap2):
logger.warning("Loaded index type mismatch. Rebuilding...")
self._init_index()
self._force_train_small_data()
except Exception as e:
logger.error(f"Failed to load index: {e}. Rebuilding...")
self._init_index()
self._force_train_small_data()
else:
logger.warning("Index file missing despite metadata indicating trained. Rebuilding from bin...")
self._init_index()
self._force_train_small_data()
if bin_path.exists():
self._bin_count = bin_path.stat().st_size // (self.dimension * 2)
def _migrate_from_npy(self, npy_path, idx_path, data_dir):
with self._lock:
self._migrate_from_npy_unlocked(npy_path, idx_path, data_dir)
def _migrate_from_npy_unlocked(self, npy_path, idx_path, data_dir):
try:
arr = np.load(npy_path, mmap_mode="r")
except Exception:
arr = np.load(npy_path)
meta_path = data_dir / "vectors_metadata.pkl"
old_ids = []
if meta_path.exists():
with open(meta_path, "rb") as f:
m = pickle.load(f)
old_ids = m.get("ids", [])
if len(arr) != len(old_ids):
logger.error(f"Migration mismatch: arr {len(arr)} != ids {len(old_ids)}")
return
logger.info(f"Migrating {len(arr)} vectors...")
chunk = 1000
for i in range(0, len(arr), chunk):
sub_arr = arr[i : i+chunk]
sub_ids = old_ids[i : i+chunk]
self.add(sub_arr, sub_ids)
if not self._is_trained:
self._force_train_small_data()
shutil.move(str(npy_path), str(npy_path) + ".bak")
if idx_path.exists():
shutil.move(str(idx_path), str(idx_path) + ".bak")
logger.info("Migration complete.")
def clear(self) -> None:
with self._lock:
self._ids_bin_path.unlink(missing_ok=True)
self._bin_path.unlink(missing_ok=True)
self._init_index()
self._known_hashes.clear()
self._deleted_ids.clear()
self._bin_count = 0
logger.info("VectorStore cleared.")
def has_data(self) -> bool:
return (self.data_dir / "vectors_metadata.pkl").exists()
@property
def num_vectors(self) -> int:
return len(self._known_hashes) - len(self._deleted_ids)
def __contains__(self, hash_value: str) -> bool:
"""Check if a hash exists in the store"""
return hash_value in self._known_hashes and self._generate_id(hash_value) not in self._deleted_ids

View File

@@ -0,0 +1,89 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from dataclasses import dataclass, field
from enum import Enum
import hashlib
class KnowledgeType(str, Enum):
NARRATIVE = "narrative"
FACTUAL = "factual"
QUOTE = "quote"
MIXED = "mixed"
@dataclass
class SourceInfo:
file: str
offset_start: int
offset_end: int
checksum: str = ""
@dataclass
class ChunkContext:
chunk_id: str
index: int
context: Dict[str, Any] = field(default_factory=dict)
text: str = ""
@dataclass
class ChunkFlags:
verbatim: bool = False
requires_llm: bool = True
@dataclass
class ProcessedChunk:
type: KnowledgeType
source: SourceInfo
chunk: ChunkContext
data: Dict[str, Any] = field(default_factory=dict) # triples、events、verbatim_entities
flags: ChunkFlags = field(default_factory=ChunkFlags)
def to_dict(self) -> Dict:
return {
"type": self.type.value,
"source": {
"file": self.source.file,
"offset_start": self.source.offset_start,
"offset_end": self.source.offset_end,
"checksum": self.source.checksum
},
"chunk": {
"text": self.chunk.text,
"chunk_id": self.chunk.chunk_id,
"index": self.chunk.index,
"context": self.chunk.context
},
"data": self.data,
"flags": {
"verbatim": self.flags.verbatim,
"requires_llm": self.flags.requires_llm
}
}
class BaseStrategy(ABC):
def __init__(self, filename: str):
self.filename = filename
@abstractmethod
def split(self, text: str) -> List[ProcessedChunk]:
"""按策略将文本切分为块。"""
pass
@abstractmethod
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
"""从文本块中抽取结构化信息。"""
pass
def calculate_checksum(self, text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def build_language_guard(self, text: str) -> str:
"""
构建统一的输出语言约束。
不区分语言类型,仅要求抽取值保持原文语言,不做翻译。
"""
_ = text # 预留参数,便于后续按需扩展
return (
"Focus on the original source language. Keep extracted events, entities, predicates "
"and objects in the same language as the source text, preserve names/terms as-is, "
"and do not translate."
)

View File

@@ -0,0 +1,98 @@
import re
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
class FactualStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
# 结构感知切分
lines = text.split('\n')
chunks = []
current_chunk_lines = []
current_len = 0
target_size = 600
for i, line in enumerate(lines):
# 判断是否应当切分
# 若当前行为列表项/定义/表格行,则尽量不切分
is_structure = self._is_structural_line(line)
current_len += len(line) + 1
current_chunk_lines.append(line)
# 达到目标长度且不在紧凑结构块内时切分(过长时强制切分)
if current_len >= target_size and not is_structure:
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
current_chunk_lines = []
current_len = 0
elif current_len >= target_size * 2: # 超长时强制切分
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
current_chunk_lines = []
current_len = 0
if current_chunk_lines:
chunks.append(self._create_chunk(current_chunk_lines, len(chunks)))
return chunks
def _is_structural_line(self, line: str) -> bool:
line = line.strip()
if not line: return False
# 列表项
if re.match(r'^[\-\*]\s+', line) or re.match(r'^\d+\.\s+', line):
return True
# 定义项(术语: 定义)
if re.match(r'^[^:]+[:].+', line):
return True
# 表格行(按 markdown 语法假设)
if line.startswith('|') and line.endswith('|'):
return True
return False
def _create_chunk(self, lines: List[str], index: int) -> ProcessedChunk:
text = "\n".join(lines)
return ProcessedChunk(
type=KnowledgeType.FACTUAL,
source=SourceInfo(
file=self.filename,
offset_start=0, # 简化处理:真实偏移跟踪需要额外状态
offset_end=0,
checksum=self.calculate_checksum(text)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{index}",
index=index,
text=text
)
)
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
if not llm_func:
raise ValueError("LLM function required for Factual extraction")
language_guard = self.build_language_guard(chunk.chunk.text)
prompt = f"""You are a factual knowledge extraction engine.
Extract factual triples and entities from the text.
Preserve lists and definitions accurately.
Language constraints:
- {language_guard}
- Preserve original names and domain terms exactly when possible.
- JSON keys must stay exactly as: triples, entities, subject, predicate, object.
Text:
{chunk.chunk.text}
Return ONLY valid JSON:
{{
"triples": [
{{"subject": "Entity", "predicate": "Relationship", "object": "Entity"}}
],
"entities": ["Entity1", "Entity2"]
}}
"""
result = await llm_func(prompt)
# 结果保持原样存入 data后续统一归一化流程会处理
# vector_store 侧期望关系字段为 subject/predicate/object 映射形式
chunk.data = result
return chunk

View File

@@ -0,0 +1,126 @@
import re
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext
class NarrativeStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
scenes = self._split_into_scenes(text)
chunks = []
for scene_idx, (scene_text, scene_title) in enumerate(scenes):
scene_chunks = self._sliding_window(scene_text, scene_title, scene_idx)
chunks.extend(scene_chunks)
return chunks
def _split_into_scenes(self, text: str) -> List[tuple[str, str]]:
"""按标题或分隔符把文本切分为场景。"""
# 简单启发式:按 markdown 标题或特定分隔符切分
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
# 该正则匹配以 #、Chapter 或 *** / === 开头的分隔行
scene_pattern_str = r'^(?:#{1,6}\s+.*|Chapter\s+\d+|^\*{3,}$|^={3,}$)'
# 保留分隔符,以便识别场景起点
parts = re.split(f"({scene_pattern_str})", text, flags=re.MULTILINE)
scenes = []
current_scene_title = "Start"
current_scene_content = []
if parts and parts[0].strip() == "":
parts = parts[1:]
for part in parts:
if re.match(scene_pattern_str, part, re.MULTILINE):
# 先保存上一段场景
if current_scene_content:
scenes.append(("".join(current_scene_content), current_scene_title))
current_scene_content = []
current_scene_title = part.strip()
else:
current_scene_content.append(part)
if current_scene_content:
scenes.append(("".join(current_scene_content), current_scene_title))
# 若未识别到场景,则把全文视作单一场景
if not scenes:
scenes = [(text, "Whole Text")]
return scenes
def _sliding_window(self, text: str, scene_id: str, scene_idx: int, window_size=800, overlap=200) -> List[ProcessedChunk]:
chunks = []
if len(text) <= window_size:
chunks.append(self._create_chunk(text, scene_id, scene_idx, 0, 0))
return chunks
stride = window_size - overlap
start = 0
local_idx = 0
while start < len(text):
end = min(start + window_size, len(text))
chunk_text = text[start:end]
# 尽量对齐到最近换行,避免生硬截断句子
# 仅在未到文本尾部时进行回退
if end < len(text):
last_newline = chunk_text.rfind('\n')
if last_newline > window_size // 2: # 仅在回退距离可接受时启用
end = start + last_newline + 1
chunk_text = text[start:end]
chunks.append(self._create_chunk(chunk_text, scene_id, scene_idx, local_idx, start))
start += len(chunk_text) - overlap if end < len(text) else len(chunk_text)
local_idx += 1
return chunks
def _create_chunk(self, text: str, scene_id: str, scene_idx: int, local_idx: int, offset: int) -> ProcessedChunk:
return ProcessedChunk(
type=KnowledgeType.NARRATIVE,
source=SourceInfo(
file=self.filename,
offset_start=offset,
offset_end=offset + len(text),
checksum=self.calculate_checksum(text)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{scene_idx}_{local_idx}",
index=local_idx,
text=text,
context={"scene_id": scene_id}
)
)
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
if not llm_func:
raise ValueError("LLM function required for Narrative extraction")
language_guard = self.build_language_guard(chunk.chunk.text)
prompt = f"""You are a narrative knowledge extraction engine.
Extract key events and character relations from the scene text.
Language constraints:
- {language_guard}
- Preserve original names and terms exactly when possible.
- JSON keys must stay exactly as: events, relations, subject, predicate, object.
Scene:
{chunk.chunk.context.get('scene_id')}
Text:
{chunk.chunk.text}
Return ONLY valid JSON:
{{
"events": ["event description 1", "event description 2"],
"relations": [
{{"subject": "CharacterA", "predicate": "relation", "object": "CharacterB"}}
]
}}
"""
result = await llm_func(prompt)
chunk.data = result
return chunk

View File

@@ -0,0 +1,52 @@
from typing import List, Dict, Any
from .base import BaseStrategy, ProcessedChunk, KnowledgeType, SourceInfo, ChunkContext, ChunkFlags
class QuoteStrategy(BaseStrategy):
def split(self, text: str) -> List[ProcessedChunk]:
# Split by double newlines (stanzas)
stanzas = text.split("\n\n")
chunks = []
offset = 0
for idx, stanza in enumerate(stanzas):
if not stanza.strip():
offset += len(stanza) + 2
continue
chunk = ProcessedChunk(
type=KnowledgeType.QUOTE,
source=SourceInfo(
file=self.filename,
offset_start=offset,
offset_end=offset + len(stanza),
checksum=self.calculate_checksum(stanza)
),
chunk=ChunkContext(
chunk_id=f"{self.filename}_{idx}",
index=idx,
text=stanza
),
flags=ChunkFlags(
verbatim=True,
requires_llm=False # Default to no LLM, but can be overridden
)
)
chunks.append(chunk)
offset += len(stanza) + 2 # +2 for \n\n
return chunks
async def extract(self, chunk: ProcessedChunk, llm_func=None) -> ProcessedChunk:
# For quotes, the text itself is the entity/knowledge
# We might use LLM to extract headers/metadata if requested, but core logic is pass-through
# Treat the whole chunk text as a verbatim entity
chunk.data = {
"verbatim_entities": [chunk.chunk.text]
}
if llm_func and chunk.flags.requires_llm:
# Optional: Extract metadata
pass
return chunk

View File

@@ -0,0 +1,33 @@
"""工具模块 - 哈希、监控等辅助功能"""
from .hash import compute_hash, normalize_text
from .monitor import MemoryMonitor
from .quantization import quantize_vector, dequantize_vector
from .time_parser import (
parse_query_datetime_to_timestamp,
parse_query_time_range,
parse_ingest_datetime_to_timestamp,
normalize_time_meta,
format_timestamp,
)
from .relation_write_service import RelationWriteService, RelationWriteResult
from .relation_query import RelationQuerySpec, parse_relation_query_spec
from .plugin_id_policy import PluginIdPolicy
__all__ = [
"compute_hash",
"normalize_text",
"MemoryMonitor",
"quantize_vector",
"dequantize_vector",
"parse_query_datetime_to_timestamp",
"parse_query_time_range",
"parse_ingest_datetime_to_timestamp",
"normalize_time_meta",
"format_timestamp",
"RelationWriteService",
"RelationWriteResult",
"RelationQuerySpec",
"parse_relation_query_spec",
"PluginIdPolicy",
]

View File

@@ -0,0 +1,360 @@
"""
聚合查询服务:
- 并发执行 search/time/episode 分支
- 统一分支结果结构
- 可选混合排序Weighted RRF
"""
from __future__ import annotations
import asyncio
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from src.common.logger import get_logger
logger = get_logger("A_Memorix.AggregateQueryService")
BranchRunner = Callable[[], Awaitable[Dict[str, Any]]]
class AggregateQueryService:
"""聚合查询执行服务search/time/episode"""
def __init__(self, plugin_config: Optional[Any] = None):
self.plugin_config = plugin_config or {}
def _cfg(self, key: str, default: Any = None) -> Any:
getter = getattr(self.plugin_config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
@staticmethod
def _as_float(value: Any, default: float = 0.0) -> float:
try:
return float(value)
except Exception:
return float(default)
@staticmethod
def _as_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _rrf_k(self) -> float:
raw = self._cfg("retrieval.aggregate.rrf_k", 60.0)
value = self._as_float(raw, 60.0)
return max(1.0, value)
def _weights(self) -> Dict[str, float]:
defaults = {"search": 1.0, "time": 1.0, "episode": 1.0}
raw = self._cfg("retrieval.aggregate.weights", {})
if not isinstance(raw, dict):
return defaults
out = dict(defaults)
for key in ("search", "time", "episode"):
if key in raw:
out[key] = max(0.0, self._as_float(raw.get(key), defaults[key]))
return out
@staticmethod
def _normalize_branch_payload(
name: str,
payload: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
data = payload if isinstance(payload, dict) else {}
results_raw = data.get("results", [])
results = results_raw if isinstance(results_raw, list) else []
count = data.get("count")
if count is None:
count = len(results)
return {
"name": name,
"success": bool(data.get("success", False)),
"skipped": bool(data.get("skipped", False)),
"skip_reason": str(data.get("skip_reason", "") or "").strip(),
"error": str(data.get("error", "") or "").strip(),
"results": results,
"count": max(0, int(count)),
"elapsed_ms": max(0.0, float(data.get("elapsed_ms", 0.0) or 0.0)),
"content": str(data.get("content", "") or ""),
"query_type": str(data.get("query_type", "") or name),
}
@staticmethod
def _mix_key(item: Dict[str, Any], branch: str, rank: int) -> str:
item_type = str(item.get("type", "") or "").strip().lower()
if item_type == "episode":
episode_id = str(item.get("episode_id", "") or "").strip()
if episode_id:
return f"episode:{episode_id}"
item_hash = str(item.get("hash", "") or "").strip()
if item_hash:
return f"{item_type}:{item_hash}"
return f"{branch}:{item_type}:{rank}:{str(item.get('content', '') or '')[:80]}"
def _build_mixed_results(
self,
*,
branches: Dict[str, Dict[str, Any]],
top_k: int,
) -> List[Dict[str, Any]]:
rrf_k = self._rrf_k()
weights = self._weights()
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, branch in branches.items():
if not branch.get("success", False):
continue
results = branch.get("results", [])
if not isinstance(results, list):
continue
weight = max(0.0, float(weights.get(branch_name, 1.0)))
for idx, item in enumerate(results, start=1):
if not isinstance(item, dict):
continue
key = self._mix_key(item, branch_name, idx)
score = weight / (rrf_k + float(idx))
if key not in bucket:
merged = dict(item)
merged["fusion_score"] = 0.0
merged["_source_branches"] = set()
bucket[key] = merged
target = bucket[key]
target["fusion_score"] = float(target.get("fusion_score", 0.0)) + score
target["_source_branches"].add(branch_name)
mixed = list(bucket.values())
mixed.sort(
key=lambda x: (
-float(x.get("fusion_score", 0.0)),
str(x.get("type", "") or ""),
str(x.get("hash", "") or x.get("episode_id", "") or ""),
)
)
out: List[Dict[str, Any]] = []
for rank, item in enumerate(mixed[: max(1, int(top_k))], start=1):
merged = dict(item)
branches_set = merged.pop("_source_branches", set())
merged["source_branches"] = sorted(list(branches_set))
merged["rank"] = rank
out.append(merged)
return out
@staticmethod
def _status(branch: Dict[str, Any]) -> str:
if branch.get("skipped", False):
return "skipped"
if branch.get("success", False):
return "success"
return "failed"
def _build_summary(self, branches: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
summary: Dict[str, Dict[str, Any]] = {}
for name, branch in branches.items():
status = self._status(branch)
summary[name] = {
"status": status,
"count": int(branch.get("count", 0) or 0),
}
if status == "skipped":
summary[name]["reason"] = str(branch.get("skip_reason", "") or "")
if status == "failed":
summary[name]["error"] = str(branch.get("error", "") or "")
return summary
def _build_content(
self,
*,
query: str,
branches: Dict[str, Dict[str, Any]],
errors: List[Dict[str, str]],
mixed_results: Optional[List[Dict[str, Any]]],
) -> str:
lines: List[str] = [
f"🔀 聚合查询结果query='{query or 'N/A'}'",
"",
"分支状态:",
]
for name in ("search", "time", "episode"):
branch = branches.get(name, {})
status = self._status(branch)
count = int(branch.get("count", 0) or 0)
line = f"- {name}: {status}, count={count}"
reason = str(branch.get("skip_reason", "") or "").strip()
err = str(branch.get("error", "") or "").strip()
if status == "skipped" and reason:
line += f" ({reason})"
if status == "failed" and err:
line += f" ({err})"
lines.append(line)
if errors:
lines.append("")
lines.append("错误:")
for item in errors[:6]:
lines.append(f"- {item.get('branch', 'unknown')}: {item.get('error', 'unknown error')}")
if mixed_results is not None:
lines.append("")
lines.append(f"🧩 混合结果({len(mixed_results)} 条):")
for idx, item in enumerate(mixed_results[:5], start=1):
src = ",".join(item.get("source_branches", []) or [])
if str(item.get("type", "") or "") == "episode":
title = str(item.get("title", "") or "Untitled")
lines.append(f"{idx}. 🧠 {title} [{src}]")
else:
text = str(item.get("content", "") or "")
if len(text) > 80:
text = text[:80] + "..."
lines.append(f"{idx}. {text} [{src}]")
return "\n".join(lines)
async def execute(
self,
*,
query: str,
top_k: int,
mix: bool,
mix_top_k: Optional[int],
time_from: Optional[str],
time_to: Optional[str],
search_runner: Optional[BranchRunner],
time_runner: Optional[BranchRunner],
episode_runner: Optional[BranchRunner],
) -> Dict[str, Any]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
safe_mix_top_k = max(1, int(mix_top_k if mix_top_k is not None else safe_top_k))
branches: Dict[str, Dict[str, Any]] = {}
errors: List[Dict[str, str]] = []
scheduled: List[Tuple[str, asyncio.Task]] = []
if clean_query:
if search_runner is not None:
scheduled.append(("search", asyncio.create_task(search_runner())))
else:
branches["search"] = self._normalize_branch_payload(
"search",
{"success": False, "error": "search runner unavailable", "results": []},
)
else:
branches["search"] = self._normalize_branch_payload(
"search",
{
"success": False,
"skipped": True,
"skip_reason": "missing_query",
"results": [],
"count": 0,
},
)
if time_from or time_to:
if time_runner is not None:
scheduled.append(("time", asyncio.create_task(time_runner())))
else:
branches["time"] = self._normalize_branch_payload(
"time",
{"success": False, "error": "time runner unavailable", "results": []},
)
else:
branches["time"] = self._normalize_branch_payload(
"time",
{
"success": False,
"skipped": True,
"skip_reason": "missing_time_window",
"results": [],
"count": 0,
},
)
if episode_runner is not None:
scheduled.append(("episode", asyncio.create_task(episode_runner())))
else:
branches["episode"] = self._normalize_branch_payload(
"episode",
{"success": False, "error": "episode runner unavailable", "results": []},
)
if scheduled:
done = await asyncio.gather(
*[task for _, task in scheduled],
return_exceptions=True,
)
for (branch_name, _), payload in zip(scheduled, done):
if isinstance(payload, Exception):
logger.error("aggregate branch failed: branch=%s error=%s", branch_name, payload)
normalized = self._normalize_branch_payload(
branch_name,
{
"success": False,
"error": str(payload),
"results": [],
},
)
else:
normalized = self._normalize_branch_payload(branch_name, payload)
branches[branch_name] = normalized
for name in ("search", "time", "episode"):
branch = branches.get(name)
if not branch:
continue
if branch.get("skipped", False):
continue
if not branch.get("success", False):
errors.append(
{
"branch": name,
"error": str(branch.get("error", "") or "unknown error"),
}
)
success = any(
bool(branches.get(name, {}).get("success", False))
for name in ("search", "time", "episode")
)
mixed_results: Optional[List[Dict[str, Any]]] = None
if mix:
mixed_results = self._build_mixed_results(branches=branches, top_k=safe_mix_top_k)
payload: Dict[str, Any] = {
"success": success,
"query_type": "aggregate",
"query": clean_query,
"top_k": safe_top_k,
"mix": bool(mix),
"mix_top_k": safe_mix_top_k,
"branches": branches,
"errors": errors,
"summary": self._build_summary(branches),
}
if mixed_results is not None:
payload["mixed_results"] = mixed_results
payload["content"] = self._build_content(
query=clean_query,
branches=branches,
errors=errors,
mixed_results=mixed_results,
)
return payload

View File

@@ -0,0 +1,182 @@
"""Episode hybrid retrieval service."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from src.common.logger import get_logger
from ..retrieval import DualPathRetriever, TemporalQueryOptions
logger = get_logger("A_Memorix.EpisodeRetrievalService")
class EpisodeRetrievalService:
"""Hybrid episode retrieval backed by lexical rows and evidence projection."""
_RRF_K = 60.0
_BRANCH_WEIGHTS = {
"lexical": 1.0,
"paragraph_evidence": 1.0,
"relation_evidence": 0.85,
}
def __init__(
self,
*,
metadata_store: Any,
retriever: Optional[DualPathRetriever] = None,
) -> None:
self.metadata_store = metadata_store
self.retriever = retriever
async def query(
self,
*,
query: str = "",
top_k: int = 5,
time_from: Optional[float] = None,
time_to: Optional[float] = None,
person: Optional[str] = None,
source: Optional[str] = None,
include_paragraphs: bool = False,
) -> List[Dict[str, Any]]:
clean_query = str(query or "").strip()
safe_top_k = max(1, int(top_k))
candidate_k = max(30, safe_top_k * 6)
branches: Dict[str, List[Dict[str, Any]]] = {
"lexical": self.metadata_store.query_episodes(
query=clean_query,
time_from=time_from,
time_to=time_to,
person=person,
source=source,
limit=(candidate_k if clean_query else safe_top_k),
)
}
if clean_query and self.retriever is not None:
try:
temporal = TemporalQueryOptions(
time_from=time_from,
time_to=time_to,
person=person,
source=source,
)
results = await self.retriever.retrieve(
query=clean_query,
top_k=candidate_k,
temporal=temporal,
)
except Exception as exc:
logger.warning("episode evidence retrieval failed, fallback to lexical only: %s", exc)
else:
paragraph_rank_map: Dict[str, int] = {}
relation_rank_map: Dict[str, int] = {}
for rank, item in enumerate(results, start=1):
hash_value = str(getattr(item, "hash_value", "") or "").strip()
result_type = str(getattr(item, "result_type", "") or "").strip().lower()
if not hash_value:
continue
if result_type == "paragraph" and hash_value not in paragraph_rank_map:
paragraph_rank_map[hash_value] = rank
elif result_type == "relation" and hash_value not in relation_rank_map:
relation_rank_map[hash_value] = rank
if paragraph_rank_map:
paragraph_rows = self.metadata_store.get_episode_rows_by_paragraph_hashes(
list(paragraph_rank_map.keys()),
source=source,
)
if paragraph_rows:
branches["paragraph_evidence"] = self._rank_projected_rows(
paragraph_rows,
rank_map=paragraph_rank_map,
support_key="matched_paragraph_hashes",
)
if relation_rank_map:
relation_rows = self.metadata_store.get_episode_rows_by_relation_hashes(
list(relation_rank_map.keys()),
source=source,
)
if relation_rows:
branches["relation_evidence"] = self._rank_projected_rows(
relation_rows,
rank_map=relation_rank_map,
support_key="matched_relation_hashes",
)
fused = self._fuse_branches(branches, top_k=safe_top_k)
if include_paragraphs:
for item in fused:
item["paragraphs"] = self.metadata_store.get_episode_paragraphs(
episode_id=str(item.get("episode_id") or ""),
limit=50,
)
return fused
@staticmethod
def _rank_projected_rows(
rows: List[Dict[str, Any]],
*,
rank_map: Dict[str, int],
support_key: str,
) -> List[Dict[str, Any]]:
sentinel = 10**9
ranked = [dict(item) for item in rows]
def _first_support_rank(item: Dict[str, Any]) -> int:
support_hashes = [str(x or "").strip() for x in (item.get(support_key) or [])]
ranks = [int(rank_map[h]) for h in support_hashes if h in rank_map]
return min(ranks) if ranks else sentinel
ranked.sort(
key=lambda item: (
_first_support_rank(item),
-int(item.get("matched_paragraph_count") or 0),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
return ranked
def _fuse_branches(
self,
branches: Dict[str, List[Dict[str, Any]]],
*,
top_k: int,
) -> List[Dict[str, Any]]:
bucket: Dict[str, Dict[str, Any]] = {}
for branch_name, rows in branches.items():
weight = float(self._BRANCH_WEIGHTS.get(branch_name, 0.0) or 0.0)
if weight <= 0.0:
continue
for rank, row in enumerate(rows, start=1):
episode_id = str(row.get("episode_id", "") or "").strip()
if not episode_id:
continue
if episode_id not in bucket:
payload = dict(row)
payload.pop("matched_paragraph_hashes", None)
payload.pop("matched_relation_hashes", None)
payload.pop("matched_paragraph_count", None)
payload.pop("matched_relation_count", None)
payload["_fusion_score"] = 0.0
bucket[episode_id] = payload
bucket[episode_id]["_fusion_score"] = float(
bucket[episode_id].get("_fusion_score", 0.0)
) + weight / (self._RRF_K + float(rank))
out = list(bucket.values())
out.sort(
key=lambda item: (
-float(item.get("_fusion_score", 0.0)),
-float(item.get("updated_at") or 0.0),
str(item.get("episode_id") or ""),
)
)
for item in out:
item.pop("_fusion_score", None)
return out[: max(1, int(top_k))]

View File

@@ -0,0 +1,129 @@
"""
哈希工具模块
提供文本哈希计算功能,用于唯一标识和去重。
"""
import hashlib
import re
from typing import Union
def compute_hash(text: str, hash_type: str = "sha256") -> str:
"""
计算文本的哈希值
Args:
text: 输入文本
hash_type: 哈希算法类型sha256, md5等
Returns:
哈希值字符串
"""
if hash_type == "sha256":
return hashlib.sha256(text.encode("utf-8")).hexdigest()
elif hash_type == "md5":
return hashlib.md5(text.encode("utf-8")).hexdigest()
else:
raise ValueError(f"不支持的哈希算法: {hash_type}")
def normalize_text(text: str) -> str:
"""
规范化文本用于哈希计算
执行以下操作:
- 去除首尾空白
- 统一换行符为\\n
- 压缩多个连续空格
- 去除不可见字符(保留换行和制表符)
Args:
text: 输入文本
Returns:
规范化后的文本
"""
# 去除首尾空白
text = text.strip()
# 统一换行符
text = text.replace("\r\n", "\n").replace("\r", "\n")
# 压缩多个连续空格为一个(但保留换行和制表符)
text = re.sub(r"[^\S\n]+", " ", text)
return text
def compute_paragraph_hash(paragraph: str) -> str:
"""
计算段落的哈希值
Args:
paragraph: 段落文本
Returns:
段落哈希值用于paragraph-前缀)
"""
normalized = normalize_text(paragraph)
return compute_hash(normalized)
def compute_entity_hash(entity: str) -> str:
"""
计算实体的哈希值
Args:
entity: 实体名称
Returns:
实体哈希值用于entity-前缀)
"""
normalized = entity.strip().lower()
return compute_hash(normalized)
def compute_relation_hash(relation: tuple) -> str:
"""
计算关系的哈希值
Args:
relation: 关系元组 (subject, predicate, object)
Returns:
关系哈希值用于relation-前缀)
"""
# 将关系元组转为字符串
relation_str = str(tuple(relation))
return compute_hash(relation_str)
def format_hash_key(hash_type: str, hash_value: str) -> str:
"""
格式化哈希键
Args:
hash_type: 类型前缀paragraph, entity, relation
hash_value: 哈希值
Returns:
格式化的键(如 paragraph-abc123...
"""
return f"{hash_type}-{hash_value}"
def parse_hash_key(key: str) -> tuple[str, str]:
"""
解析哈希键
Args:
key: 格式化的键(如 paragraph-abc123...
Returns:
(类型, 哈希值) 元组
"""
parts = key.split("-", 1)
if len(parts) != 2:
raise ValueError(f"无效的哈希键格式: {key}")
return parts[0], parts[1]

View File

@@ -0,0 +1,110 @@
"""Shared import payload normalization helpers."""
from __future__ import annotations
from typing import Any, Dict, List, Optional
from ..storage import KnowledgeType, resolve_stored_knowledge_type
from .time_parser import normalize_time_meta
def _normalize_entities(raw_entities: Any) -> List[str]:
if not isinstance(raw_entities, list):
return []
out: List[str] = []
seen = set()
for item in raw_entities:
name = str(item or "").strip()
if not name:
continue
key = name.lower()
if key in seen:
continue
seen.add(key)
out.append(name)
return out
def _normalize_relations(raw_relations: Any) -> List[Dict[str, str]]:
if not isinstance(raw_relations, list):
return []
out: List[Dict[str, str]] = []
for item in raw_relations:
if not isinstance(item, dict):
continue
subject = str(item.get("subject", "")).strip()
predicate = str(item.get("predicate", "")).strip()
obj = str(item.get("object", "")).strip()
if not (subject and predicate and obj):
continue
out.append(
{
"subject": subject,
"predicate": predicate,
"object": obj,
}
)
return out
def normalize_paragraph_import_item(
item: Any,
*,
default_source: str,
) -> Dict[str, Any]:
"""Normalize one paragraph import item from text/json payloads."""
if isinstance(item, str):
content = str(item)
knowledge_type = resolve_stored_knowledge_type(None, content=content)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": str(default_source or "").strip(),
"time_meta": None,
"entities": [],
"relations": [],
}
if not isinstance(item, dict) or "content" not in item:
raise ValueError("段落项必须为字符串或包含 content 的对象")
content = str(item.get("content", "") or "")
if not content.strip():
raise ValueError("段落 content 不能为空")
raw_time_meta = {
"event_time": item.get("event_time"),
"event_time_start": item.get("event_time_start"),
"event_time_end": item.get("event_time_end"),
"time_range": item.get("time_range"),
"time_granularity": item.get("time_granularity"),
"time_confidence": item.get("time_confidence"),
}
time_meta_field = item.get("time_meta")
if isinstance(time_meta_field, dict):
raw_time_meta.update(time_meta_field)
knowledge_type_raw = item.get("knowledge_type")
if knowledge_type_raw is None:
knowledge_type_raw = item.get("type")
knowledge_type = resolve_stored_knowledge_type(knowledge_type_raw, content=content)
source = str(item.get("source") or default_source or "").strip()
if not source:
source = str(default_source or "").strip()
normalized_time_meta = normalize_time_meta(raw_time_meta)
return {
"content": content,
"knowledge_type": knowledge_type.value,
"source": source,
"time_meta": normalized_time_meta if normalized_time_meta else None,
"entities": _normalize_entities(item.get("entities")),
"relations": _normalize_relations(item.get("relations")),
}
def normalize_summary_knowledge_type(value: Any) -> KnowledgeType:
"""Normalize config-driven summary knowledge type."""
return resolve_stored_knowledge_type(value, content="")

View File

@@ -0,0 +1,84 @@
"""
IO Utilities
提供原子文件写入等IO辅助功能。
"""
import os
import shutil
import contextlib
from pathlib import Path
from typing import Union
@contextlib.contextmanager
def atomic_write(file_path: Union[str, Path], mode: str = "w", encoding: str = None, **kwargs):
"""
原子文件写入上下文管理器
原理:
1. 写入 .tmp 临时文件
2. 写入成功后,使用 os.replace 原子替换目标文件
3. 如果失败,自动删除临时文件
Args:
file_path: 目标文件路径
mode: 打开模式 ('w', 'wb' 等)
encoding: 编码
**kwargs: 传给 open() 的其他参数
"""
path = Path(file_path)
# 确保父目录存在
path.parent.mkdir(parents=True, exist_ok=True)
# 临时文件路径
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
with open(tmp_path, mode, encoding=encoding, **kwargs) as f:
yield f
# 确保写入磁盘
f.flush()
os.fsync(f.fileno())
# 原子替换 (Windows下可能需要先删除目标文件但 os.replace 在 Py3.3+ 尽可能原子)
# 注意: Windows 上如果有其他进程占用文件os.replace 可能会失败
os.replace(tmp_path, path)
except Exception as e:
# 清理临时文件
if tmp_path.exists():
try:
os.remove(tmp_path)
except:
pass
raise e
@contextlib.contextmanager
def atomic_save_path(file_path: Union[str, Path]):
"""
提供临时路径用于原子保存 (针对只接受路径的API如Faiss)
Args:
file_path: 最终目标路径
Yields:
tmp_path: 临时文件路径 (str)
"""
path = Path(file_path)
path.parent.mkdir(parents=True, exist_ok=True)
tmp_path = path.with_suffix(path.suffix + ".tmp")
try:
yield str(tmp_path)
if Path(tmp_path).exists():
os.replace(tmp_path, path)
except Exception as e:
if Path(tmp_path).exists():
try:
os.remove(tmp_path)
except:
pass
raise e

View File

@@ -0,0 +1,89 @@
"""
高效文本匹配工具模块
实现 Aho-Corasick 算法用于多模式匹配。
"""
from typing import List, Dict, Tuple, Set, Any
from collections import deque
class AhoCorasick:
"""
Aho-Corasick 自动机实现高效多模式匹配
"""
def __init__(self):
# next_states[state][char] = next_state
self.next_states: List[Dict[str, int]] = [{}]
# fail[state] = fail_state
self.fail: List[int] = [0]
# output[state] = set of patterns ending at this state
self.output: List[Set[str]] = [set()]
self.patterns: Set[str] = set()
def add_pattern(self, pattern: str):
"""添加模式"""
if not pattern:
return
self.patterns.add(pattern)
state = 0
for char in pattern:
if char not in self.next_states[state]:
new_state = len(self.next_states)
self.next_states[state][char] = new_state
self.next_states.append({})
self.fail.append(0)
self.output.append(set())
state = self.next_states[state][char]
self.output[state].add(pattern)
def build(self):
"""构建失败指针"""
queue = deque()
# 处理第一层
for char, state in self.next_states[0].items():
queue.append(state)
self.fail[state] = 0
while queue:
r = queue.popleft()
for char, s in self.next_states[r].items():
queue.append(s)
# 找到失败路径
state = self.fail[r]
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
self.fail[s] = self.next_states[state].get(char, 0)
# 合并输出
self.output[s].update(self.output[self.fail[s]])
def search(self, text: str) -> List[Tuple[int, str]]:
"""
在文本中搜索所有模式
Returns:
[(结束索引, 匹配到的模式), ...]
"""
state = 0
results = []
for i, char in enumerate(text):
while char not in self.next_states[state] and state != 0:
state = self.fail[state]
state = self.next_states[state].get(char, 0)
for pattern in self.output[state]:
results.append((i, pattern))
return results
def find_all(self, text: str) -> Dict[str, int]:
"""
查找并统计所有模式出现次数
Returns:
{模式: 出现次数}
"""
results = self.search(text)
stats = {}
for _, pattern in results:
stats[pattern] = stats.get(pattern, 0) + 1
return stats

View File

@@ -0,0 +1,189 @@
"""
内存监控模块
提供内存使用监控和预警功能。
"""
import gc
import threading
import time
from typing import Callable, Optional
try:
import psutil
HAS_PSUTIL = True
except ImportError:
HAS_PSUTIL = False
from src.common.logger import get_logger
logger = get_logger("A_Memorix.MemoryMonitor")
class MemoryMonitor:
"""
内存监控器
功能:
- 实时监控内存使用
- 超过阈值时触发警告
- 支持自动垃圾回收
"""
def __init__(
self,
max_memory_mb: int,
warning_threshold: float = 0.9,
check_interval: float = 10.0,
enable_auto_gc: bool = True,
):
"""
初始化内存监控器
Args:
max_memory_mb: 最大内存限制MB
warning_threshold: 警告阈值0-1之间默认0.9表示90%
check_interval: 检查间隔(秒)
enable_auto_gc: 是否启用自动垃圾回收
"""
self.max_memory_mb = max_memory_mb
self.warning_threshold = warning_threshold
self.check_interval = check_interval
self.enable_auto_gc = enable_auto_gc
self._running = False
self._thread: Optional[threading.Thread] = None
self._callbacks: list[Callable[[float, float], None]] = []
def start(self):
"""启动监控"""
if self._running:
logger.warning("内存监控已在运行")
return
if not HAS_PSUTIL:
logger.warning("psutil 未安装,内存监控功能不可用")
return
self._running = True
self._thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._thread.start()
logger.info(f"内存监控已启动 (限制: {self.max_memory_mb}MB)")
def stop(self):
"""停止监控"""
if not self._running:
return
self._running = False
if self._thread:
self._thread.join(timeout=5.0)
logger.info("内存监控已停止")
def register_callback(self, callback: Callable[[float, float], None]):
"""
注册内存超限回调函数
Args:
callback: 回调函数,接收 (当前使用MB, 限制MB) 参数
"""
self._callbacks.append(callback)
def get_current_memory_mb(self) -> float:
"""
获取当前进程内存使用量MB
Returns:
内存使用量MB
"""
if not HAS_PSUTIL:
# 降级方案:使用内置函数
import sys
return sys.getsizeof(gc.get_objects()) / 1024 / 1024
process = psutil.Process()
return process.memory_info().rss / 1024 / 1024
def get_memory_usage_ratio(self) -> float:
"""
获取内存使用率
Returns:
使用率0-1之间
"""
current = self.get_current_memory_mb()
return current / self.max_memory_mb if self.max_memory_mb > 0 else 0
def _monitor_loop(self):
"""监控循环"""
while self._running:
try:
current_mb = self.get_current_memory_mb()
ratio = current_mb / self.max_memory_mb if self.max_memory_mb > 0 else 0
# 检查是否超过阈值
if ratio >= self.warning_threshold:
logger.warning(
f"内存使用率过高: {current_mb:.1f}MB / {self.max_memory_mb}MB "
f"({ratio*100:.1f}%)"
)
# 触发回调
for callback in self._callbacks:
try:
callback(current_mb, self.max_memory_mb)
except Exception as e:
logger.error(f"内存回调执行失败: {e}")
# 自动垃圾回收
if self.enable_auto_gc:
before = self.get_current_memory_mb()
gc.collect()
after = self.get_current_memory_mb()
freed = before - after
if freed > 1: # 释放超过1MB才记录
logger.info(f"垃圾回收释放: {freed:.1f}MB")
# 定期垃圾回收(即使未超限)
elif self.enable_auto_gc and int(time.time()) % 60 == 0:
gc.collect()
except Exception as e:
logger.error(f"内存监控出错: {e}")
# 等待下次检查
time.sleep(self.check_interval)
def __enter__(self):
"""上下文管理器入口"""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
self.stop()
def get_memory_info() -> dict:
"""
获取系统内存信息
Returns:
内存信息字典
"""
if not HAS_PSUTIL:
return {"error": "psutil 未安装"}
try:
mem = psutil.virtual_memory()
process = psutil.Process()
return {
"system_total_gb": mem.total / 1024 / 1024 / 1024,
"system_available_gb": mem.available / 1024 / 1024 / 1024,
"system_usage_percent": mem.percent,
"process_mb": process.memory_info().rss / 1024 / 1024,
"process_percent": (process.memory_info().rss / mem.total) * 100,
}
except Exception as e:
return {"error": str(e)}

View File

@@ -0,0 +1,165 @@
"""Shared path-fallback helpers for search post-processing."""
from __future__ import annotations
import hashlib
from typing import Any, Dict, List, Optional, Sequence, Tuple
from ..retrieval.dual_path import RetrievalResult
def extract_entities(query: str, graph_store: Any) -> List[str]:
"""Extract up to two graph nodes from a query using n-gram matching."""
if not graph_store:
return []
text = str(query or "").strip()
if not text:
return []
# Keep the heuristic aligned with previous legacy behavior.
tokens = (
text.replace("?", " ")
.replace("!", " ")
.replace(".", " ")
.split()
)
if not tokens:
return []
found_entities = set()
skip_indices = set()
max_n = min(4, len(tokens))
for size in range(max_n, 0, -1):
for i in range(len(tokens) - size + 1):
if any(idx in skip_indices for idx in range(i, i + size)):
continue
span = " ".join(tokens[i : i + size])
matched_node = graph_store.find_node(span, ignore_case=True)
if not matched_node:
continue
found_entities.add(matched_node)
for idx in range(i, i + size):
skip_indices.add(idx)
return list(found_entities)
def find_paths_between_entities(
start_node: str,
end_node: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Find and enrich indirect paths between two nodes."""
if not graph_store or not metadata_store:
return []
try:
paths = graph_store.find_paths(
start_node,
end_node,
max_depth=max_depth,
max_paths=max_paths,
)
except Exception:
return []
if not paths:
return []
edge_cache: Dict[Tuple[str, str], Tuple[str, str]] = {}
formatted_paths: List[Dict[str, Any]] = []
for path_nodes in paths:
if not isinstance(path_nodes, Sequence) or len(path_nodes) < 2:
continue
path_desc: List[str] = []
for i in range(len(path_nodes) - 1):
u = str(path_nodes[i])
v = str(path_nodes[i + 1])
cache_key = tuple(sorted((u, v)))
if cache_key in edge_cache:
pred, direction = edge_cache[cache_key]
else:
pred = "related"
direction = "->"
rels = metadata_store.get_relations(subject=u, object=v)
if not rels:
rels = metadata_store.get_relations(subject=v, object=u)
direction = "<-"
if rels:
best_rel = max(rels, key=lambda x: x.get("confidence", 1.0))
pred = str(best_rel.get("predicate", "related") or "related")
edge_cache[cache_key] = (pred, direction)
step_str = f"-[{pred}]->" if direction == "->" else f"<-[{pred}]-"
path_desc.append(step_str)
full_path_str = str(path_nodes[0])
for i, step in enumerate(path_desc):
full_path_str += f" {step} {path_nodes[i + 1]}"
formatted_paths.append(
{
"nodes": list(path_nodes),
"description": full_path_str,
}
)
return formatted_paths
def find_paths_from_query(
query: str,
graph_store: Any,
metadata_store: Any,
*,
max_depth: int = 3,
max_paths: int = 5,
) -> List[Dict[str, Any]]:
"""Extract entities from query and resolve indirect paths."""
entities = extract_entities(query, graph_store)
if len(entities) != 2:
return []
return find_paths_between_entities(
entities[0],
entities[1],
graph_store,
metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
def to_retrieval_results(paths: Sequence[Dict[str, Any]]) -> List[RetrievalResult]:
"""Convert path results into retrieval results for the unified pipeline."""
converted: List[RetrievalResult] = []
for item in paths:
description = str(item.get("description", "")).strip()
if not description:
continue
hash_seed = description.encode("utf-8")
path_hash = f"path_{hashlib.sha1(hash_seed).hexdigest()}"
converted.append(
RetrievalResult(
hash_value=path_hash,
content=f"[Indirect Relation] {description}",
score=0.95,
result_type="relation",
source="graph_path",
metadata={
"source": "graph_path",
"is_indirect": True,
"nodes": list(item.get("nodes", [])),
},
)
)
return converted

View File

@@ -0,0 +1,495 @@
"""
人物画像服务
主链路:
person_id -> 用户名/别名 -> 图谱关系 + 向量证据 -> 证据总结画像 -> 快照版本化存储
"""
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from ..embedding import EmbeddingAPIAdapter
from ..retrieval import (
DualPathRetriever,
RetrievalStrategy,
DualPathRetrieverConfig,
SparseBM25Config,
FusionConfig,
GraphRelationRecallConfig,
)
from ..storage import MetadataStore, GraphStore, VectorStore
logger = get_logger("A_Memorix.PersonProfileService")
class PersonProfileService:
"""人物画像聚合/刷新服务。"""
def __init__(
self,
metadata_store: MetadataStore,
graph_store: Optional[GraphStore] = None,
vector_store: Optional[VectorStore] = None,
embedding_manager: Optional[EmbeddingAPIAdapter] = None,
sparse_index: Any = None,
plugin_config: Optional[dict] = None,
retriever: Optional[DualPathRetriever] = None,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
self.sparse_index = sparse_index
self.plugin_config = plugin_config or {}
self.retriever = retriever or self._build_retriever()
def _cfg(self, key: str, default: Any = None) -> Any:
"""读取嵌套配置。"""
if not isinstance(self.plugin_config, dict):
return default
current: Any = self.plugin_config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_retriever(self) -> Optional[DualPathRetriever]:
"""按需构建检索器(无依赖时返回 None"""
if not all(
[
self.vector_store is not None,
self.graph_store is not None,
self.metadata_store is not None,
self.embedding_manager is not None,
]
):
return None
try:
sparse_cfg_raw = self._cfg("retrieval.sparse", {}) or {}
fusion_cfg_raw = self._cfg("retrieval.fusion", {}) or {}
graph_recall_cfg_raw = self._cfg("retrieval.search.graph_recall", {}) or {}
if not isinstance(sparse_cfg_raw, dict):
sparse_cfg_raw = {}
if not isinstance(fusion_cfg_raw, dict):
fusion_cfg_raw = {}
if not isinstance(graph_recall_cfg_raw, dict):
graph_recall_cfg_raw = {}
sparse_cfg = SparseBM25Config(**sparse_cfg_raw)
fusion_cfg = FusionConfig(**fusion_cfg_raw)
graph_recall_cfg = GraphRelationRecallConfig(**graph_recall_cfg_raw)
config = DualPathRetrieverConfig(
top_k_paragraphs=int(self._cfg("retrieval.top_k_paragraphs", 20)),
top_k_relations=int(self._cfg("retrieval.top_k_relations", 10)),
top_k_final=int(self._cfg("retrieval.top_k_final", 10)),
alpha=float(self._cfg("retrieval.alpha", 0.5)),
enable_ppr=bool(self._cfg("retrieval.enable_ppr", True)),
ppr_alpha=float(self._cfg("retrieval.ppr_alpha", 0.85)),
ppr_concurrency_limit=int(self._cfg("retrieval.ppr_concurrency_limit", 4)),
enable_parallel=bool(self._cfg("retrieval.enable_parallel", True)),
retrieval_strategy=RetrievalStrategy.DUAL_PATH,
debug=bool(self._cfg("advanced.debug", False)),
sparse=sparse_cfg,
fusion=fusion_cfg,
graph_recall=graph_recall_cfg,
)
return DualPathRetriever(
vector_store=self.vector_store,
graph_store=self.graph_store,
metadata_store=self.metadata_store,
embedding_manager=self.embedding_manager,
sparse_index=self.sparse_index,
config=config,
)
except Exception as e:
logger.warning(f"初始化人物画像检索器失败,将只使用关系证据: {e}")
return None
@staticmethod
def resolve_person_id(identifier: str) -> str:
"""按 person_id 或姓名/别名解析 person_id。"""
if not identifier:
return ""
key = str(identifier).strip()
if not key:
return ""
if len(key) == 32 and all(ch in "0123456789abcdefABCDEF" for ch in key):
return key.lower()
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where((PersonInfo.person_name == key) | (PersonInfo.nickname == key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
try:
record = (
PersonInfo.select(PersonInfo.person_id)
.where(PersonInfo.group_nick_name.contains(key))
.first()
)
if record and record.person_id:
return str(record.person_id)
except Exception:
pass
return ""
def _parse_group_nicks(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
if isinstance(raw_value, list):
items = raw_value
else:
try:
items = json.loads(raw_value)
except Exception:
return []
names: List[str] = []
for item in items:
if isinstance(item, dict):
value = str(item.get("group_nick_name", "")).strip()
if value:
names.append(value)
elif isinstance(item, str):
value = item.strip()
if value:
names.append(value)
return names
def _parse_memory_traits(self, raw_value: Any) -> List[str]:
if not raw_value:
return []
try:
values = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
except Exception:
return []
if not isinstance(values, list):
return []
traits: List[str] = []
for item in values:
text = str(item).strip()
if not text:
continue
if ":" in text:
parts = text.split(":")
if len(parts) >= 3:
content = ":".join(parts[1:-1]).strip()
if content:
traits.append(content)
continue
traits.append(text)
return traits[:10]
def get_person_aliases(self, person_id: str) -> Tuple[List[str], str, List[str]]:
"""获取人物别名集合、主展示名、记忆特征。"""
aliases: List[str] = []
primary_name = ""
memory_traits: List[str] = []
if not person_id:
return aliases, primary_name, memory_traits
try:
record = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not record:
return aliases, primary_name, memory_traits
person_name = str(getattr(record, "person_name", "") or "").strip()
nickname = str(getattr(record, "nickname", "") or "").strip()
group_nicks = self._parse_group_nicks(getattr(record, "group_nick_name", None))
memory_traits = self._parse_memory_traits(getattr(record, "memory_points", None))
primary_name = person_name or nickname or str(getattr(record, "user_id", "") or "").strip() or person_id
candidates = [person_name, nickname] + group_nicks
seen = set()
for item in candidates:
norm = str(item or "").strip()
if not norm or norm in seen:
continue
seen.add(norm)
aliases.append(norm)
except Exception as e:
logger.warning(f"解析人物别名失败: person_id={person_id}, err={e}")
return aliases, primary_name, memory_traits
def _collect_relation_evidence(self, aliases: List[str], limit: int = 30) -> List[Dict[str, Any]]:
relation_by_hash: Dict[str, Dict[str, Any]] = {}
for alias in aliases:
for rel in self.metadata_store.get_relations(subject=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
for rel in self.metadata_store.get_relations(object=alias):
h = str(rel.get("hash", ""))
if h:
relation_by_hash[h] = rel
relations = list(relation_by_hash.values())
relations.sort(key=lambda item: float(item.get("confidence", 0.0)), reverse=True)
relations = relations[: max(1, int(limit))]
edges: List[Dict[str, Any]] = []
for rel in relations:
edges.append(
{
"hash": str(rel.get("hash", "")),
"subject": str(rel.get("subject", "")),
"predicate": str(rel.get("predicate", "")),
"object": str(rel.get("object", "")),
"confidence": float(rel.get("confidence", 1.0) or 1.0),
}
)
return edges
async def _collect_vector_evidence(self, aliases: List[str], top_k: int = 12) -> List[Dict[str, Any]]:
alias_queries = [a for a in aliases if a]
if not alias_queries:
return []
if self.retriever is None:
# 回退:无检索器时只做简单内容匹配
fallback: List[Dict[str, Any]] = []
seen_hash = set()
for alias in alias_queries:
for para in self.metadata_store.search_paragraphs_by_content(alias)[: max(2, top_k // 2)]:
h = str(para.get("hash", ""))
if not h or h in seen_hash:
continue
seen_hash.add(h)
fallback.append(
{
"hash": h,
"type": "paragraph",
"score": 0.0,
"content": str(para.get("content", ""))[:180],
"metadata": {},
}
)
return fallback[:top_k]
per_alias_top_k = max(2, int(top_k / max(1, len(alias_queries))))
seen_hash = set()
evidence: List[Dict[str, Any]] = []
for alias in alias_queries:
try:
results = await self.retriever.retrieve(alias, top_k=per_alias_top_k)
except Exception as e:
logger.warning(f"向量证据召回失败: alias={alias}, err={e}")
continue
for item in results:
h = str(getattr(item, "hash_value", "") or "")
if not h or h in seen_hash:
continue
seen_hash.add(h)
evidence.append(
{
"hash": h,
"type": str(getattr(item, "result_type", "")),
"score": float(getattr(item, "score", 0.0) or 0.0),
"content": str(getattr(item, "content", "") or "")[:220],
"metadata": dict(getattr(item, "metadata", {}) or {}),
}
)
evidence.sort(key=lambda x: x.get("score", 0.0), reverse=True)
return evidence[:top_k]
def _build_profile_text(
self,
person_id: str,
primary_name: str,
aliases: List[str],
relation_edges: List[Dict[str, Any]],
vector_evidence: List[Dict[str, Any]],
memory_traits: List[str],
) -> str:
"""基于证据构建画像文本(供 LLM 上下文注入)。"""
lines: List[str] = []
lines.append(f"人物ID: {person_id}")
if primary_name:
lines.append(f"主称呼: {primary_name}")
if aliases:
lines.append(f"别名: {', '.join(aliases[:8])}")
if memory_traits:
lines.append(f"记忆特征: {'; '.join(memory_traits[:6])}")
if relation_edges:
lines.append("关系证据:")
for rel in relation_edges[:6]:
s = rel.get("subject", "")
p = rel.get("predicate", "")
o = rel.get("object", "")
conf = float(rel.get("confidence", 0.0))
lines.append(f"- {s} {p} {o} (conf={conf:.2f})")
if vector_evidence:
lines.append("向量证据摘要:")
for item in vector_evidence[:4]:
content = str(item.get("content", "")).strip()
if content:
lines.append(f"- {content}")
if len(lines) <= 2:
lines.append("暂无足够证据形成稳定画像。")
return "\n".join(lines)
@staticmethod
def _is_snapshot_stale(snapshot: Optional[Dict[str, Any]], ttl_seconds: float) -> bool:
if not snapshot:
return True
now = time.time()
expires_at = snapshot.get("expires_at")
if expires_at is not None:
try:
return now >= float(expires_at)
except Exception:
return True
updated_at = float(snapshot.get("updated_at") or 0.0)
return (now - updated_at) >= ttl_seconds
def _apply_manual_override(self, person_id: str, profile_payload: Dict[str, Any]) -> Dict[str, Any]:
"""将手工覆盖并入画像结果(覆盖 profile_text同时保留 auto_profile_text"""
payload = dict(profile_payload or {})
auto_text = str(payload.get("profile_text", "") or "")
payload["auto_profile_text"] = auto_text
payload["has_manual_override"] = False
payload["manual_override_text"] = ""
payload["override_updated_at"] = None
payload["override_updated_by"] = ""
payload["profile_source"] = "auto_snapshot"
if not person_id or self.metadata_store is None:
return payload
try:
override = self.metadata_store.get_person_profile_override(person_id)
except Exception as e:
logger.warning(f"读取人物画像手工覆盖失败: person_id={person_id}, err={e}")
return payload
if not override:
return payload
manual_text = str(override.get("override_text", "") or "").strip()
if not manual_text:
return payload
payload["has_manual_override"] = True
payload["manual_override_text"] = manual_text
payload["override_updated_at"] = override.get("updated_at")
payload["override_updated_by"] = str(override.get("updated_by", "") or "")
payload["profile_text"] = manual_text
payload["profile_source"] = "manual_override"
return payload
async def query_person_profile(
self,
person_id: str = "",
person_keyword: str = "",
top_k: int = 12,
ttl_seconds: float = 6 * 3600,
force_refresh: bool = False,
source_note: str = "",
) -> Dict[str, Any]:
"""查询或刷新人物画像。"""
pid = str(person_id or "").strip()
if not pid and person_keyword:
pid = self.resolve_person_id(person_keyword)
if not pid:
return {
"success": False,
"error": "person_id 无效,且未能通过别名解析",
}
latest = self.metadata_store.get_latest_person_profile_snapshot(pid)
if not force_refresh and not self._is_snapshot_stale(latest, ttl_seconds):
aliases, primary_name, _ = self.get_person_aliases(pid)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": True,
**(latest or {}),
}
if aliases and not payload.get("aliases"):
payload["aliases"] = aliases
return {
**self._apply_manual_override(pid, payload),
}
aliases, primary_name, memory_traits = self.get_person_aliases(pid)
if not aliases and person_keyword:
aliases = [person_keyword.strip()]
primary_name = person_keyword.strip()
relation_edges = self._collect_relation_evidence(aliases, limit=max(10, top_k * 2))
vector_evidence = await self._collect_vector_evidence(aliases, top_k=max(4, top_k))
evidence_ids = [
str(item.get("hash", ""))
for item in (relation_edges + vector_evidence)
if str(item.get("hash", "")).strip()
]
dedup_ids: List[str] = []
seen = set()
for item in evidence_ids:
if item in seen:
continue
seen.add(item)
dedup_ids.append(item)
profile_text = self._build_profile_text(
person_id=pid,
primary_name=primary_name,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
memory_traits=memory_traits,
)
expires_at = time.time() + float(ttl_seconds) if ttl_seconds > 0 else None
snapshot = self.metadata_store.upsert_person_profile_snapshot(
person_id=pid,
profile_text=profile_text,
aliases=aliases,
relation_edges=relation_edges,
vector_evidence=vector_evidence,
evidence_ids=dedup_ids,
expires_at=expires_at,
source_note=source_note,
)
payload = {
"success": True,
"person_id": pid,
"person_name": primary_name,
"from_cache": False,
**snapshot,
}
return {
**self._apply_manual_override(pid, payload),
}
@staticmethod
def format_persona_profile_block(profile: Dict[str, Any]) -> str:
"""格式化给 replyer 的注入块。"""
if not profile or not profile.get("success"):
return ""
text = str(profile.get("profile_text", "") or "").strip()
if not text:
return ""
return (
"【人物画像-内部参考】\n"
f"{text}\n"
"仅供内部推理,不要向用户逐字复述。"
)

View File

@@ -0,0 +1,27 @@
"""Plugin ID matching policy for A_Memorix."""
from __future__ import annotations
from typing import Any
class PluginIdPolicy:
"""Centralized plugin id normalization/matching policy."""
CANONICAL_ID = "a_memorix"
@classmethod
def normalize(cls, plugin_id: Any) -> str:
if not isinstance(plugin_id, str):
return ""
return plugin_id.strip().lower()
@classmethod
def is_target_plugin_id(cls, plugin_id: Any) -> bool:
normalized = cls.normalize(plugin_id)
if not normalized:
return False
if normalized == cls.CANONICAL_ID:
return True
return normalized.split(".")[-1] == cls.CANONICAL_ID

View File

@@ -0,0 +1,344 @@
"""
向量量化工具模块
提供向量量化与反量化功能,用于压缩存储空间。
"""
import numpy as np
from enum import Enum
from typing import Tuple, Union
from src.common.logger import get_logger
logger = get_logger("A_Memorix.Quantization")
class QuantizationType(Enum):
"""量化类型枚举"""
FLOAT32 = "float32" # 无量化
INT8 = "int8" # 标量量化8位整数
PQ = "pq" # 乘积量化Product Quantization
def quantize_vector(
vector: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化向量
Args:
vector: 输入向量float32
quant_type: 量化类型
Returns:
量化后的向量:
- INT8: int8向量
- PQ: (编码向量, 聚类中心) 元组
"""
if quant_type == QuantizationType.FLOAT32:
return vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_quantize_int8(vector)
elif quant_type == QuantizationType.PQ:
return _product_quantize(vector)
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_vector(
quantized_vector: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]],
quant_type: QuantizationType = QuantizationType.INT8,
original_shape: Tuple[int, ...] = None,
) -> np.ndarray:
"""
反量化向量
Args:
quantized_vector: 量化后的向量
quant_type: 量化类型
original_shape: 原始向量形状用于PQ
Returns:
反量化后的向量float32
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_vector.astype(np.float32)
elif quant_type == QuantizationType.INT8:
return _scalar_dequantize_int8(quantized_vector)
elif quant_type == QuantizationType.PQ:
if not isinstance(quantized_vector, tuple):
raise ValueError("PQ反量化需要列表/元组格式: (codes, centroids)")
return _product_dequantize(quantized_vector[0], quantized_vector[1])
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def _scalar_quantize_int8(vector: np.ndarray) -> np.ndarray:
"""
标量量化float32 -> int8
将向量归一化到 [0, 255] 范围,然后映射到 int8
Args:
vector: 输入向量
Returns:
量化后的 int8 向量
"""
# 计算最小最大值
min_val = np.min(vector)
max_val = np.max(vector)
# 避免除零
if max_val == min_val:
return np.zeros_like(vector, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (vector - min_val) / (max_val - min_val) * 255
# 映射到 [-128, 127] 并转换为 int8
# np.round might return float, minus 128 then cast
quantized = np.round(normalized - 128.0).astype(np.int8)
# 存储归一化参数(用于反量化)
# 在实际存储中,这些参数需要单独保存
# 这里为了简单,我们使用一个全局字典来模拟
if not hasattr(_scalar_quantize_int8, "_params"):
_scalar_quantize_int8._params = {}
vector_id = id(vector)
_scalar_quantize_int8._params[vector_id] = (min_val, max_val)
return quantized
def _scalar_dequantize_int8(quantized: np.ndarray) -> np.ndarray:
"""
标量反量化int8 -> float32
Args:
quantized: 量化后的 int8 向量
Returns:
反量化后的 float32 向量
"""
# 计算归一化参数(如果提供了)
# 在实际应用中min_val 和 max_val 应该被保存
if not hasattr(_scalar_dequantize_int8, "_params"):
# 默认假设范围是 [-1, 1]
return (quantized.astype(np.float32) + 128.0) / 255.0 * 2.0 - 1.0
# 尝试查找参数 (这里只是演示逻辑,实际应从存储中读取)
# return (quantized.astype(np.float32) + 128.0) / 255.0 * (max - min) + min
return (quantized.astype(np.float32) + 128.0) / 255.0
def quantize_matrix(
matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
量化矩阵(批量量化向量)
Args:
matrix: 输入矩阵N x D每行是一个向量
quant_type: 量化类型
Returns:
量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 对整个矩阵进行全局归一化
min_val = np.min(matrix)
max_val = np.max(matrix)
if max_val == min_val:
return np.zeros_like(matrix, dtype=np.int8)
# 归一化到 [0, 255]
normalized = (matrix - min_val) / (max_val - min_val) * 255
quantized = np.round(normalized).astype(np.int8)
return quantized
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def dequantize_matrix(
quantized_matrix: np.ndarray,
quant_type: QuantizationType = QuantizationType.INT8,
min_val: float = None,
max_val: float = None,
) -> np.ndarray:
"""
反量化矩阵
Args:
quantized_matrix: 量化后的矩阵
quant_type: 量化类型
min_val: 归一化最小值int8反量化需要
max_val: 归一化最大值int8反量化需要
Returns:
反量化后的矩阵
"""
if quant_type == QuantizationType.FLOAT32:
return quantized_matrix.astype(np.float32)
elif quant_type == QuantizationType.INT8:
# 使用提供的归一化参数反量化
if min_val is None or max_val is None:
# 默认假设范围是 [0, 255] -> [-1, 1]
return quantized_matrix.astype(np.float32) / 127.0
else:
# 恢复到原始范围
normalized = quantized_matrix.astype(np.float32) / 255.0
return normalized * (max_val - min_val) + min_val
else:
raise ValueError(f"不支持的量化类型: {quant_type}")
def estimate_memory_reduction(
num_vectors: int,
dimension: int,
from_type: QuantizationType,
to_type: QuantizationType,
) -> Tuple[float, float]:
"""
估算内存节省量
Args:
num_vectors: 向量数量
dimension: 向量维度
from_type: 原始量化类型
to_type: 目标量化类型
Returns:
(原始大小MB, 量化后大小MB, 节省比例)
"""
# 计算每个向量占用的字节数
bytes_per_element = {
QuantizationType.FLOAT32: 4,
QuantizationType.INT8: 1,
QuantizationType.PQ: 0.25, # 假设压缩到1/4
}
original_bytes = num_vectors * dimension * bytes_per_element[from_type]
quantized_bytes = num_vectors * dimension * bytes_per_element[to_type]
original_mb = original_bytes / 1024 / 1024
quantized_mb = quantized_bytes / 1024 / 1024
reduction_ratio = (original_bytes - quantized_bytes) / original_bytes
return original_mb, quantized_mb, reduction_ratio
def estimate_compression_stats(
num_vectors: int,
dimension: int,
quant_type: QuantizationType,
) -> dict:
"""
估算压缩统计信息
Args:
num_vectors: 向量数量
dimension: 向量维度
quant_type: 量化类型
Returns:
统计信息字典
"""
original_mb, quantized_mb, ratio = estimate_memory_reduction(
num_vectors, dimension, QuantizationType.FLOAT32, quant_type
)
return {
"num_vectors": num_vectors,
"dimension": dimension,
"quantization_type": quant_type.value,
"original_size_mb": round(original_mb, 2),
"quantized_size_mb": round(quantized_mb, 2),
"saved_mb": round(original_mb - quantized_mb, 2),
"compression_ratio": round(ratio * 100, 2),
}
def _product_quantize(
vector: np.ndarray, m: int = 8, k: int = 256
) -> Tuple[np.ndarray, np.ndarray]:
"""
乘积量化 (PQ) 简化实现
Args:
vector: 输入向量 (D,)
m: 子空间数量
k: 每个子空间的聚类中心数
Returns:
(编码后的向量, 聚类中心)
"""
d = vector.shape[0]
if d % m != 0:
raise ValueError(f"维度 {d} 必须能被子空间数量 {m} 整除")
ds = d // m # 子空间维度
codes = np.zeros(m, dtype=np.uint8)
centroids = np.zeros((m, k, ds), dtype=np.float32)
# 这里采用一种简化的 PQ不进行 K-means 训练,
# 而是预定一些量化点或针对单向量的微型聚类(实际应用中应离线训练)
# 为了演示,我们直接将子空间切分为 k 份进行量化
for i in range(m):
sub_vec = vector[i * ds : (i + 1) * ds]
# 简化:假定每个子空间的取值范围并划分
# 实际 PQ 应使用 k-means 产生的 centroids
# 这里为演示创建一个随机 codebook 并找到最接近的核心
sub_min, sub_max = np.min(sub_vec), np.max(sub_vec)
if sub_max == sub_min:
linspace = np.zeros(k)
else:
linspace = np.linspace(sub_min, sub_max, k)
for j in range(k):
centroids[i, j, :] = linspace[j]
# 编码:这里简化为取子空间均值找最接近的 centroid
sub_mean = np.mean(sub_vec)
code = np.argmin(np.abs(linspace - sub_mean))
codes[i] = code
return codes, centroids
def _product_dequantize(codes: np.ndarray, centroids: np.ndarray) -> np.ndarray:
"""
PQ 反量化
Args:
codes: 编码向量 (M,)
centroids: 聚类中心 (M, K, DS)
Returns:
恢复后的向量 (D,)
"""
m, k, ds = centroids.shape
vector = np.zeros(m * ds, dtype=np.float32)
for i in range(m):
code = codes[i]
vector[i * ds : (i + 1) * ds] = centroids[i, code, :]
return vector

View File

@@ -0,0 +1,121 @@
"""关系查询规格解析工具。"""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Optional
@dataclass
class RelationQuerySpec:
raw: str
is_structured: bool
subject: Optional[str]
predicate: Optional[str]
object: Optional[str]
error: Optional[str] = None
_NATURAL_LANGUAGE_PATTERN = re.compile(
r"(^\s*(what|who|which|how|why|when|where)\b|"
r"\?||"
r"\b(relation|related|between)\b|"
r"(什么关系|有哪些关系|之间|关联))",
re.IGNORECASE,
)
def _looks_like_natural_language(raw: str) -> bool:
text = str(raw or "").strip()
if not text:
return False
return _NATURAL_LANGUAGE_PATTERN.search(text) is not None
def parse_relation_query_spec(relation_spec: str) -> RelationQuerySpec:
raw = str(relation_spec or "").strip()
if not raw:
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
error="empty",
)
if "|" in raw:
parts = [p.strip() for p in raw.split("|")]
if len(parts) < 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_pipe_format",
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0] or None,
predicate=parts[1] or None,
object=parts[2] if len(parts) > 2 and parts[2] else None,
)
if "->" in raw:
parts = [p.strip() for p in raw.split("->") if p.strip()]
if len(parts) >= 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
if len(parts) == 2:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=None,
object=parts[1],
)
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=None,
predicate=None,
object=None,
error="invalid_arrow_format",
)
if _looks_like_natural_language(raw):
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)
# 仅保留低歧义的紧凑三元组作为兼容语法,例如 "Alice likes Apple"。
# 两词形式过于模糊,不再视为结构化关系查询。
parts = raw.split()
if len(parts) == 3:
return RelationQuerySpec(
raw=raw,
is_structured=True,
subject=parts[0],
predicate=parts[1],
object=parts[2],
)
return RelationQuerySpec(
raw=raw,
is_structured=False,
subject=None,
predicate=None,
object=None,
)

View File

@@ -0,0 +1,164 @@
"""
统一关系写入与关系向量化服务。
规则:
1. 元数据是主数据源,向量是从索引。
2. 关系先写 metadata再写向量。
3. 向量失败不回滚 metadata依赖状态机与回填任务修复。
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RelationWriteService")
@dataclass
class RelationWriteResult:
hash_value: str
vector_written: bool
vector_already_exists: bool
vector_state: str
class RelationWriteService:
"""关系写入收口服务。"""
ERROR_MAX_LEN = 500
def __init__(
self,
metadata_store: Any,
graph_store: Any,
vector_store: Any,
embedding_manager: Any,
):
self.metadata_store = metadata_store
self.graph_store = graph_store
self.vector_store = vector_store
self.embedding_manager = embedding_manager
@staticmethod
def build_relation_vector_text(subject: str, predicate: str, obj: str) -> str:
s = str(subject or "").strip()
p = str(predicate or "").strip()
o = str(obj or "").strip()
# 双表达:兼容关键词检索与自然语言问句
return f"{s} {p} {o}\n{s}{o}的关系是{p}"
async def ensure_relation_vector(
self,
hash_value: str,
subject: str,
predicate: str,
obj: str,
*,
max_error_len: int = ERROR_MAX_LEN,
) -> RelationWriteResult:
"""
为已有关系确保向量存在并更新状态。
"""
if hash_value in self.vector_store:
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
self.metadata_store.set_relation_vector_state(hash_value, "pending")
try:
vector_text = self.build_relation_vector_text(subject, predicate, obj)
embedding = await self.embedding_manager.encode(vector_text)
self.vector_store.add(
vectors=embedding.reshape(1, -1),
ids=[hash_value],
)
self.metadata_store.set_relation_vector_state(hash_value, "ready")
logger.info(
"metric.relation_vector_write_success=1 metric.relation_vector_write_success_count=1 hash=%s",
hash_value[:16],
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=True,
vector_already_exists=False,
vector_state="ready",
)
except ValueError:
# 向量已存在冲突,按成功处理
self.metadata_store.set_relation_vector_state(hash_value, "ready")
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=True,
vector_state="ready",
)
except Exception as e:
err = str(e)[:max_error_len]
self.metadata_store.set_relation_vector_state(
hash_value,
"failed",
error=err,
bump_retry=True,
)
logger.warning(
"metric.relation_vector_write_fail=1 metric.relation_vector_write_fail_count=1 hash=%s err=%s",
hash_value[:16],
err,
)
return RelationWriteResult(
hash_value=hash_value,
vector_written=False,
vector_already_exists=False,
vector_state="failed",
)
async def upsert_relation_with_vector(
self,
subject: str,
predicate: str,
obj: str,
confidence: float = 1.0,
source_paragraph: str = "",
metadata: Optional[Dict[str, Any]] = None,
*,
write_vector: bool = True,
) -> RelationWriteResult:
"""
统一关系写入:
1) 写 metadata relation
2) 写 graph edge relation_hash
3) 按需写 relation vector
"""
rel_hash = self.metadata_store.add_relation(
subject=subject,
predicate=predicate,
obj=obj,
confidence=confidence,
source_paragraph=source_paragraph,
metadata=metadata or {},
)
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
if not write_vector:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
return RelationWriteResult(
hash_value=rel_hash,
vector_written=False,
vector_already_exists=False,
vector_state="none",
)
return await self.ensure_relation_vector(
hash_value=rel_hash,
subject=subject,
predicate=predicate,
obj=obj,
)

View File

@@ -0,0 +1,197 @@
"""Runtime self-check helpers for A_Memorix."""
from __future__ import annotations
import time
from typing import Any, Dict, Optional
import numpy as np
from src.common.logger import get_logger
logger = get_logger("A_Memorix.RuntimeSelfCheck")
_DEFAULT_SAMPLE_TEXT = "A_Memorix runtime self check"
def _safe_int(value: Any, default: int = 0) -> int:
try:
return int(value)
except Exception:
return int(default)
def _get_config_value(config: Any, key: str, default: Any = None) -> Any:
getter = getattr(config, "get_config", None)
if callable(getter):
return getter(key, default)
current: Any = config
for part in key.split("."):
if isinstance(current, dict) and part in current:
current = current[part]
else:
return default
return current
def _build_report(
*,
ok: bool,
code: str,
message: str,
configured_dimension: int,
vector_store_dimension: int,
detected_dimension: int,
encoded_dimension: int,
elapsed_ms: float,
sample_text: str,
) -> Dict[str, Any]:
return {
"ok": bool(ok),
"code": str(code or "").strip(),
"message": str(message or "").strip(),
"configured_dimension": int(configured_dimension),
"vector_store_dimension": int(vector_store_dimension),
"detected_dimension": int(detected_dimension),
"encoded_dimension": int(encoded_dimension),
"elapsed_ms": float(elapsed_ms),
"sample_text": str(sample_text or ""),
"checked_at": time.time(),
}
async def run_embedding_runtime_self_check(
*,
config: Any,
vector_store: Optional[Any],
embedding_manager: Optional[Any],
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Probe the real embedding path and compare dimensions with runtime storage."""
configured_dimension = _safe_int(_get_config_value(config, "embedding.dimension", 0), 0)
vector_store_dimension = _safe_int(getattr(vector_store, "dimension", 0), 0)
if vector_store is None or embedding_manager is None:
return _build_report(
ok=False,
code="runtime_components_missing",
message="vector_store 或 embedding_manager 未初始化",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
start = time.perf_counter()
detected_dimension = 0
encoded_dimension = 0
try:
detected_dimension = _safe_int(await embedding_manager._detect_dimension(), 0)
encoded = await embedding_manager.encode(sample_text)
if isinstance(encoded, np.ndarray):
encoded_dimension = int(encoded.shape[0]) if encoded.ndim == 1 else int(encoded.shape[-1])
else:
encoded_dimension = len(encoded) if encoded is not None else 0
except Exception as exc:
elapsed_ms = (time.perf_counter() - start) * 1000.0
logger.warning("embedding runtime self-check failed: %s", exc)
return _build_report(
ok=False,
code="embedding_probe_failed",
message=f"embedding probe failed: {exc}",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
elapsed_ms = (time.perf_counter() - start) * 1000.0
expected_dimension = vector_store_dimension or configured_dimension or detected_dimension
if expected_dimension <= 0:
return _build_report(
ok=False,
code="invalid_expected_dimension",
message="无法确定期望 embedding 维度",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
if encoded_dimension != expected_dimension:
msg = (
"embedding 真实输出维度与当前向量存储不一致: "
f"expected={expected_dimension}, encoded={encoded_dimension}"
)
logger.error(msg)
return _build_report(
ok=False,
code="embedding_dimension_mismatch",
message=msg,
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
return _build_report(
ok=True,
code="ok",
message="embedding runtime self-check passed",
configured_dimension=configured_dimension,
vector_store_dimension=vector_store_dimension,
detected_dimension=detected_dimension,
encoded_dimension=encoded_dimension,
elapsed_ms=elapsed_ms,
sample_text=sample_text,
)
async def ensure_runtime_self_check(
plugin_or_config: Any,
*,
force: bool = False,
sample_text: str = _DEFAULT_SAMPLE_TEXT,
) -> Dict[str, Any]:
"""Run or reuse cached runtime self-check report."""
if plugin_or_config is None:
return _build_report(
ok=False,
code="missing_plugin_or_config",
message="plugin/config unavailable",
configured_dimension=0,
vector_store_dimension=0,
detected_dimension=0,
encoded_dimension=0,
elapsed_ms=0.0,
sample_text=sample_text,
)
cache = getattr(plugin_or_config, "_runtime_self_check_report", None)
if isinstance(cache, dict) and cache and not force:
return cache
report = await run_embedding_runtime_self_check(
config=getattr(plugin_or_config, "config", plugin_or_config),
vector_store=getattr(plugin_or_config, "vector_store", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("vector_store"),
embedding_manager=getattr(plugin_or_config, "embedding_manager", None)
if not isinstance(plugin_or_config, dict)
else plugin_or_config.get("embedding_manager"),
sample_text=sample_text,
)
try:
setattr(plugin_or_config, "_runtime_self_check_report", report)
except Exception:
pass
return report

View File

@@ -0,0 +1,90 @@
"""Post-processing helpers for unified search execution."""
from __future__ import annotations
from typing import Any, List, Tuple
from .path_fallback_service import find_paths_from_query, to_retrieval_results
def apply_safe_content_dedup(results: List[Any]) -> Tuple[List[Any], int]:
"""Deduplicate results by hash/content while preserving at least one entry."""
if not results:
return [], 0
unique_results: List[Any] = []
seen_hashes = set()
seen_contents = set()
for item in results:
content = str(getattr(item, "content", "") or "").strip()
if not content:
continue
hash_value = str(getattr(item, "hash_value", "") or "").strip() or str(hash(content))
if hash_value in seen_hashes:
continue
is_dup = False
for seen in seen_contents:
if content in seen or seen in content:
is_dup = True
break
if is_dup:
continue
seen_hashes.add(hash_value)
seen_contents.add(content)
unique_results.append(item)
if not unique_results:
unique_results.append(results[0])
removed_count = max(0, len(results) - len(unique_results))
return unique_results, removed_count
def maybe_apply_smart_path_fallback(
*,
query: str,
results: List[Any],
graph_store: Any,
metadata_store: Any,
enabled: bool,
threshold: float,
max_depth: int = 3,
max_paths: int = 5,
) -> Tuple[List[Any], bool, int]:
"""Append indirect relation paths when semantic results are weak."""
if not enabled or not str(query or "").strip():
return results, False, 0
if graph_store is None or metadata_store is None:
return results, False, 0
max_score = 0.0
if results:
try:
max_score = float(getattr(results[0], "score", 0.0) or 0.0)
except Exception:
max_score = 0.0
if max_score >= float(threshold):
return results, False, 0
paths = find_paths_from_query(
query=query,
graph_store=graph_store,
metadata_store=metadata_store,
max_depth=max_depth,
max_paths=max_paths,
)
if not paths:
return results, False, 0
path_results = to_retrieval_results(paths)
if not path_results:
return results, False, 0
merged = list(path_results) + list(results)
return merged, True, len(path_results)

View File

@@ -0,0 +1,170 @@
"""
时间解析工具。
约束:
1. 查询参数Action/Command/Tool仅接受结构化绝对时间
- YYYY/MM/DD
- YYYY/MM/DD HH:mm
2. 入库时允许更宽松格式含时间戳、YYYY-MM-DD 等)。
"""
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, Dict, Optional, Tuple
_QUERY_DATE_RE = re.compile(r"^\d{4}/\d{2}/\d{2}$")
_QUERY_MINUTE_RE = re.compile(r"^\d{4}/\d{2}/\d{2} \d{2}:\d{2}$")
_NUMERIC_RE = re.compile(r"^-?\d+(?:\.\d+)?$")
_INGEST_FORMATS = [
"%Y/%m/%d %H:%M:%S",
"%Y/%m/%d %H:%M",
"%Y-%m-%d %H:%M:%S",
"%Y-%m-%d %H:%M",
"%Y-%m-%dT%H:%M:%S",
"%Y-%m-%dT%H:%M",
"%Y/%m/%d",
"%Y-%m-%d",
]
_INGEST_DATE_FORMATS = {"%Y/%m/%d", "%Y-%m-%d"}
def parse_query_datetime_to_timestamp(value: str, is_end: bool = False) -> float:
"""解析查询时间,仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm。"""
text = str(value).strip()
if not text:
raise ValueError("时间不能为空")
if _QUERY_DATE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d")
if is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
if _QUERY_MINUTE_RE.fullmatch(text):
dt = datetime.strptime(text, "%Y/%m/%d %H:%M")
return dt.timestamp()
raise ValueError(
f"时间格式错误: {text}。仅支持 YYYY/MM/DD 或 YYYY/MM/DD HH:mm"
)
def parse_query_time_range(
time_from: Optional[str],
time_to: Optional[str],
) -> Tuple[Optional[float], Optional[float]]:
"""解析查询窗口并验证区间。"""
ts_from = (
parse_query_datetime_to_timestamp(time_from, is_end=False)
if time_from
else None
)
ts_to = (
parse_query_datetime_to_timestamp(time_to, is_end=True)
if time_to
else None
)
if ts_from is not None and ts_to is not None and ts_from > ts_to:
raise ValueError("time_from 不能晚于 time_to")
return ts_from, ts_to
def parse_ingest_datetime_to_timestamp(
value: Any,
is_end: bool = False,
) -> Optional[float]:
"""解析入库时间,允许 timestamp/常见字符串格式。"""
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return None
if _NUMERIC_RE.fullmatch(text):
return float(text)
for fmt in _INGEST_FORMATS:
try:
dt = datetime.strptime(text, fmt)
if fmt in _INGEST_DATE_FORMATS and is_end:
dt = dt.replace(hour=23, minute=59, second=0, microsecond=0)
return dt.timestamp()
except ValueError:
continue
raise ValueError(f"无法解析时间: {text}")
def normalize_time_meta(time_meta: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""归一化 time_meta 到存储层字段。"""
if not time_meta:
return {}
normalized: Dict[str, Any] = {}
event_time = parse_ingest_datetime_to_timestamp(time_meta.get("event_time"))
event_start = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_start"),
is_end=False,
)
event_end = parse_ingest_datetime_to_timestamp(
time_meta.get("event_time_end"),
is_end=True,
)
time_range = time_meta.get("time_range")
if (
isinstance(time_range, (list, tuple))
and len(time_range) == 2
):
if event_start is None:
event_start = parse_ingest_datetime_to_timestamp(time_range[0], is_end=False)
if event_end is None:
event_end = parse_ingest_datetime_to_timestamp(time_range[1], is_end=True)
if event_start is not None and event_end is not None and event_start > event_end:
raise ValueError("event_time_start 不能晚于 event_time_end")
if event_time is not None:
normalized["event_time"] = event_time
if event_start is not None:
normalized["event_time_start"] = event_start
if event_end is not None:
normalized["event_time_end"] = event_end
granularity = time_meta.get("time_granularity")
if granularity:
normalized["time_granularity"] = str(granularity)
else:
raw_time_values = [
time_meta.get("event_time"),
time_meta.get("event_time_start"),
time_meta.get("event_time_end"),
]
has_minute = any(isinstance(v, str) and ":" in v for v in raw_time_values if v is not None)
normalized["time_granularity"] = "minute" if has_minute else "day"
confidence = time_meta.get("time_confidence")
if confidence is not None:
normalized["time_confidence"] = float(confidence)
return normalized
def format_timestamp(ts: Optional[float]) -> Optional[str]:
"""将 timestamp 格式化为 YYYY/MM/DD HH:mm。"""
if ts is None:
return None
return datetime.fromtimestamp(ts).strftime("%Y/%m/%d %H:%M")

207
plugins/A_memorix/plugin.py Normal file
View File

@@ -0,0 +1,207 @@
"""A_Memorix SDK plugin entry."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Optional
from maibot_sdk import MaiBotPlugin, Tool
from maibot_sdk.types import ToolParameterInfo, ToolParamType
from A_memorix.core.runtime.sdk_memory_kernel import KernelSearchRequest, SDKMemoryKernel
def _tool_param(name: str, param_type: ToolParamType, description: str, required: bool) -> ToolParameterInfo:
return ToolParameterInfo(name=name, param_type=param_type, description=description, required=required)
class AMemorixPlugin(MaiBotPlugin):
def __init__(self) -> None:
super().__init__()
self._plugin_root = Path(__file__).resolve().parent
self._plugin_config: Dict[str, Any] = {}
self._kernel: Optional[SDKMemoryKernel] = None
def set_plugin_config(self, config: Dict[str, Any]) -> None:
self._plugin_config = config or {}
if self._kernel is not None:
self._kernel.close()
self._kernel = None
async def on_load(self):
await self._get_kernel()
async def on_unload(self):
if self._kernel is not None:
self._kernel.close()
self._kernel = None
async def _get_kernel(self) -> SDKMemoryKernel:
if self._kernel is None:
self._kernel = SDKMemoryKernel(plugin_root=self._plugin_root, config=self._plugin_config)
await self._kernel.initialize()
return self._kernel
@Tool(
"search_memory",
description="搜索长期记忆",
parameters=[
_tool_param("query", ToolParamType.STRING, "查询文本", False),
_tool_param("limit", ToolParamType.INTEGER, "返回条数", False),
_tool_param("mode", ToolParamType.STRING, "search/time/hybrid/episode/aggregate", False),
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
_tool_param("person_id", ToolParamType.STRING, "人物 ID", False),
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
],
)
async def handle_search_memory(
self,
query: str = "",
limit: int = 5,
mode: str = "hybrid",
chat_id: str = "",
person_id: str = "",
time_start: float | None = None,
time_end: float | None = None,
**kwargs,
):
_ = kwargs
kernel = await self._get_kernel()
return await kernel.search_memory(
KernelSearchRequest(
query=query,
limit=limit,
mode=mode,
chat_id=chat_id,
person_id=person_id,
time_start=time_start,
time_end=time_end,
)
)
@Tool(
"ingest_summary",
description="写入聊天摘要到长期记忆",
parameters=[
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", True),
_tool_param("text", ToolParamType.STRING, "摘要文本", True),
_tool_param("time_start", ToolParamType.FLOAT, "起始时间戳", False),
_tool_param("time_end", ToolParamType.FLOAT, "结束时间戳", False),
],
)
async def handle_ingest_summary(
self,
external_id: str,
chat_id: str,
text: str,
participants: Optional[List[str]] = None,
time_start: float | None = None,
time_end: float | None = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
_ = kwargs
kernel = await self._get_kernel()
return await kernel.ingest_summary(
external_id=external_id,
chat_id=chat_id,
text=text,
participants=participants,
time_start=time_start,
time_end=time_end,
tags=tags,
metadata=metadata,
)
@Tool(
"ingest_text",
description="写入普通长期记忆文本",
parameters=[
_tool_param("external_id", ToolParamType.STRING, "外部幂等 ID", True),
_tool_param("source_type", ToolParamType.STRING, "来源类型", True),
_tool_param("text", ToolParamType.STRING, "原始文本", True),
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
_tool_param("timestamp", ToolParamType.FLOAT, "时间戳", False),
],
)
async def handle_ingest_text(
self,
external_id: str,
source_type: str,
text: str,
chat_id: str = "",
person_ids: Optional[List[str]] = None,
participants: Optional[List[str]] = None,
timestamp: float | None = None,
time_start: float | None = None,
time_end: float | None = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs,
):
relations = kwargs.get("relations")
entities = kwargs.get("entities")
kernel = await self._get_kernel()
return await kernel.ingest_text(
external_id=external_id,
source_type=source_type,
text=text,
chat_id=chat_id,
person_ids=person_ids,
participants=participants,
timestamp=timestamp,
time_start=time_start,
time_end=time_end,
tags=tags,
metadata=metadata,
entities=entities,
relations=relations,
)
@Tool(
"get_person_profile",
description="获取人物画像",
parameters=[
_tool_param("person_id", ToolParamType.STRING, "人物 ID", True),
_tool_param("chat_id", ToolParamType.STRING, "聊天流 ID", False),
_tool_param("limit", ToolParamType.INTEGER, "证据条数", False),
],
)
async def handle_get_person_profile(self, person_id: str, chat_id: str = "", limit: int = 10, **kwargs):
_ = kwargs
kernel = await self._get_kernel()
return await kernel.get_person_profile(person_id=person_id, chat_id=chat_id, limit=limit)
@Tool(
"maintain_memory",
description="维护长期记忆关系状态",
parameters=[
_tool_param("action", ToolParamType.STRING, "reinforce/protect/restore", True),
_tool_param("target", ToolParamType.STRING, "目标哈希或查询文本", True),
_tool_param("hours", ToolParamType.FLOAT, "保护时长(小时)", False),
],
)
async def handle_maintain_memory(
self,
action: str,
target: str,
hours: float | None = None,
reason: str = "",
**kwargs,
):
_ = kwargs
kernel = await self._get_kernel()
return await kernel.maintain_memory(action=action, target=target, hours=hours, reason=reason)
@Tool("memory_stats", description="获取长期记忆统计", parameters=[])
async def handle_memory_stats(self, **kwargs):
_ = kwargs
kernel = await self._get_kernel()
return kernel.memory_stats()
def create_plugin():
return AMemorixPlugin()

View File

@@ -0,0 +1,535 @@
#!/usr/bin/env python3
"""
LPMM 到 A_memorix 存储转换器
功能:
1. 读取 LPMM parquet 文件 (paragraph.parquet, entity.parquet, relation.parquet)
2. 读取 LPMM 图文件 (graph.graphml 或 graph_structure.pkl)
3. 直接写入 A_memorix 二进制 VectorStore 和稀疏 GraphStore
4. 绕过 Embedding 生成以节省 Token
"""
import sys
import os
import json
import argparse
import asyncio
import pickle
import logging
from pathlib import Path
from typing import Dict, Any, List, Tuple
import numpy as np
import tomlkit
# 设置路径
current_dir = Path(__file__).resolve().parent
plugin_root = current_dir.parent
project_root = plugin_root.parent.parent
sys.path.insert(0, str(project_root))
def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="将 LPMM 数据转换为 A_memorix 格式")
parser.add_argument("--input", "-i", required=True, help="包含 LPMM 数据的输入目录 (parquet, graphml)")
parser.add_argument("--output", "-o", required=True, help="A_memorix 数据的输出目录")
parser.add_argument("--dim", type=int, default=384, help="Embedding 维度 (必须与 LPMM 模型匹配)")
parser.add_argument("--batch-size", type=int, default=1024, help="Parquet 分批读取大小 (默认 1024)")
parser.add_argument(
"--skip-relation-vector-rebuild",
action="store_true",
help="跳过按关系元数据重建关系向量(默认开启)",
)
return parser
# --help/-h fast path: avoid heavy host/plugin bootstrap
if any(arg in {"-h", "--help"} for arg in sys.argv[1:]):
_build_arg_parser().print_help()
sys.exit(0)
# 设置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("LPMM_Converter")
try:
import networkx as nx
from scipy import sparse
import pyarrow.parquet as pq
except ImportError as e:
logger.error(f"缺少依赖: {e}")
logger.error("请安装: pip install pandas pyarrow networkx scipy")
sys.exit(1)
try:
# 优先采取相对导入 (将插件根目录加入路径)
# 这样可以避免硬编码插件名称 (plugins.A_memorix)
if str(plugin_root) not in sys.path:
sys.path.insert(0, str(plugin_root))
from core.storage.vector_store import VectorStore
from core.storage.graph_store import GraphStore
from core.storage.metadata_store import MetadataStore
from core.storage import QuantizationType, SparseMatrixFormat
from core.embedding import create_embedding_api_adapter
from core.utils.relation_write_service import RelationWriteService
except ImportError as e:
logger.error(f"无法导入 A_memorix 核心模块: {e}")
logger.error("请确保在正确的环境中运行,且已安装所有依赖。")
sys.exit(1)
class LPMMConverter:
def __init__(
self,
lpmm_data_dir: Path,
output_dir: Path,
dimension: int = 384,
batch_size: int = 1024,
rebuild_relation_vectors: bool = True,
):
self.lpmm_dir = lpmm_data_dir
self.output_dir = output_dir
self.dimension = dimension
self.batch_size = max(1, int(batch_size))
self.rebuild_relation_vectors = bool(rebuild_relation_vectors)
self.vector_dir = output_dir / "vectors"
self.graph_dir = output_dir / "graph"
self.metadata_dir = output_dir / "metadata"
self.vector_store = None
self.graph_store = None
self.metadata_store = None
self.embedding_manager = None
self.relation_write_service = None
# LPMM 原 ID -> A_memorix ID 映射(用于图重写)
self.id_mapping: Dict[str, str] = {}
def _register_id_mapping(self, raw_id: Any, mapped_id: str, p_type: str) -> None:
"""记录 ID 映射,兼容带/不带类型前缀两种格式。"""
if raw_id is None:
return
raw = str(raw_id).strip()
if not raw:
return
self.id_mapping[raw] = mapped_id
prefix = f"{p_type}-"
if raw.startswith(prefix):
self.id_mapping[raw[len(prefix):]] = mapped_id
else:
self.id_mapping[prefix + raw] = mapped_id
def _map_node_id(self, node: Any) -> str:
"""将图节点 ID 映射到转换后的 A_memorix ID。"""
node_key = str(node)
return self.id_mapping.get(node_key, node_key)
def initialize_stores(self):
"""初始化空的 A_memorix 存储"""
logger.info(f"正在初始化存储于 {self.output_dir}...")
# 初始化 VectorStore (A_memorix 默认使用 INT8 量化)
self.vector_store = VectorStore(
dimension=self.dimension,
quantization_type=QuantizationType.INT8,
data_dir=self.vector_dir
)
self.vector_store.clear() # 清空旧数据
# 初始化 GraphStore (使用 CSR 格式)
self.graph_store = GraphStore(
matrix_format=SparseMatrixFormat.CSR,
data_dir=self.graph_dir
)
self.graph_store.clear()
# 初始化 MetadataStore
self.metadata_store = MetadataStore(data_dir=self.metadata_dir)
self.metadata_store.connect()
# 清空元数据表?理想情况下是的,但要小心。
# 对于转换,我们假设是全新的开始或覆盖。
# A_memorix 中的 MetadataStore 通常使用 SQLite。
# 如果目录是新的,我们会依赖它创建新文件。
if self.rebuild_relation_vectors:
self._init_relation_vector_service()
def _load_plugin_config(self) -> Dict[str, Any]:
config_path = plugin_root / "config.toml"
if not config_path.exists():
return {}
try:
with open(config_path, "r", encoding="utf-8") as f:
parsed = tomlkit.load(f)
return dict(parsed) if isinstance(parsed, dict) else {}
except Exception as e:
logger.warning(f"读取 config.toml 失败,使用默认 embedding 配置: {e}")
return {}
def _init_relation_vector_service(self) -> None:
if not self.rebuild_relation_vectors:
return
cfg = self._load_plugin_config()
emb_cfg = cfg.get("embedding", {}) if isinstance(cfg, dict) else {}
if not isinstance(emb_cfg, dict):
emb_cfg = {}
try:
self.embedding_manager = create_embedding_api_adapter(
batch_size=int(emb_cfg.get("batch_size", 32)),
max_concurrent=int(emb_cfg.get("max_concurrent", 5)),
default_dimension=int(emb_cfg.get("dimension", self.dimension)),
model_name=str(emb_cfg.get("model_name", "auto")),
retry_config=emb_cfg.get("retry", {}) if isinstance(emb_cfg.get("retry", {}), dict) else {},
)
self.relation_write_service = RelationWriteService(
metadata_store=self.metadata_store,
graph_store=self.graph_store,
vector_store=self.vector_store,
embedding_manager=self.embedding_manager,
)
except Exception as e:
self.embedding_manager = None
self.relation_write_service = None
logger.warning(f"初始化关系向量重建服务失败,将跳过关系向量回填: {e}")
async def _rebuild_relation_vectors(self) -> None:
if not self.rebuild_relation_vectors:
return
if self.relation_write_service is None:
logger.warning("关系向量重建已启用,但写入服务不可用,已跳过。")
return
rows = self.metadata_store.get_relations()
if not rows:
logger.info("未发现关系元数据,无需重建关系向量。")
return
success = 0
failed = 0
skipped = 0
for row in rows:
result = await self.relation_write_service.ensure_relation_vector(
hash_value=str(row["hash"]),
subject=str(row.get("subject", "")),
predicate=str(row.get("predicate", "")),
obj=str(row.get("object", "")),
)
if result.vector_state == "ready":
if result.vector_written:
success += 1
else:
skipped += 1
else:
failed += 1
logger.info(
"关系向量重建完成: total=%s success=%s skipped=%s failed=%s",
len(rows),
success,
skipped,
failed,
)
@staticmethod
def _parse_relation_text(text: str) -> Tuple[str, str, str]:
raw = str(text or "").strip()
if not raw:
return "", "", ""
if "|" in raw:
parts = [p.strip() for p in raw.split("|") if p.strip()]
if len(parts) >= 3:
return parts[0], parts[1], parts[2]
if "->" in raw:
parts = [p.strip() for p in raw.split("->") if p.strip()]
if len(parts) >= 3:
return parts[0], parts[1], parts[2]
pieces = raw.split()
if len(pieces) >= 3:
return pieces[0], pieces[1], " ".join(pieces[2:])
return "", "", ""
def _import_relation_metadata_from_parquet(self, relation_path: Path) -> int:
if not relation_path.exists():
return 0
try:
parquet_file = pq.ParquetFile(relation_path)
except Exception as e:
logger.warning(f"读取 relation.parquet 失败,跳过关系元数据导入: {e}")
return 0
cols = set(parquet_file.schema_arrow.names)
has_triple_cols = {"subject", "predicate", "object"}.issubset(cols)
content_col = "str" if "str" in cols else ("content" if "content" in cols else "")
imported_hashes = set()
imported = 0
for record_batch in parquet_file.iter_batches(batch_size=self.batch_size):
df_batch = record_batch.to_pandas()
for _, row in df_batch.iterrows():
subject = ""
predicate = ""
obj = ""
if has_triple_cols:
subject = str(row.get("subject", "") or "").strip()
predicate = str(row.get("predicate", "") or "").strip()
obj = str(row.get("object", "") or "").strip()
elif content_col:
subject, predicate, obj = self._parse_relation_text(row.get(content_col, ""))
if not (subject and predicate and obj):
continue
rel_hash = self.metadata_store.add_relation(
subject=subject,
predicate=predicate,
obj=obj,
source_paragraph=None,
)
if rel_hash in imported_hashes:
continue
imported_hashes.add(rel_hash)
self.graph_store.add_edges([(subject, obj)], relation_hashes=[rel_hash])
try:
self.metadata_store.set_relation_vector_state(rel_hash, "none")
except Exception:
pass
imported += 1
return imported
def convert_vectors(self):
"""将 Parquet 向量转换为 VectorStore"""
# LPMM 默认文件名
parquet_files = {
"paragraph": self.lpmm_dir / "paragraph.parquet",
"entity": self.lpmm_dir / "entity.parquet",
"relation": self.lpmm_dir / "relation.parquet"
}
total_vectors = 0
for p_type, p_path in parquet_files.items():
# 关系向量在当前脚本中无法保证与 MetadataStore 的关系记录一一对应,
# 直接导入会污染召回结果(命中后无法反查 relation 元数据)。
if p_type == "relation":
relation_count = self._import_relation_metadata_from_parquet(p_path)
logger.warning(
"跳过 relation.parquet 向量导入(保持一致性);已导入关系元数据: %s",
relation_count,
)
continue
if not p_path.exists():
logger.warning(f"文件未找到: {p_path}, 跳过 {p_type} 向量。")
continue
logger.info(f"正在处理 {p_type} 向量,来源: {p_path}...")
try:
parquet_file = pq.ParquetFile(p_path)
total_rows = parquet_file.metadata.num_rows
if total_rows == 0:
logger.info(f"{p_path} 为空,跳过。")
continue
# LPMM Schema: 'hash', 'embedding', 'str'
cols = parquet_file.schema_arrow.names
# 兼容性检查
content_col = 'str' if 'str' in cols else 'content'
emb_col = 'embedding'
hash_col = 'hash'
if content_col not in cols or emb_col not in cols:
logger.error(f"{p_path} 中缺少必要列 (需包含 {content_col}, {emb_col})。发现: {cols}")
continue
batch_columns = [content_col, emb_col]
if hash_col in cols:
batch_columns.append(hash_col)
processed_rows = 0
added_for_type = 0
batch_idx = 0
for record_batch in parquet_file.iter_batches(
batch_size=self.batch_size,
columns=batch_columns,
):
batch_idx += 1
df_batch = record_batch.to_pandas()
embeddings_list = []
ids_list = []
# 同时处理元数据映射
for _, row in df_batch.iterrows():
processed_rows += 1
content = row[content_col]
emb = row[emb_col]
if content is None or (isinstance(content, float) and np.isnan(content)):
continue
content = str(content).strip()
if not content:
continue
if emb is None or len(emb) == 0:
continue
# 先写 MetadataStore并使用其返回的真实 hash 作为向量 ID
# 保证检索返回 ID 可以直接反查元数据。
store_id = None
if p_type == "paragraph":
store_id = self.metadata_store.add_paragraph(
content=content,
source="lpmm_import",
knowledge_type="factual",
)
elif p_type == "entity":
store_id = self.metadata_store.add_entity(name=content)
else:
continue
raw_hash = row[hash_col] if hash_col in df_batch.columns else None
if raw_hash is not None and not (isinstance(raw_hash, float) and np.isnan(raw_hash)):
self._register_id_mapping(raw_hash, store_id, p_type)
# 确保 embedding 是 numpy 数组
emb_np = np.array(emb, dtype=np.float32)
if emb_np.shape[0] != self.dimension:
logger.error(f"维度不匹配: {emb_np.shape[0]} vs {self.dimension}")
continue
embeddings_list.append(emb_np)
ids_list.append(store_id)
if embeddings_list:
# 分批添加到向量存储
vectors_np = np.stack(embeddings_list)
count = self.vector_store.add(vectors_np, ids_list)
added_for_type += count
total_vectors += count
if batch_idx == 1 or batch_idx % 10 == 0:
logger.info(
f"[{p_type}] 批次 {batch_idx}: 已扫描 {processed_rows}/{total_rows}, 已导入 {added_for_type}"
)
logger.info(
f"{p_type} 向量处理完成:总扫描 {processed_rows},总导入 {added_for_type}"
)
except Exception as e:
logger.error(f"处理 {p_path} 时出错: {e}")
# 提交向量存储
self.vector_store.save()
logger.info(f"向量转换完成。总向量数: {total_vectors}")
def convert_graph(self):
"""将 LPMM 图转换为 GraphStore"""
# LPMM 默认文件名是 rag-graph.graphml
graph_files = [
self.lpmm_dir / "rag-graph.graphml",
self.lpmm_dir / "graph.graphml",
self.lpmm_dir / "graph_structure.pkl"
]
nx_graph = None
for g_path in graph_files:
if g_path.exists():
logger.info(f"发现图文件: {g_path}")
try:
if g_path.suffix == ".graphml":
nx_graph = nx.read_graphml(g_path)
elif g_path.suffix == ".pkl":
with open(g_path, "rb") as f:
data = pickle.load(f)
# LPMM 可能会将图存储在包装类中
if hasattr(data, "graph") and isinstance(data.graph, nx.Graph):
nx_graph = data.graph
elif isinstance(data, nx.Graph):
nx_graph = data
break
except Exception as e:
logger.error(f"加载 {g_path} 失败: {e}")
if nx_graph is None:
logger.warning("未找到有效的图文件。跳过图转换。")
return
logger.info(f"已加载图,包含 {nx_graph.number_of_nodes()} 个节点和 {nx_graph.number_of_edges()} 条边。")
# 1. 添加节点
# LPMM 节点通常是哈希或带前缀的字符串。
# 我们需要将它们映射到 A_memorix 格式。
# 如果 LPMM 使用 "entity-HASH",则与 A_memorix 匹配。
nodes_to_add = []
node_attrs = {}
for node, attrs in nx_graph.nodes(data=True):
# 假设 LPMM 使用一致的命名 "entity-..." 或 "paragraph-..."
mapped_node = self._map_node_id(node)
nodes_to_add.append(mapped_node)
if attrs:
node_attrs[mapped_node] = attrs
self.graph_store.add_nodes(nodes_to_add, node_attrs)
# 2. 添加边
edges_to_add = []
weights = []
for u, v, data in nx_graph.edges(data=True):
weight = data.get("weight", 1.0)
edges_to_add.append((self._map_node_id(u), self._map_node_id(v)))
weights.append(float(weight))
# 如果可能,将关系同步到 MetadataStore
# 但图的边并不总是包含关系谓词
# 如果 LPMM 边数据有 'predicate',我们可以添加到元数据
# 通常 LPMM 边是加权和,谓词信息可能在简单图中丢失
if edges_to_add:
self.graph_store.add_edges(edges_to_add, weights)
self.graph_store.save()
logger.info("图转换完成。")
def run(self):
self.initialize_stores()
self.convert_vectors()
self.convert_graph()
asyncio.run(self._rebuild_relation_vectors())
self.vector_store.save()
self.graph_store.save()
self.metadata_store.close()
logger.info("所有转换成功完成。")
def main():
parser = _build_arg_parser()
args = parser.parse_args()
input_path = Path(args.input)
output_path = Path(args.output)
if not input_path.exists():
logger.error(f"输入目录不存在: {input_path}")
sys.exit(1)
converter = LPMMConverter(
input_path,
output_path,
dimension=args.dim,
batch_size=args.batch_size,
rebuild_relation_vectors=not bool(args.skip_relation_vector_rebuild),
)
converter.run()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,110 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import asyncio
import json
import sqlite3
import sys
from datetime import datetime
from pathlib import Path
from typing import Any, Dict
CURRENT_DIR = Path(__file__).resolve().parent
PLUGIN_ROOT = CURRENT_DIR.parent
WORKSPACE_ROOT = PLUGIN_ROOT.parent
MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot"
DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db"
if str(WORKSPACE_ROOT) not in sys.path:
sys.path.insert(0, str(WORKSPACE_ROOT))
if str(MAIBOT_ROOT) not in sys.path:
sys.path.insert(0, str(MAIBOT_ROOT))
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="迁移 MaiBot chat_history 到 A_Memorix")
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录")
parser.add_argument("--limit", type=int, default=0, help="限制迁移条数0 表示全部")
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
return parser.parse_args()
def _to_timestamp(value: Any) -> float | None:
if value is None:
return None
if isinstance(value, (int, float)):
return float(value)
text = str(value).strip()
if not text:
return None
try:
return datetime.fromisoformat(text).timestamp()
except ValueError:
return None
async def _main() -> int:
args = _parse_args()
db_path = Path(args.db_path).resolve()
if not db_path.exists():
print(f"数据库不存在: {db_path}")
return 1
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
sql = """
SELECT id, session_id, start_timestamp, end_timestamp, participants, theme, keywords, summary
FROM chat_history
ORDER BY id ASC
"""
if int(args.limit or 0) > 0:
sql += " LIMIT ?"
rows = conn.execute(sql, (int(args.limit),)).fetchall()
else:
rows = conn.execute(sql).fetchall()
conn.close()
print(f"chat_history 待处理: {len(rows)}")
if args.dry_run:
for row in rows[:5]:
print(f"- id={row['id']} session={row['session_id']} theme={row['theme']}")
return 0
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}})
await kernel.initialize()
migrated = 0
skipped = 0
for row in rows:
participants = json.loads(row["participants"]) if row["participants"] else []
keywords = json.loads(row["keywords"]) if row["keywords"] else []
theme = str(row["theme"] or "").strip()
summary = str(row["summary"] or "").strip()
text = f"主题:{theme}\n概括:{summary}".strip()
result: Dict[str, Any] = await kernel.ingest_summary(
external_id=f"chat_history:{row['id']}",
chat_id=str(row["session_id"] or ""),
text=text,
participants=participants,
time_start=_to_timestamp(row["start_timestamp"]),
time_end=_to_timestamp(row["end_timestamp"]),
tags=keywords,
metadata={"theme": theme, "source_row_id": int(row["id"])},
)
if result.get("stored_ids"):
migrated += 1
else:
skipped += 1
print(f"迁移完成: migrated={migrated} skipped={skipped}")
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
kernel.close()
return 0
if __name__ == "__main__":
raise SystemExit(asyncio.run(_main()))

View File

@@ -0,0 +1,120 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import asyncio
import json
import sqlite3
import sys
from pathlib import Path
from typing import Any, Dict, List
CURRENT_DIR = Path(__file__).resolve().parent
PLUGIN_ROOT = CURRENT_DIR.parent
WORKSPACE_ROOT = PLUGIN_ROOT.parent
MAIBOT_ROOT = WORKSPACE_ROOT / "MaiBot"
DEFAULT_DB_PATH = MAIBOT_ROOT / "data" / "MaiBot.db"
if str(WORKSPACE_ROOT) not in sys.path:
sys.path.insert(0, str(WORKSPACE_ROOT))
if str(MAIBOT_ROOT) not in sys.path:
sys.path.insert(0, str(MAIBOT_ROOT))
from A_memorix.core.runtime.sdk_memory_kernel import SDKMemoryKernel # noqa: E402
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="迁移 MaiBot person_info.memory_points 到 A_Memorix")
parser.add_argument("--db-path", default=str(DEFAULT_DB_PATH), help="MaiBot SQLite 路径")
parser.add_argument("--data-dir", default="./data", help="A_Memorix 数据目录")
parser.add_argument("--limit", type=int, default=0, help="限制迁移人数0 表示全部")
parser.add_argument("--dry-run", action="store_true", help="仅预览,不写入")
return parser.parse_args()
def _parse_memory_points(raw_value: Any) -> List[Dict[str, Any]]:
try:
values = json.loads(raw_value) if raw_value else []
except Exception:
values = []
items: List[Dict[str, Any]] = []
for index, item in enumerate(values):
text = str(item or "").strip()
if not text:
continue
parts = text.split(":")
if len(parts) >= 3:
category = parts[0].strip()
content = ":".join(parts[1:-1]).strip()
weight = parts[-1].strip()
else:
category = "其他"
content = text
weight = "1.0"
if content:
items.append({"index": index, "category": category or "其他", "content": content, "weight": weight or "1.0"})
return items
async def _main() -> int:
args = _parse_args()
db_path = Path(args.db_path).resolve()
if not db_path.exists():
print(f"数据库不存在: {db_path}")
return 1
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
sql = """
SELECT person_id, person_name, user_nickname, memory_points
FROM person_info
WHERE memory_points IS NOT NULL AND memory_points != ''
ORDER BY id ASC
"""
if int(args.limit or 0) > 0:
sql += " LIMIT ?"
rows = conn.execute(sql, (int(args.limit),)).fetchall()
else:
rows = conn.execute(sql).fetchall()
conn.close()
preview_total = sum(len(_parse_memory_points(row["memory_points"])) for row in rows)
print(f"person_info 待迁移人物: {len(rows)} 记忆点: {preview_total}")
if args.dry_run:
for row in rows[:5]:
print(f"- person_id={row['person_id']} person_name={row['person_name'] or row['user_nickname']}")
return 0
kernel = SDKMemoryKernel(plugin_root=PLUGIN_ROOT, config={"storage": {"data_dir": args.data_dir}})
await kernel.initialize()
migrated = 0
skipped = 0
for row in rows:
person_id = str(row["person_id"] or "").strip()
if not person_id:
continue
display_name = str(row["person_name"] or row["user_nickname"] or "").strip()
for item in _parse_memory_points(row["memory_points"]):
result: Dict[str, Any] = await kernel.ingest_text(
external_id=f"person_memory:{person_id}:{item['index']}",
source_type="person_fact",
text=f"[{item['category']}] {item['content']}",
person_ids=[person_id],
tags=[item["category"]],
entities=[person_id, display_name] if display_name else [person_id],
metadata={"category": item["category"], "weight": item["weight"], "display_name": display_name},
)
if result.get("stored_ids"):
migrated += 1
else:
skipped += 1
print(f"迁移完成: migrated={migrated} skipped={skipped}")
print(json.dumps(kernel.memory_stats(), ensure_ascii=False))
kernel.close()
return 0
if __name__ == "__main__":
raise SystemExit(asyncio.run(_main()))